📄 smo.java
字号:
} for (int i = 0; i < insts.numClasses(); i++) { subsets[i].compactify(); } // Build the binary classifiers Random rand = new Random(m_randomSeed); m_classifiers = new BinarySMO[insts.numClasses()][insts.numClasses()]; for (int i = 0; i < insts.numClasses(); i++) { for (int j = i + 1; j < insts.numClasses(); j++) { m_classifiers[i][j] = new BinarySMO(); m_classifiers[i][j].setKernel(Kernel.makeCopy(getKernel())); Instances data = new Instances(insts, insts.numInstances()); for (int k = 0; k < subsets[i].numInstances(); k++) { data.add(subsets[i].instance(k)); } for (int k = 0; k < subsets[j].numInstances(); k++) { data.add(subsets[j].instance(k)); } data.compactify(); data.randomize(rand); m_classifiers[i][j].buildClassifier(data, i, j, m_fitLogisticModels, m_numFolds, m_randomSeed); } } } /** * Estimates class probabilities for given instance. * * @param inst the instance to compute the probabilities for * @throws Execption in case of an error */ public double[] distributionForInstance(Instance inst) throws Exception { // Filter instance if (!m_checksTurnedOff) { m_Missing.input(inst); m_Missing.batchFinished(); inst = m_Missing.output(); } if (m_NominalToBinary != null) { m_NominalToBinary.input(inst); m_NominalToBinary.batchFinished(); inst = m_NominalToBinary.output(); } if (m_Filter != null) { m_Filter.input(inst); m_Filter.batchFinished(); inst = m_Filter.output(); } if (!m_fitLogisticModels) { double[] result = new double[inst.numClasses()]; for (int i = 0; i < inst.numClasses(); i++) { for (int j = i + 1; j < inst.numClasses(); j++) { if ((m_classifiers[i][j].m_alpha != null) || (m_classifiers[i][j].m_sparseWeights != null)) { double output = m_classifiers[i][j].SVMOutput(-1, inst); if (output > 0) { result[j] += 1; } else { result[i] += 1; } } } } Utils.normalize(result); return result; } else { // We only need to do pairwise coupling if there are more // then two classes. if (inst.numClasses() == 2) { double[] newInst = new double[2]; newInst[0] = m_classifiers[0][1].SVMOutput(-1, inst); newInst[1] = Instance.missingValue(); return m_classifiers[0][1].m_logistic. distributionForInstance(new Instance(1, newInst)); } double[][] r = new double[inst.numClasses()][inst.numClasses()]; double[][] n = new double[inst.numClasses()][inst.numClasses()]; for (int i = 0; i < inst.numClasses(); i++) { for (int j = i + 1; j < inst.numClasses(); j++) { if ((m_classifiers[i][j].m_alpha != null) || (m_classifiers[i][j].m_sparseWeights != null)) { double[] newInst = new double[2]; newInst[0] = m_classifiers[i][j].SVMOutput(-1, inst); newInst[1] = Instance.missingValue(); r[i][j] = m_classifiers[i][j].m_logistic. distributionForInstance(new Instance(1, newInst))[0]; n[i][j] = m_classifiers[i][j].m_sumOfWeights; } } } return pairwiseCoupling(n, r); } } /** * Implements pairwise coupling. * * @param n the sum of weights used to train each model * @param r the probability estimate from each model * @return the coupled estimates */ public double[] pairwiseCoupling(double[][] n, double[][] r) { // Initialize p and u array double[] p = new double[r.length]; for (int i =0; i < p.length; i++) { p[i] = 1.0 / (double)p.length; } double[][] u = new double[r.length][r.length]; for (int i = 0; i < r.length; i++) { for (int j = i + 1; j < r.length; j++) { u[i][j] = 0.5; } } // firstSum doesn't change double[] firstSum = new double[p.length]; for (int i = 0; i < p.length; i++) { for (int j = i + 1; j < p.length; j++) { firstSum[i] += n[i][j] * r[i][j]; firstSum[j] += n[i][j] * (1 - r[i][j]); } } // Iterate until convergence boolean changed; do { changed = false; double[] secondSum = new double[p.length]; for (int i = 0; i < p.length; i++) { for (int j = i + 1; j < p.length; j++) { secondSum[i] += n[i][j] * u[i][j]; secondSum[j] += n[i][j] * (1 - u[i][j]); } } for (int i = 0; i < p.length; i++) { if ((firstSum[i] == 0) || (secondSum[i] == 0)) { if (p[i] > 0) { changed = true; } p[i] = 0; } else { double factor = firstSum[i] / secondSum[i]; double pOld = p[i]; p[i] *= factor; if (Math.abs(pOld - p[i]) > 1.0e-3) { changed = true; } } } Utils.normalize(p); for (int i = 0; i < r.length; i++) { for (int j = i + 1; j < r.length; j++) { u[i][j] = p[i] / (p[i] + p[j]); } } } while (changed); return p; } /** * Returns an array of votes for the given instance. * @param inst the instance * @return array of votex * @throws Exception if something goes wrong */ public int[] obtainVotes(Instance inst) throws Exception { // Filter instance if (!m_checksTurnedOff) { m_Missing.input(inst); m_Missing.batchFinished(); inst = m_Missing.output(); } if (m_NominalToBinary != null) { m_NominalToBinary.input(inst); m_NominalToBinary.batchFinished(); inst = m_NominalToBinary.output(); } if (m_Filter != null) { m_Filter.input(inst); m_Filter.batchFinished(); inst = m_Filter.output(); } int[] votes = new int[inst.numClasses()]; for (int i = 0; i < inst.numClasses(); i++) { for (int j = i + 1; j < inst.numClasses(); j++) { double output = m_classifiers[i][j].SVMOutput(-1, inst); if (output > 0) { votes[j] += 1; } else { votes[i] += 1; } } } return votes; } /** * Returns the weights in sparse format. */ public double [][][] sparseWeights() { int numValues = m_classAttribute.numValues(); double [][][] sparseWeights = new double[numValues][numValues][]; for (int i = 0; i < numValues; i++) { for (int j = i + 1; j < numValues; j++) { sparseWeights[i][j] = m_classifiers[i][j].m_sparseWeights; } } return sparseWeights; } /** * Returns the indices in sparse format. */ public int [][][] sparseIndices() { int numValues = m_classAttribute.numValues(); int [][][] sparseIndices = new int[numValues][numValues][]; for (int i = 0; i < numValues; i++) { for (int j = i + 1; j < numValues; j++) { sparseIndices[i][j] = m_classifiers[i][j].m_sparseIndices; } } return sparseIndices; } /** * Returns the bias of each binary SMO. */ public double [][] bias() { int numValues = m_classAttribute.numValues(); double [][] bias = new double[numValues][numValues]; for (int i = 0; i < numValues; i++) { for (int j = i + 1; j < numValues; j++) { bias[i][j] = m_classifiers[i][j].m_b; } } return bias; } /* * Returns the number of values of the class attribute. */ public int numClassAttributeValues() { return m_classAttribute.numValues(); } /* * Returns the names of the class attributes. */ public String [] classAttributeNames() { int numValues = m_classAttribute.numValues(); String [] classAttributeNames = new String[numValues]; for (int i = 0; i < numValues; i++) { classAttributeNames[i] = m_classAttribute.value(i); } return classAttributeNames; } /** * Returns the attribute names. */ public String [][][] attributeNames() { int numValues = m_classAttribute.numValues(); String [][][] attributeNames = new String[numValues][numValues][]; for (int i = 0; i < numValues; i++) { for (int j = i + 1; j < numValues; j++) { int numAttributes = m_classifiers[i][j].m_data.numAttributes(); String [] attrNames = new String[numAttributes]; for (int k = 0; k < numAttributes; k++) { attrNames[k] = m_classifiers[i][j].m_data.attribute(k).name(); } attributeNames[i][j] = attrNames; } } return attributeNames; } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration listOptions() { Vector result = new Vector(); Enumeration enm = super.listOptions(); while (enm.hasMoreElements()) result.addElement(enm.nextElement()); result.addElement(new Option( "\tTurns off all checks - use with caution!\n" + "\tTurning them off assumes that data is purely numeric, doesn't\n" + "\tcontain any missing values, and has a nominal class. Turning them\n" + "\toff also means that no header information will be stored if the\n" + "\tmachine is linear. Finally, it also assumes that no instance has\n" + "\ta weight equal to 0.\n" + "\t(default: checks on)", "no-checks", 0, "-no-checks")); result.addElement(new Option( "\tThe complexity constant C. (default 1)", "C", 1, "-C <double>")); result.addElement(new Option( "\tWhether to 0=normalize/1=standardize/2=neither. " + "(default 0=normalize)", "N", 1, "-N")); result.addElement(new Option( "\tThe tolerance parameter. " + "(default 1.0e-3)", "L", 1, "-L <double>")); result.addElement(new Option( "\tThe epsilon for round-off error. " + "(default 1.0e-12)", "P", 1, "-P <double>")); result.addElement(new Option( "\tFit logistic models to SVM outputs. ", "M", 0, "-M")); result.addElement(new Option( "\tThe number of folds for the internal\n" + "\tcross-validation. " + "(default -1, use training data)", "V", 1, "-V <double>")); result.addElement(new Option( "\tThe random number seed. " + "(default 1)", "W", 1, "-W <double>")); result.addElement(new Option( "\tThe Kernel to use.\n" + "\t(default: weka.classifiers.functions.supportVector.PolyKernel)", "K", 1, "-K <classname and parameters>")); result.addElement(new Option( "", "", 0, "\nOptions specific to kernel " + getKernel().getClass().getName() + ":")); enm = ((OptionHandler) getKernel()).listOptions(); while (enm.hasMoreElements()) result.addElement(enm.nextElement()); return result.elements(); } /** * Parses a given list of options. <p/> * <!-- options-start --> * Valid options are: <p/> * * <pre> -D * If set, classifier is run in debug mode and * may output additional info to the console</pre> * * <pre> -no-checks * Turns off all checks - use with caution! * Turning them off assumes that data is purely numeric, doesn't * contain any missing values, and has a nominal class. Turning them * off also means that no header information will be stored if the * machine is linear. Finally, it also assumes that no instance has * a weight equal to 0. * (default: checks on)</pre> * * <pre> -C <double> * The complexity constant C. (default 1)</pre> * * <pre> -N * Whether to 0=normalize/1=standardize/2=neither. (default 0=normalize)</pre> * * <pre> -L <double> * The tolerance parameter. (default 1.0e-3)</pre> * * <pre> -P <double> * The epsilon for round-off error. (default 1.0e-12)</pre> * * <pre> -M * Fit logistic models to SVM outputs. </pre> * * <pre> -V <double> * The number of folds for the internal * cross-validation. (default -1, use training data)</pre> * * <pre> -W <double> * The random number seed. (default 1)</pre> * * <pre> -K <classname and parameters> * The Kernel to use. * (default: weka.classifiers.functions.supportVector.PolyKernel)</pre> * * <pre> * Options specific to kernel weka.classifiers.functions.supportVector.PolyKernel: * </pre> * * <pre> -D * Enables debugging output (if available) to be printed. * (default: off)</pre> * * <pre> -no-checks * Turns off all checks - use with caution! * (default: checks on)</pre> * * <pre> -C <num> * The size of the cache (a prime number). * (default: 250007)</pre>
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -