📄 id3node.java
字号:
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;
public class ID3Node {
//private ArrayList column = new ArrayList();
private int column_num = 0;
//
ID3Node left_leaf;
ID3Node right_leaf;
ID3Node(int data){
//把属性值,即他的列号赋给接点
column_num = data;
left_leaf = null;
right_leaf = null;
}
public void insert_leaf(ArrayList list,int data,ID3Node root){
ArrayList a1 = new ArrayList();
ArrayList a2 = new ArrayList();
//把列data号内为0的放入a1,为1的放入a2中,此时将一个集合分成两个,并且把data列删除
for(int i = 0 ; i<list.size()-1;i++ ){
LinkedList cpu = new LinkedList((LinkedList)list.get(i));
if(((Integer)cpu.get(data)).intValue() == 0){
cpu.remove(data);
a1.add(cpu);
}
else {
cpu.remove(data);
a2.add(cpu);
}
}
LinkedList linkedlist = new LinkedList((LinkedList) list.get(list.size()-1));
//提示GC对象已经不用,可以回收,什么时候回收?????
list = null;
linkedlist.remove(data);
//每一个子集的后一列跟着标签,相当于记录每个属性的名称
a1.add(linkedlist);
a2.add(linkedlist);
ReadTrain rtt = new ReadTrain();
/*
************************************************************************************************
*以下是求data属性为0阈值和求最大信息增益的属性号,集合a1的信息增益
************************************************************************************************
*/
//求子集合的火商值
int p_arrtri = 0;
int n_arrtri = 0;
double entropy = 0.0;
//System.out.println("a1.szie:" + a1.size());
//System.out.println("a2.szie:" + a2.size());
if(a1.size() == 1)
entropy = 0.0;
else{
for(int i = 0;i<a1.size()-1;i++){
LinkedList ss = new LinkedList((Collection) a1.get(i));
//int a = Integer.parseInt((String) ss.get(ss.size()-1));
if(((Integer)ss.get(ss.size()-1)).intValue() == 0 )
n_arrtri++;
else
p_arrtri++;
}
entropy = rtt.comput_entropy(n_arrtri,p_arrtri,(double)a1.size());
}
//System.out.println("entropy:"+entropy);
if(entropy == 0.0){
//叶子节点不是记录列号,记录的是最后的目标
LinkedList temp = new LinkedList((LinkedList)a1.get(0));
int mubiaoshuxing = ((Integer)temp.get(temp.size()-1)).intValue();
System.out.println("!!!!!!!!!!!!!!!!!!!");
root.left_leaf = new ID3Node(mubiaoshuxing);
}
else{
//计算属性中最大的信息增益
LinkedList temp = new LinkedList((LinkedList) a1.get(0));
//每列属性01数
int zero = 0;
int one = 0;
//目标属性在0上01各自的分布
int zero_zero = 0;
int zero_one = 0;
//目标属性在1上01各自的分布
int one_zero = 0;
int one_one = 0;
//定义信息增益
double gain = 0.0;
double max = 0.0;
//信息增益中最大的属性列号
int max_arrtri = 0;
//信息增益中最大的属性描述,用数字表示
int leaf_node = 0;
for(int i = 0 ;i<temp.size();i++){//列
for(int j = 0;j<a1.size()-1;j++){//行
LinkedList inner = new LinkedList((LinkedList)a1.get(j));
if(((Integer)inner.get(i)).intValue() == 0){
zero++;
if(((Integer)inner.get(inner.size()-1)).intValue() == 0)
zero_zero++;
else
zero_one++;
}
else {
one++;
if(((Integer)inner.get(inner.size()-1)).intValue() == 0)
one_zero++;
else
one_one++;
}
}
//得到0和1的火商值
double zero_entropy = rtt.comput_entropy(zero_zero,zero_one,zero_zero+zero_one);
double one_entropy = rtt.comput_entropy(one_zero,one_one,one_zero+one_one);
//得到信息增益
gain = entropy - zero_entropy*zero/(zero+one) - one_entropy*one/(zero+one);
if( gain > max){
max = gain;
//得到阈制是0的最大信息增益的列号
max_arrtri = i;
}
}//每一列都做完了信息增益
LinkedList temp11 = new LinkedList((LinkedList) a1.get(a1.size()-1));
leaf_node = ((Integer)temp11.get(max_arrtri)).intValue();
//System.out.println("左边:" + temp11.get(max_arrtri));
ID3Node leaf = new ID3Node(leaf_node);
left_leaf = leaf;
leaf.insert_leaf(a1,max_arrtri,leaf);
}
/*
*******************************************************************************************
*求节点属性中为1的子集合的信息增益,集合a2的信息增益
*******************************************************************************************
*/
int p_arrtri2 = 0;
int n_arrtri2 = 0;
double entropy2 = 0.0;
//求右子集合的火商值
if(a2.size() == 1)
entropy2 = 0.0;
else{
for(int i = 0;i<a2.size()-1;i++){
LinkedList ss2 = new LinkedList((Collection) a2.get(i));
int t = ((Integer)ss2.get(ss2.size()-1)).intValue();
if( t == 0 )
n_arrtri2++;
else
p_arrtri2++;
}
entropy2 = rtt.comput_entropy(n_arrtri,p_arrtri,(double)a1.size());
}
if(entropy2 == 0.0){
LinkedList temp2 = new LinkedList((LinkedList)a2.get(0));
int mubiaoshuxing = ((Integer)temp2.get(temp2.size()-1)).intValue();
System.out.println("@@@@@@@@@@@@@@@@@@@@");
root.right_leaf = new ID3Node(mubiaoshuxing);
}
else{
//计算属性中最大的信息增益
LinkedList temp2 = new LinkedList((Collection) a1.get(0));
//每列属性01数
int zero2 = 0;
int one2 = 0;
//目标属性在0上各自的分布
int zero_zero2 = 0;
int zero_one2 = 0;
//目标属性在1上各自的分布
int one_zero2 = 0;
int one_one2 = 0;
//定义信息增益
double gain2 = 0.0;
double max2 = 0.0;
//信息增益中最大的属性值
int max_arrtri2 = 0;
int right_node = 0;
for(int i = 0 ;i<temp2.size();i++){
for(int j = 0;j<a2.size()-1;j++){
LinkedList inner2 = new LinkedList((Collection)a2.get(j));
if(((Integer)inner2.get(i)).intValue() == 0){
zero2++;
if(((Integer)inner2.get(inner2.size())).intValue() == 0)
zero_zero2++;
else
zero_one2++;
}
else {
one2++;
if(((Integer)inner2.get(inner2.size())).intValue() == 0)
one_zero2++;
else
one_one2++;
}
}
double zero_entropy2 = rtt.comput_entropy(zero_zero2,zero_one2,zero_zero2+zero_one2);
double one_entropy2 = rtt.comput_entropy(one_zero2,one_one2,one_zero2+one_one2);
gain2 = entropy - zero_entropy2*zero2/(zero2+one2) - one_entropy2*one2/(zero2+one2);
if( gain2 > max2){
max2 = gain2;
//得到阈制是0的最大信息增益的列号
max_arrtri2 = i;
}
}
LinkedList temp22 = new LinkedList((LinkedList) a2.get(a2.size()-1));
right_node = ((Integer)temp22.get(max_arrtri2)).intValue();
ID3Node leaf = new ID3Node(right_node);
right_leaf = leaf;
leaf.insert_leaf(a2,max_arrtri2,leaf);
}
}//结束insert_leaf
public void test(ID3Node root){
//if(root.left_leaf == null)
System.out.println(root.column_num);
if(root.left_leaf != null)
test(root.left_leaf);
if(root.right_leaf != null)
test(root.right_leaf);
}
public int return_result(ID3Node leaf,int i){
//System.out.println(leaf.column_num);
if(leaf.left_leaf == null){
return leaf.column_num;
}
else{
if(DecisionTree.test[i][leaf.column_num] == 0)
return(return_result(leaf.left_leaf,i));
else
return(return_result(leaf.right_leaf,i));
}
}
public void compute(ID3Node root){
int right_num = 0;
for(int i = 0;i<34;i++){
//System.out.println(return_leaf(root,i));
if(i<20){
if(return_result(root,i) == 0 )
right_num ++;
else System.out.println("测试错误的节点:" + (i+1));
}
else{
if(return_result(root,i) == 1)
right_num ++;
else System.out.println("测试错误的节点:" + (i+1));
}
}
double the_num_I_want = (double)right_num/34;
System.out.println("测试集的正确率:" + the_num_I_want);
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -