⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 smo.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
📖 第 1 页 / 共 4 页
字号:
      } else {	m_I3.delete(i1);      }      if ((y1 == -1) && (a1 == 0)) {	m_I4.insert(i1);      } else {	m_I4.delete(i1);      }      if (a2 > 0) {	m_supportVectors.insert(i2);      } else {	m_supportVectors.delete(i2);      }      if ((a2 > 0) && (a2 < C2)) {	m_I0.insert(i2);      } else {	m_I0.delete(i2);      }      if ((y2 == 1) && (a2 == 0)) {	m_I1.insert(i2);      } else {	m_I1.delete(i2);      }      if ((y2 == -1) && (a2 == C2)) {	m_I2.insert(i2);      } else {	m_I2.delete(i2);      }      if ((y2 == 1) && (a2 == C2)) {	m_I3.insert(i2);      } else {	m_I3.delete(i2);      }      if ((y2 == -1) && (a2 == 0)) {	m_I4.insert(i2);      } else {	m_I4.delete(i2);      }            // Update weight vector to reflect change a1 and a2, if linear SVM      if (!m_useRBF && m_exponent == 1.0) {	Instance inst1 = m_data.instance(i1);	for (int p1 = 0; p1 < inst1.numValues(); p1++) {	  if (inst1.index(p1) != m_data.classIndex()) {	    m_weights[inst1.index(p1)] += 	      y1 * (a1 - alph1) * inst1.valueSparse(p1);	  }	}	Instance inst2 = m_data.instance(i2);	for (int p2 = 0; p2 < inst2.numValues(); p2++) {	  if (inst2.index(p2) != m_data.classIndex()) {	    m_weights[inst2.index(p2)] += 	      y2 * (a2 - alph2) * inst2.valueSparse(p2);	  }	}      }            // Update error cache using new Lagrange multipliers      for (int j = m_I0.getNext(-1); j != -1; j = m_I0.getNext(j)) {	if ((j != i1) && (j != i2)) {	  m_errors[j] += 	    y1 * (a1 - alph1) * m_kernel.eval(i1, j, m_data.instance(i1)) + 	    y2 * (a2 - alph2) * m_kernel.eval(i2, j, m_data.instance(i2));	}      }            // Update error cache for i1 and i2      m_errors[i1] += y1 * (a1 - alph1) * k11 + y2 * (a2 - alph2) * k12;      m_errors[i2] += y1 * (a1 - alph1) * k12 + y2 * (a2 - alph2) * k22;            // Update array with Lagrange multipliers      m_alpha[i1] = a1;      m_alpha[i2] = a2;            // Update thresholds      m_bLow = -Double.MAX_VALUE; m_bUp = Double.MAX_VALUE;      m_iLow = -1; m_iUp = -1;      for (int j = m_I0.getNext(-1); j != -1; j = m_I0.getNext(j)) {	if (m_errors[j] < m_bUp) {	  m_bUp = m_errors[j]; m_iUp = j;	}	if (m_errors[j] > m_bLow) {	  m_bLow = m_errors[j]; m_iLow = j;	}      }      if (!m_I0.contains(i1)) {	if (m_I3.contains(i1) || m_I4.contains(i1)) {	  if (m_errors[i1] > m_bLow) {	    m_bLow = m_errors[i1]; m_iLow = i1;	  } 	} else {	  if (m_errors[i1] < m_bUp) {	    m_bUp = m_errors[i1]; m_iUp = i1;	  }	}      }      if (!m_I0.contains(i2)) {	if (m_I3.contains(i2) || m_I4.contains(i2)) {	  if (m_errors[i2] > m_bLow) {	    m_bLow = m_errors[i2]; m_iLow = i2;	  }	} else {	  if (m_errors[i2] < m_bUp) {	    m_bUp = m_errors[i2]; m_iUp = i2;	  }	}      }      if ((m_iLow == -1) || (m_iUp == -1)) {	throw new Exception("This should never happen!");      }      // Made some progress.      return true;    }      /**     * Quick and dirty check whether the quadratic programming problem is solved.     */    private void checkClassifier() throws Exception {      double sum = 0;      for (int i = 0; i < m_alpha.length; i++) {	if (m_alpha[i] > 0) {	  sum += m_class[i] * m_alpha[i];	}      }      System.err.println("Sum of y(i) * alpha(i): " + sum);      for (int i = 0; i < m_alpha.length; i++) {	double output = SVMOutput(i, m_data.instance(i));	if (Utils.eq(m_alpha[i], 0)) {	  if (Utils.sm(m_class[i] * output, 1)) {	    System.err.println("KKT condition 1 violated: " + m_class[i] * output);	  }	} 	if (Utils.gr(m_alpha[i], 0) && 	    Utils.sm(m_alpha[i], m_C * m_data.instance(i).weight())) {	  if (!Utils.eq(m_class[i] * output, 1)) {	    System.err.println("KKT condition 2 violated: " + m_class[i] * output);	  }	} 	if (Utils.eq(m_alpha[i], m_C * m_data.instance(i).weight())) {	  if (Utils.gr(m_class[i] * output, 1)) {	    System.err.println("KKT condition 3 violated: " + m_class[i] * output);	  }	}       }    }    }  /** The filter to apply to the training data */  public static final int FILTER_NORMALIZE = 0;  public static final int FILTER_STANDARDIZE = 1;  public static final int FILTER_NONE = 2;  public static final Tag [] TAGS_FILTER = {    new Tag(FILTER_NORMALIZE, "Normalize training data"),    new Tag(FILTER_STANDARDIZE, "Standardize training data"),    new Tag(FILTER_NONE, "No normalization/standardization"),  };  /** The binary classifier(s) */  private BinarySMO[][] m_classifiers = null;  /** The exponent for the polynomial kernel. */  private double m_exponent = 1.0;   /** Gamma for the RBF kernel. */  private double m_gamma = 0.01;    /** The complexity parameter. */  private double m_C = 1.0;    /** Epsilon for rounding. */  private double m_eps = 1.0e-12;    /** Tolerance for accuracy of result. */  private double m_tol = 1.0e-3;  /** Whether to normalize/standardize/neither */  private int m_filterType = FILTER_NORMALIZE;    /** Feature-space normalization? */  private boolean m_featureSpaceNormalization = false;    /** Use lower-order terms? */  private boolean m_lowerOrder = false;  /** Use RBF kernel? (default: poly) */  private boolean m_useRBF = false;    /** The size of the cache (a prime number) */  private int m_cacheSize = 1000003;  /** The filter used to make attributes numeric. */  private NominalToBinary m_NominalToBinary;  /** The filter used to standardize/normalize all values. */  private Filter m_Filter = null;  /** The filter used to get rid of missing values. */  private ReplaceMissingValues m_Missing;  /** Only numeric attributes in the dataset? */  private boolean m_onlyNumeric;  /** The class index from the training data */  private int m_classIndex = -1;  /** The class attribute */  private Attribute m_classAttribute;  /** Turn off all checks and conversions? Turning them off assumes      that data is purely numeric, doesn't contain any missing values,      and has a nominal class. Turning them off also means that      no header information will be stored if the machine is linear. */  private boolean m_checksTurnedOff;  /** Precision constant for updating sets */  private static double m_Del = 1000 * Double.MIN_VALUE;  /** Whether logistic models are to be fit */  private boolean m_fitLogisticModels = false;  /** The number of folds for the internal cross-validation */  private int m_numFolds = -1;  /** The random number seed for the internal cross-validation */  private int m_randomSeed = 1;  /**   * Turns off checks for missing values, etc. Use with caution.   */  public void turnChecksOff() {    m_checksTurnedOff = true;  }  /**   * Turns on checks for missing values, etc.   */  public void turnChecksOn() {    m_checksTurnedOff = false;  }  /**   * Method for building the classifier. Implements a one-against-one   * wrapper for multi-class problems.   *   * @param insts the set of training instances   * @exception Exception if the classifier can't be built successfully   */  public void buildClassifier(Instances insts) throws Exception {    if (!m_checksTurnedOff) {      if (insts.checkForStringAttributes()) {	throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");      }      if (insts.classAttribute().isNumeric()) {	throw new UnsupportedClassTypeException("SMO can't handle a numeric class!");      }      insts = new Instances(insts);      insts.deleteWithMissingClass();      if (insts.numInstances() == 0) {	throw new Exception("No training instances without a missing class!");      }    }    m_onlyNumeric = true;    if (!m_checksTurnedOff) {      for (int i = 0; i < insts.numAttributes(); i++) {	if (i != insts.classIndex()) {	  if (!insts.attribute(i).isNumeric()) {	    m_onlyNumeric = false;	    break;	  }	}      }    }    if (!m_checksTurnedOff) {      m_Missing = new ReplaceMissingValues();      m_Missing.setInputFormat(insts);      insts = Filter.useFilter(insts, m_Missing);     } else {      m_Missing = null;    }    if (!m_onlyNumeric) {      m_NominalToBinary = new NominalToBinary();      m_NominalToBinary.setInputFormat(insts);      insts = Filter.useFilter(insts, m_NominalToBinary);    } else {      m_NominalToBinary = null;    }    if (m_filterType == FILTER_STANDARDIZE) {      m_Filter = new Standardize();      m_Filter.setInputFormat(insts);      insts = Filter.useFilter(insts, m_Filter);     } else if (m_filterType == FILTER_NORMALIZE) {      m_Filter = new Normalize();      m_Filter.setInputFormat(insts);      insts = Filter.useFilter(insts, m_Filter);     } else {      m_Filter = null;    }    m_classIndex = insts.classIndex();    m_classAttribute = insts.classAttribute();    // Generate subsets representing each class    Instances[] subsets = new Instances[insts.numClasses()];    for (int i = 0; i < insts.numClasses(); i++) {      subsets[i] = new Instances(insts, insts.numInstances());    }    for (int j = 0; j < insts.numInstances(); j++) {      Instance inst = insts.instance(j);      subsets[(int)inst.classValue()].add(inst);    }    for (int i = 0; i < insts.numClasses(); i++) {      subsets[i].compactify();    }    // Build the binary classifiers    m_classifiers = new BinarySMO[insts.numClasses()][insts.numClasses()];    for (int i = 0; i < insts.numClasses(); i++) {      for (int j = i + 1; j < insts.numClasses(); j++) {	m_classifiers[i][j] = new BinarySMO();	Instances data = new Instances(insts, insts.numInstances());	for (int k = 0; k < subsets[i].numInstances(); k++) {	  data.add(subsets[i].instance(k));	}	for (int k = 0; k < subsets[j].numInstances(); k++) {	  data.add(subsets[j].instance(k));	}	data.compactify();	m_classifiers[i][j].buildClassifier(data, i, j, 					    m_fitLogisticModels,					    m_numFolds, m_randomSeed);      }    }  }  /**   * Estimates class probabilities for given instance.   */  public double[] distributionForInstance(Instance inst) throws Exception {    // Filter instance    if (!m_checksTurnedOff) {      m_Missing.input(inst);      m_Missing.batchFinished();      inst = m_Missing.output();    }    if (!m_onlyNumeric) {      m_NominalToBinary.input(inst);      m_NominalToBinary.batchFinished();      inst = m_NominalToBinary.output();    }        if (m_Filter != null) {      m_Filter.input(inst);      m_Filter.batchFinished();      inst = m_Filter.output();    }        if (!m_fitLogisticModels) {      double[] result = new double[inst.numClasses()];      for (int i = 0; i < inst.numClasses(); i++) {	for (int j = i + 1; j < inst.numClasses(); j++) {	  double output = m_classifiers[i][j].SVMOutput(-1, inst);	  if (output > 0) {	    result[j] += 1;	  } else {	    result[i] += 1;	  }	}       }      Utils.normalize(result);      return result;    } else {      // We only need to do pairwise coupling if there are more      // then two classes.      if (inst.numClasses() == 2) {	double[] newInst = new double[2];	newInst[0] = m_classifiers[0][1].SVMOutput(-1, inst);	newInst[1] = Instance.missingValue();	return m_classifiers[0][1].m_logistic.	  distributionForInstance(new Instance(1, newInst));      }      double[][] r = new double[inst.numClasses()][inst.numClasses()];      double[][] n = new double[inst.numClasses()][inst.numClasses()];      for (int i = 0; i < inst.numClasses(); i++) {	for (int j = i + 1; j < inst.numClasses(); j++) {	  double[] newInst = new double[2];	  newInst[0] = m_classifiers[i][j].SVMOutput(-1, inst);	  newInst[1] = Instance.missingValue();	  r[i][j] = m_classifiers[i][j].m_logistic.	    distributionForInstance(new Instance(1, newInst))[0];	  n[i][j] = m_classifiers[i][j].m_sumOfWeights;	}      }      return pairwiseCoupling(n, r);    }  }  /**   * Implements pairwise coupling.   *   * @param n the sum of weights used to train each model   * @param r the probability estimate from each model   * @return the coupled estimates   */  public double[] pairwiseCoupling(double[][] n, double[][] r) {    // Initialize p and u array    double[] p = new double[r.length];    for (int i =0; i < p.length; i++) {      p[i] = 1.0 / (double)p.length;    }    double[][] u = new double[r.length][r.length];    for (int i = 0; i < r.length; i++) {      for (int j = i + 1; j < r.length; j++) {	u[i][j] = 0.5;      }    }    // firstSum doesn't change    double[] firstSum = new double[p.length];    for (int i = 0; i < p.length; i++) {      for (int j = i + 1; j < p.length; j++) {	firstSum[i] += n[i][j] * r[i][j];	firstSum[j] += n[i][j] * (1 - r[i][j]);      }    }    // Iterate until convergence    boolean changed;    do {      changed = false;      double[] secondSum = new double[p.length];      for (int i = 0; i < p.length; i++) {	for (int j = i + 1; j < p.length; j++) {	  secondSum[i] += n[i][j] * u[i][j];	  secondSum[j] += n[i][j] * (1 - u[i][j]);	}      }      for (int i = 0; i < p.length; i++) {	if ((firstSum[i] == 0) || (secondSum[i] == 0)) {	  if (p[i] > 0) {	    changed = true;	  }	  p[i] = 0;	} else {	  double factor = firstSum[i] / secondSum[i];	  double pOld = p[i];	  p[i] *= factor;	  if (Math.abs(pOld - p[i]) > 1.0e-3) {	    changed = true;	  }	}      }      Utils.normalize(p);      for (int i = 0; i < r.length; i++) {	for (int j = i + 1; j < r.length; j++) {	  u[i][j] = p[i] / (p[i] + p[j]);	}      }    } while (changed);    return p;  }  /**   * Returns an array of votes for the given instance.   * @param inst the instance   * @return array of votex   * @exception Exception if something goes wrong   */  public int[] obtainVotes(Instance inst) throws Exception {    // Filter instance    if (!m_checksTurnedOff) {      m_Missing.input(inst);      m_Missing.batchFinished();      inst = m_Missing.output();    }    if (!m_onlyNumeric) {      m_NominalToBinary.input(inst);      m_NominalToBinary.batchFinished();      inst = m_NominalToBinary.output();    }        if (m_Filter != null) {      m_Filter.input(inst);      m_Filter.batchFinished();      inst = m_Filter.output();    }    int[] votes = new int[inst.numClasses()];    for (int i = 0; i < inst.numClasses(); i++) {      for (int j = i + 1; j < inst.numClasses(); j++) {	double output = m_classifiers[i][j].SVMOutput(-1, inst);	if (output > 0) {	  votes[j] += 1;	} else {	  votes[i] += 1;	}      }    }    return votes;  }  /**   * Returns the coefficients in sparse format.  Throws an exception   * if there is more than one machine or if the machine is not   * linear.     */  public FastVector weights() throws Exception {        if (m_classifiers.length > 2) {      throw new Exception("More than one machine has been built.");    }    if (m_classifiers[0][1].m_sparseWeights == null) {      throw new Exception("No weight vector available.");    }    FastVector vec = new FastVector(2);    vec.addElement(m_classifiers[0][1].m_sparseWeights);    vec.addElement(m_classifiers[0][1].m_sparseIndices);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -