⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 mybase.cpp

📁 实现决策树分类训练试验。 源自c4.5
💻 CPP
📖 第 1 页 / 共 2 页
字号:

    Raw    = (Tree *) calloc(TRIALS, sizeof(Tree));
    Pruned = (Tree *) calloc(TRIALS, sizeof(Tree));

    /*  If necessary, set initial size of window to 20% (or twice
	the sqrt, if this is larger) of the number of data items,
	and the maximum number of items that can be added to the
	window at each iteration to 20% of the initial window size  */

    if ( ! WINDOW )
    {
		WINDOW = Max(2 * sqrt(MaxItem+1.0), (MaxItem+1) / 5);
    }

    if ( ! INCREMENT )
    {
		INCREMENT = Max(WINDOW / 5, 1);
    }

    FormTarget(WINDOW);

    /*  Form set of trees by iteration and prune  */

    ForEach(t, 0, TRIALS-1 )
    {
        FormInitialWindow();
		fprintf(fLog,"\n--------\nTrial %d\n--------\n\n", t);
		Raw[t] = Iterate(WINDOW, INCREMENT);
		fprintf(fLog,"\n");
		PrintTree(Raw[t]);
		
		SaveTree(Raw[t], ".unpruned");
		Pruned[t] = CopyTree(Raw[t]);
		//**
		if ( Prune(Pruned[t]) )
		{
			fprintf(fLog,"\nSimplified ");
			PrintTree(Pruned[t]);
		}/***/
		if ( Pruned[t]->Errors < Pruned[Best]->Errors )
		{
			Best = t;
		}
    }

    fprintf(fLog,"\n--------\n");
    return Best;
}



/*************************************************************************/
/*									 */
/*  The windowing approach seems to work best when the class		 */
/*  distribution of the initial window is as close to uniform as	 */
/*  possible.  FormTarget generates this initial target distribution,	 */
/*  setting up a TargetClassFreq value for each class.			 */
/*									 */
/*************************************************************************/

void FormTarget(ItemNo Size)
{
    ItemNo i, *ClassFreq;
    ClassNo c, Smallest, ClassesLeft=0;

    ClassFreq = (ItemNo *) calloc(MaxClass+1, sizeof(ItemNo));

    /*  Generate the class frequency distribution  */

    ForEach(i, 0, MaxItem)
    {
		ClassFreq[ Class(Item[i]) ]++;
    }

    /*  Calculate the no. of classes of which there are items  */

    ForEach(c, 0, MaxClass)
    {
		if ( ClassFreq[c] )
		{
			ClassesLeft++;
		}
		else
		{
			TargetClassFreq[c] = 0;
		}
    }

    while ( ClassesLeft )
    {
		/*  Find least common class of which there are some items  */
		
		Smallest = -1;
		ForEach(c, 0, MaxClass)
		{
			if ( ClassFreq[c] && ( Smallest < 0 || ClassFreq[c] < ClassFreq[Smallest] ) )
			{
				Smallest = c;
			}
		}
		/*  Allocate the no. of items of this class to use in the window  */
		TargetClassFreq[Smallest] = Min(ClassFreq[Smallest], Round(Size/ClassesLeft));
		ClassFreq[Smallest] = 0;
		Size -= TargetClassFreq[Smallest];
		ClassesLeft--;
    }
    delete ClassFreq;
}

/*************************************************************************/
/*									 */
/*  Form initial window, attempting to obtain the target class profile	 */
/*  in TargetClassFreq.  This is done by placing the targeted number     */
/*  of items of each class at the beginning of the set of data items.	 */
/*									 */
/*************************************************************************/
void  FormInitialWindow()
{
    ItemNo i, Start=0, More;
    ClassNo c;
 
    Shuffle();

    ForEach(c, 0, MaxClass)
    {
		More = TargetClassFreq[c];
		for ( i = Start ; More ; i++ )
		{
			if ( Class(Item[i]) == c )
			{
				Swap(Start, i);
				Start++;
				More--;
			}
		}
    }
}

/*************************************************************************/
/*									 */
/*		Shuffle the data items randomly				 */
/*									 */
/*************************************************************************/
void Shuffle()
{
    ItemNo This, Alt, Left;
    Description Hold;

    This = 0;
    for( Left = MaxItem+1 ; Left ; )
    {
        Alt = This + (Left--) * Random;
        Hold = Item[This];
        Item[This++] = Item[Alt];
        Item[Alt] = Hold;
    }
}



/*************************************************************************/
/*									 */
/*  Grow a tree iteratively with initial window size Window and		 */
/*  initial window increment IncExceptions.				 */
/*									 */
/*  Construct a classifier tree using the data items in the		 */
/*  window, then test for the successful classification of other	 */
/*  data items by this tree.  If there are misclassified items,		 */
/*  put them immediately after the items in the window, increase	 */
/*  the size of the window and build another classifier tree, and	 */
/*  so on until we have a tree which successfully classifies all	 */
/*  of the test items or no improvement is apparent.			 */
/*									 */
/*  On completion, return the tree which produced the least errors.	 */
/*									 */
/*************************************************************************/


Tree Iterate(ItemNo Window, ItemNo IncExceptions)
{
    Tree Classifier, BestClassifier=Nil;
    ItemNo i, Errors, TotalErrors, BestTotalErrors=MaxItem+1,
	   Exceptions, Additions;
    ClassNo Assigned;
    short Cycle=0;

    fprintf(fLog,"Cycle   Tree    -----Cases----");
    fprintf(fLog,"    -----------------Errors-----------------\n");
    fprintf(fLog,"        size    window   other");
    fprintf(fLog,"    window  rate   other  rate   total  rate\n");
    fprintf(fLog,"-----   ----    ------  ------");
    fprintf(fLog,"    ------  ----  ------  ----  ------  ----\n");

    do
    {
		/*  Build a classifier tree with the first Window items  */
		InitialiseWeights();
		AllKnown = true;
		Classifier = FormTree(0, Window-1);
		/*  Error analysis  */
		Errors = Round(Classifier->Errors);
		/*  Move all items that are incorrectly classified by the
			classifier tree to immediately after the items in the
			current window.  */
		Exceptions = Window;
		ForEach(i, Window, MaxItem)
		{
			Assigned = Category(Item[i], Classifier);
			if ( Assigned != Class(Item[i]) )
			{
				Swap(Exceptions, i);
				Exceptions++;
			}
		}
		
		Exceptions -= Window;
		TotalErrors = Errors + Exceptions;

		/*  Print error analysis  */

		fprintf(fLog,"%3i  %7i  %8i  %6i  %8i%5.1f%%  %6i%5.1f%%  %6i%5.1f%%\n",
			   ++Cycle, TreeSize(Classifier), Window, MaxItem-Window+1,
			   Errors, 100*(float)Errors/Window,
			   Exceptions, 100*Exceptions/(MaxItem-Window+1.001),
			   TotalErrors, 100*TotalErrors/(MaxItem+1.0));

		/*  Keep track of the most successful classifier tree so far  */

		if ( ! BestClassifier || TotalErrors < BestTotalErrors )
		{
			if ( BestClassifier ) ReleaseTree(BestClassifier);
			BestClassifier = Classifier;
			BestTotalErrors = TotalErrors;
		}
		else
		{
			ReleaseTree(Classifier);
		}

		/*  Increment window size  */

		Additions = Min(Exceptions, IncExceptions);
		Window = Min(Window + Max(Additions, Exceptions / 2), MaxItem + 1);
    } while ( Exceptions );

	return BestClassifier;
}



/*************************************************************************/
/*									 */
/*	Print report of errors for each of the trials			 */
/*									 */
/*************************************************************************/
void Evaluate(Boolean CMInfo, short Saved)
{
    ClassNo RealClass, PrunedClass;
    short t;
    ItemNo *ConfusionMat, i, RawErrors, PrunedErrors;

    if ( CMInfo )
    {
		ConfusionMat = (ItemNo *) calloc((MaxClass+1)*(MaxClass+1), sizeof(ItemNo));
    }

    fprintf(fLog,"\n");

    if ( TRIALS > 1 )
    {
		fprintf(fLog,"Trial\t Before Pruning           After Pruning\n");
		fprintf(fLog,"-----\t----------------   ---------------------------\n");
    }
    else
    {
		fprintf(fLog,"\t Before Pruning           After Pruning\n");
		fprintf(fLog,"\t----------------   ---------------------------\n");
    }
    fprintf(fLog,"\tSize      Errors   Size      Errors   Estimate\n\n");
    ForEach(t, 0, TRIALS-1)
    {
		RawErrors = PrunedErrors = 0;
		ForEach(i, 0, MaxItem)
		{
			RealClass = Class(Item[i]);

			if ( Category(Item[i], Raw[t]) != RealClass ) RawErrors++;

			PrunedClass = Category(Item[i], Pruned[t]);

			if ( PrunedClass != RealClass ) PrunedErrors++;

			if(fRST!=NULL)
				fprintf(fRST,"%i %i\n",RealClass,PrunedClass);
			if ( CMInfo && t == Saved )
			{
				ConfusionMat[RealClass*(MaxClass+1)+PrunedClass]++;
			}
		}
    
	if ( TRIALS > 1 )
	{
	    fprintf(fLog,"%4d", t);
	}

	fprintf(fLog,"\t%4d  %3d(%4.1f%%)   %4d  %3d(%4.1f%%)    (%4.1f%%)%s\n",
	       TreeSize(Raw[t]), RawErrors, 100.0*RawErrors / (MaxItem+1.0),
	       TreeSize(Pruned[t]), PrunedErrors, 100.0*PrunedErrors / (MaxItem+1.0),
	       100 * Pruned[t]->Errors / Pruned[t]->Items,
	       ( t == Saved ? "   <<" : "" ));
    }

    if ( CMInfo )
    {
		PrintConfusionMatrix(ConfusionMat);
		delete ConfusionMat;
    }
}

void PrintConfusionMatrix(ItemNo *ConfusionMat)
{
    short Row, Col;
	float CorrectSum[20];

    if ( MaxClass > 20 ) return;  /* Don't print nonsensical matrices */

    /*  Print the heading, then each row  */

    fprintf(fLog,"\n\n\t");
    ForEach(Col, 0, MaxClass)
    {
		fprintf(fLog,"  (%c)", 'a' + Col);
    }

    fprintf(fLog,"\t<-classified as\n\t");
    ForEach(Col, 0, MaxClass)
    {
		fprintf(fLog," ----");
    }
    fprintf(fLog,"\n");

    ForEach(Row, 0, MaxClass)
    {
		CorrectSum[Row]=0;
		fprintf(fLog,"\t");
		ForEach(Col, 0, MaxClass)
		{
			if ( ConfusionMat[Row*(MaxClass+1) + Col] )
			{
				fprintf(fLog,"%5d", ConfusionMat[Row*(MaxClass+1) + Col]);
			}
			else
			{
				fprintf(fLog,"     ");
			}
			CorrectSum[Row] += ConfusionMat[Row*(MaxClass+1) + Col];			
		}
		CorrectSum[Row]= ConfusionMat[Row*(MaxClass+1) + Row]*100/CorrectSum[Row];
		fprintf(fLog,"\t%4.1f\%",CorrectSum[Row]);
		fprintf(fLog,"\t(%c): class %s\n", 'a' + Row, ClassName[Row]);
    }
    fprintf(fLog,"\n");
}



/*************************************************************************/
/*									 */
/*  Compute the additional errors if the error rate increases to the	 */
/*  upper limit of the confidence level.  The coefficient is the	 */
/*  square of the number of standard deviations corresponding to the	 */
/*  selected confidence level.  (Taken from Documenta Geigy Scientific	 */
/*  Tables (Sixth Edition), p185 (with modifications).)			 */
/*									 */
/*************************************************************************/

/****
   Move to top of the file!
float Val[] = {  0,  0.001, 0.005, 0.01, 0.05, 0.10, 0.20, 0.40, 1.00},
      Dev[] = {4.0,  3.09,  2.58,  2.33, 1.65, 1.28, 0.84, 0.25, 0.00};
/*********/

float AddErrs(ItemCount N, ItemCount e)
{
    static float Coeff=0;
    float Val0, Pr;

    if ( ! Coeff )
    {
		/*  Compute and retain the coefficient value, interpolating from
			the values in Val and Dev  */

		int i;

		i = 0;
		while ( CF > Val[i] ) i++;

		Coeff = Dev[i-1] +
			  (Dev[i] - Dev[i-1]) * (CF - Val[i-1]) /(Val[i] - Val[i-1]);
		Coeff = Coeff * Coeff;
    }

    if ( e < 1E-6 )
    {
		return N * (1 - exp(log(CF) / N));
    }
    else
    if ( e < 0.9999 )
    {
		Val0 = N * (1 - exp(log(CF) / N));
		return Val0 + e * (AddErrs(N, 1.0) - Val0);
    }
    else if ( e + 0.5 >= N )
    {
		return 0.67 * (N - e);
    }
    else
    {
		Pr = (e + 0.5 + Coeff/2
				+ sqrt(Coeff * ((e + 0.5) * (1 - (e + 0.5)/N) + Coeff/4)) )
				 / (N + Coeff);
		return (N * Pr - e);
    }
}


/*************************************************************************/
/*									 */
/*	Sort items from Fp to Lp on attribute a				 */
/*									 */
/*************************************************************************/
void  Quicksort(ItemNo Fp, ItemNo Lp, Attribute Att,int Exchange)
{
    register ItemNo Lower, Middle;
    register float Thresh;
    register ItemNo i;

    if ( Fp < Lp )
    {
		Thresh = CVal(Item[Lp], Att);

		/*  Isolate all items with values <= threshold  */

		Middle = Fp;

		for ( i = Fp ; i < Lp ; i++ )
		{ 
			if ( CVal(Item[i], Att) <= Thresh )
			{ 
				if ( i != Middle )
				{
					if(Exchange==0)
						Swap(Middle, i);
					else
						SwapUnweighted(Middle, i);
				}
				Middle++; 
			} 
		} 

		/*  Extract all values equal to the threshold  */

		Lower = Middle - 1;

		for ( i = Lower ; i >= Fp ; i-- )
		{
			if ( CVal(Item[i], Att) == Thresh )
			{ 
				if ( i != Lower ) 
				{					
					if(Exchange==0)
						Swap(Lower, i);
					else
						SwapUnweighted(Lower, i);
				}
				Lower--;
			} 
		} 

		/*  Sort the lower values  */

		Quicksort(Fp, Lower, Att, Exchange);

		/*  Position the middle element  */

		if(Exchange==0)
			Swap(Middle, Lp);
		else
			SwapUnweighted(Middle, Lp);

		/*  Sort the higher values  */

		Quicksort(Middle+1, Lp, Att, Exchange);
    }
}


float	*LogItemNo;
double	*LogFact;
/*************************************************************************/
/*									 */
/*  Set up the array LogItemNo to contain the logs of integers and	 */
/*  the array LogFact to contain logs of factorials (all to base 2)	 */
/*									 */
/*************************************************************************/
void GenerateLogs()
{
    ItemNo i;

    LogItemNo = (float *) malloc((MaxItem+100) * sizeof(float));
    LogFact = (double *) malloc((MaxItem+100) * sizeof(double));

    LogItemNo[0] = -1E38;
    LogItemNo[1] = 0;
    LogFact[0] = LogFact[1] = 0;

    ForEach(i, 2, MaxItem+99)
    {
		LogItemNo[i] = log((float) i) / Log2;
		LogFact[i] = LogFact[i-1] + LogItemNo[i];
    }

	return;
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -