📄 instances.java
字号:
public final void renameAttributeValue(int att, int val, String name) {
Attribute newAtt = (Attribute)attribute(att).copy();
FastVector newVec = new FastVector(numAttributes());
newAtt.setValue(val, name);
for (int i = 0; i < numAttributes(); i++) {
if (i == att) {
newVec.addElement(newAtt);
} else {
newVec.addElement(attribute(i));
}
}
m_Attributes = newVec;
}
/**
* Renames the value of a nominal (or string) attribute value. This
* change only affects this dataset.
*
* @param att the attribute
* @param val the value
* @param name the new name
*/
public final void renameAttributeValue(Attribute att, String val,
String name) {
int v = att.indexOfValue(val);
if (v == -1) throw new IllegalArgumentException(val + " not found");
renameAttributeValue(att.index(), v, name);
}
/**
* Creates a new dataset of the same size using random sampling
* with replacement.
*
* @param random a random number generator
* @return the new dataset
*/
public final Instances resample(Random random) {
Instances newData = new Instances(this, numInstances());
while (newData.numInstances() < numInstances()) {
newData.add(instance(random.nextInt(numInstances())));
}
return newData;
}
/**
* Creates a new dataset of the same size using random sampling
* with replacement according to the current instance weights. The
* weights of the instances in the new dataset are set to one.
*
* @param random a random number generator
* @return the new dataset
*/
public final Instances resampleWithWeights(Random random) {
double [] weights = new double[numInstances()];
for (int i = 0; i < weights.length; i++) {
weights[i] = instance(i).weight();
}
return resampleWithWeights(random, weights);
}
/**
* Creates a new dataset of the same size using random sampling
* with replacement according to the given weight vector. The
* weights of the instances in the new dataset are set to one.
* The length of the weight vector has to be the same as the
* number of instances in the dataset, and all weights have to
* be positive.
*
* @param random a random number generator
* @param weights the weight vector
* @return the new dataset
* @exception IllegalArgumentException if the weights array is of the wrong
* length or contains negative weights.
*/
public final Instances resampleWithWeights(Random random,
double[] weights) {
if (weights.length != numInstances()) {
throw new IllegalArgumentException("weights.length != numInstances.");
}
Instances newData = new Instances(this, numInstances());
if (numInstances() == 0) {
return newData;
}
double[] probabilities = new double[numInstances()];
double sumProbs = 0, sumOfWeights = Utils.sum(weights);
for (int i = 0; i < numInstances(); i++) {
sumProbs += random.nextDouble();
probabilities[i] = sumProbs;
}
Utils.normalize(probabilities, sumProbs / sumOfWeights);
// Make sure that rounding errors don't mess things up
probabilities[numInstances() - 1] = sumOfWeights;
int k = 0; int l = 0;
sumProbs = 0;
while ((k < numInstances() && (l < numInstances()))) {
if (weights[l] < 0) {
throw new IllegalArgumentException("Weights have to be positive.");
}
sumProbs += weights[l];
while ((k < numInstances()) &&
(probabilities[k] <= sumProbs)) {
newData.add(instance(l));
newData.instance(k).setWeight(1);
k++;
}
l++;
}
return newData;
}
/**
* Sets the class attribute.
*
* @param att attribute to be the class
*/
public final void setClass(Attribute att) {
m_ClassIndex = att.index();
}
/**
* Sets the class index of the set.
* If the class index is negative there is assumed to be no class.
* (ie. it is undefined)
*
* @param classIndex the new class index
* @exception IllegalArgumentException if the class index is too big or < 0
*/
public final void setClassIndex(int classIndex) {
if (classIndex >= numAttributes()) {
throw new IllegalArgumentException("Invalid class index: " + classIndex);
}
m_ClassIndex = classIndex;
}
//<<18/03/2005, Frank J. Xu
//Sets the class index of scoring dataset without target attribute.
/**
* Sets the class index of the set.
* If the class index is negative there is assumed to be no class.
* (ie. it is undefined)
*
* @param classIndex the new class index
* @exception IllegalArgumentException if the class index is too big or < 0
*/
public final void setScoringDataClassIndex(int classIndex) {
if ((classIndex-1) >= numAttributes()) {
throw new IllegalArgumentException("Invalid class index: " + classIndex);
}
m_ClassIndex = classIndex;
}
//18/03/2005, Frank J. Xu>>
/**
* Sets the relation's name.
*
* @param newName the new relation name.
*/
public final void setRelationName(String newName) {
m_RelationName = newName;
}
/**
* Sorts the instances based on an attribute. For numeric attributes,
* instances are sorted in ascending order. For nominal attributes,
* instances are sorted based on the attribute label ordering
* specified in the header. Instances with missing values for the
* attribute are placed at the end of the dataset.
*
* @param attIndex the attribute's index
*/
public final void sort(int attIndex) {
int i,j;
// move all instances with missing values to end
j = numInstances() - 1;
i = 0;
while (i <= j) {
if (instance(j).isMissing(attIndex)) {
j--;
} else {
if (instance(i).isMissing(attIndex)) {
swap(i,j);
j--;
}
i++;
}
}
quickSort(attIndex, 0, j);
}
/**
* Sorts the instances based on an attribute. For numeric attributes,
* instances are sorted into ascending order. For nominal attributes,
* instances are sorted based on the attribute label ordering
* specified in the header. Instances with missing values for the
* attribute are placed at the end of the dataset.
*
* @param att the attribute
*/
public final void sort(Attribute att) {
sort(att.index());
}
/**
* Stratifies a set of instances according to its class values
* if the class attribute is nominal (so that afterwards a
* stratified cross-validation can be performed).
*
* @param numFolds the number of folds in the cross-validation
* @exception UnassignedClassException if the class is not set
*/
public final void stratify(int numFolds) {
if (numFolds <= 0) {
throw new IllegalArgumentException("Number of folds must be greater than 1");
}
if (m_ClassIndex < 0) {
throw new UnassignedClassException("Class index is negative (not set)!");
}
if (classAttribute().isNominal()) {
// sort by class
int index = 1;
while (index < numInstances()) {
Instance instance1 = instance(index - 1);
for (int j = index; j < numInstances(); j++) {
Instance instance2 = instance(j);
if ((instance1.classValue() == instance2.classValue()) ||
(instance1.classIsMissing() &&
instance2.classIsMissing())) {
swap(index,j);
index++;
}
}
index++;
}
stratStep(numFolds);
}
}
/**
* Computes the sum of all the instances' weights.
*
* @return the sum of all the instances' weights as a double
*/
public final double sumOfWeights() {
double sum = 0;
for (int i = 0; i < numInstances(); i++) {
sum += instance(i).weight();
}
return sum;
}
/**
* Creates the test set for one fold of a cross-validation on
* the dataset.
*
* @param numFolds the number of folds in the cross-validation. Must
* be greater than 1.
* @param numFold 0 for the first fold, 1 for the second, ...
* @return the test set as a set of weighted instances
* @exception IllegalArgumentException if the number of folds is less than 2
* or greater than the number of instances.
*/
public Instances testCV(int numFolds, int numFold) {
int numInstForFold, first, offset;
Instances test;
if (numFolds < 2) {
throw new IllegalArgumentException("Number of folds must be at least 2!");
}
if (numFolds > numInstances()) {
throw new IllegalArgumentException("Can't have more folds than instances!");
}
numInstForFold = numInstances() / numFolds;
if (numFold < numInstances() % numFolds){
numInstForFold++;
offset = numFold;
}else
offset = numInstances() % numFolds;
test = new Instances(this, numInstForFold);
first = numFold * (numInstances() / numFolds) + offset;
copyInstances(first, test, numInstForFold);
return test;
}
/**
* Returns the dataset as a string in ARFF format. Strings
* are quoted if they contain whitespace characters, or if they
* are a question mark.
*
* @return the dataset in ARFF format as a string
*/
public final String toString() {
StringBuffer text = new StringBuffer();
text.append(ARFF_RELATION).append(" ").
append(Utils.quote(m_RelationName)).append("\n\n");
for (int i = 0; i < numAttributes(); i++) {
text.append(attribute(i)).append("\n");
}
text.append("\n").append(ARFF_DATA).append("\n");
for (int i = 0; i < numInstances(); i++) {
text.append(instance(i));
if (i < numInstances() - 1) {
text.append('\n');
}
}
return text.toString();
}
/**
* Creates the training set for one fold of a cross-validation
* on the dataset. The data is subsequently randomized based
* on the given random number generator.
*
* @param numFolds the number of folds in the cross-validation. Must
* be greater than 1.
* @param numFold 0 for the first fold, 1 for the second, ...
* @return the training set
* @exception IllegalArgumentException if the number of folds is less than 2
* or greater than the number of instances.
*/
public Instances trainCV(int numFolds, int numFold) {
int numInstForFold, first, offset;
Instances train;
if (numFolds < 2) {
throw new IllegalArgumentException("Number of folds must be at least 2!");
}
if (numFolds > numInstances()) {
throw new IllegalArgumentException("Can't have more folds than instances!");
}
numInstForFold = numInstances() / numFolds;
if (numFold < numInstances() % numFolds) {
numInstForFold++;
offset = numFold;
}else
offset = numInstances() % numFolds;
train = new Instances(this, numInstances() - numInstForFold);
first = numFold * (numInstances() / numFolds) + offset;
copyInstances(0, train, first);
copyInstances(first + numInstForFold, train,
numInstances() - first - numInstForFold);
return train;
}
/**
* Creates the training set for one fold of a cross-validation
* on the dataset. The data is subsequently randomized based
* on the given random number generator.
*
* @param numFolds the number of folds in the cross-validation. Must
* be greater than 1.
* @param numFold 0 for the first fold, 1 for the second, ...
* @param random the random number generator
* @return the training set
* @exception IllegalArgumentException if the number of folds is less than 2
* or greater than the number of instances.
*/
public Instances trainCV(int numFolds, int numFold, Random random) {
Instances train = trainCV(numFolds, numFold);
train.randomize(random);
return train;
}
/**
* Computes the variance for a numeric attribute.
*
* @param attIndex the numeric attribute
* @return the variance if the attribute is numeric
* @exception IllegalArgumentException if the attribute is not numeric
*/
public final double variance(int attIndex) {
double sum = 0, sumSquared = 0, sumOfWeights = 0;
if (!attribute(attIndex).isNumeric()) {
throw new IllegalArgumentException("Can't compute variance because attribute is " +
"not numeric!");
}
for (int i = 0; i < numInstances(); i++) {
if (!instance(i).isMissing(attIndex)) {
sum += instance(i).weight() *
instance(i).value(attIndex);
sumSquared += instance(i).weight() *
instance(i).value(attIndex) *
instance(i).value(attIndex);
sumOfWeights += instance(i).weight();
}
}
if (sumOfWeights <= 1) {
return 0;
}
double result = (sumSquared - (sum * sum / sumOfWeights)) /
(sumOfWeights - 1);
// We don't like negative variance
if (result < 0) {
return 0;
} else {
return result;
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -