📄 lwl.java
字号:
for(int i=0, insCount=0; i < m_Train.numInstances(); i++) {
switch(m_WeightKernel) {
case LINEAR:
case EPANECHNIKOV:
case TRICUBE:
if(insCount<k) {
distance[i] = distance(instance, m_Train.instance(i));
h.put(i, distance[i]);
}
else {
MyHeapElement temp = h.peek();
distance[i] = distance(instance, m_Train.instance(i),
temp.distance);
if(distance[i]<temp.distance) {
h.get();
h.put(i, distance[i]);
}
}
break;
default:
distance[i] = distance(instance, m_Train.instance(i));
break;
}
insCount++;
}
int [] sortKey;
sortKey = Utils.sort(distance);
if (m_Debug) {
System.out.println("Instance Distances");
for (int i = 0; i < sortKey.length; i++) {
System.out.println("" + distance[sortKey[i]]);
}
}
// Determine the bandwidth
double bandwidth = distance[sortKey[k-1]];
// Check for bandwidth zero
if (bandwidth <= 0) {
for (int i = k; i < sortKey.length; i++) {
if (distance[sortKey[i]] > bandwidth) {
bandwidth = distance[sortKey[i]];
break;
}
}
if(bandwidth <= 0) {
throw new Exception("All training instances coincide with test "+
"instance!");
}
}
// Rescale the distances by the bandwidth
for (int i = 0; i < distance.length; i++) {
distance[i] = distance[i] / bandwidth;
}
// Pass the distances through a weighting kernel
for (int i = 0; i < distance.length; i++) {
switch (m_WeightKernel) {
case LINEAR:
distance[i] = Math.max(1.0001 - distance[i], 0);
break;
case EPANECHNIKOV:
if(distance[i]<=1)
distance[i] = 3/4D*(1.0001 - distance[i]*distance[i]);
else
distance[i] = 0;
break;
case TRICUBE:
if(distance[i]<=1)
distance[i] = Math.pow( (1.0001 - Math.pow(distance[i], 3)), 3 );
else
distance[i] = 0;
break;
case CONSTANT:
//System.err.println("using constant kernel");
if(distance[i]<=1)
distance[i] = 1;
else
distance[i] = 0;
break;
case INVERSE:
distance[i] = 1.0 / (1.0 + distance[i]);
break;
case GAUSS:
distance[i] = Math.exp(-distance[i] * distance[i]);
break;
}
}
if (m_Debug) {
System.out.println("Instance Weights");
for (int i = 0; i < sortKey.length; i++) {
System.out.println("" + distance[sortKey[i]]);
}
}
// Set the weights on a copy of the training data
Instances weightedTrain = new Instances(m_Train, 0);
double sumOfWeights = 0, newSumOfWeights = 0;
for (int i = 0; i < sortKey.length; i++) {
double weight = distance[sortKey[i]];
if (weight < 1e-20) {
break;
}
Instance newInst = (Instance) m_Train.instance(sortKey[i]).copy();
sumOfWeights += newInst.weight();
newSumOfWeights += newInst.weight() * weight;
newInst.setWeight(newInst.weight() * weight);
weightedTrain.add(newInst);
}
if (m_Debug) {
System.out.println("Kept " + weightedTrain.numInstances() + " out of "
+ m_Train.numInstances() + " instances");
}
// Rescale weights
for (int i = 0; i < weightedTrain.numInstances(); i++) {
Instance newInst = weightedTrain.instance(i);
newInst.setWeight(newInst.weight() * sumOfWeights / newSumOfWeights);
}
// Create a weighted classifier
m_Classifier.buildClassifier(weightedTrain);
if (m_Debug) {
System.out.println("Classifying test instance: " + instance);
System.out.println("Built base classifier:\n"
+ m_Classifier.toString());
}
// Return the classifier's predictions
return m_Classifier.distributionForInstance(instance);
}
/**
* Returns a description of this classifier.
*
* @return a description of this classifier as a string.
*/
public String toString() {
if (m_Train == null) {
return "Locally weighted learning: No model built yet.";
}
String result = "Locally weighted learning\n"
+ "===========================\n";
result += "Using classifier: " + m_Classifier.getClass().getName() + "\n";
switch (m_WeightKernel) {
case LINEAR:
result += "Using linear weighting kernels\n";
break;
case EPANECHNIKOV:
result += "Using epanechnikov weighting kernels\n";
break;
case TRICUBE:
result += "Using tricube weighting kernels\n";
break;
case INVERSE:
result += "Using inverse-distance weighting kernels\n";
break;
case GAUSS:
result += "Using gaussian weighting kernels\n";
break;
case CONSTANT:
result += "Using constant weighting kernels\n";
break;
}
result += "Using " + (m_UseAllK ? "all" : "" + m_kNN) + " neighbours";
return result;
}
/**
* Calculates the distance between two instances
*
* @param test the first instance
* @param train the second instance
* @return the distance between the two given instances, between 0 and 1
*/
private double distance(Instance first, Instance second) throws Exception {
return distance(first, second, Math.sqrt(Double.MAX_VALUE));
}
/**
* Calculates the distance between two instances
*
* @param test the first instance
* @param train the second instance
* @param cutOffValue skips the rest of the calculations and returns Double.Max
* if distance is going to become larger than this cutOffValue.
* @return the distance between the two given instances, between 0 and 1
*/
private double distance(Instance first, Instance second, double cutOffValue)
throws Exception {
return euclideanDistance(first, second, cutOffValue);
}
/**
* Calculates the euclidean distance between two instances
*
* @param test the first instance
* @param train the second instance
* @param cutOffValue skips the rest of the calculations and returns Double.Max
* if distance is going to become larger than this cutOffValue.
* @return the distance between the two given instances, between 0 and 1
*/
private double euclideanDistance(Instance first, Instance second,
double cutOffValue) {
double distance = 0;
int firstI, secondI;
cutOffValue = cutOffValue*cutOffValue;
for (int p1 = 0, p2 = 0;
p1 < first.numValues() || p2 < second.numValues();) {
if (p1 >= first.numValues()) {
firstI = m_Train.numAttributes();
} else {
firstI = first.index(p1);
}
if (p2 >= second.numValues()) {
secondI = m_Train.numAttributes();
} else {
secondI = second.index(p2);
}
if (firstI == m_Train.classIndex()) {
p1++; continue;
}
if (secondI == m_Train.classIndex()) {
p2++; continue;
}
double diff;
if (firstI == secondI) {
diff = difference(firstI,
first.valueSparse(p1),
second.valueSparse(p2));
p1++; p2++;
} else if (firstI > secondI) {
diff = difference(secondI,
0, second.valueSparse(p2));
p2++;
} else {
diff = difference(firstI,
first.valueSparse(p1), 0);
p1++;
}
distance += diff * diff;
if(distance>cutOffValue)
return Double.MAX_VALUE; //distance;
}
distance = Math.sqrt(distance);
return distance;
}
/**
* Computes the difference between two given attribute
* values.
*/
private double difference(int index, double val1, double val2) {
switch (m_Train.attribute(index).type()) {
case Attribute.NOMINAL:
// If attribute is nominal
if(Instance.isMissingValue(val1) ||
Instance.isMissingValue(val2) ||
((int)val1 != (int)val2)) {
return 1;
} else {
return 0;
}
case Attribute.NUMERIC:
// If attribute is numeric
if (Instance.isMissingValue(val1) ||
Instance.isMissingValue(val2)) {
if(Instance.isMissingValue(val1) &&
Instance.isMissingValue(val2)) {
if(m_NoAttribNorm==false) //We are doing normalization
return 1;
else
return (m_Max[index] - m_Min[index]);
} else {
double diff;
if (Instance.isMissingValue(val2)) {
diff = (m_NoAttribNorm==false) ? norm(val1, index) : val1;
} else {
diff = (m_NoAttribNorm==false) ? norm(val2, index) : val2;
}
if (m_NoAttribNorm==false && diff < 0.5) {
diff = 1.0 - diff;
}
else if (m_NoAttribNorm==true) {
if((m_Max[index]-diff) > (diff-m_Min[index]))
return m_Max[index]-diff;
else
return diff-m_Min[index];
}
return diff;
}
} else {
return (m_NoAttribNorm==false) ?
(norm(val1, index) - norm(val2, index)) :
(val1 - val2);
}
default:
return 0;
}
}
/**
* Normalizes a given value of a numeric attribute.
*
* @param x the value to be normalized
* @param i the attribute's index
*/
private double norm(double x,int i) {
if (Double.isNaN(m_Min[i]) || Utils.eq(m_Max[i], m_Min[i])) {
return 0;
} else {
return (x - m_Min[i]) / (m_Max[i] - m_Min[i]);
}
}
/**
* Updates the minimum and maximum values for all the attributes
* based on a new instance.
*
* @param instance the new instance
*/
private void updateMinMax(Instance instance) {
for (int j = 0; j < m_Train.numAttributes(); j++) {
if (!instance.isMissing(j)) {
if (Double.isNaN(m_Min[j])) {
m_Min[j] = instance.value(j);
m_Max[j] = instance.value(j);
} else if (instance.value(j) < m_Min[j]) {
m_Min[j] = instance.value(j);
} else if (instance.value(j) > m_Max[j]) {
m_Max[j] = instance.value(j);
}
}
}
}
/**
* Main method for testing this class.
*
* @param argv the options
*/
public static void main(String [] argv) {
try {
System.out.println(Evaluation.evaluateModel(
new LWL(), argv));
} catch (Exception e) {
e.printStackTrace();
System.err.println(e.getMessage());
}
}
private class MyHeap {
//m_heap[0].index containts the current size of the heap
//m_heap[0].distance is unused.
MyHeapElement m_heap[] = null;
public MyHeap(int maxSize) {
if((maxSize%2)==0)
maxSize++;
m_heap = new MyHeapElement[maxSize+1];
m_heap[0] = new MyHeapElement(0, 0);
//System.err.println("m_heap size is: "+m_heap.length);
}
public int size() {
return m_heap[0].index;
}
public MyHeapElement peek() {
return m_heap[1];
}
public MyHeapElement get() throws Exception {
if(m_heap[0].index==0)
throw new Exception("No elements present in the heap");
MyHeapElement r = m_heap[1];
m_heap[1] = m_heap[m_heap[0].index];
m_heap[0].index--;
downheap();
return r;
}
public void put(int i, double d) throws Exception {
if((m_heap[0].index+1)>(m_heap.length-1))
throw new Exception("the number of elements cannot exceed the "+
"initially set maximum limit");
m_heap[0].index++;
m_heap[m_heap[0].index] = new MyHeapElement(i, d);
//m_heap[m_heap[0].index].index = i;
//m_heap[m_heap[0].index].distance = d;
//System.err.print("new size: "+(int)m_heap[0]+", ");
upheap();
}
private void upheap() {
int i = m_heap[0].index;
MyHeapElement temp;
while( i > 1 && m_heap[i].distance>m_heap[i/2].distance) {
temp = m_heap[i];
m_heap[i] = m_heap[i/2];
i = i/2;
m_heap[i] = temp; //this is i/2 done here to avoid another division.
}
}
private void downheap() {
int i = 1;
MyHeapElement temp;
while( (2*i) <= m_heap[0].index &&
(m_heap[i].distance < m_heap[2*i].distance ||
m_heap[i].distance<m_heap[2*i+1].distance )) {
if((2*i+1)<=m_heap[0].index) {
if(m_heap[2*i].distance>m_heap[2*i+1].distance) {
temp = m_heap[i];
m_heap[i] = m_heap[2*i];
i = 2*i;
m_heap[i] = temp;
}
else {
temp = m_heap[i];
m_heap[i] = m_heap[2*i+1];
i = 2*i+1;
m_heap[i] = temp;
}
}
else {
temp = m_heap[i];
m_heap[i] = m_heap[2*i];
i = 2*i;
m_heap[i] = temp;
}
}
}
}
private class MyHeapElement {
int index;
double distance;
public MyHeapElement(int i, double d) {
distance = d; index = i;
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -