📄 ibk.java
字号:
*
* @return Value of WindowSize.
*/
public int getWindowSize() {
return m_WindowSize;
}
/**
* Sets the maximum number of instances allowed in the training
* pool. The addition of new instances above this value will result
* in old instances being removed. A value of 0 signifies no limit
* to the number of training instances.
*
* @param newWindowSize Value to assign to WindowSize.
*/
public void setWindowSize(int newWindowSize) {
m_WindowSize = newWindowSize;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String distanceWeightingTipText() {
return "Gets the distance weighting method used.";
}
/**
* Gets the distance weighting method used. Will be one of
* WEIGHT_NONE, WEIGHT_INVERSE, or WEIGHT_SIMILARITY
*
* @return the distance weighting method used.
*/
public SelectedTag getDistanceWeighting() {
return new SelectedTag(m_DistanceWeighting, TAGS_WEIGHTING);
}
/**
* Sets the distance weighting method used. Values other than
* WEIGHT_NONE, WEIGHT_INVERSE, or WEIGHT_SIMILARITY will be ignored.
*
* @param newDistanceWeighting the distance weighting method to use
*/
public void setDistanceWeighting(SelectedTag newMethod) {
if (newMethod.getTags() == TAGS_WEIGHTING) {
m_DistanceWeighting = newMethod.getSelectedTag().getID();
}
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String meanSquaredTipText() {
return "Whether the mean squared error is used rather than mean "
+ "absolute error when doing cross-validation for regression problems.";
}
/**
* Gets whether the mean squared error is used rather than mean
* absolute error when doing cross-validation.
*
* @return true if so.
*/
public boolean getMeanSquared() {
return m_MeanSquared;
}
/**
* Sets whether the mean squared error is used rather than mean
* absolute error when doing cross-validation.
*
* @param newMeanSquared true if so.
*/
public void setMeanSquared(boolean newMeanSquared) {
m_MeanSquared = newMeanSquared;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String crossValidateTipText() {
return "Whether hold-one-out cross-validation will be used " +
"to select the best k value.";
}
/**
* Gets whether hold-one-out cross-validation will be used
* to select the best k value
*
* @return true if cross-validation will be used.
*/
public boolean getCrossValidate() {
return m_CrossValidate;
}
/**
* Sets whether hold-one-out cross-validation will be used
* to select the best k value
*
* @param newCrossValidate true if cross-validation should be used.
*/
public void setCrossValidate(boolean newCrossValidate) {
m_CrossValidate = newCrossValidate;
}
/**
* Get the number of training instances the classifier is currently using
*/
public int getNumTraining() {
return m_Train.numInstances();
}
/**
* Get an attributes minimum observed value
*/
public double getAttributeMin(int index) throws Exception {
if (m_Min == null) {
throw new Exception("Minimum value for attribute not available!");
}
return m_Min[index];
}
/**
* Get an attributes maximum observed value
*/
public double getAttributeMax(int index) throws Exception {
if (m_Max == null) {
throw new Exception("Maximum value for attribute not available!");
}
return m_Max[index];
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String noNormalizationTipText() {
return "Whether attribute normalization is turned off.";
}
/**
* Gets whether normalization is turned off.
* @return Value of DontNormalize.
*/
public boolean getNoNormalization() {
return m_DontNormalize;
}
/**
* Set whether normalization is turned off.
* @param v Value to assign to DontNormalize.
*/
public void setNoNormalization(boolean v) {
m_DontNormalize = v;
}
/**
* Generates the classifier.
*
* @param instances set of instances serving as training data
* @exception Exception if the classifier has not been generated successfully
*/
public void buildClassifier(Instances instances) throws Exception {
if (instances.classIndex() < 0) {
throw new Exception ("No class attribute assigned to instances");
}
if (instances.checkForStringAttributes()) {
throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
}
try {
m_NumClasses = instances.numClasses();
m_ClassType = instances.classAttribute().type();
} catch (Exception ex) {
throw new Error("This should never be reached");
}
// Throw away training instances with missing class
m_Train = new Instances(instances, 0, instances.numInstances());
m_Train.deleteWithMissingClass();
// Throw away initial instances until within the specified window size
if ((m_WindowSize > 0) && (instances.numInstances() > m_WindowSize)) {
m_Train = new Instances(m_Train,
m_Train.numInstances()-m_WindowSize,
m_WindowSize);
}
// Calculate the minimum and maximum values
if (m_DontNormalize) {
m_Min = null; m_Max = null;
} else {
m_Min = new double [m_Train.numAttributes()];
m_Max = new double [m_Train.numAttributes()];
for (int i = 0; i < m_Train.numAttributes(); i++) {
m_Min[i] = m_Max[i] = Double.NaN;
}
Enumeration em = m_Train.emerateInstances();
while (em.hasMoreElements()) {
updateMinMax((Instance) em.nextElement());
}
}
// Compute the number of attributes that contribute
// to each prediction
m_NumAttributesUsed = 0.0;
for (int i = 0; i < m_Train.numAttributes(); i++) {
if ((i != m_Train.classIndex()) &&
(m_Train.attribute(i).isNominal() ||
m_Train.attribute(i).isNumeric())) {
m_NumAttributesUsed += 1.0;
}
}
// Invalidate any currently cross-validation selected k
m_kNNValid = false;
}
/**
* Adds the supplied instance to the training set
*
* @param instance the instance to add
* @exception Exception if instance could not be incorporated
* successfully
*/
public void updateClassifier(Instance instance) throws Exception {
if (m_Train.equalHeaders(instance.dataset()) == false) {
throw new Exception("Incompatible instance types");
}
if (instance.classIsMissing()) {
return;
}
if (!m_DontNormalize) {
updateMinMax(instance);
}
m_Train.add(instance);
m_kNNValid = false;
if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) {
while (m_Train.numInstances() > m_WindowSize) {
m_Train.delete(0);
}
}
}
/**
* Calculates the class membership probabilities for the given test instance.
*
* @param instance the instance to be classified
* @return predicted class probability distribution
* @exception Exception if an error occurred during the prediction
*/
public double [] distributionForInstance(Instance instance) throws Exception{
if (m_Train.numInstances() == 0) {
throw new Exception("No training instances!");
}
if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) {
m_kNNValid = false;
while (m_Train.numInstances() > m_WindowSize) {
m_Train.delete(0);
}
}
// Select k by cross validation
if (!m_kNNValid && (m_CrossValidate) && (m_kNNUpper >= 1)) {
crossValidate();
}
if (!m_DontNormalize) {
updateMinMax(instance);
}
NeighborList neighborlist = findNeighbors(instance);
return makeDistribution(neighborlist);
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
public Enumeration listOptions() {
Vector newVector = new Vector(8);
newVector.addElement(new Option(
"\tWeight neighbours by the inverse of their distance\n"
+"\t(use when k > 1)",
"I", 0, "-I"));
newVector.addElement(new Option(
"\tWeight neighbours by 1 - their distance\n"
+"\t(use when k > 1)",
"F", 0, "-F"));
newVector.addElement(new Option(
"\tNumber of nearest neighbours (k) used in classification.\n"
+"\t(Default = 1)",
"K", 1,"-K <number of neighbors>"));
newVector.addElement(new Option(
"\tMinimise mean squared error rather than mean absolute\n"
+"\terror when using -X option with numeric prediction.",
"E", 0,"-E"));
newVector.addElement(new Option(
"\tMaximum number of training instances maintained.\n"
+"\tTraining instances are dropped FIFO. (Default = no window)",
"W", 1,"-W <window size>"));
newVector.addElement(new Option(
"\tSelect the number of nearest neighbours between 1\n"
+"\tand the k value specified using hold-one-out evaluation\n"
+"\ton the training data (use when k > 1)",
"X", 0,"-X"));
newVector.addElement(new Option(
"\tDon't normalize the data.\n",
"N", 0, "-N"));
return newVector.elements();
}
/**
* Parses a given list of options. Valid options are:<p>
*
* -K num <br>
* Set the number of nearest neighbors to use in prediction
* (default 1) <p>
*
* -W num <br>
* Set a fixed window size for incremental train/testing. As
* new training instances are added, oldest instances are removed
* to maintain the number of training instances at this size.
* (default no window) <p>
*
* -I <br>
* Neighbors will be weighted by the inverse of their distance
* when voting. (default equal weighting) <p>
*
* -F <br>
* Neighbors will be weighted by their similarity when voting.
* (default equal weighting) <p>
*
* -X <br>
* Select the number of neighbors to use by hold-one-out cross
* validation, with an upper limit given by the -K option. <p>
*
* -E <br>
* When k is selected by cross-validation for numeric class attributes,
* minimize mean-squared error. (default mean absolute error) <p>
*
* @param options the list of options as an array of strings
* @exception Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception {
String knnString = Utils.getOption('K', options);
if (knnString.length() != 0) {
setKNN(Integer.parseInt(knnString));
} else {
setKNN(1);
}
String windowString = Utils.getOption('W', options);
if (windowString.length() != 0) {
setWindowSize(Integer.parseInt(windowString));
} else {
setWindowSize(0);
}
if (Utils.getFlag('I', options)) {
setDistanceWeighting(new SelectedTag(WEIGHT_INVERSE, TAGS_WEIGHTING));
} else if (Utils.getFlag('F', options)) {
setDistanceWeighting(new SelectedTag(WEIGHT_SIMILARITY, TAGS_WEIGHTING));
} else {
setDistanceWeighting(new SelectedTag(WEIGHT_NONE, TAGS_WEIGHTING));
}
setCrossValidate(Utils.getFlag('X', options));
setMeanSquared(Utils.getFlag('E', options));
setNoNormalization(Utils.getFlag('N', options));
Utils.checkForRemainingOptions(options);
}
/**
* Gets the current settings of IBk.
*
* @return an array of strings suitable for passing to setOptions()
*/
public String [] getOptions() {
String [] options = new String [11];
int current = 0;
options[current++] = "-K"; options[current++] = "" + getKNN();
options[current++] = "-W"; options[current++] = "" + m_WindowSize;
if (getCrossValidate()) {
options[current++] = "-X";
}
if (getMeanSquared()) {
options[current++] = "-E";
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -