📄 data.java
字号:
package learner;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.Vector;
import java.util.Collections;
public class Data {
int c1 = 0;
int c2 = 0;
Datastructure[] fulldata;
Datastructure[] training;
Datastructure[] test;
Data(String filename, int dataindex, int labelindex) {
// load data and labels
// possible dataindex: 1,2,7,10,13,14; possible label: 15
double[] data = loaddata(filename, dataindex);
double[] labels = loaddata(filename, labelindex);
// initialize weights
double[] weights = new double[data.length];
Arrays.fill(weights, (1.0 / data.length));
// iterativelly combine everything in the datastructure
fulldata = new Datastructure[data.length];
test = new Datastructure[data.length];
training = new Datastructure[data.length];
for (int i = 0; i < data.length; i++) {
fulldata[i] = new Datastructure(data[i], labels[i], weights[i]);
training[i] = new Datastructure(data[i], labels[i], weights[i]);
test[i] = new Datastructure(data[i], labels[i], weights[i]);
if (labels[i] == -1)
c1++;
else
c2++;
}
}
public void split(int index, int folds) {
c1 = 0;
c2 = 0;
if (folds <= 1) {
training = new Datastructure[fulldata.length];
test = new Datastructure[fulldata.length];
for (int i = 0; i < fulldata.length; i++) {
training[i] = new Datastructure(fulldata[i].data,
fulldata[i].label, fulldata[i].weight);
test[i] = new Datastructure(fulldata[i].data,
fulldata[i].label, fulldata[i].weight);
if (fulldata[i].label == -1)
c1++;
else
c2++;
}
} else {
int unit = fulldata.length / folds;
int testindex = 0;
int trainindex = 0;
test = new Datastructure[unit];
training = new Datastructure[fulldata.length - unit];
for (int i = 0; i < fulldata.length; i++)
if (i >= index * unit && (i < (index * unit) + unit))
test[testindex++] = new Datastructure(fulldata[i].data,
fulldata[i].label, fulldata[i].weight);
else {
training[trainindex++] = new Datastructure(
fulldata[i].data, fulldata[i].label,
fulldata[i].weight);
if (fulldata[i].label == -1)
c1++;
else
c2++;
}
}
}
public Datastructure[] distance(double testdata) {
Datastructure[] temp = new Datastructure[training.length];
for (int i = 0; i < training.length; i++) {
temp[i] = new Datastructure(Math.abs(training[i].data - testdata),
training[i].label, 0);
}
Arrays.sort(temp);
return temp;
}
public static double std(double[] data) {
// Calculate the mean
double mean = mean(data);
int n = data.length;
// calculate the sum of squares
double sum = 0;
for (int i = 0; i < n; i++) {
final double v = data[i] - mean;
sum += v * v;
}
return Math.sqrt(sum / (n - 1));
}
public static double mean(double[] data) {
double mean = 0;
int n = data.length;
for (int i = 0; i < n; i++) {
mean += data[i];
}
return mean / n;
}
public void shuffle() {
Collections.shuffle(Arrays.asList(fulldata));
}
public String print(Datastructure[] structure, int index) {
String output = "";
for (int i = 0; i < structure.length; i++)
switch (index) {
case 0:
output += structure[i].data + "|";
break;
case 1:
output += structure[i].label + "|";
break;
case 2:
output += structure[i].weight + "|";
break;
default:
output += structure[i].data + "\t" + structure[i].label + "\t"
+ structure[i].weight + "\n";
break;
}
return output;
}
public Datastructure[] filter(Datastructure[] structure) {
int i, class1 = 0, class2 = 0;
double sumofweights = 0;
Arrays.sort(structure);
Vector vector = new Vector();
for (i = 0; i < structure.length - 1; i++) {
if (structure[i].label == -1.0)
class1++;
else
class2++;
sumofweights += structure[i].weight;
if (structure[i].data != structure[i + 1].data) {
vector.addElement(new Datastructure(structure[i].data,
class1 >= class2 ? -1 : 1, sumofweights));
class1 = 0;
class2 = 0;
sumofweights = 0;
}
}
Datastructure[] output = new Datastructure[vector.size()];
for (i = 0; i < vector.size(); i++)
output[i] = (Datastructure) vector.get(i);
return output;
}
public double[] loaddata(String filename, int dataindex) {
int i = 0;
double output[];
String[] dataline;
Vector datavector = new Vector();
try {
BufferedReader filereader = new BufferedReader(
new InputStreamReader(this.getClass().getResourceAsStream(
filename)));
while (true) {
String line = filereader.readLine();
if (line == null)
break;
i++;
dataline = line.split(",");
if (dataline[dataindex].equals("-"))
datavector.add(new Double(-1.0));
else if (dataline[dataindex].equals("+"))
datavector.add(new Double(1.0));
else if (!dataline[dataindex].matches("\\d*\\.?\\d*"))
System.out.println("Warning: Found wrong data, removed");
else
datavector.add(dataline[dataindex]);
}
filereader.close();
} catch (Exception e) {
e.printStackTrace();
}
output = new double[datavector.size()];
for (i = 0; i < datavector.size(); i++)
output[i] = new Double(datavector.get(i).toString()).doubleValue();
return output;
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -