📄 lwl.java
字号:
if (instances.checkForStringAttributes()) {
throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
}
// Throw away training instances with missing class
m_Train = new Instances(instances, 0, instances.numInstances());
m_Train.deleteWithMissingClass();
// Calculate the minimum and maximum values
m_Min = new double [m_Train.numAttributes()];
m_Max = new double [m_Train.numAttributes()];
for (int i = 0; i < m_Train.numAttributes(); i++) {
m_Min[i] = m_Max[i] = Double.NaN;
}
for (int i = 0; i < m_Train.numInstances(); i++) {
updateMinMax(m_Train.instance(i));
}
}
/**
* Adds the supplied instance to the training set
*
* @param instance the instance to add
* @exception Exception if instance could not be incorporated
* successfully
*/
public void updateClassifier(Instance instance) throws Exception {
if (m_Train.equalHeaders(instance.dataset()) == false) {
throw new Exception("Incompatible instance types");
}
if (!instance.classIsMissing()) {
updateMinMax(instance);
m_Train.add(instance);
}
}
/**
* Calculates the class membership probabilities for the given test instance.
*
* @param instance the instance to be classified
* @return preedicted class probability distribution
* @exception Exception if distribution can't be computed successfully
*/
public double[] distributionForInstance(Instance instance) throws Exception {
if (m_Train.numInstances() == 0) {
throw new Exception("No training instances!");
}
updateMinMax(instance);
// Get the distances to each training instance
double [] distance = new double [m_Train.numInstances()];
for (int i = 0; i < m_Train.numInstances(); i++) {
distance[i] = distance(instance, m_Train.instance(i));
}
int [] sortKey = Utils.sort(distance);
if (m_Debug) {
System.out.println("Instance Distances");
for (int i = 0; i < distance.length; i++) {
System.out.println("" + distance[sortKey[i]]);
}
}
// Determine the bandwidth
int k = sortKey.length - 1;
if (!m_UseAllK && (m_kNN < k)) {
k = m_kNN;
}
double bandwidth = distance[sortKey[k]];
// Check for bandwidth zero
if (bandwidth <= 0) {
for (int i = k + 1; i < sortKey.length; i++) {
if (distance[sortKey[i]] > bandwidth) {
bandwidth = distance[sortKey[i]];
break;
}
}
if (bandwidth <= 0) {
throw new Exception("All training instances coincide with test instance!");
}
}
// Rescale the distances by the bandwidth
for (int i = 0; i < distance.length; i++) {
distance[i] = distance[i] / bandwidth;
}
// Pass the distances through a weighting kernel
for (int i = 0; i < distance.length; i++) {
switch (m_WeightKernel) {
case LINEAR:
distance[i] = Math.max(1.0001 - distance[i], 0);
break;
case INVERSE:
distance[i] = 1.0 / (1.0 + distance[i]);
break;
case GAUSS:
distance[i] = Math.exp(-distance[i] * distance[i]);
break;
}
}
if (m_Debug) {
System.out.println("Instance Weights");
for (int i = 0; i < distance.length; i++) {
System.out.println("" + distance[i]);
}
}
// Set the weights on a copy of the training data
Instances weightedTrain = new Instances(m_Train, 0);
double sumOfWeights = 0, newSumOfWeights = 0;
for (int i = 0; i < distance.length; i++) {
double weight = distance[sortKey[i]];
if (weight < 1e-20) {
break;
}
Instance newInst = (Instance) m_Train.instance(sortKey[i]).copy();
sumOfWeights += newInst.weight();
newSumOfWeights += newInst.weight() * weight;
newInst.setWeight(newInst.weight() * weight);
weightedTrain.add(newInst);
}
if (m_Debug) {
System.out.println("Kept " + weightedTrain.numInstances() + " out of "
+ m_Train.numInstances() + " instances");
}
// Rescale weights
for (int i = 0; i < weightedTrain.numInstances(); i++) {
Instance newInst = weightedTrain.instance(i);
newInst.setWeight(newInst.weight() * sumOfWeights / newSumOfWeights);
}
// Create a weighted classifier
m_Classifier.buildClassifier(weightedTrain);
if (m_Debug) {
System.out.println("Classifying test instance: " + instance);
System.out.println("Built base classifier:\n"
+ m_Classifier.toString());
}
// Return the classifier's predictions
return m_Classifier.distributionForInstance(instance);
}
/**
* Returns a description of this classifier.
*
* @return a description of this classifier as a string.
*/
public String toString() {
if (m_Train == null) {
return "Locally weighted learning: No model built yet.";
}
String result = "Locally weighted learning\n"
+ "===========================\n";
result += "Using classifier: " + m_Classifier.getClass().getName() + "\n";
switch (m_WeightKernel) {
case LINEAR:
result += "Using linear weighting kernels\n";
break;
case INVERSE:
result += "Using inverse-distance weighting kernels\n";
break;
case GAUSS:
result += "Using gaussian weighting kernels\n";
break;
}
result += "Using " + (m_UseAllK ? "all" : "" + m_kNN) + " neighbours";
return result;
}
/**
* Calculates the distance between two instances
*
* @param test the first instance
* @param train the second instance
* @return the distance between the two given instances, between 0 and 1
*/
protected double distance(Instance first, Instance second) {
double distance = 0;
int firstI, secondI;
for (int p1 = 0, p2 = 0;
p1 < first.numValues() || p2 < second.numValues();) {
if (p1 >= first.numValues()) {
firstI = m_Train.numAttributes();
} else {
firstI = first.index(p1);
}
if (p2 >= second.numValues()) {
secondI = m_Train.numAttributes();
} else {
secondI = second.index(p2);
}
if (firstI == m_Train.classIndex()) {
p1++; continue;
}
if (secondI == m_Train.classIndex()) {
p2++; continue;
}
double diff;
if (firstI == secondI) {
diff = difference(firstI,
first.valueSparse(p1),
second.valueSparse(p2));
p1++; p2++;
} else if (firstI > secondI) {
diff = difference(secondI,
0, second.valueSparse(p2));
p2++;
} else {
diff = difference(firstI,
first.valueSparse(p1), 0);
p1++;
}
distance += diff * diff;
}
distance = Math.sqrt(distance);
return distance;
}
/**
* Computes the difference between two given attribute
* values.
*/
private double difference(int index, double val1, double val2) {
switch (m_Train.attribute(index).type()) {
case Attribute.NOMINAL:
// If attribute is nominal
if (Instance.isMissingValue(val1) ||
Instance.isMissingValue(val2) ||
((int)val1 != (int)val2)) {
return 1;
} else {
return 0;
}
case Attribute.NUMERIC:
// If attribute is numeric
if (Instance.isMissingValue(val1) ||
Instance.isMissingValue(val2)) {
if (Instance.isMissingValue(val1) &&
Instance.isMissingValue(val2)) {
return 1;
} else {
double diff;
if (Instance.isMissingValue(val2)) {
diff = norm(val1, index);
} else {
diff = norm(val2, index);
}
if (diff < 0.5) {
diff = 1.0 - diff;
}
return diff;
}
} else {
return norm(val1, index) - norm(val2, index);
}
default:
return 0;
}
}
/**
* Normalizes a given value of a numeric attribute.
*
* @param x the value to be normalized
* @param i the attribute's index
*/
private double norm(double x,int i) {
if (Double.isNaN(m_Min[i]) || Utils.eq(m_Max[i], m_Min[i])) {
return 0;
} else {
return (x - m_Min[i]) / (m_Max[i] - m_Min[i]);
}
}
/**
* Updates the minimum and maximum values for all the attributes
* based on a new instance.
*
* @param instance the new instance
*/
private void updateMinMax(Instance instance) {
for (int j = 0; j < m_Train.numAttributes(); j++) {
if (!instance.isMissing(j)) {
if (Double.isNaN(m_Min[j])) {
m_Min[j] = instance.value(j);
m_Max[j] = instance.value(j);
} else if (instance.value(j) < m_Min[j]) {
m_Min[j] = instance.value(j);
} else if (instance.value(j) > m_Max[j]) {
m_Max[j] = instance.value(j);
}
}
}
}
/**
* Main method for testing this class.
*
* @param argv the options
*/
public static void main(String [] argv) {
try {
System.out.println(Evaluation.evaluateModel(
new LWL(), argv));
} catch (Exception e) {
e.printStackTrace();
System.err.println(e.getMessage());
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -