📄 ibk.java
字号:
options[current++] = "-I";
} else if (m_DistanceWeighting == WEIGHT_SIMILARITY) {
options[current++] = "-F";
}
if (m_DontNormalize) {
options[current++] = "-N";
}
while (current < options.length) {
options[current++] = "";
}
return options;
}
/**
* Returns a description of this classifier.
*
* @return a description of this classifier as a string.
*/
public String toString() {
if (m_Train == null) {
return "IBk: No model built yet.";
}
if (!m_kNNValid && m_CrossValidate) {
crossValidate();
}
String result = "IB1 instance-based classifier\n" +
"using " + m_kNN;
switch (m_DistanceWeighting) {
case WEIGHT_INVERSE:
result += " inverse-distance-weighted";
break;
case WEIGHT_SIMILARITY:
result += " similarity-weighted";
break;
}
result += " nearest neighbour(s) for classification\n";
if (m_WindowSize != 0) {
result += "using a maximum of "
+ m_WindowSize + " (windowed) training instances\n";
}
return result;
}
/**
* Initialise scheme variables.
*/
protected void init() {
setKNN(1);
m_WindowSize = 0;
m_DistanceWeighting = WEIGHT_NONE;
m_CrossValidate = false;
m_MeanSquared = false;
m_DontNormalize = false;
}
/**
* Calculates the distance between two instances
*
* @param test the first instance
* @param train the second instance
* @return the distance between the two given instances, between 0 and 1
*/
protected double distance(Instance first, Instance second) {
double distance = 0;
int firstI, secondI;
for (int p1 = 0, p2 = 0;
p1 < first.numValues() || p2 < second.numValues();) {
if (p1 >= first.numValues()) {
firstI = m_Train.numAttributes();
} else {
firstI = first.index(p1);
}
if (p2 >= second.numValues()) {
secondI = m_Train.numAttributes();
} else {
secondI = second.index(p2);
}
if (firstI == m_Train.classIndex()) {
p1++; continue;
}
if (secondI == m_Train.classIndex()) {
p2++; continue;
}
double diff;
if (firstI == secondI) {
diff = difference(firstI,
first.valueSparse(p1),
second.valueSparse(p2));
p1++; p2++;
} else if (firstI > secondI) {
diff = difference(secondI,
0, second.valueSparse(p2));
p2++;
} else {
diff = difference(firstI,
first.valueSparse(p1), 0);
p1++;
}
distance += diff * diff;
}
return Math.sqrt(distance / m_NumAttributesUsed);
}
/**
* Computes the difference between two given attribute
* values.
*/
protected double difference(int index, double val1, double val2) {
switch (m_Train.attribute(index).type()) {
case Attribute.NOMINAL:
// If attribute is nominal
if (Instance.isMissingValue(val1) ||
Instance.isMissingValue(val2) ||
((int)val1 != (int)val2)) {
return 1;
} else {
return 0;
}
case Attribute.NUMERIC:
// If attribute is numeric
if (Instance.isMissingValue(val1) ||
Instance.isMissingValue(val2)) {
if (Instance.isMissingValue(val1) &&
Instance.isMissingValue(val2)) {
return 1;
} else {
double diff;
if (Instance.isMissingValue(val2)) {
diff = norm(val1, index);
} else {
diff = norm(val2, index);
}
if (diff < 0.5) {
diff = 1.0 - diff;
}
return diff;
}
} else {
return norm(val1, index) - norm(val2, index);
}
default:
return 0;
}
}
/**
* Normalizes a given value of a numeric attribute.
*
* @param x the value to be normalized
* @param i the attribute's index
*/
protected double norm(double x, int i) {
if (m_DontNormalize) {
return x;
} else if (Double.isNaN(m_Min[i]) || Utils.eq(m_Max[i],m_Min[i])) {
return 0;
} else {
return (x - m_Min[i]) / (m_Max[i] - m_Min[i]);
}
}
/**
* Updates the minimum and maximum values for all the attributes
* based on a new instance.
*
* @param instance the new instance
*/
protected void updateMinMax(Instance instance) {
for (int j = 0;j < m_Train.numAttributes(); j++) {
if (!instance.isMissing(j)) {
if (Double.isNaN(m_Min[j])) {
m_Min[j] = instance.value(j);
m_Max[j] = instance.value(j);
} else {
if (instance.value(j) < m_Min[j]) {
m_Min[j] = instance.value(j);
} else {
if (instance.value(j) > m_Max[j]) {
m_Max[j] = instance.value(j);
}
}
}
}
}
}
/**
* Build the list of nearest k neighbors to the given test instance.
*
* @param instance the instance to search for neighbours of
* @return a list of neighbors
*/
protected NeighborList findNeighbors(Instance instance) {
double distance;
NeighborList neighborlist = new NeighborList(m_kNN);
Enumeration enu = m_Train.enumerateInstances();
int i = 0;
while (enu.hasMoreElements()) {
Instance trainInstance = (Instance) enu.nextElement();
if (instance != trainInstance) { // for hold-one-out cross-validation
distance = distance(instance, trainInstance);
if (neighborlist.isEmpty() || (i < m_kNN) ||
(distance <= neighborlist.m_Last.m_Distance)) {
neighborlist.insertSorted(distance, trainInstance);
}
i++;
}
}
return neighborlist;
}
/**
* Turn the list of nearest neighbors into a probability distribution
*
* @param neighborlist the list of nearest neighboring instances
* @return the probability distribution
*/
protected double [] makeDistribution(NeighborList neighborlist)
throws Exception {
double total = 0, weight;
double [] distribution = new double [m_NumClasses];
// Set up a correction to the estimator
if (m_ClassType == Attribute.NOMINAL) {
for(int i = 0; i < m_NumClasses; i++) {
distribution[i] = 1.0 / Math.max(1,m_Train.numInstances());
}
total = (double)m_NumClasses / Math.max(1,m_Train.numInstances());
}
if (!neighborlist.isEmpty()) {
// Collect class counts
NeighborNode current = neighborlist.m_First;
while (current != null) {
switch (m_DistanceWeighting) {
case WEIGHT_INVERSE:
weight = 1.0 / (current.m_Distance + 0.001); // to avoid div by zero
break;
case WEIGHT_SIMILARITY:
weight = 1.0 - current.m_Distance;
break;
default: // WEIGHT_NONE:
weight = 1.0;
break;
}
weight *= current.m_Instance.weight();
try {
switch (m_ClassType) {
case Attribute.NOMINAL:
distribution[(int)current.m_Instance.classValue()] += weight;
break;
case Attribute.NUMERIC:
distribution[0] += current.m_Instance.classValue() * weight;
break;
}
} catch (Exception ex) {
throw new Error("Data has no class attribute!");
}
total += weight;
current = current.m_Next;
}
}
// Normalise distribution
if (total > 0) {
Utils.normalize(distribution, total);
}
return distribution;
}
/**
* Select the best value for k by hold-one-out cross-validation.
* If the class attribute is nominal, classification error is
* minimised. If the class attribute is numeric, mean absolute
* error is minimised
*/
protected void crossValidate() {
try {
double [] performanceStats = new double [m_kNNUpper];
double [] performanceStatsSq = new double [m_kNNUpper];
for(int i = 0; i < m_kNNUpper; i++) {
performanceStats[i] = 0;
performanceStatsSq[i] = 0;
}
m_kNN = m_kNNUpper;
Instance instance;
NeighborList neighborlist;
for(int i = 0; i < m_Train.numInstances(); i++) {
if (m_Debug && (i % 50 == 0)) {
System.err.print("Cross validating "
+ i + "/" + m_Train.numInstances() + "\r");
}
instance = m_Train.instance(i);
neighborlist = findNeighbors(instance);
for(int j = m_kNNUpper - 1; j >= 0; j--) {
// Update the performance stats
double [] distribution = makeDistribution(neighborlist);
double thisPrediction = Utils.maxIndex(distribution);
if (m_Train.classAttribute().isNumeric()) {
thisPrediction = distribution[0];
double err = thisPrediction - instance.classValue();
performanceStatsSq[j] += err * err; // Squared error
performanceStats[j] += Math.abs(err); // Absolute error
} else {
if (thisPrediction != instance.classValue()) {
performanceStats[j] ++; // Classification error
}
}
if (j >= 1) {
neighborlist.pruneToK(j);
}
}
}
// Display the results of the cross-validation
for(int i = 0; i < m_kNNUpper; i++) {
if (m_Debug) {
System.err.print("Hold-one-out performance of " + (i + 1)
+ " neighbors " );
}
if (m_Train.classAttribute().isNumeric()) {
if (m_Debug) {
if (m_MeanSquared) {
System.err.println("(RMSE) = "
+ Math.sqrt(performanceStatsSq[i]
/ m_Train.numInstances()));
} else {
System.err.println("(MAE) = "
+ performanceStats[i]
/ m_Train.numInstances());
}
}
} else {
if (m_Debug) {
System.err.println("(%ERR) = "
+ 100.0 * performanceStats[i]
/ m_Train.numInstances());
}
}
}
// Check through the performance stats and select the best
// k value (or the lowest k if more than one best)
double [] searchStats = performanceStats;
if (m_Train.classAttribute().isNumeric() && m_MeanSquared) {
searchStats = performanceStatsSq;
}
double bestPerformance = Double.NaN;
int bestK = 1;
for(int i = 0; i < m_kNNUpper; i++) {
if (Double.isNaN(bestPerformance)
|| (bestPerformance > searchStats[i])) {
bestPerformance = searchStats[i];
bestK = i + 1;
}
}
m_kNN = bestK;
if (m_Debug) {
System.err.println("Selected k = " + bestK);
}
m_kNNValid = true;
} catch (Exception ex) {
throw new Error("Couldn't optimize by cross-validation: "
+ex.getMessage());
}
}
/**
* Main method for testing this class.
*
* @param argv should contain command line options (see setOptions)
*/
public static void main(String [] argv) {
try {
System.out.println(Evaluation.evaluateModel(new IBk(), argv));
} catch (Exception e) {
e.printStackTrace();
System.err.println(e.getMessage());
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -