⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 smo.java

📁 MacroWeka扩展了著名数据挖掘工具weka
💻 JAVA
📖 第 1 页 / 共 5 页
字号:
      p[i] = 1.0 / (double)p.length;
    }
    double[][] u = new double[r.length][r.length];
    for (int i = 0; i < r.length; i++) {
      for (int j = i + 1; j < r.length; j++) {
	u[i][j] = 0.5;
      }
    }

    // firstSum doesn't change
    double[] firstSum = new double[p.length];
    for (int i = 0; i < p.length; i++) {
      for (int j = i + 1; j < p.length; j++) {
	firstSum[i] += n[i][j] * r[i][j];
	firstSum[j] += n[i][j] * (1 - r[i][j]);
      }
    }

    // Iterate until convergence
    boolean changed;
    do {
      changed = false;
      double[] secondSum = new double[p.length];
      for (int i = 0; i < p.length; i++) {
	for (int j = i + 1; j < p.length; j++) {
	  secondSum[i] += n[i][j] * u[i][j];
	  secondSum[j] += n[i][j] * (1 - u[i][j]);
	}
      }
      for (int i = 0; i < p.length; i++) {
	if ((firstSum[i] == 0) || (secondSum[i] == 0)) {
	  if (p[i] > 0) {
	    changed = true;
	  }
	  p[i] = 0;
	} else {
	  double factor = firstSum[i] / secondSum[i];
	  double pOld = p[i];
	  p[i] *= factor;
	  if (Math.abs(pOld - p[i]) > 1.0e-3) {
	    changed = true;
	  }
	}
      }
      Utils.normalize(p);
      for (int i = 0; i < r.length; i++) {
	for (int j = i + 1; j < r.length; j++) {
	  u[i][j] = p[i] / (p[i] + p[j]);
	}
      }
    } while (changed);
    return p;
  }

  /**
   * Returns an array of votes for the given instance.
   * @param inst the instance
   * @return array of votex
   * @exception Exception if something goes wrong
   */
  public int[] obtainVotes(Instance inst) throws Exception {

    // Filter instance
    if (!m_checksTurnedOff) {
      m_Missing.input(inst);
      m_Missing.batchFinished();
      inst = m_Missing.output();
    }

    if (!m_onlyNumeric) {
      m_NominalToBinary.input(inst);
      m_NominalToBinary.batchFinished();
      inst = m_NominalToBinary.output();
    }
    
    if (m_Filter != null) {
      m_Filter.input(inst);
      m_Filter.batchFinished();
      inst = m_Filter.output();
    }

    int[] votes = new int[inst.numClasses()];
    for (int i = 0; i < inst.numClasses(); i++) {
      for (int j = i + 1; j < inst.numClasses(); j++) {
	double output = m_classifiers[i][j].SVMOutput(-1, inst);
	if (output > 0) {
	  votes[j] += 1;
	} else {
	  votes[i] += 1;
	}
      }
    }
    return votes;
  }

  /**
   * Returns the weights in sparse format.
   */
  public double [][][] sparseWeights() {
    
    int numValues = m_classAttribute.numValues();
    double [][][] sparseWeights = new double[numValues][numValues][];
    
    for (int i = 0; i < numValues; i++) {
      for (int j = i + 1; j < numValues; j++) {
	sparseWeights[i][j] = m_classifiers[i][j].m_sparseWeights;
      }
    }
    
    return sparseWeights;
  }
  
  /**
   * Returns the indices in sparse format.
   */
  public int [][][] sparseIndices() {
    
    int numValues = m_classAttribute.numValues();
    int [][][] sparseIndices = new int[numValues][numValues][];

    for (int i = 0; i < numValues; i++) {
      for (int j = i + 1; j < numValues; j++) {
	sparseIndices[i][j] = m_classifiers[i][j].m_sparseIndices;
      }
    }
    
    return sparseIndices;
  }
  
  /**
   * Returns the bias of each binary SMO.
   */
  public double [][] bias() {
    
    int numValues = m_classAttribute.numValues();
    double [][] bias = new double[numValues][numValues];

    for (int i = 0; i < numValues; i++) {
      for (int j = i + 1; j < numValues; j++) {
	bias[i][j] = m_classifiers[i][j].m_b;
      }
    }
    
    return bias;
  }
  
  /*
   * Returns the number of values of the class attribute.
   */
  public int numClassAttributeValues() {

    return m_classAttribute.numValues();
  }
  
  /*
   * Returns the names of the class attributes.
   */
  public String [] classAttributeNames() {

    int numValues = m_classAttribute.numValues();
    
    String [] classAttributeNames = new String[numValues];
    
    for (int i = 0; i < numValues; i++) {
      classAttributeNames[i] = m_classAttribute.value(i);
    }
    
    return classAttributeNames;
  }
  
  /**
   * Returns the attribute names.
   */
  public String [][][] attributeNames() {
    
    int numValues = m_classAttribute.numValues();
    String [][][] attributeNames = new String[numValues][numValues][];
    
    for (int i = 0; i < numValues; i++) {
      for (int j = i + 1; j < numValues; j++) {
	int numAttributes = m_classifiers[i][j].m_data.numAttributes();
	String [] attrNames = new String[numAttributes];
	for (int k = 0; k < numAttributes; k++) {
	  attrNames[k] = m_classifiers[i][j].m_data.attribute(k).name();
	}
	attributeNames[i][j] = attrNames;          
      }
    }
    return attributeNames;
  }
  
  /**
   * Returns an enumeration describing the available options.
   *
   * @return an enumeration of all the available options.
   */
  public Enumeration listOptions() {

    Vector newVector = new Vector(13);

    newVector.addElement(new Option("\tThe complexity constant C. (default 1)",
				    "C", 1, "-C <double>"));
    newVector.addElement(new Option("\tThe exponent for the "
				    + "polynomial kernel. (default 1)",
				    "E", 1, "-E <double>"));
    newVector.addElement(new Option("\tGamma for the RBF kernel. (default 0.01)",
				    "G", 1, "-G <double>"));
    newVector.addElement(new Option("\tWhether to 0=normalize/1=standardize/2=neither. " +
				    "(default 0=normalize)",
				    "N", 1, "-N"));
    newVector.addElement(new Option("\tFeature-space normalization (only for\n" +
				    "\tnon-linear polynomial kernels).",
				    "F", 0, "-F"));
    newVector.addElement(new Option("\tUse lower-order terms (only for non-linear\n" +
				    "\tpolynomial kernels).",
				    "O", 0, "-O"));
    newVector.addElement(new Option("\tUse RBF kernel. " +
    				    "(default poly)",
				    "R", 0, "-R"));
    newVector.addElement(new Option("\tThe size of the kernel cache. " +
				    "(default 250007, use 0 for full cache)",
				    "A", 1, "-A <int>"));
    newVector.addElement(new Option("\tThe tolerance parameter. " +
				    "(default 1.0e-3)",
				    "L", 1, "-L <double>"));
    newVector.addElement(new Option("\tThe epsilon for round-off error. " +
				    "(default 1.0e-12)",
				    "P", 1, "-P <double>"));
    newVector.addElement(new Option("\tFit logistic models to SVM outputs. ",
				    "M", 0, "-M"));
    newVector.addElement(new Option("\tThe number of folds for the internal\n" +
				    "\tcross-validation. " +
				    "(default -1, use training data)",
				    "V", 1, "-V <double>"));
    newVector.addElement(new Option("\tThe random number seed. " +
				    "(default 1)",
				    "W", 1, "-W <double>"));

    return newVector.elements();
  }

  /**
   * Parses a given list of options. Valid options are:<p>
   *
   * -C num <br>
   * The complexity constant C. (default 1)<p>
   *
   * -E num <br>
   * The exponent for the polynomial kernel. (default 1) <p>
   *
   * -G num <br>
   * Gamma for the RBF kernel. (default 0.01) <p>
   *
   * -N <0|1|2> <br>
   * Whether to 0=normalize/1=standardize/2=neither. (default 0=normalize)<p>
   *
   * -F <br>
   * Feature-space normalization (only for  non-linear polynomial kernels). <p>
   *
   * -O <br>
   * Use lower-order terms (only for non-linear polynomial kernels). <p>
   *
   * -R <br>
   * Use RBF kernel (default poly). <p>
   * 
   * -A num <br> Sets the size of the kernel cache. Should be a prime
   * number. (default 250007, use 0 for full cache) <p>
   *
   * -L num <br>
   * Sets the tolerance parameter. (default 1.0e-3)<p>
   *
   * -P num <br> 
   * Sets the epsilon for round-off error. (default 1.0e-12)<p>
   *
   * -M <br>
   * Fit logistic models to SVM outputs.<p>
   *
   * -V num <br>
   * Number of folds for cross-validation used to generate data
   * for logistic models. (default -1, use training data)
   *
   * -W num <br>
   * Random number seed. (default 1)
   *
   * @param options the list of options as an array of strings
   * @exception Exception if an option is not supported 
   */
  public void setOptions(String[] options) throws Exception {
    
    String complexityString = Utils.getOption('C', options);
    if (complexityString.length() != 0) {
      m_C = (new Double(complexityString)).doubleValue();
    } else {
      m_C = 1.0;
    }
    String exponentsString = Utils.getOption('E', options);
    if (exponentsString.length() != 0) {
      m_exponent = (new Double(exponentsString)).doubleValue();
    } else {
      m_exponent = 1.0;
    }
    String gammaString = Utils.getOption('G', options);
    if (gammaString.length() != 0) {
      m_gamma = (new Double(gammaString)).doubleValue();
    } else {
      m_gamma = 0.01;
    }
    String cacheString = Utils.getOption('A', options);
    if (cacheString.length() != 0) {
      m_cacheSize = Integer.parseInt(cacheString);
    } else {
      m_cacheSize = 250007;
    }
    String toleranceString = Utils.getOption('L', options);
    if (toleranceString.length() != 0) {
      m_tol = (new Double(toleranceString)).doubleValue();
    } else {
      m_tol = 1.0e-3;
    }
    String epsilonString = Utils.getOption('P', options);
    if (epsilonString.length() != 0) {
      m_eps = (new Double(epsilonString)).doubleValue();
    } else {
      m_eps = 1.0e-12;
    }
    m_useRBF = Utils.getFlag('R', options);
    String nString = Utils.getOption('N', options);
    if (nString.length() != 0) {
      setFilterType(new SelectedTag(Integer.parseInt(nString), TAGS_FILTER));
    } else {
      setFilterType(new SelectedTag(FILTER_NORMALIZE, TAGS_FILTER));
    }
    m_featureSpaceNormalization = Utils.getFlag('F', options);
    if ((m_useRBF) && (m_featureSpaceNormalization)) {
      throw new Exception("RBF machine doesn't require feature-space normalization.");
    }
    if ((m_exponent == 1.0) && (m_featureSpaceNormalization)) {
      throw new Exception("Can't use feature-space normalization with linear machine.");
    }
    m_lowerOrder = Utils.getFlag('O', options);
    if ((m_useRBF) && (m_lowerOrder)) {
      throw new Exception("Can't use lower-order terms with RBF machine.");
    }
    if ((m_exponent == 1.0) && (m_lowerOrder)) {
      throw new Exception("Can't use lower-order terms with linear machine.");
    }
    m_fitLogisticModels = Utils.getFlag('M', options);
    String foldsString = Utils.getOption('V', options);
    if (foldsString.length() != 0) {
      m_numFolds = Integer.parseInt(foldsString);
    } else {
      m_numFolds = -1;
    }
    String randomSeedString = Utils.getOption('W', options);
    if (randomSeedString.length() != 0) {
      m_randomSeed = Integer.parseInt(randomSeedString);
    } else {
      m_randomSeed = 1;
    }
  }

  /**
   * Gets the current settings of the classifier.
   *
   * @return an array of strings suitable for passing to setOptions
   */
  public String [] getOptions() {

    String [] options = new String [21];
    int current = 0;

    options[current++] = "-C"; options[current++] = "" + m_C;
    options[current++] = "-E"; options[current++] = "" + m_exponent;
    options[current++] = "-G"; options[current++] = "" + m_gamma;
    options[current++] = "-A"; options[current++] = "" + m_cacheSize;
    options[current++] = "-L"; options[current++] = "" + m_tol;
    options[current++] = "-P"; options[current++] = "" + m_eps;
    options[current++] = "-N"; options[current++] = "" + m_filterType;
    if (m_featureSpaceNormalization) {
      options[current++] = "-F";
    }
    if (m_lowerOrder) {
      options[current++] = "-O";
    }
    if (m_useRBF) {
      options[current++] = "-R";
    }
    if (m_fitLogisticModels) {
      options[current++] = "-M";
    }
    options[current++] = "-V"; options[current++] = "" + m_numFolds;
    options[current++] = "-W"; options[current++] = "" + m_randomSeed;    

    while (current < options.length) {
      options[current++] = "";
    }
    return options;
  }
     
  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String exponentTipText() {
    return "The exponent for the polynomial kernel.";
  }
  
  /**
   * Get the value of exponent. 
   *
   * @return Value of exponent.
   */
  public double getExponent() {
    
    return m_exponent;
  }
  
  /**
   * Set the value of exponent. If linear kernel
   * is used, rescaling and lower-order terms are
   * turned off.
   *
   * @param v  Value to assign to exponent.

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -