⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 smo.java

📁 MacroWeka扩展了著名数据挖掘工具weka
💻 JAVA
📖 第 1 页 / 共 5 页
字号:
      if ((y2 == -1) && (a2 == C2)) {
	m_I2.insert(i2);
      } else {
	m_I2.delete(i2);
      }
      if ((y2 == 1) && (a2 == C2)) {
	m_I3.insert(i2);
      } else {
	m_I3.delete(i2);
      }
      if ((y2 == -1) && (a2 == 0)) {
	m_I4.insert(i2);
      } else {
	m_I4.delete(i2);
      }
      
      // Update weight vector to reflect change a1 and a2, if linear SVM
      if (!m_useRBF && m_exponent == 1.0) {
	Instance inst1 = m_data.instance(i1);
	for (int p1 = 0; p1 < inst1.numValues(); p1++) {
	  if (inst1.index(p1) != m_data.classIndex()) {
	    m_weights[inst1.index(p1)] += 
	      y1 * (a1 - alph1) * inst1.valueSparse(p1);
	  }
	}
	Instance inst2 = m_data.instance(i2);
	for (int p2 = 0; p2 < inst2.numValues(); p2++) {
	  if (inst2.index(p2) != m_data.classIndex()) {
	    m_weights[inst2.index(p2)] += 
	      y2 * (a2 - alph2) * inst2.valueSparse(p2);
	  }
	}
      }
      
      // Update error cache using new Lagrange multipliers
      for (int j = m_I0.getNext(-1); j != -1; j = m_I0.getNext(j)) {
	if ((j != i1) && (j != i2)) {
	  m_errors[j] += 
	    y1 * (a1 - alph1) * m_kernel.eval(i1, j, m_data.instance(i1)) + 
	    y2 * (a2 - alph2) * m_kernel.eval(i2, j, m_data.instance(i2));
	}
      }
      
      // Update error cache for i1 and i2
      m_errors[i1] += y1 * (a1 - alph1) * k11 + y2 * (a2 - alph2) * k12;
      m_errors[i2] += y1 * (a1 - alph1) * k12 + y2 * (a2 - alph2) * k22;
      
      // Update array with Lagrange multipliers
      m_alpha[i1] = a1;
      m_alpha[i2] = a2;
      
      // Update thresholds
      m_bLow = -Double.MAX_VALUE; m_bUp = Double.MAX_VALUE;
      m_iLow = -1; m_iUp = -1;
      for (int j = m_I0.getNext(-1); j != -1; j = m_I0.getNext(j)) {
	if (m_errors[j] < m_bUp) {
	  m_bUp = m_errors[j]; m_iUp = j;
	}
	if (m_errors[j] > m_bLow) {
	  m_bLow = m_errors[j]; m_iLow = j;
	}
      }
      if (!m_I0.contains(i1)) {
	if (m_I3.contains(i1) || m_I4.contains(i1)) {
	  if (m_errors[i1] > m_bLow) {
	    m_bLow = m_errors[i1]; m_iLow = i1;
	  } 
	} else {
	  if (m_errors[i1] < m_bUp) {
	    m_bUp = m_errors[i1]; m_iUp = i1;
	  }
	}
      }
      if (!m_I0.contains(i2)) {
	if (m_I3.contains(i2) || m_I4.contains(i2)) {
	  if (m_errors[i2] > m_bLow) {
	    m_bLow = m_errors[i2]; m_iLow = i2;
	  }
	} else {
	  if (m_errors[i2] < m_bUp) {
	    m_bUp = m_errors[i2]; m_iUp = i2;
	  }
	}
      }
      if ((m_iLow == -1) || (m_iUp == -1)) {
	throw new Exception("This should never happen!");
      }

      // Made some progress.
      return true;
    }
  
    /**
     * Quick and dirty check whether the quadratic programming problem is solved.
     */
    protected void checkClassifier() throws Exception {

      double sum = 0;
      for (int i = 0; i < m_alpha.length; i++) {
	if (m_alpha[i] > 0) {
	  sum += m_class[i] * m_alpha[i];
	}
      }
      System.err.println("Sum of y(i) * alpha(i): " + sum);

      for (int i = 0; i < m_alpha.length; i++) {
	double output = SVMOutput(i, m_data.instance(i));
	if (Utils.eq(m_alpha[i], 0)) {
	  if (Utils.sm(m_class[i] * output, 1)) {
	    System.err.println("KKT condition 1 violated: " + m_class[i] * output);
	  }
	} 
	if (Utils.gr(m_alpha[i], 0) && 
	    Utils.sm(m_alpha[i], m_C * m_data.instance(i).weight())) {
	  if (!Utils.eq(m_class[i] * output, 1)) {
	    System.err.println("KKT condition 2 violated: " + m_class[i] * output);
	  }
	} 
	if (Utils.eq(m_alpha[i], m_C * m_data.instance(i).weight())) {
	  if (Utils.gr(m_class[i] * output, 1)) {
	    System.err.println("KKT condition 3 violated: " + m_class[i] * output);
	  }
	} 
      }
    }  
  }

  /** The filter to apply to the training data */
  public static final int FILTER_NORMALIZE = 0;
  public static final int FILTER_STANDARDIZE = 1;
  public static final int FILTER_NONE = 2;
  public static final Tag [] TAGS_FILTER = {
    new Tag(FILTER_NORMALIZE, "Normalize training data"),
    new Tag(FILTER_STANDARDIZE, "Standardize training data"),
    new Tag(FILTER_NONE, "No normalization/standardization"),
  };

  /** The binary classifier(s) */
  protected BinarySMO[][] m_classifiers = null;

  /** The exponent for the polynomial kernel. */
  protected double m_exponent = 1.0;
 
  /** Use lower-order terms? */
  protected boolean m_lowerOrder = false;
  
  /** Gamma for the RBF kernel. */
  protected double m_gamma = 0.01;
  
  /** The complexity parameter. */
  protected double m_C = 1.0;
  
  /** Epsilon for rounding. */
  protected double m_eps = 1.0e-12;
  
  /** Tolerance for accuracy of result. */
  protected double m_tol = 1.0e-3;

  /** Whether to normalize/standardize/neither */
  protected int m_filterType = FILTER_NORMALIZE;
  
  /** Feature-space normalization? */
  protected boolean m_featureSpaceNormalization = false;
  
  /** Use RBF kernel? (default: poly) */
  protected boolean m_useRBF = false;
  
  /** The size of the cache (a prime number) */
  protected int m_cacheSize = 250007;

  /** The filter used to make attributes numeric. */
  protected NominalToBinary m_NominalToBinary;

  /** The filter used to standardize/normalize all values. */
  protected Filter m_Filter = null;

  /** The filter used to get rid of missing values. */
  protected ReplaceMissingValues m_Missing;

  /** Only numeric attributes in the dataset? */
  protected boolean m_onlyNumeric;

  /** The class index from the training data */
  protected int m_classIndex = -1;

  /** The class attribute */
  protected Attribute m_classAttribute;

  /** Turn off all checks and conversions? Turning them off assumes
      that data is purely numeric, doesn't contain any missing values,
      and has a nominal class. Turning them off also means that
      no header information will be stored if the machine is linear. 
      Finally, it also assumes that no instance has a weight equal to 0.*/
  protected boolean m_checksTurnedOff;

  /** Precision constant for updating sets */
  protected static double m_Del = 1000 * Double.MIN_VALUE;

  /** Whether logistic models are to be fit */
  protected boolean m_fitLogisticModels = false;

  /** The number of folds for the internal cross-validation */
  protected int m_numFolds = -1;

  /** The random number seed  */
  protected int m_randomSeed = 1;

  /**
   * Turns off checks for missing values, etc. Use with caution.
   */
  public void turnChecksOff() {

    m_checksTurnedOff = true;
  }

  /**
   * Turns on checks for missing values, etc.
   */
  public void turnChecksOn() {

    m_checksTurnedOff = false;
  }

  /**
   * Method for building the classifier. Implements a one-against-one
   * wrapper for multi-class problems.
   *
   * @param insts the set of training instances
   * @exception Exception if the classifier can't be built successfully
   */
  public void buildClassifier(Instances insts) throws Exception {

    if (!m_checksTurnedOff) {
      if (insts.checkForStringAttributes()) {
	throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
      }
      if (insts.classAttribute().isNumeric()) {
	throw new UnsupportedClassTypeException("SMO can't handle a numeric class! Use"
						+ "SMOreg for performing regression.");
      }
      insts = new Instances(insts);
      insts.deleteWithMissingClass();
      if (insts.numInstances() == 0) {
	throw new Exception("No training instances without a missing class!");
      }

      
      /* Removes all the instances with weight equal to 0.
	 MUST be done since condition (8) of Keerthi's paper 
	 is made with the assertion Ci > 0 (See equation (3a). */
      Instances data = new Instances(insts, insts.numInstances());
      for(int i = 0; i < insts.numInstances(); i++){
	if(insts.instance(i).weight() > 0)
	  data.add(insts.instance(i));
      }
      if (data.numInstances() == 0) {
	throw new Exception("No training instances left after removing " + 
			    "instance with either a weight null or a missing class!");
      }
      insts = data;
      
    }

    m_onlyNumeric = true;
    if (!m_checksTurnedOff) {
      for (int i = 0; i < insts.numAttributes(); i++) {
	if (i != insts.classIndex()) {
	  if (!insts.attribute(i).isNumeric()) {
	    m_onlyNumeric = false;
	    break;
	  }
	}
      }
    }

    if (!m_checksTurnedOff) {
      m_Missing = new ReplaceMissingValues();
      m_Missing.setInputFormat(insts);
      insts = Filter.useFilter(insts, m_Missing); 
    } else {
      m_Missing = null;
    }

    if (!m_onlyNumeric) {
      m_NominalToBinary = new NominalToBinary();
      m_NominalToBinary.setInputFormat(insts);
      insts = Filter.useFilter(insts, m_NominalToBinary);
    } else {
      m_NominalToBinary = null;
    }

    if (m_filterType == FILTER_STANDARDIZE) {
      m_Filter = new Standardize();
      m_Filter.setInputFormat(insts);
      insts = Filter.useFilter(insts, m_Filter); 
    } else if (m_filterType == FILTER_NORMALIZE) {
      m_Filter = new Normalize();
      m_Filter.setInputFormat(insts);
      insts = Filter.useFilter(insts, m_Filter); 
    } else {
      m_Filter = null;
    }

    m_classIndex = insts.classIndex();
    m_classAttribute = insts.classAttribute();

    // Generate subsets representing each class
    Instances[] subsets = new Instances[insts.numClasses()];
    for (int i = 0; i < insts.numClasses(); i++) {
      subsets[i] = new Instances(insts, insts.numInstances());
    }
    for (int j = 0; j < insts.numInstances(); j++) {
      Instance inst = insts.instance(j);
      subsets[(int)inst.classValue()].add(inst);
    }
    for (int i = 0; i < insts.numClasses(); i++) {
      subsets[i].compactify();
    }

    // Build the binary classifiers
    Random rand = new Random(m_randomSeed);
    m_classifiers = new BinarySMO[insts.numClasses()][insts.numClasses()];
    for (int i = 0; i < insts.numClasses(); i++) {
      for (int j = i + 1; j < insts.numClasses(); j++) {
	m_classifiers[i][j] = new BinarySMO();
	Instances data = new Instances(insts, insts.numInstances());
	for (int k = 0; k < subsets[i].numInstances(); k++) {
	  data.add(subsets[i].instance(k));
	}
	for (int k = 0; k < subsets[j].numInstances(); k++) {
	  data.add(subsets[j].instance(k));
	}
	data.compactify();
	data.randomize(rand);
	m_classifiers[i][j].buildClassifier(data, i, j, 
					    m_fitLogisticModels,
					    m_numFolds, m_randomSeed);
      }
    }
  }

  /**
   * Estimates class probabilities for given instance.
   */
  public double[] distributionForInstance(Instance inst) throws Exception {

    // Filter instance
    if (!m_checksTurnedOff) {
      m_Missing.input(inst);
      m_Missing.batchFinished();
      inst = m_Missing.output();
    }

    if (!m_onlyNumeric) {
      m_NominalToBinary.input(inst);
      m_NominalToBinary.batchFinished();
      inst = m_NominalToBinary.output();
    }
    
    if (m_Filter != null) {
      m_Filter.input(inst);
      m_Filter.batchFinished();
      inst = m_Filter.output();
    }
    
    if (!m_fitLogisticModels) {
      double[] result = new double[inst.numClasses()];
      for (int i = 0; i < inst.numClasses(); i++) {
	for (int j = i + 1; j < inst.numClasses(); j++) {
	  if ((m_classifiers[i][j].m_alpha != null) || 
	      (m_classifiers[i][j].m_sparseWeights != null)) {
	    double output = m_classifiers[i][j].SVMOutput(-1, inst);
	    if (output > 0) {
	      result[j] += 1;
	    } else {
	      result[i] += 1;
	    }
	  }
	} 
      }
      Utils.normalize(result);
      return result;
    } else {

      // We only need to do pairwise coupling if there are more
      // then two classes.
      if (inst.numClasses() == 2) {
	double[] newInst = new double[2];
	newInst[0] = m_classifiers[0][1].SVMOutput(-1, inst);
	newInst[1] = Instance.missingValue();
	return m_classifiers[0][1].m_logistic.
	  distributionForInstance(new Instance(1, newInst));
      }
      double[][] r = new double[inst.numClasses()][inst.numClasses()];
      double[][] n = new double[inst.numClasses()][inst.numClasses()];
      for (int i = 0; i < inst.numClasses(); i++) {
	for (int j = i + 1; j < inst.numClasses(); j++) {
	  if ((m_classifiers[i][j].m_alpha != null) || 
	      (m_classifiers[i][j].m_sparseWeights != null)) {
	    double[] newInst = new double[2];
	    newInst[0] = m_classifiers[i][j].SVMOutput(-1, inst);
	    newInst[1] = Instance.missingValue();
	    r[i][j] = m_classifiers[i][j].m_logistic.
	      distributionForInstance(new Instance(1, newInst))[0];
	    n[i][j] = m_classifiers[i][j].m_sumOfWeights;
	  }
	}
      }
      return pairwiseCoupling(n, r);
    }
  }

  /**
   * Implements pairwise coupling.
   *
   * @param n the sum of weights used to train each model
   * @param r the probability estimate from each model
   * @return the coupled estimates
   */
  public double[] pairwiseCoupling(double[][] n, double[][] r) {

    // Initialize p and u array
    double[] p = new double[r.length];
    for (int i =0; i < p.length; i++) {

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -