📄 spreadsubsample.java
字号:
}
/**
* Sets the value for the max count
*
* @param spread the new max count
*/
public void setMaxCount(double maxcount) {
m_MaxCount = (int)maxcount;
}
/**
* Gets the value for the max count
*
* @return the max count
*/
public double getMaxCount() {
return m_MaxCount;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String randomSeedTipText() {
return "Sets the random number seed for subsampling.";
}
/**
* Gets the random number seed.
*
* @return the random number seed.
*/
public int getRandomSeed() {
return m_RandomSeed;
}
/**
* Sets the random number seed.
*
* @param newSeed the new random number seed.
*/
public void setRandomSeed(int newSeed) {
m_RandomSeed = newSeed;
}
/**
* Sets the format of the input instances.
*
* @param instanceInfo an Instances object containing the input
* instance structure (any instances contained in the object are
* ignored - only the structure is required).
* @return true if the outputFormat may be collected immediately
* @exception UnassignedClassException if no class attribute has been set.
* @exception UnsupportedClassTypeException if the class attribute
* is not nominal.
*/
public boolean setInputFormat(Instances instanceInfo)
throws Exception {
super.setInputFormat(instanceInfo);
if (instanceInfo.classAttribute().isNominal() == false) {
throw new UnsupportedClassTypeException("The class attribute must be nominal.");
}
setOutputFormat(instanceInfo);
m_FirstBatchDone = false;
return true;
}
/**
* Input an instance for filtering. Filter requires all
* training instances be read before producing output.
*
* @param instance the input instance
* @return true if the filtered instance may now be
* collected with output().
* @exception IllegalStateException if no input structure has been defined
*/
public boolean input(Instance instance) {
if (getInputFormat() == null) {
throw new IllegalStateException("No input instance format defined");
}
if (m_NewBatch) {
resetQueue();
m_NewBatch = false;
}
if (m_FirstBatchDone) {
push(instance);
return true;
} else {
bufferInput(instance);
return false;
}
}
/**
* Signify that this batch of input to the filter is finished.
* If the filter requires all instances prior to filtering,
* output() may now be called to retrieve the filtered instances.
*
* @return true if there are instances pending output
* @exception IllegalStateException if no input structure has been defined
*/
public boolean batchFinished() {
if (getInputFormat() == null) {
throw new IllegalStateException("No input instance format defined");
}
if (!m_FirstBatchDone) {
// Do the subsample, and clear the input instances.
createSubsample();
}
flushInput();
m_NewBatch = true;
m_FirstBatchDone = true;
return (numPendingOutput() != 0);
}
/**
* Creates a subsample of the current set of input instances. The output
* instances are pushed onto the output queue for collection.
*/
private void createSubsample() {
int classI = getInputFormat().classIndex();
// Sort according to class attribute.
getInputFormat().sort(classI);
// Determine where each class starts in the sorted dataset
int [] classIndices = getClassIndices();
// Get the existing class distribution
int [] counts = new int [getInputFormat().numClasses()];
double [] weights = new double [getInputFormat().numClasses()];
int min = -1;
for (int i = 0; i < getInputFormat().numInstances(); i++) {
Instance current = getInputFormat().instance(i);
if (current.classIsMissing() == false) {
counts[(int)current.classValue()]++;
weights[(int)current.classValue()]+= current.weight();
}
}
// Convert from total weight to average weight
for (int i = 0; i < counts.length; i++) {
if (counts[i] > 0) {
weights[i] = weights[i] / counts[i];
}
/*
System.err.println("Class:" + i + " " + getInputFormat().classAttribute().value(i)
+ " Count:" + counts[i]
+ " Total:" + weights[i] * counts[i]
+ " Avg:" + weights[i]);
*/
}
// find the class with the minimum number of instances
for (int i = 0; i < counts.length; i++) {
if ( (min < 0) && (counts[i] > 0) ) {
min = counts[i];
} else if ((counts[i] < min) && (counts[i] > 0)) {
min = counts[i];
}
}
if (min < 0) {
System.err.println("SpreadSubsample: *warning* none of the classes have any values in them.");
return;
}
// determine the new distribution
int [] new_counts = new int [getInputFormat().numClasses()];
for (int i = 0; i < counts.length; i++) {
new_counts[i] = (int)Math.abs(Math.min(counts[i],
min * m_DistributionSpread));
if (m_DistributionSpread == 0) {
new_counts[i] = counts[i];
}
if (m_MaxCount > 0) {
new_counts[i] = Math.min(new_counts[i], m_MaxCount);
}
}
// Sample without replacement
Random random = new Random(m_RandomSeed);
Hashtable t = new Hashtable();
for (int j = 0; j < new_counts.length; j++) {
double newWeight = 1.0;
if (m_AdjustWeights && (new_counts[j] > 0)) {
newWeight = weights[j] * counts[j] / new_counts[j];
/*
System.err.println("Class:" + j + " " + getInputFormat().classAttribute().value(j)
+ " Count:" + counts[j]
+ " Total:" + weights[j] * counts[j]
+ " Avg:" + weights[j]
+ " NewCount:" + new_counts[j]
+ " NewAvg:" + newWeight);
*/
}
for (int k = 0; k < new_counts[j]; k++) {
boolean ok = false;
do {
int index = classIndices[j] + (Math.abs(random.nextInt())
% (classIndices[j + 1] - classIndices[j])) ;
// Have we used this instance before?
if (t.get("" + index) == null) {
// if not, add it to the hashtable and use it
t.put("" + index, "");
ok = true;
if(index >= 0) {
Instance newInst = (Instance)getInputFormat().instance(index).copy();
if (m_AdjustWeights) {
newInst.setWeight(newWeight);
}
push(newInst);
}
}
} while (!ok);
}
}
}
/**
* Creates an index containing the position where each class starts in
* the getInputFormat(). m_InputFormat must be sorted on the class attribute.
*/
private int []getClassIndices() {
// Create an index of where each class value starts
int [] classIndices = new int [getInputFormat().numClasses() + 1];
int currentClass = 0;
classIndices[currentClass] = 0;
for (int i = 0; i < getInputFormat().numInstances(); i++) {
Instance current = getInputFormat().instance(i);
if (current.classIsMissing()) {
for (int j = currentClass + 1; j < classIndices.length; j++) {
classIndices[j] = i;
}
break;
} else if (current.classValue() != currentClass) {
for (int j = currentClass + 1; j <= current.classValue(); j++) {
classIndices[j] = i;
}
currentClass = (int) current.classValue();
}
}
if (currentClass <= getInputFormat().numClasses()) {
for (int j = currentClass + 1; j < classIndices.length; j++) {
classIndices[j] = getInputFormat().numInstances();
}
}
return classIndices;
}
/**
* Main method for testing this class.
*
* @param argv should contain arguments to the filter:
* use -h for help
*/
public static void main(String [] argv) {
try {
if (Utils.getFlag('b', argv)) {
Filter.batchFilterFile(new SpreadSubsample(), argv);
} else {
Filter.filterFile(new SpreadSubsample(), argv);
}
} catch (Exception ex) {
ex.printStackTrace();
System.out.println(ex.getMessage());
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -