📄 decisiontreeminingmodel.java
字号:
// Categorical attribute, DISCOVERER-specific:
else {
preTypes[i] = 1;
CategoricalAttribute att = (CategoricalAttribute) metaData.getMiningAttribute( mAttName );
if (att != null) {
CategoryHierarchy cah = att.getTaxonomy();
if (cah != null) {
CategoricalAttribute cmAtt = (CategoricalAttribute) mAtt;
// Look for '_fehlend_' category of missing values:
Category missCat = new Category("_fehlend_");
double key = att.getKey(missCat);
if ( !Category.isMissingValue(key) ) {
missCat = getRootCategory(cah, missCat);
System.out.println("missCat=" + missCat);
value = cmAtt.getKey(missCat);
if ( Category.isMissingValue(value) )
value = cmAtt.addCategory(missCat);
}
// Use first root category:
else {
Vector par = cah.getAllRoots();
if (par.size() > 0) {
Category cat = (Category) par.elementAt(0);
value = cmAtt.getKey(cat);
if ( Category.isMissingValue(value) )
value = cmAtt.addCategory(cat);
};
};
// Get all categories for category replacement:
ArrayList categs = att.getValues();
if ( categs.size() > 0 ) {
repCategs[i] = new Hashtable();
for (int j = 0; j < categs.size(); j++) {
Category cat = (Category) categs.get(j);
Category root = getRootCategory(cah, cat);
if (root == null)
continue;
if ( Category.isMissingValue( cmAtt.getKey(root)) )
cmAtt.addCategory(root);
repCategs[i].put(cat, root);
};
};
};
};
};
// Set replacement value:
repValues[i] = value;
};
// For caching:
preMetaData = inputMetaData;
}
// Replace missing values and outliers:
for (int i = 0; i < nAtt; i++) {
// Target attribute:
if ( preTypes[i] == 2 ) continue;
// Get value:
double value = miningVector.getValue(i);
// Replace missing values:
if ( Category.isMissingValue(value) ) {
value = repValues[i];
}
// Numeric attribute without missing value:
else if ( preTypes[i] == 0 ) {
// Check for explicit missing values:
if ( missValuesNum[i] != null) {
for (int j = 0; j < missValuesNum[i].length; j++)
if ( Math.abs(value - missValuesNum[i][j]) < 1.0e-5 ) {
value = repValues[i];
System.out.println("-------------------------->replace missing");
}
}
else
continue;
}
// Categorical without missing values, check for outliers and category replacement:
else {
CategoricalAttribute cmAtt = (CategoricalAttribute) inputMetaData.getMiningAttribute(i);
CategoricalAttribute att = (CategoricalAttribute) metaData.getMiningAttribute( cmAtt.getName() );
if (att == null) continue;
Category cat = cmAtt.getCategory(value);
if (cat == null || Category.isMissingValue( att.getKey(cat) ) ) {
value = repValues[i];
}
else {
if (repCategs[i] == null)
continue;
// Check for category replacement:
Category repCat = (Category) repCategs[i].get(cat);
if (repCat == null)
continue;
value = cmAtt.getKey(repCat);
if ( Category.isMissingValue(value) )
continue;
}
}
// Set replacement value:
miningVector.setValue(i, value);
}
return miningVector;
}
/**
* Determines root category of a given category. Same category, if
* this category is already a root one.
*
* @param cah txanomy
* @param cat category
* @return root category of given category
*/
private Category getRootCategory(CategoryHierarchy cah, Category cat) {
Vector parents = cah.getAllRootParents(cat);
if (parents.size() > 0)
return (Category) parents.elementAt(0);
return cat;
}
// -----------------------------------------------------------------------------
// ... DISCOVERER-specific preprocessing.
// -----------------------------------------------------------------------------
/**
* Update cumulatedRecordCountThis and cumulatedRecordCountOther in
* each terminal node.
*
* @return true if succesful, false if unequal distribution lengths
*/
public boolean updateCumulatedCounts()
{
// Get root node and apply decision tree:
DecisionTreeNode root = (DecisionTreeNode) classifier;
// Get list of terminal nodes:
Vector nlist = new Vector();
traverse(root, nlist);
int nlistlen = nlist.size();
// Get distribution length:
double[] dist = ((DecisionTreeNode) nlist.elementAt(0)).getDistribution();
int distributionlen = 0;
if (dist != null) distributionlen = dist.length;
double[] trct = new double[distributionlen];
double[] trco = new double[distributionlen];
for (int j = 0; j < distributionlen; j++)
{
trct[j] = 0;
trco[j] = 0;
};
// Add distribution numbers:
for (int i = 0; i < nlistlen; i++)
{
dist = ((DecisionTreeNode) nlist.elementAt(i)).getDistribution();
int distlen = 0;
if (dist != null) distlen = dist.length;
if (distlen != distributionlen)
return false;
for (int j = 0; j < distributionlen; j++) {
trct[j] += dist[j];
};
double[] crct = new double[distributionlen];
((DecisionTreeNode)nlist.elementAt(i)).setCumulatedRecordCountThis(crct);
double[] crco = new double[distributionlen];
((DecisionTreeNode)nlist.elementAt(i)).setCumulatedRecordCountOther(crco);
};
// Calculate distribution counts:
double rsum = 0;
for (int j = 0; j < distributionlen; j++) {
rsum += trct[j];
}
for (int j = 0; j < distributionlen; j++) {
trco[j] = rsum - trct[j];
};
for (int i = 0; i < nlistlen; i++)
{
// Vectors of totals are identical in each node:
((DecisionTreeNode)nlist.elementAt(i)).setTotalRecordCountThis(trct);
((DecisionTreeNode)nlist.elementAt(i)).setTotalRecordCountOther(trco);
}
for (int i = 0; i < nlistlen; i++)
{
for (int j = 0; j < distributionlen; j++)
{
((DecisionTreeNode)nlist.elementAt(i)).setCumulatedRecordCounts(j, 0, 0);
}
};
// Sort:
slistent[] slist = new slistent[nlistlen];
for (int j = 0; j < distributionlen; j++)
{
// Sort wrt. target "j":
for (int i = 0; i < nlistlen; i++)
{
dist = ((DecisionTreeNode) nlist.elementAt(i)).getDistribution();
slist[i] = new slistent();
slist[i].i = i;
slist[i].sum_this = (int) dist[j];
slist[i].sum_all = 0;
for (int k = 0; k < distributionlen; k++)
{
slist[i].sum_all += dist[k];
}
// initialize with individual distribution
((DecisionTreeNode) nlist.elementAt(i)).setCumulatedRecordCounts(j, slist[i].sum_this, slist[i].sum_all - slist[i].sum_this);
};
java.util.Arrays.sort(slist);
int crctv = 0;
int crcov = 0;
for (int i = 0; i < nlistlen; i++)
{
// Cumulate distribution:
DecisionTreeNode N = (DecisionTreeNode) nlist.elementAt( slist[i].i );
crctv += N.getCumulatedRecordCountThis(j);
crcov += N.getCumulatedRecordCountOther(j);
N.setCumulatedRecordCounts(j, crctv, crcov);
};
};
return true;
}
/**
* Creates vector for a given node including all of its leafs.
*
* @param dtn initial node
* @param nlist list of all leafs
*/
private static void traverse(DecisionTreeNode dtn, Vector nlist)
{
int cn = dtn.getChildCount();
if (cn == 0)
nlist.add(dtn);
for (int i = 0; i < cn; i++)
traverse( (DecisionTreeNode)dtn.getChildAt(i), nlist);
}
/**
* List for sorting with given order function.
*/
private class slistent implements java.lang.Comparable
{
int i;
int sum_this;
int sum_all;
public int compareTo(Object o) {
slistent y = (slistent) o;
// Compare (sum_this / sum_all) < (y.sum_this / y.sum_all):
int diff = sum_this * y.sum_all - y.sum_this * sum_all;
if (diff < 0)
return +1;
if (diff == 0) // secoundary criterion
{
if (sum_this < y.sum_this)
return +1;
else
return -1;
};
return -1;
}
};
// -----------------------------------------------------------------------
// Methods of PMML handling
// -----------------------------------------------------------------------
/**
* Creates PMML document of tree model.
* PMMLs TreeModel is used.
*
* @param writer writer for PMML model
* @exception MiningException cannot write PMML model
* @see com.prudsys.pdm.Adapters.PmmlVersion20.TreeModel
*/
public void writePmml( Writer writer ) throws MiningException
{
PMML pmml = new PMML();
pmml.setVersion( "2.0" );
// Set PMML header:
pmml.setHeader( (Header)PmmlUtils.getHeader(applicationName, applicationVersion) );
// Add data and transformation dictionary:
MiningDataSpecification metaData = miningSettings.getDataSpecification();
if ( metaData.isTransformed() )
{
pmml.setDataDictionary( (DataDictionary)metaData.getPretransformedMetaData().createPmmlObject() );
pmml.setTransformationDictionary( (com.prudsys.pdm.Adapters.PmmlVersion20.TransformationDictionary)metaData.getMiningTransformationActivity().createPmmlObject() );
}
else
{
pmml.setDataDictionary( (DataDictionary)metaData.createPmmlObject() );
};
// Add support vector dictionary if NDT with SVM nodes:
if (globalSupportVectors != null) {
VectorDictionary vdic = new VectorDictionary();
int nSupp = globalSupportVectors.size();
vdic.setNumberOfVectors(String.valueOf(nSupp));
Enumeration gsv = globalSupportVectors.keys();
while ( gsv.hasMoreElements() ) {
String suppVecId = (String) gsv.nextElement();
MiningVector suppVec = (MiningVector) globalSupportVectors.get(suppVecId);
VectorInstance xmv = new VectorInstance();
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -