⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 createid3.java

📁 商业只能中数据挖掘的决策树算法 用于数据分类
💻 JAVA
字号:
package id3;
import java.util.ArrayList;
import java.util.List;
import java.util.TreeSet;

public class CreateId3 {
	TreeNode Root;
	private boolean[] Visable;
	private static final int NO_FOUND = -1;
	private Object[] TrainingArray;
	private int NodeIndex;
	
	public String getNodeName( int index, String[] ColName ) {
		String[] strs = ColName;
		
		for( int i = 0; i < strs.length; i++ ) {
			if( i == index )
				return strs[i];
		}
		
		return null;
	}
	
	public String getLeafName( Object[] a ) {
		if( a != null && a.length > 0 ) {
			String[] strs = ( String[] ) a[0];
			return strs[NodeIndex];
		}
		
		return null;
	}
	
	@SuppressWarnings("unchecked")public String[] getArrti( int index ) {
		TreeSet<String> set = new TreeSet<String>( new SequenceComparator() );
		
		for( int i = 0; i < TrainingArray.length; i++ ) {
			String[] strs = ( String[] )TrainingArray[i];
			set.add( strs[index] );
		}
		
		String[] result = new String[ set.size() ];
		return set.toArray(result);
	}
	
	public int getNodeIndex( String name, String[] ColName ) {
		String[] strs = new String[] {};
		strs = ColName;
		
		for( int i = 0; i < strs.length; i++ ) {
			if( name.equals(strs[i]) )
				return i;
		}
		
		return NO_FOUND;
	}

	public void init( Object[] a, int index ) {
		this.NodeIndex = index;
		Visable = new boolean[ ( ( String[] ) a[0] ).length ];
		for( int i = 0; i < Visable.length; i++ ) {
			if( i == index ) 
				Visable[i] = true;
			else 
				Visable[i] = false;
		}
	}
	
	public double Ph( Object[] a, int index, String arrti, int allTotal) {   
        String[] playBalls = getArrti( this.NodeIndex );   
        int[] counts = new int[playBalls.length];  
        
        for ( int i = 0; i < counts.length; i++ )    
            counts[i] = 0;   
  
        for ( int i = 0; i < a.length; i++ ) {   
            String[] strs = ( String[] )a[i]; 
            
            if ( strs[index].equals( arrti ) ) {   
                for ( int k = 0; k < playBalls.length; k++ ) {   
                    if ( strs[this.NodeIndex].equals( playBalls[k] ) )  
                        counts[k]++; 
                }   
            }   
        }   
  
        int total = 0;   
        double h = 0;  
        
        for ( int i = 0; i < counts.length; i++ )    
            total += counts[i];   
          
        for ( int i = 0; i < counts.length; i++ )   
            h += PartOfH.PartResult( counts[i], total );   
           
        return PartOfH.P( total, allTotal ) * h;   
    }   

	public double G( Object[] a, int index ) {   
        String[] playBalls = getArrti( this.NodeIndex );   
        int[] counts = new int[playBalls.length]; 
        
        for ( int i = 0; i < counts.length; i++ )    
            counts[i] = 0;  
        
        for ( int i = 0; i < a.length; i++ ) {   
            String[] strs = ( String[] )a[i];   
            for ( int j = 0; j < playBalls.length; j++ ) {   
                if ( strs[this.NodeIndex].equals( playBalls[j] ) )   
                    counts[j]++;   
            }   
        }   
        
        double H = 0;   
        
        for ( int i = 0; i < counts.length; i++ )    
            H += PartOfH.PartResult( counts[i], a.length ); 
        
        String[] arrti = getArrti( index );   
        
        double E = 0;   
        
        for ( int i = 0; i < arrti.length; i++ )    
            E += Ph( a, index, arrti[i], a.length );  
        
        return H - E;   
    }   
	
	public Object[] Gmax( Object[] a ) {
		Object[] result = new Object[3];
		double gain = 0;
		int index = -1;
		
		for( int i = 0; i < Visable.length; i++ ) {
			if( !Visable[i] ) {
				double x = G( a, i );
				if( gain < x ) {
					gain = x;
					index = i;
				}
			}
		}
		
		result[0] = gain;
		result[1] = index;
		
		if( index != -1 ) 
			Visable[index] = true;
		
		return result;
	}
	
	public Object[] chooseArray( Object[] array, String arrti, int index ) {
		List< String[] > list = new ArrayList< String[] >();
		
		for( int i = 0; i < array.length; i++ ) {
			String[] strs = ( String[] )array[i];
			if( strs[index].equals(arrti) ) {
				list.add(strs);
			}
		}
		
		return list.toArray();
	}

	public void insertToTree( Object[] a, TreeNode Pa, String[] ColName ) {
		String[] arrti = Pa.Arrti;
		
		for( int i = 0; i < arrti.length; i++ ) {
			Object[] Chosen = chooseArray( a, arrti[i], getNodeIndex( Pa.NodeName, ColName ) );
			Object[] info = Gmax( Chosen );
			double gain = ( ( Double )info[0] ).doubleValue();
			
			if ( gain != 0 ) {   
                int index = ( ( Integer )info[1] ).intValue();   
                TreeNode currentNode = new TreeNode();   
                currentNode.Pa = Pa;   
                currentNode.PaArrti = arrti[i];   
                currentNode.Arrti = getArrti( index );   
                currentNode.NodeName = getNodeName( index, ColName );   
                currentNode.ChildNodes = new TreeNode[currentNode.Arrti.length];   
                Pa.ChildNodes[i] = currentNode;   
                insertToTree( Chosen, currentNode, ColName );   
            } 
			else {   
                TreeNode leafNode = new TreeNode();   
                leafNode.Pa = Pa;   
                leafNode.PaArrti = arrti[i];   
                leafNode.Arrti = new String[0];   
                leafNode.NodeName = getLeafName( Chosen );   
                leafNode.ChildNodes = new TreeNode[0];   
                Pa.ChildNodes[i] = leafNode;   
            }   
        }   
    }  
	
	public void createTree( Object[] a, String[] ColName ) {
		Object[] maxgain = Gmax( a );
		
		if( Root == null ) {			
			Root = new TreeNode();
			Root.Pa = null;
			Root.PaArrti = null;
			Root.Arrti = getArrti( ( ( Integer )maxgain[1] ).intValue() );
			Root.NodeName = getNodeName( ( ( Integer )maxgain[1] ).intValue(), ColName );
			Root.ChildNodes = new TreeNode[Root.Arrti.length];
			insertToTree( a, Root, ColName );
		}
	}
	
	public String printTree( TreeNode N, String res ) {    
		res += N.NodeName + "\n";
		TreeNode[] childs = N.ChildNodes;
		
		for ( int i = 0; i < childs.length; i++ ) {   
			if ( childs[i] != null ) {  				
				res += childs[i].PaArrti + "\n";
				res = printTree( childs[i], res );   
			}   
		}  
    
    return res;
	}   

	public String create( Object[] a, int index, String res, String[] ColName ) {   
		this.TrainingArray = a; 
    
		init( a, index );   
		createTree( a, ColName );   
		return printTree( Root, res );   
	}  

}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -