📄 classificationcriteriontest.java
字号:
/* * YALE - Yet Another Learning Environment * Copyright (C) 2002, 2003 * Simon Fischer, Ralf Klinkenberg, Ingo Mierswa, * Katharina Morik, Oliver Ritthoff * Artificial Intelligence Unit * Computer Science Department * University of Dortmund * 44221 Dortmund, Germany * email: yale@ls8.cs.uni-dortmund.de * web: http://yale.cs.uni-dortmund.de/ * * This program is free software; you can redistribute it and/or * modify it under the terms of the GNU General Public License as * published by the Free Software Foundation; either version 2 of the * License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, but * WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 * USA. */package edu.udo.cs.yale.operator.performance.test;import edu.udo.cs.yale.operator.performance.*;import edu.udo.cs.yale.example.test.*;import edu.udo.cs.yale.example.*;import edu.udo.cs.yale.tools.att.*;import java.util.*;/** Tests classification criteria. * * @version $Id: ClassificationCriterionTest.java,v 1.4 2003/07/24 09:52:53 fischer Exp $ */public class ClassificationCriterionTest extends CriterionTestCase { public void testClassificationError() throws Exception { Attribute label = ExampleTestTools.attributeYesNo(); label.setIndex(0); List attributeList = new LinkedList(); attributeList.add(label); MemoryExampleTable exampleTable = new MemoryExampleTable(attributeList, ExampleTestTools.createDataRowReader(new DataRowFactory(DataRowFactory.TYPE_DOUBLE_ARRAY), new Attribute[] {label}, new String[][] { {"yes"}, {"no"}, {"yes"}, {"no"}, {"yes"}, {"no"}, {"yes"}, {"yes"}, {"yes"}, {"no"}, {"no"}, {"yes"}} )); AttributeSet attributeSet = new AttributeSet(); attributeSet.setSpecialAttribute("label", label); ExampleSet eSet = exampleTable.createExampleSet(attributeSet); eSet.createPredictedLabel(); ExampleReader r = eSet.getExampleReader(); Example e; e = r.next(); e.setPredictedLabel("yes"); // yy e = r.next(); e.setPredictedLabel("no"); // nn e = r.next(); e.setPredictedLabel("no"); // yn e = r.next(); e.setPredictedLabel("yes"); // ny e = r.next(); e.setPredictedLabel("yes"); // yy e = r.next(); e.setPredictedLabel("no"); // nn e = r.next(); e.setPredictedLabel("yes"); // yy e = r.next(); e.setPredictedLabel("no"); // yn e = r.next(); e.setPredictedLabel("no"); // yn e = r.next(); e.setPredictedLabel("no"); // nn e = r.next(); e.setPredictedLabel("yes"); // ny e = r.next(); e.setPredictedLabel("yes"); // yy // 4x yy // 3x nn // 3x yn // 2x ny PerformanceVector pv = new PerformanceVector(); for (int i = 0; i < UniversalClassificationCriterion.NAME.length; i++) pv.addCriterion(new UniversalClassificationCriterion(i)); PerformanceEvaluator.evaluate(null, eSet, pv, false); assertEquals("accuracy", 7.0 / 12.0, pv.get(UniversalClassificationCriterion.ACCURACY).getValue(), 0.00000001); assertEquals("classification_error", 5.0 / 12.0, pv.get(UniversalClassificationCriterion.CLASS_ERROR).getValue(), 0.00000001); assertEquals("precision", 4.0 / 6.0, pv.get(UniversalClassificationCriterion.PRECISION).getValue(), 0.00000001); assertEquals("recall", 4.0 / 7.0, pv.get(UniversalClassificationCriterion.RECALL).getValue(), 0.00000001); assertEquals("fallout", 3.0 / 5.0, pv.get(UniversalClassificationCriterion.FALLOUT).getValue(), 0.00000001); assertEquals("true_pos", 4, pv.get(UniversalClassificationCriterion.TRUE_POSITIVE).getValue(), 0.00000001); assertEquals("true_neg", 3, pv.get(UniversalClassificationCriterion.TRUE_NEGATIVE).getValue(), 0.00000001); assertEquals("false_pos", 2, pv.get(UniversalClassificationCriterion.FALSE_POSITIVE).getValue(), 0.00000001); assertEquals("false_neg", 3, pv.get(UniversalClassificationCriterion.FALSE_NEGATIVE).getValue(), 0.00000001); } public void testUCCClone() { int counter[][] = {{3, 5},{4, 6}}; cloneTest("", new UniversalClassificationCriterion(UniversalClassificationCriterion.TRUE_POSITIVE, counter)); cloneTest("", new UniversalClassificationCriterion(UniversalClassificationCriterion.TRUE_NEGATIVE, counter)); cloneTest("", new UniversalClassificationCriterion(UniversalClassificationCriterion.FALSE_POSITIVE, counter)); cloneTest("", new UniversalClassificationCriterion(UniversalClassificationCriterion.FALSE_NEGATIVE, counter)); } public void testUCCAverage() { int counter1[][] = {{3, 5}, {4, 6}}; int counter2[][] = {{5, 8}, {2, 9}}; int sum[][] = {{8, 13},{6, 15}}; UniversalClassificationCriterion[] ucc1 = new UniversalClassificationCriterion[4]; ucc1[0] = new UniversalClassificationCriterion(UniversalClassificationCriterion.TRUE_POSITIVE, counter1); ucc1[1] = new UniversalClassificationCriterion(UniversalClassificationCriterion.TRUE_NEGATIVE, counter1); ucc1[2] = new UniversalClassificationCriterion(UniversalClassificationCriterion.FALSE_POSITIVE, counter1); ucc1[3] = new UniversalClassificationCriterion(UniversalClassificationCriterion.FALSE_NEGATIVE, counter1); UniversalClassificationCriterion[] ucc2 = new UniversalClassificationCriterion[4]; ucc2[0] = new UniversalClassificationCriterion(UniversalClassificationCriterion.TRUE_POSITIVE, counter2); ucc2[1] = new UniversalClassificationCriterion(UniversalClassificationCriterion.TRUE_NEGATIVE, counter2); ucc2[2] = new UniversalClassificationCriterion(UniversalClassificationCriterion.FALSE_POSITIVE, counter2); ucc2[3] = new UniversalClassificationCriterion(UniversalClassificationCriterion.FALSE_NEGATIVE, counter2); UniversalClassificationCriterion[] avg = new UniversalClassificationCriterion[4]; avg[0] = new UniversalClassificationCriterion(UniversalClassificationCriterion.TRUE_POSITIVE, sum); avg[1] = new UniversalClassificationCriterion(UniversalClassificationCriterion.TRUE_NEGATIVE, sum); avg[2] = new UniversalClassificationCriterion(UniversalClassificationCriterion.FALSE_POSITIVE, sum); avg[3] = new UniversalClassificationCriterion(UniversalClassificationCriterion.FALSE_NEGATIVE, sum); for (int i = 0; i < ucc1.length; i++) { ucc1[i].buildAverage(ucc2[i]); assertEquals(ucc1[i].getName(), avg[i].getValue(), ucc1[i].getValue(), 0.0000001); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -