📄 instances.java
字号:
int v = att.indexOfValue(val); if (v == -1) throw new IllegalArgumentException(val + " not found"); renameAttributeValue(att.index(), v, name); } /** * Creates a new dataset of the same size using random sampling * with replacement. * * @param random a random number generator * @return the new dataset */ public final Instances resample(Random random) { Instances newData = new Instances(this, numInstances()); while (newData.numInstances() < numInstances()) { newData.add(instance(random.nextInt(numInstances()))); } return newData; } /** * Creates a new dataset of the same size using random sampling * with replacement according to the current instance weights. The * weights of the instances in the new dataset are set to one. * * @param random a random number generator * @return the new dataset */ public final Instances resampleWithWeights(Random random) { double [] weights = new double[numInstances()]; for (int i = 0; i < weights.length; i++) { weights[i] = instance(i).weight(); } return resampleWithWeights(random, weights); } /** * Creates a new dataset of the same size using random sampling * with replacement according to the given weight vector. The * weights of the instances in the new dataset are set to one. * The length of the weight vector has to be the same as the * number of instances in the dataset, and all weights have to * be positive. * * @param random a random number generator * @param weights the weight vector * @return the new dataset * @exception IllegalArgumentException if the weights array is of the wrong * length or contains negative weights. */ public final Instances resampleWithWeights(Random random, double[] weights) { if (weights.length != numInstances()) { throw new IllegalArgumentException("weights.length != numInstances."); } Instances newData = new Instances(this, numInstances()); double[] probabilities = new double[numInstances()]; double sumProbs = 0, sumOfWeights = Utils.sum(weights); for (int i = 0; i < numInstances(); i++) { sumProbs += random.nextDouble(); probabilities[i] = sumProbs; } Utils.normalize(probabilities, sumProbs / sumOfWeights); // Make sure that rounding errors don't mess things up probabilities[numInstances() - 1] = sumOfWeights; int k = 0; int l = 0; sumProbs = 0; while ((k < numInstances() && (l < numInstances()))) { if (weights[l] < 0) { throw new IllegalArgumentException("Weights have to be positive."); } sumProbs += weights[l]; while ((k < numInstances()) && (probabilities[k] <= sumProbs)) { newData.add(instance(l)); newData.instance(k).setWeight(1); k++; } l++; } return newData; } /** * Sets the class attribute. * * @param att attribute to be the class */ public final void setClass(Attribute att) { m_ClassIndex = att.index(); } /** * Sets the class index of the set. * If the class index is negative there is assumed to be no class. * (ie. it is undefined) * * @param classIndex the new class index * @exception IllegalArgumentException if the class index is too big or < 0 */ public final void setClassIndex(int classIndex) { if (classIndex >= numAttributes()) { throw new IllegalArgumentException("Invalid class index: " + classIndex); } m_ClassIndex = classIndex; } /** * Sets the relation's name. * * @param newName the new relation name. */ public final void setRelationName(String newName) { m_RelationName = newName; } /** * Sorts the instances based on an attribute. For numeric attributes, * instances are sorted in ascending order. For nominal attributes, * instances are sorted based on the attribute label ordering * specified in the header. Instances with missing values for the * attribute are placed at the end of the dataset. * * @param attIndex the attribute's index */ public final void sort(int attIndex) { int i,j; // move all instances with missing values to end j = numInstances() - 1; i = 0; while (i <= j) { if (instance(j).isMissing(attIndex)) { j--; } else { if (instance(i).isMissing(attIndex)) { swap(i,j); j--; } i++; } } quickSort(attIndex, 0, j); } /** * Sorts the instances based on an attribute. For numeric attributes, * instances are sorted into ascending order. For nominal attributes, * instances are sorted based on the attribute label ordering * specified in the header. Instances with missing values for the * attribute are placed at the end of the dataset. * * @param att the attribute */ public final void sort(Attribute att) { sort(att.index()); } /** * Stratifies a set of instances according to its class values * if the class attribute is nominal (so that afterwards a * stratified cross-validation can be performed). * * @param numFolds the number of folds in the cross-validation * @exception UnassignedClassException if the class is not set */ public final void stratify(int numFolds) { if (numFolds <= 0) { throw new IllegalArgumentException("Number of folds must be greater than 1"); } if (m_ClassIndex < 0) { throw new UnassignedClassException("Class index is negative (not set)!"); } if (classAttribute().isNominal()) { // sort by class int index = 1; while (index < numInstances()) { Instance instance1 = instance(index - 1); for (int j = index; j < numInstances(); j++) { Instance instance2 = instance(j); if ((instance1.classValue() == instance2.classValue()) || (instance1.classIsMissing() && instance2.classIsMissing())) { swap(index,j); index++; } } index++; } stratStep(numFolds); } } /** * Computes the sum of all the instances' weights. * * @return the sum of all the instances' weights as a double */ public final double sumOfWeights() { double sum = 0; for (int i = 0; i < numInstances(); i++) { sum += instance(i).weight(); } return sum; } /** * Creates the test set for one fold of a cross-validation on * the dataset. * * @param numFolds the number of folds in the cross-validation. Must * be greater than 1. * @param numFold 0 for the first fold, 1 for the second, ... * @return the test set as a set of weighted instances * @exception IllegalArgumentException if the number of folds is less than 2 * or greater than the number of instances. */ public Instances testCV(int numFolds, int numFold) { int numInstForFold, first, offset; Instances test; if (numFolds < 2) { throw new IllegalArgumentException("Number of folds must be at least 2!"); } if (numFolds > numInstances()) { throw new IllegalArgumentException("Can't have more folds than instances!"); } numInstForFold = numInstances() / numFolds; if (numFold < numInstances() % numFolds){ numInstForFold++; offset = numFold; }else offset = numInstances() % numFolds; test = new Instances(this, numInstForFold); first = numFold * (numInstances() / numFolds) + offset; copyInstances(first, test, numInstForFold); return test; } /** * Returns the dataset as a string in ARFF format. Strings * are quoted if they contain whitespace characters, or if they * are a question mark. * * @return the dataset in ARFF format as a string */ public final String toString() { StringBuffer text = new StringBuffer(); text.append("@relation " + Utils.quote(m_RelationName) + "\n\n"); for (int i = 0; i < numAttributes(); i++) { text.append(attribute(i) + "\n"); } text.append("\n@data\n"); for (int i = 0; i < numInstances(); i++) { text.append(instance(i)); if (i < numInstances() - 1) { text.append('\n'); } } return text.toString(); } /** * Creates the training set for one fold of a cross-validation * on the dataset. * * @param numFolds the number of folds in the cross-validation. Must * be greater than 1. * @param numFold 0 for the first fold, 1 for the second, ... * @return the training set as a set of weighted * instances * @exception IllegalArgumentException if the number of folds is less than 2 * or greater than the number of instances. */ public Instances trainCV(int numFolds, int numFold) { int numInstForFold, first, offset; Instances train; if (numFolds < 2) { throw new IllegalArgumentException("Number of folds must be at least 2!"); } if (numFolds > numInstances()) { throw new IllegalArgumentException("Can't have more folds than instances!"); } numInstForFold = numInstances() / numFolds; if (numFold < numInstances() % numFolds) { numInstForFold++; offset = numFold; }else offset = numInstances() % numFolds; train = new Instances(this, numInstances() - numInstForFold); first = numFold * (numInstances() / numFolds) + offset; copyInstances(0, train, first); copyInstances(first + numInstForFold, train, numInstances() - first - numInstForFold); return train; } /** * Computes the variance for a numeric attribute. * * @param attIndex the numeric attribute * @return the variance if the attribute is numeric * @exception IllegalArgumentException if the attribute is not numeric */ public final double variance(int attIndex) { double sum = 0, sumSquared = 0, sumOfWeights = 0; if (!attribute(attIndex).isNumeric()) { throw new IllegalArgumentException("Can't compute variance because attribute is " + "not numeric!"); } for (int i = 0; i < numInstances(); i++) { if (!instance(i).isMissing(attIndex)) { sum += instance(i).weight() * instance(i).value(attIndex); sumSquared += instance(i).weight() * instance(i).value(attIndex) * instance(i).value(attIndex); sumOfWeights += instance(i).weight(); } } if (Utils.smOrEq(sumOfWeights, 1)) { return 0; } double result = (sumSquared - (sum * sum / sumOfWeights)) / (sumOfWeights - 1); // We don't like negative variance if (result < 0) { return 0; } else { return result; } } /** * Computes the variance for a numeric attribute. * * @param att the numeric attribute * @return the variance if the attribute is numeric * @exception IllegalArgumentException if the attribute is not numeric */ public final double variance(Attribute att) { return variance(att.index()); } /** * Calculates summary statistics on the values that appear in this * set of instances for a specified attribute. * * @param index the index of the attribute to summarize. * @return an AttributeStats object with it's fields calculated. */ public AttributeStats attributeStats(int index) { AttributeStats result = new AttributeStats(); if (attribute(index).isNominal()) { result.nominalCounts = new int [attribute(index).numValues()]; } if (attribute(index).isNumeric()) { result.numericStats = new weka.experiment.Stats(); } result.totalCount = numInstances(); double [] attVals = attributeToDoubleArray(index); int [] sorted = Utils.sort(attVals); int currentCount = 0; double prev = Instance.missingValue(); for (int j = 0; j < numInstances(); j++) { Instance current = instance(sorted[j]); if (current.isMissing(index)) { result.missingCount = numInstances() - j; break; } if (Utils.eq(current.value(index), prev)) { currentCount++; } else { result.addDistinct(prev, currentCount); currentCount = 1; prev = current.value(index); } } result.addDistinct(prev, currentCount); result.distinctCount--; // So we don't count "missing" as a value return result; } /** * Gets the value of all instances in this dataset for a particular * attribute. Useful in conjunction with Utils.sort to allow iterating * through the dataset in sorted order for some attribute. * * @param index the index of the attribute. * @return an array containing the value of the desired attribute for * each instance in the dataset. */ public double [] attributeToDoubleArray(int index) { double [] result = new double[numInstances()]; for (int i = 0; i < result.length; i++) { result[i] = instance(i).value(index); } return result; } /** * Generates a string summarizing the set of instances. Gives a breakdown * for each attribute indicating the number of missing/discrete/unique * values and other information. * * @return a string summarizing the dataset */ public String toSummaryString() { StringBuffer result = new StringBuffer(); result.append("Relation Name: ").append(relationName()).append('\n'); result.append("Num Instances: ").append(numInstances()).append('\n'); result.append("Num Attributes: ").append(numAttributes()).append('\n'); result.append('\n'); result.append(Utils.padLeft("", 5)).append(Utils.padRight("Name", 25));
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -