📄 adtree.java
字号:
/**
* Parses a given list of options. Valid options are:<p>
*
* -B num <br>
* Set the number of boosting iterations
* (default 10) <p>
*
* -E num <br>
* Set the nodes to expand: -3(all), -2(weight), -1(z_pure), >=0 seed for random walk
* (default -3) <p>
*
* -D <br>
* Save the instance data with the model <p>
*
* @param options the list of options as an array of strings
* @exception Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception {
String bString = Utils.getOption('B', options);
if (bString.length() != 0) setNumOfBoostingIterations(Integer.parseInt(bString));
String eString = Utils.getOption('E', options);
if (eString.length() != 0) {
int value = Integer.parseInt(eString);
if (value >= 0) {
setSearchPath(new SelectedTag(SEARCHPATH_RANDOM, TAGS_SEARCHPATH));
setRandomSeed(value);
} else setSearchPath(new SelectedTag(value + 3, TAGS_SEARCHPATH));
}
setSaveInstanceData(Utils.getFlag('D', options));
Utils.checkForRemainingOptions(options);
}
/**
* Gets the current settings of ADTree.
*
* @return an array of strings suitable for passing to setOptions()
*/
public String[] getOptions() {
String[] options = new String[6];
int current = 0;
options[current++] = "-B"; options[current++] = "" + getNumOfBoostingIterations();
options[current++] = "-E"; options[current++] = "" +
(m_searchPath == SEARCHPATH_RANDOM ?
m_randomSeed : m_searchPath - 3);
if (getSaveInstanceData()) options[current++] = "-D";
while (current < options.length) options[current++] = "";
return options;
}
/**
* Calls measure function for tree size - the total number of nodes.
*
* @return the tree size
*/
public double measureTreeSize() {
return numOfAllNodes(m_root);
}
/**
* Calls measure function for leaf size - the number of prediction nodes.
*
* @return the leaf size
*/
public double measureNumLeaves() {
return numOfPredictionNodes(m_root);
}
/**
* Calls measure function for prediction leaf size - the number of
* prediction nodes without children.
*
* @return the leaf size
*/
public double measureNumPredictionLeaves() {
return numOfPredictionLeafNodes(m_root);
}
/**
* Returns the number of nodes expanded.
*
* @return the number of nodes expanded during search
*/
public double measureNodesExpanded() {
return m_nodesExpanded;
}
/**
* Returns the number of examples "counted".
*
* @return the number of nodes processed during search
*/
public double measureExamplesProcessed() {
return m_examplesCounted;
}
/**
* Returns an enumeration of the additional measure names.
*
* @return an enumeration of the measure names
*/
public Enumeration emerateMeasures() {
Vector newVector = new Vector(4);
newVector.addElement("measureTreeSize");
newVector.addElement("measureNumLeaves");
newVector.addElement("measureNumPredictionLeaves");
newVector.addElement("measureNodesExpanded");
newVector.addElement("measureExamplesProcessed");
return newVector.elements();
}
/**
* Returns the value of the named measure.
*
* @param measureName the name of the measure to query for its value
* @return the value of the named measure
* @exception IllegalArgumentException if the named measure is not supported
*/
public double getMeasure(String additionalMeasureName) {
if (additionalMeasureName.equalsIgnoreCase("measureTreeSize")) {
return measureTreeSize();
}
else if (additionalMeasureName.equalsIgnoreCase("measureNumLeaves")) {
return measureNumLeaves();
}
else if (additionalMeasureName.equalsIgnoreCase("measureNumPredictionLeaves")) {
return measureNumPredictionLeaves();
}
else if (additionalMeasureName.equalsIgnoreCase("measureNodesExpanded")) {
return measureNodesExpanded();
}
else if (additionalMeasureName.equalsIgnoreCase("measureExamplesProcessed")) {
return measureExamplesProcessed();
}
else {throw new IllegalArgumentException(additionalMeasureName
+ " not supported (ADTree)");
}
}
/**
* Returns the total number of nodes in a tree.
*
* @param root the root of the tree being measured
* @return tree size in number of splitter + prediction nodes
*/
protected int numOfAllNodes(PredictionNode root) {
int numSoFar = 0;
if (root != null) {
numSoFar++;
for (Enumeration e = root.children(); e.hasMoreElements(); ) {
numSoFar++;
Splitter split = (Splitter) e.nextElement();
for (int i=0; i<split.getNumOfBranches(); i++)
numSoFar += numOfAllNodes(split.getChildForBranch(i));
}
}
return numSoFar;
}
/**
* Returns the number of prediction nodes in a tree.
*
* @param root the root of the tree being measured
* @return tree size in number of prediction nodes
*/
protected int numOfPredictionNodes(PredictionNode root) {
int numSoFar = 0;
if (root != null) {
numSoFar++;
for (Enumeration e = root.children(); e.hasMoreElements(); ) {
Splitter split = (Splitter) e.nextElement();
for (int i=0; i<split.getNumOfBranches(); i++)
numSoFar += numOfPredictionNodes(split.getChildForBranch(i));
}
}
return numSoFar;
}
/**
* Returns the number of leaf nodes in a tree - prediction nodes without
* children.
*
* @param root the root of the tree being measured
* @return tree leaf size in number of prediction nodes
*/
protected int numOfPredictionLeafNodes(PredictionNode root) {
int numSoFar = 0;
if (root.getChildren().size() > 0) {
for (Enumeration e = root.children(); e.hasMoreElements(); ) {
Splitter split = (Splitter) e.nextElement();
for (int i=0; i<split.getNumOfBranches(); i++)
numSoFar += numOfPredictionLeafNodes(split.getChildForBranch(i));
}
} else numSoFar = 1;
return numSoFar;
}
/**
* Gets the next random value.
*
* @param max the maximum value (+1) to be returned
* @return the next random value (between 0 and max-1)
*/
protected int getRandom(int max) {
return m_random.nextInt(max);
}
/**
* Returns the next number in the order that splitter nodes have been added to
* the tree, and records that a new splitter has been added.
*
* @return the next number in the order
*/
public int nextSplitAddedOrder() {
return ++m_lastAddedSplitNum;
}
/**
* Builds a classifier for a set of instances.
*
* @param instances the instances to train the classifier with
* @exception Exception if something goes wrong
*/
public void buildClassifier(Instances instances) throws Exception {
// set up the tree
initClassifier(instances);
// build the tree
for (int T = 0; T < m_boostingIterations; T++) boost();
// clean up if desired
if (!m_saveInstanceData) done();
}
/**
* Frees memory that is no longer needed for a final model - will no longer be able
* to increment the classifier after calling this.
*
*/
public void done() {
m_trainInstances = new Instances(m_trainInstances, 0);
m_random = null;
m_numericAttIndices = null;
m_nominalAttIndices = null;
m_posTrainInstances = null;
m_negTrainInstances = null;
}
/**
* Creates a clone that is identical to the current tree, but is independent.
* Deep copies the essential elements such as the tree nodes, and the instances
* (because the weights change.) Reference copies several elements such as the
* potential splitter sets, assuming that such elements should never differ between
* clones.
*
* @return the clone
*/
public Object clone() {
ADTree clone = new ADTree();
if (m_root != null) { // check for initialization first
clone.m_root = (PredictionNode) m_root.clone(); // deep copy the tree
clone.m_trainInstances = new Instances(m_trainInstances); // copy training instances
// deep copy the random object
if (m_random != null) {
SerializedObject randomSerial = null;
try {
randomSerial = new SerializedObject(m_random);
} catch (Exception ignored) {} // we know that Random is serializable
clone.m_random = (Random) randomSerial.getObject();
}
clone.m_lastAddedSplitNum = m_lastAddedSplitNum;
clone.m_numericAttIndices = m_numericAttIndices;
clone.m_nominalAttIndices = m_nominalAttIndices;
clone.m_trainTotalWeight = m_trainTotalWeight;
// reconstruct pos/negTrainInstances references
if (m_posTrainInstances != null) {
clone.m_posTrainInstances =
new ReferenceInstances(m_trainInstances, m_posTrainInstances.numInstances());
clone.m_negTrainInstances =
new ReferenceInstances(m_trainInstances, m_negTrainInstances.numInstances());
for (Enumeration e = clone.m_trainInstances.emerateInstances();
e.hasMoreElements(); ) {
Instance inst = (Instance) e.nextElement();
try { // ignore classValue() exception
if ((int) inst.classValue() == 0)
clone.m_negTrainInstances.addReference(inst); // belongs in negative class
else
clone.m_posTrainInstances.addReference(inst); // belongs in positive class
} catch (Exception ignored) {}
}
}
}
clone.m_nodesExpanded = m_nodesExpanded;
clone.m_examplesCounted = m_examplesCounted;
clone.m_boostingIterations = m_boostingIterations;
clone.m_searchPath = m_searchPath;
clone.m_randomSeed = m_randomSeed;
return clone;
}
/**
* Merges two trees together. Modifies the tree being acted on, leaving tree passed
* as a parameter untouched (cloned). Does not check to see whether training instances
* are compatible - strange things could occur if they are not.
*
* @param mergeWith the tree to merge with
* @exception Exception if merge could not be performed
*/
public void merge(ADTree mergeWith) throws Exception {
if (m_root == null || mergeWith.m_root == null)
throw new Exception("Trying to merge an uninitialized tree");
m_root.merge(mergeWith.m_root, this);
}
/**
* Main method for testing this class.
*
* @param argv the options
*/
public static void main(String [] argv) {
try {
System.out.println(Evaluation.evaluateModel(new ADTree(),
argv));
} catch (Exception e) {
System.err.println(e.getMessage());
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -