📄 knn.java
字号:
package learner;
public class Knn implements Classifier {
int k;
Data data;
Knn(Data data, int k) {
this.data = data;
this.k = k;
}
public double test(Datastructure[] testdata) {
int good = 0;
for (int i = 0; i < testdata.length; i++)
if (classify(testdata[i].data) == testdata[i].label)
good++;
return (good * 100.0) / testdata.length;
}
public int classify(double testdata) {
int i = 0;
int class1 = 0, class2 = 0;
Datastructure[] distance = data.distance(testdata);
while (true) {
i++;
if (distance[i].label == -1)
class1 += 1;
else
class2 += 1;
if (class1 == k)
return -1;
if (class2 == k)
return +1;
}
}
public double[] crossvalidate(int folds) {
double average = 0;
double[] results = new double[folds];
for (int i = 0; i < folds; i++) {
data.split(i, folds);
results[i] = test(data.test);
}
average = average / folds;
return results;
}
public double findparameter() {
int k = 0;
double performance, bestperformance = 0;
data.split(1, 1);
for (int i = 0; i < data.test.length / 2; i++) {
Knn knn = new Knn(data, i);
performance = knn.test(data.test);
if (performance > bestperformance) {
bestperformance = performance;
k = i;
}
}
System.out.println("Best parameter found: " + k);
return k;
}
public Data getdata() {
return this.data;
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -