📄 svmattributeeval.java
字号:
*/
public String toleranceParameterTipText() {
return "T tolerance parameter to pass to the SVM";
}
/**
* Returns a tip text for this property suitable for display in the
* GUI
*
* @return tip text string describing this property
*/
public String complexityParameterTipText() {
return "C complexity parameter to pass to the SVM";
}
/**
* Returns a tip text for this property suitable for display in the
* GUI
*
* @return tip text string describing this property
*/
public String filterTypeTipText() {
return "filtering used by the SVM";
}
//________________________________________________________________________
/**
* Set the constant rate of attribute elimination per iteration
*
* @param X the constant rate of attribute elimination per iteration
*/
public void setAttsToEliminatePerIteration(int cRate) {
m_numToEliminate = cRate;
}
/**
* Get the constant rate of attribute elimination per iteration
*
* @return the constant rate of attribute elimination per iteration
*/
public int getAttsToEliminatePerIteration() {
return m_numToEliminate;
}
/**
* Set the percentage of attributes to eliminate per iteration
*
* @param Y percent of attributes to eliminate per iteration
*/
public void setPercentToEliminatePerIteration(int pRate) {
m_percentToEliminate = pRate;
}
/**
* Get the percentage rate of attribute elimination per iteration
*
* @return the percentage rate of attribute elimination per iteration
*/
public int getPercentToEliminatePerIteration() {
return m_percentToEliminate;
}
/**
* Set the threshold below which percentage elimination reverts to
* constant elimination.
*
* @param thresh percent of attributes to eliminate per iteration
*/
public void setPercentThreshold(int pThresh) {
m_percentThreshold = pThresh;
}
/**
* Get the threshold below which percentage elimination reverts to
* constant elimination.
*
* @return the threshold below which percentage elimination stops
*/
public int getPercentThreshold() {
return m_percentThreshold;
}
/**
* Set the value of P for SMO
*
* @param svmP the value of P
*/
public void setEpsilonParameter(double svmP) {
m_smoPParameter = svmP;
}
/**
* Get the value of P used with SMO
*
* @return the value of P
*/
public double getEpsilonParameter() {
return m_smoPParameter;
}
/**
* Set the value of T for SMO
*
* @param svmC the value of T
*/
public void setToleranceParameter(double svmT) {
m_smoTParameter = svmT;
}
/**
* Get the value of T used with SMO
*
* @return the value of T
*/
public double getToleranceParameter() {
return m_smoTParameter;
}
/**
* Set the value of C for SMO
*
* @param svmC the value of C
*/
public void setComplexityParameter(double svmC) {
m_smoCParameter = svmC;
}
/**
* Get the value of C used with SMO
*
* @return the value of C
*/
public double getComplexityParameter() {
return m_smoCParameter;
}
/**
* The filtering mode to pass to SMO
*
* @param newType the new filtering mode
*/
public void setFilterType(SelectedTag newType) {
if (newType.getTags() == SMO.TAGS_FILTER) {
m_smoFilterType = newType.getSelectedTag().getID();
}
}
/**
* Get the filtering mode passed to SMO
*
* @return the filtering mode
*/
public SelectedTag getFilterType() {
return new SelectedTag(m_smoFilterType, SMO.TAGS_FILTER);
}
//________________________________________________________________________
/**
* Initializes the evaluator.
*
* @param data set of instances serving as training data
* @exception Exception if the evaluator has not been
* generated successfully
*/
public void buildEvaluator(Instances data) throws Exception {
if (data.checkForStringAttributes()) {
throw new Exception("Can't handle string attributes!");
}
if (!data.classAttribute().isNominal()) {
throw new Exception("Class must be nominal!");
}
for (int i = 0; i < data.numAttributes(); i++) {
if (data.attribute(i).isNominal() && (data.attribute(i).numValues() != 2) && !(i==data.classIndex()) ) {
throw new Exception("All nominal attributes must be binary!");
}
}
System.out.println("Class attribute: " + data.attribute(data.classIndex()).name());
// Check settings
m_numToEliminate = (m_numToEliminate > 1) ? m_numToEliminate : 1;
m_percentToEliminate = (m_percentToEliminate < 100) ? m_percentToEliminate : 100;
m_percentToEliminate = (m_percentToEliminate > 0) ? m_percentToEliminate : 0;
m_percentThreshold = (m_percentThreshold < data.numAttributes()) ? m_percentThreshold : data.numAttributes() - 1;
m_percentThreshold = (m_percentThreshold > 0) ? m_percentThreshold : 0;
// Get ranked attributes for each class seperately, one-vs-all
int[][] attScoresByClass;
int numAttr = data.numAttributes() - 1;
if(data.numClasses()>2) {
attScoresByClass = new int[data.numClasses()][numAttr];
for (int i = 0; i < data.numClasses(); i++) {
attScoresByClass[i] = rankBySVM(i, data);
}
}
else {
attScoresByClass = new int[1][numAttr];
attScoresByClass[0] = rankBySVM(0, data);
}
// Cycle through class-specific ranked lists, poping top one off for each class
// and adding it to the overall ranked attribute list if it's not there already
ArrayList ordered = new ArrayList(numAttr);
for (int i = 0; i < numAttr; i++) {
for (int j = 0; j < (data.numClasses()>2 ? data.numClasses() : 1); j++) {
Integer rank = new Integer(attScoresByClass[j][i]);
if (!ordered.contains(rank))
ordered.add(rank);
}
}
m_attScores = new double[data.numAttributes()];
Iterator listIt = ordered.iterator();
for (double i = (double) numAttr; listIt.hasNext(); i = i - 1.0) {
m_attScores[((Integer) listIt.next()).intValue()] = i;
}
}
/**
* Get SVM-ranked attribute indexes (best to worst) selected for
* the class attribute indexed by classInd (one-vs-all).
*/
private int[] rankBySVM(int classInd, Instances data) {
// Holds a mapping into the original array of attribute indices
int[] origIndices = new int[data.numAttributes()];
for (int i = 0; i < origIndices.length; i++)
origIndices[i] = i;
// Count down of number of attributes remaining
int numAttrLeft = data.numAttributes()-1;
// Ranked attribute indices for this class, one vs.all (highest->lowest)
int[] attRanks = new int[numAttrLeft];
try {
MakeIndicator filter = new MakeIndicator();
filter.setAttributeIndex("" + (data.classIndex() + 1));
filter.setNumeric(false);
filter.setValueIndex(classInd);
filter.setInputFormat(data);
Instances trainCopy = Filter.useFilter(data, filter);
double pctToElim = ((double) m_percentToEliminate) / 100.0;
while (numAttrLeft > 0) {
int numToElim;
if (pctToElim > 0) {
numToElim = (int) (trainCopy.numAttributes() * pctToElim);
numToElim = (numToElim > 1) ? numToElim : 1;
if (numAttrLeft - numToElim <= m_percentThreshold) {
pctToElim = 0;
numToElim = numAttrLeft - m_percentThreshold;
}
} else {
numToElim = (numAttrLeft >= m_numToEliminate) ? m_numToEliminate : numAttrLeft;
}
// Build the linear SVM with default parameters
SMO smo = new SMO();
// SMO seems to get stuck if data not normalised when few attributes remain
// smo.setNormalizeData(numAttrLeft < 40);
smo.setFilterType(new SelectedTag(m_smoFilterType, SMO.TAGS_FILTER));
smo.setEpsilon(m_smoPParameter);
smo.setToleranceParameter(m_smoTParameter);
smo.setC(m_smoCParameter);
smo.buildClassifier(trainCopy);
// Find the attribute with maximum weight^2
double[] weightsSparse = smo.sparseWeights()[0][1];
int[] indicesSparse = smo.sparseIndices()[0][1];
double[] weights = new double[trainCopy.numAttributes()];
for (int j = 0; j < weightsSparse.length; j++) {
weights[indicesSparse[j]] = weightsSparse[j] * weightsSparse[j];
}
weights[trainCopy.classIndex()] = Double.MAX_VALUE;
int minWeightIndex;
int[] featArray = new int[numToElim];
boolean[] eliminated = new boolean[origIndices.length];
for (int j = 0; j < numToElim; j++) {
minWeightIndex = Utils.minIndex(weights);
attRanks[--numAttrLeft] = origIndices[minWeightIndex];
featArray[j] = minWeightIndex;
eliminated[minWeightIndex] = true;
weights[minWeightIndex] = Double.MAX_VALUE;
}
// Delete the worst attributes.
weka.filters.unsupervised.attribute.Remove delTransform =
new weka.filters.unsupervised.attribute.Remove();
delTransform.setInvertSelection(false);
delTransform.setAttributeIndicesArray(featArray);
delTransform.setInputFormat(trainCopy);
trainCopy = Filter.useFilter(trainCopy, delTransform);
// Update the array of remaining attribute indices
int[] temp = new int[origIndices.length - numToElim];
int k = 0;
for (int j = 0; j < origIndices.length; j++) {
if (!eliminated[j]) {
temp[k++] = origIndices[j];
}
}
origIndices = temp;
}
// Carefully handle all exceptions
} catch (Exception e) {
e.printStackTrace();
}
return attRanks;
}
/**
* Resets options to defaults.
*/
protected void resetOptions() {
m_attScores = null;
}
/**
* Evaluates an attribute by returning the rank of the square of its coefficient in a
* linear support vector machine.
*
*@param attribute the index of the attribute to be evaluated
* @exception Exception if the attribute could not be evaluated
*/
public double evaluateAttribute(int attribute) throws Exception {
return m_attScores[attribute];
}
/**
* Return a description of the evaluator
* @return description as a string
*/
public String toString() {
StringBuffer text = new StringBuffer();
if (m_attScores == null) {
text.append("\tSVM feature evaluator has not been built yet");
} else {
text.append("\tSVM feature evaluator");
}
text.append("\n");
return text.toString();
}
/**
* Main method for testing this class.
*
* @param args the options
*/
public static void main(String[] args) {
try {
System.out.println(AttributeSelection.SelectAttributes(new SVMAttributeEval(), args));
} catch (Exception e) {
e.printStackTrace();
System.out.println(e.getMessage());
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -