📄 smo.java
字号:
p[i] = 1.0 / (double)p.length;
}
double[][] u = new double[r.length][r.length];
for (int i = 0; i < r.length; i++) {
for (int j = i + 1; j < r.length; j++) {
u[i][j] = 0.5;
}
}
// firstSum doesn't change
double[] firstSum = new double[p.length];
for (int i = 0; i < p.length; i++) {
for (int j = i + 1; j < p.length; j++) {
firstSum[i] += n[i][j] * r[i][j];
firstSum[j] += n[i][j] * (1 - r[i][j]);
}
}
// Iterate until convergence
boolean changed;
do {
changed = false;
double[] secondSum = new double[p.length];
for (int i = 0; i < p.length; i++) {
for (int j = i + 1; j < p.length; j++) {
secondSum[i] += n[i][j] * u[i][j];
secondSum[j] += n[i][j] * (1 - u[i][j]);
}
}
for (int i = 0; i < p.length; i++) {
if ((firstSum[i] == 0) || (secondSum[i] == 0)) {
if (p[i] > 0) {
changed = true;
}
p[i] = 0;
} else {
double factor = firstSum[i] / secondSum[i];
double pOld = p[i];
p[i] *= factor;
if (Math.abs(pOld - p[i]) > 1.0e-3) {
changed = true;
}
}
}
Utils.normalize(p);
for (int i = 0; i < r.length; i++) {
for (int j = i + 1; j < r.length; j++) {
u[i][j] = p[i] / (p[i] + p[j]);
}
}
} while (changed);
return p;
}
/**
* Returns an array of votes for the given instance.
* @param inst the instance
* @return array of votex
* @exception Exception if something goes wrong
*/
public int[] obtainVotes(Instance inst) throws Exception {
// Filter instance
if (!m_checksTurnedOff) {
m_Missing.input(inst);
m_Missing.batchFinished();
inst = m_Missing.output();
}
if (!m_onlyNumeric) {
m_NominalToBinary.input(inst);
m_NominalToBinary.batchFinished();
inst = m_NominalToBinary.output();
}
if (m_Filter != null) {
m_Filter.input(inst);
m_Filter.batchFinished();
inst = m_Filter.output();
}
int[] votes = new int[inst.numClasses()];
for (int i = 0; i < inst.numClasses(); i++) {
for (int j = i + 1; j < inst.numClasses(); j++) {
double output = m_classifiers[i][j].SVMOutput(-1, inst);
if (output > 0) {
votes[j] += 1;
} else {
votes[i] += 1;
}
}
}
return votes;
}
/**
* Returns the weights in sparse format.
*/
public double [][][] sparseWeights() {
int numValues = m_classAttribute.numValues();
double [][][] sparseWeights = new double[numValues][numValues][];
for (int i = 0; i < numValues; i++) {
for (int j = i + 1; j < numValues; j++) {
sparseWeights[i][j] = m_classifiers[i][j].m_sparseWeights;
}
}
return sparseWeights;
}
/**
* Returns the indices in sparse format.
*/
public int [][][] sparseIndices() {
int numValues = m_classAttribute.numValues();
int [][][] sparseIndices = new int[numValues][numValues][];
for (int i = 0; i < numValues; i++) {
for (int j = i + 1; j < numValues; j++) {
sparseIndices[i][j] = m_classifiers[i][j].m_sparseIndices;
}
}
return sparseIndices;
}
/**
* Returns the bias of each binary SMO.
*/
public double [][] bias() {
int numValues = m_classAttribute.numValues();
double [][] bias = new double[numValues][numValues];
for (int i = 0; i < numValues; i++) {
for (int j = i + 1; j < numValues; j++) {
bias[i][j] = m_classifiers[i][j].m_b;
}
}
return bias;
}
/*
* Returns the number of values of the class attribute.
*/
public int numClassAttributeValues() {
return m_classAttribute.numValues();
}
/*
* Returns the names of the class attributes.
*/
public String [] classAttributeNames() {
int numValues = m_classAttribute.numValues();
String [] classAttributeNames = new String[numValues];
for (int i = 0; i < numValues; i++) {
classAttributeNames[i] = m_classAttribute.value(i);
}
return classAttributeNames;
}
/**
* Returns the attribute names.
*/
public String [][][] attributeNames() {
int numValues = m_classAttribute.numValues();
String [][][] attributeNames = new String[numValues][numValues][];
for (int i = 0; i < numValues; i++) {
for (int j = i + 1; j < numValues; j++) {
int numAttributes = m_classifiers[i][j].m_data.numAttributes();
String [] attrNames = new String[numAttributes];
for (int k = 0; k < numAttributes; k++) {
attrNames[k] = m_classifiers[i][j].m_data.attribute(k).name();
}
attributeNames[i][j] = attrNames;
}
}
return attributeNames;
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
public Enumeration listOptions() {
Vector newVector = new Vector(13);
newVector.addElement(new Option("\tThe complexity constant C. (default 1)",
"C", 1, "-C <double>"));
newVector.addElement(new Option("\tThe exponent for the "
+ "polynomial kernel. (default 1)",
"E", 1, "-E <double>"));
newVector.addElement(new Option("\tGamma for the RBF kernel. (default 0.01)",
"G", 1, "-G <double>"));
newVector.addElement(new Option("\tWhether to 0=normalize/1=standardize/2=neither. " +
"(default 0=normalize)",
"N", 1, "-N"));
newVector.addElement(new Option("\tFeature-space normalization (only for\n" +
"\tnon-linear polynomial kernels).",
"F", 0, "-F"));
newVector.addElement(new Option("\tUse lower-order terms (only for non-linear\n" +
"\tpolynomial kernels).",
"O", 0, "-O"));
newVector.addElement(new Option("\tUse RBF kernel. " +
"(default poly)",
"R", 0, "-R"));
newVector.addElement(new Option("\tThe size of the kernel cache. " +
"(default 250007, use 0 for full cache)",
"A", 1, "-A <int>"));
newVector.addElement(new Option("\tThe tolerance parameter. " +
"(default 1.0e-3)",
"L", 1, "-L <double>"));
newVector.addElement(new Option("\tThe epsilon for round-off error. " +
"(default 1.0e-12)",
"P", 1, "-P <double>"));
newVector.addElement(new Option("\tFit logistic models to SVM outputs. ",
"M", 0, "-M"));
newVector.addElement(new Option("\tThe number of folds for the internal\n" +
"\tcross-validation. " +
"(default -1, use training data)",
"V", 1, "-V <double>"));
newVector.addElement(new Option("\tThe random number seed. " +
"(default 1)",
"W", 1, "-W <double>"));
return newVector.elements();
}
/**
* Parses a given list of options. Valid options are:<p>
*
* -C num <br>
* The complexity constant C. (default 1)<p>
*
* -E num <br>
* The exponent for the polynomial kernel. (default 1) <p>
*
* -G num <br>
* Gamma for the RBF kernel. (default 0.01) <p>
*
* -N <0|1|2> <br>
* Whether to 0=normalize/1=standardize/2=neither. (default 0=normalize)<p>
*
* -F <br>
* Feature-space normalization (only for non-linear polynomial kernels). <p>
*
* -O <br>
* Use lower-order terms (only for non-linear polynomial kernels). <p>
*
* -R <br>
* Use RBF kernel (default poly). <p>
*
* -A num <br> Sets the size of the kernel cache. Should be a prime
* number. (default 250007, use 0 for full cache) <p>
*
* -L num <br>
* Sets the tolerance parameter. (default 1.0e-3)<p>
*
* -P num <br>
* Sets the epsilon for round-off error. (default 1.0e-12)<p>
*
* -M <br>
* Fit logistic models to SVM outputs.<p>
*
* -V num <br>
* Number of folds for cross-validation used to generate data
* for logistic models. (default -1, use training data)
*
* -W num <br>
* Random number seed. (default 1)
*
* @param options the list of options as an array of strings
* @exception Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception {
String complexityString = Utils.getOption('C', options);
if (complexityString.length() != 0) {
m_C = (new Double(complexityString)).doubleValue();
} else {
m_C = 1.0;
}
String exponentsString = Utils.getOption('E', options);
if (exponentsString.length() != 0) {
m_exponent = (new Double(exponentsString)).doubleValue();
} else {
m_exponent = 1.0;
}
String gammaString = Utils.getOption('G', options);
if (gammaString.length() != 0) {
m_gamma = (new Double(gammaString)).doubleValue();
} else {
m_gamma = 0.01;
}
String cacheString = Utils.getOption('A', options);
if (cacheString.length() != 0) {
m_cacheSize = Integer.parseInt(cacheString);
} else {
m_cacheSize = 250007;
}
String toleranceString = Utils.getOption('L', options);
if (toleranceString.length() != 0) {
m_tol = (new Double(toleranceString)).doubleValue();
} else {
m_tol = 1.0e-3;
}
String epsilonString = Utils.getOption('P', options);
if (epsilonString.length() != 0) {
m_eps = (new Double(epsilonString)).doubleValue();
} else {
m_eps = 1.0e-12;
}
m_useRBF = Utils.getFlag('R', options);
String nString = Utils.getOption('N', options);
if (nString.length() != 0) {
setFilterType(new SelectedTag(Integer.parseInt(nString), TAGS_FILTER));
} else {
setFilterType(new SelectedTag(FILTER_NORMALIZE, TAGS_FILTER));
}
m_featureSpaceNormalization = Utils.getFlag('F', options);
if ((m_useRBF) && (m_featureSpaceNormalization)) {
throw new Exception("RBF machine doesn't require feature-space normalization.");
}
if ((m_exponent == 1.0) && (m_featureSpaceNormalization)) {
throw new Exception("Can't use feature-space normalization with linear machine.");
}
m_lowerOrder = Utils.getFlag('O', options);
if ((m_useRBF) && (m_lowerOrder)) {
throw new Exception("Can't use lower-order terms with RBF machine.");
}
if ((m_exponent == 1.0) && (m_lowerOrder)) {
throw new Exception("Can't use lower-order terms with linear machine.");
}
m_fitLogisticModels = Utils.getFlag('M', options);
String foldsString = Utils.getOption('V', options);
if (foldsString.length() != 0) {
m_numFolds = Integer.parseInt(foldsString);
} else {
m_numFolds = -1;
}
String randomSeedString = Utils.getOption('W', options);
if (randomSeedString.length() != 0) {
m_randomSeed = Integer.parseInt(randomSeedString);
} else {
m_randomSeed = 1;
}
}
/**
* Gets the current settings of the classifier.
*
* @return an array of strings suitable for passing to setOptions
*/
public String [] getOptions() {
String [] options = new String [21];
int current = 0;
options[current++] = "-C"; options[current++] = "" + m_C;
options[current++] = "-E"; options[current++] = "" + m_exponent;
options[current++] = "-G"; options[current++] = "" + m_gamma;
options[current++] = "-A"; options[current++] = "" + m_cacheSize;
options[current++] = "-L"; options[current++] = "" + m_tol;
options[current++] = "-P"; options[current++] = "" + m_eps;
options[current++] = "-N"; options[current++] = "" + m_filterType;
if (m_featureSpaceNormalization) {
options[current++] = "-F";
}
if (m_lowerOrder) {
options[current++] = "-O";
}
if (m_useRBF) {
options[current++] = "-R";
}
if (m_fitLogisticModels) {
options[current++] = "-M";
}
options[current++] = "-V"; options[current++] = "" + m_numFolds;
options[current++] = "-W"; options[current++] = "" + m_randomSeed;
while (current < options.length) {
options[current++] = "";
}
return options;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String exponentTipText() {
return "The exponent for the polynomial kernel.";
}
/**
* Get the value of exponent.
*
* @return Value of exponent.
*/
public double getExponent() {
return m_exponent;
}
/**
* Set the value of exponent. If linear kernel
* is used, rescaling and lower-order terms are
* turned off.
*
* @param v Value to assign to exponent.
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -