📄 gentreeutils.java
字号:
for(int i = 0; i < doubles.length; i++)
{
sum += doubles[i];
}
normalize(doubles, sum);
}
/**
* Normalizes the doubles in the array using the given value.
*
* @param doubles the array of double
* @param sum the value by which the doubles are to be normalized
* @exception IllegalArgumentException if sum is zero or NaN
*/
public static void normalize(double[] doubles, double sum)
{
if (Double.isNaN(sum))
{
throw new IllegalArgumentException("Can't normalize array. Sum is NaN.");
}
if(sum == 0)
{
// Maybe this should just be a return.
throw new IllegalArgumentException("Can't normalize array. Sum is zero.");
}
for(int i = 0; i < doubles.length; i++)
{
doubles[i] /= sum;
}
}
/**
* Computes impurity measure for an attribute.
*
* @param miningVectors whole dataset
* @param calculatedAttribute attribute for the impurity measure to calculate
* @param classificationAttribute classifying attribute
* @param cuts array of cutpoints for numeric attributes
* @param impurityMeasureType type of impurity measure
* @param discreteType type of discretization of numeric attributes (ignored)
* @param minNodeSize minimum cases in node (percentage)
* @return impurity measure for calulated attribute
*/
public static double computeImpMeas( MiningStoredData miningVectors,
MiningAttribute calculatedAttribute,
CategoricalAttribute classificationAttribute,
double[] cuts,
int impurityMeasureType,
int discreteType,
double minNodeSize)
throws MiningException
{
int numberOfCategories = 2; // binary split (numeric attributes)
if (calculatedAttribute instanceof CategoricalAttribute)
((CategoricalAttribute)calculatedAttribute).getCategoriesNumber();
// Copy calculated and classification row values:
int numberOfVectors = miningVectors.size();
double[] f = new double[numberOfVectors];
int[] C = new int[numberOfVectors];
for (int i = 0; i < numberOfVectors; i++) {
MiningVector mv = miningVectors.read(i);
f[i] = mv.getValue(calculatedAttribute);
C[i] = (int) mv.getValue(classificationAttribute);
};
// Create contingency table:
ContTable ct = new ContTable();
// Numeric attribute => use discretization:
if (calculatedAttribute instanceof NumericAttribute) {
Quicksort(f, C, 0, numberOfVectors-1);
ct = ImpurityMeasures.InformationGainDiscretization(f, C, cuts, minNodeSize);
}
// Categorical attribut => create contingency table manually:
else {
CategoricalAttribute calcAtt = (CategoricalAttribute) calculatedAttribute;
int ncont = calcAtt.getCategoriesNumber();
int nclass = classificationAttribute.getCategoriesNumber();
ContEntry[] ce = new ContEntry[ncont];
// Vector array 'cls' contains all classes for every category of 'clacAtt':
IntVector cls[] = new IntVector[ncont];
for (int i = 0; i < ncont; i++)
cls[i] = new IntVector();
for (int i = 0; i < numberOfVectors; i++) {
int ind = (int) f[i];
cls[ind].addElement(C[i]);
};
// Now sum up for all categories their different class numbers:
for (int i = 0; i < ncont; i++) {
int[] cs = new int[nclass];
for (int j = 0; j < cls[i].size(); j++) {
int ind = cls[i].IntegerAt(j);
cs[ind] = cs[ind] + 1;
};
ce[i] = new ContEntry(cs);
};
// Create contingency table:
ct = new ContTable(ce);
};
// Calculate impurity measure:
double impMeasure = 0.0;
if (ct != null) {
if (impurityMeasureType == GenTreeAlgorithm.FS_InfGain)
impMeasure = ImpurityMeasures.InformationGainFS(ct, false);
else if (impurityMeasureType == GenTreeAlgorithm.FS_InfGain_MC)
impMeasure = ImpurityMeasures.InformationGainFS(ct, true);
else if (impurityMeasureType == GenTreeAlgorithm.FS_GainRatio)
impMeasure = ImpurityMeasures.GainRatioFS(ct, false);
else if (impurityMeasureType == GenTreeAlgorithm.FS_GainRatio_MC)
impMeasure = ImpurityMeasures.GainRatioFS(ct, true);
else
throw new MiningException("unknown impurity measure");
};
return impMeasure;
}
/**
* Splits a dataset according to the values of a categorical attribute.
*
* @param miningVectors mining input stream to split
* @param splitingAttribute splittingAttribute
* @return splitted mining input streams
*/
public static MiningStoredData[] splitData( MiningStoredData miningVectors, CategoricalAttribute splitingAttribute )
{
int numberValues = splitingAttribute.getCategoriesNumber();
int numberVectors = miningVectors.size();
MiningStoredData[] splitData = new MiningStoredData[ numberValues ];
for(int j = 0; j < numberValues; j++)
{
splitData[j] = new MiningStoredData();
}
for (int i = 0; i < numberVectors; i++)
{
MiningVector vector = (MiningVector)miningVectors.get( i );
int attributeValue = (int)vector.getValue( splitingAttribute );
splitData[attributeValue].add( vector );
}
return splitData;
}
/**
* Splits a dataset according to the cut points of a numeric attribute.
*
* @param miningVectors mining input stream to split
* @param splitingAttribute splitting attribute
* @param cuts array of cut points
* @return splitted mining input streams
*/
public static MiningStoredData[] splitData( MiningStoredData miningVectors, NumericAttribute splitingAttribute, double[] cuts )
{
int numberValues = cuts.length + 1;
int numberVectors = miningVectors.size();
MiningStoredData[] splitData = new MiningStoredData[ numberValues ];
for(int j = 0; j < numberValues; j++)
{
splitData[j] = new MiningStoredData();
};
for (int i = 0; i < numberVectors; i++)
{
MiningVector vector = (MiningVector)miningVectors.get( i );
double attValue = vector.getValue( splitingAttribute );
int ind = 0;
for (int j = 0; j < cuts.length-1; j++)
if (attValue > cuts[j] && attValue < cuts[j+1])
ind = j;
if (attValue > cuts[ cuts.length-1 ] )
ind = cuts.length;
splitData[ind].add( vector );
};
return splitData;
}
/**
* Main for tests.
*
* @param arguments (ignored)
*/
public static void main(String arguments[]){
double F[] = {7,4,1,0};
int C[] = {7,4,1,0};
int r = Partition(F, C, 0, F.length-1);
System.out.println("r===============" + r);
GenTreeUtils dt=new GenTreeUtils();
double Feature[]={6,8,5,9,7,1,9};
int Class[]={1, -1,1,3,-1,-1,2};
for(int i=0;i<7;i++){
System.out.println(Feature[i]+" "+Class[i]);
}
dt.Quicksort(Feature, Class, 0, 6);
System.out.println("sorted: ");
for(int i=0;i<7;i++){
System.out.println(Feature[i]+" "+Class[i]);
};
System.out.println("log2: " + 1 / Math.log(2) );
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -