📄 semisupem.java
字号:
/** * Generates the classifier. * * @param instances set of instances serving as training data * @exception Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances data) throws Exception { if (data.checkForStringAttributes()) { throw new UnsupportedAttributeTypeException("Cannot handle string attributes!"); } if (data.classAttribute().isNumeric()) { throw new UnsupportedClassTypeException("Can't handle a numeric class!"); } if (m_Classifier == null) { throw new Exception("A base classifier has not been specified!"); } m_LabeledInstances = data; // Add "hard" soft-labeled instances of labeled data to the data for EM m_AllInstances = new SoftClassifiedInstances(data); Random m_Random = new Random(m_rseed); // Make random soft-labeled instances for unlabeled data m_UnlabeledInstances = new SoftClassifiedInstances(m_UnlabeledData, m_Random); if (m_Lambda != 1.0) weightInstances(m_UnlabeledInstances, m_Lambda); // Add the unlabeled data to the complete data set m_AllInstances.addInstances(m_UnlabeledInstances); initModel(); if (m_verbose) { System.out.println("Labeled Data Classes: "); Enumeration enumInsts = m_LabeledInstances.enumerateInstances(); while (enumInsts.hasMoreElements()) { Instance instance = (Instance) enumInsts.nextElement(); System.out.print(m_AllInstances.classAttribute().value((int)instance.classValue()) + " "); } System.out.println("\nNum Unlabeled: " + m_UnlabeledInstances.numInstances() ); // System.out.println("Labeled data: " + m_LabeledInstances); // System.out.println("Unlabeled data: " + m_UnlabeledInstances); } if (m_UnlabeledInstances.numInstances() != 0) iterate(); } /** Weighted all given instances with given weight */ protected void weightInstances (Instances insts, double weight) { Enumeration enumInsts = insts.enumerateInstances(); while (enumInsts.hasMoreElements()) { Instance instance = (Instance) enumInsts.nextElement(); instance.setWeight(weight); } } /** Intialize model using appropriate set of data */ protected void initModel() throws Exception { SoftClassifiedInstances seedInstances = new SoftClassifiedInstances(m_LabeledInstances); if (m_seedUnseenClasses && m_UnlabeledInstances.numInstances() != 0) { List unseenClasses = unseenClasses(seedInstances); if (!unseenClasses.isEmpty()) { if (m_verbose) System.out.println("Unseen classes: " + unseenClasses); // Add a seed instance for the unseen classes that is soft labeled equally // in all unkown classes. Instance farthest = farthestInstance(m_UnlabeledInstances, seedInstances); softLabelClasses((SoftClassifiedInstance)farthest, unseenClasses); if (m_verbose) System.out.println("Seeded Instance: " + classDistributionString((SoftClassifiedInstance)farthest)); seedInstances.add(farthest); } } m_Classifier.buildClassifier(seedInstances); } /** Return a list of class values for which there are no * instances in insts */ protected ArrayList unseenClasses(Instances insts) { int[] classCounts = new int[insts.numClasses()]; Enumeration enumInsts = insts.enumerateInstances(); while (enumInsts.hasMoreElements()) { Instance inst = (Instance) enumInsts.nextElement(); classCounts[(int)inst.classValue()]++; } ArrayList result = new ArrayList(); for (int i = 0; i < insts.numClasses(); i++) { if (classCounts[i] == 0) { result.add(new Integer(i)); } } return result; } /** Return the instance in candidateInsts that is farthest from any instance * in insts */ protected Instance farthestInstance(Instances candidateInsts, Instances insts) { double maxDist = Double.NEGATIVE_INFINITY; Instance farthestInst = null; double dist; setMinMax(m_AllInstances); Enumeration enumInsts = candidateInsts.enumerateInstances(); while (enumInsts.hasMoreElements()) { Instance candidate = (Instance) enumInsts.nextElement(); dist = minimumDistance(candidate, insts); if (dist > maxDist) { maxDist = dist; farthestInst = candidate; } } return farthestInst; } /** Return the distance from inst to the closest instance in insts */ protected double minimumDistance(Instance inst, Instances insts) { double minDist = Double.POSITIVE_INFINITY; double dist; Enumeration enumInsts = insts.enumerateInstances(); while (enumInsts.hasMoreElements()) { Instance X = (Instance) enumInsts.nextElement(); dist = distance(inst, X); if (dist < minDist) { minDist = dist; } } return minDist; } /** Soft label inst as being equally likely to be in an of the given classes */ protected void softLabelClasses(SoftClassifiedInstance inst, List classes) throws Exception { double prob = 1.0/classes.size(); double[] dist = new double[((Instance)inst).dataset().numClasses()]; for (int i = 0; i < classes.size(); i++) { dist[((Integer)classes.get(i)).intValue()] = prob; } inst.setClassDistribution(dist); } /** Run EM iterations until likelihood stops increasing significantly or max iterations exhausted */ protected void iterate() throws Exception { double logLikelihood, oldLogLikelihood; logLikelihood = 0; oldLogLikelihood = 0; for (int i = 0; i < m_max_iterations; i++) { // if (m_verbose) { // System.out.println(m_Classifier); // } oldLogLikelihood = logLikelihood; logLikelihood = eStep(); if (m_verbose) { System.out.println("\nIteration " + i + ": LogLikelihood = " + logLikelihood + "\n\n"); } if ( (i > 0) && ((logLikelihood - oldLogLikelihood) < m_minLogLikelihoodIncr)) break; mStep(); } } protected double eStep() throws Exception { double logLikelihood = 0; double classifiedCorrect = 0; double[] dist; Enumeration enumInsts = m_UnlabeledInstances.enumerateInstances(); while (enumInsts.hasMoreElements()) { Instance instance = (Instance) enumInsts.nextElement(); dist = m_Classifier.unNormalizedDistributionForInstance(instance); // instance.setClassDistribution(dist); // System.out.println("Instance:" + instance + " Dist: " + classDistributionString(instance)); logLikelihood += logSum(dist); NaiveBayesSimple.normalizeLogs(dist); // System.out.println("Norm Dist: " + classDistributionString((SoftClassifiedInstance)instance)); ((SoftClassifiedInstance)instance).setClassDistribution(dist); if (m_verbose) { // System.out.println(classDistributionString(instance)); if (Utils.maxIndex(dist) == (int)instance.classValue()) { classifiedCorrect++; } } } if (m_verbose) { System.out.println("\nAccuracy on Unlabeled: " + classifiedCorrect/ m_UnlabeledInstances.numInstances()); } enumInsts = m_LabeledInstances.enumerateInstances(); while (enumInsts.hasMoreElements()) { Instance instance = (Instance) enumInsts.nextElement(); dist = m_Classifier.unNormalizedDistributionForInstance(instance); logLikelihood += logSum(dist); } return logLikelihood/m_AllInstances.numInstances(); } /** Sums log of probabilities using special method for summing in log space */ public double logSum(double[] logProbs) { double sum = 0; double max = logProbs[Utils.maxIndex(logProbs)]; for (int i = 0; i < logProbs.length; i++) { sum += Math.exp(logProbs[i] - max); } return max + Math.log(sum); } protected String classDistributionString(SoftClassifiedInstance inst) { double[] dist = inst.getClassDistribution(); StringBuffer text = new StringBuffer(); Attribute classAtt = m_AllInstances.classAttribute(); text.append(classAtt.value((int)((Instance)inst).classValue()) + " | "); for (int i = 0; i < m_AllInstances.numClasses(); i++) { text.append(classAtt.value(i) + ":" + dist[i] + " "); } return text.toString(); } protected void mStep() throws Exception { m_Classifier.buildClassifier(m_AllInstances); } /** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @exception Exception if distribution can't be computed */ public double[] distributionForInstance(Instance instance) throws Exception { double[] dist = m_Classifier.unNormalizedDistributionForInstance(instance); NaiveBayesSimple.normalizeLogs(dist); return dist; } /** * Calculates the distance between two instances * * @param first the first instance * @param second the second instance * @return the distance between the two given instances */ protected double distance(Instance first, Instance second) { double diff, distance = 0; Instances dataset = first.dataset(); for(int i = 0; i < dataset.numAttributes(); i++) { if (i == dataset.classIndex()) { continue; } if (dataset.attribute(i).isNominal()) { // If attribute is nominal if (first.isMissing(i) || second.isMissing(i) || ((int)first.value(i) != (int)second.value(i))) { distance += 1; } } else { // If attribute is numeric if (first.isMissing(i) || second.isMissing(i)){ if (first.isMissing(i) && second.isMissing(i)) { diff = 1; } else { if (second.isMissing(i)) { diff = norm(first.value(i), i); } else { diff = norm(second.value(i), i); } if (diff < 0.5) { diff = 1.0 - diff; } } } else { diff = norm(first.value(i), i) - norm(second.value(i), i); } distance += diff * diff; } } return distance; } /** * Normalizes a given value of a numeric attribute. * * @param x the value to be normalized * @param i the attribute's index */ protected double norm(double x,int i) { if (Double.isNaN(m_MinArray[i]) || Utils.eq(m_MaxArray[i], m_MinArray[i])) { return 0; } else { return (x - m_MinArray[i]) / (m_MaxArray[i] - m_MinArray[i]); } } /** Compute and store min max values for each numeric feature */ protected void setMinMax(Instances insts) { m_MinArray = new double [insts.numAttributes()]; m_MaxArray = new double [insts.numAttributes()]; for (int i = 0; i < insts.numAttributes(); i++) { m_MinArray[i] = m_MaxArray[i] = Double.NaN; } Enumeration enum = insts.enumerateInstances(); while (enum.hasMoreElements()) { updateMinMax((Instance) enum.nextElement()); } } /** * Updates the minimum and maximum values for all the attributes * based on a new instance. * * @param instance the new instance */ protected void updateMinMax(Instance instance) { Instances dataset = instance.dataset(); for (int j = 0;j < dataset.numAttributes(); j++) { if ((dataset.attribute(j).isNumeric()) && (!instance.isMissing(j))) { if (Double.isNaN(m_MinArray[j])) { m_MinArray[j] = instance.value(j); m_MaxArray[j] = instance.value(j); } else { if (instance.value(j) < m_MinArray[j]) { m_MinArray[j] = instance.value(j); } else { if (instance.value(j) > m_MaxArray[j]) { m_MaxArray[j] = instance.value(j); } } } } } } /** * Main method for testing this class. * * @param argv the options */ // public static void main(String [] argv) { // try { // NaiveBayesSimpleSoft baseClassifier = new NaiveBayesSimpleSoft(); // baseClassifier.setMinStdDev(.15); // Instances instances = new Instances(new BufferedReader(new FileReader(argv[0]))); // instances.setClassIndex(instances.numAttributes() - 1); // SemiSupEM emClassifier = new SemiSupEM(); // emClassifier.resetOptions(); // emClassifier.setClassifier(baseClassifier); // emClassifier.setDebug(true); // // emClassifier.setUnlabeledSeeding(true); // Random random = new Random(); // instances.randomize(random); // int numLabeled = Integer.parseInt(argv[1]); // Instances labeledInsts = new Instances(instances, 0, numLabeled); // Instances unlabeledInsts = new Instances(instances, numLabeled, (instances.numInstances() - numLabeled)); // emClassifier.setUnlabeled(unlabeledInsts); // emClassifier.buildClassifier(labeledInsts); // } catch (Exception e) { // System.err.println(e.getMessage()); // } // } public static void main(String [] argv) { try { Instances instances = new Instances(new BufferedReader(new FileReader(argv[0]))); instances.setClassIndex(instances.numAttributes() - 1); Random random = new Random(Integer.parseInt(argv[2])); instances.randomize(random); int numLabeled = Integer.parseInt(argv[1]); Instances labeledInsts = new Instances(instances, 0, numLabeled); Instances unlabeledInsts = new Instances(instances, numLabeled, (instances.numInstances() - numLabeled)); SemiSupEM emClassifier = new SemiSupEM(); emClassifier.setOptions(argv); emClassifier.setUnlabeled(unlabeledInsts); emClassifier.buildClassifier(labeledInsts); } catch (Exception e) { System.err.println(e.getMessage()); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -