⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 regrf.c

📁 是基于linux系统的C++程序
💻 C
字号:
/*******************************************************************   Copyright (C) 2001-7 Leo Breiman, Adele Cutler and Merck & Co., Inc.   This program is free software; you can redistribute it and/or   modify it under the terms of the GNU General Public License   as published by the Free Software Foundation; either version 2   of the License, or (at your option) any later version.   This program is distributed in the hope that it will be useful,   but WITHOUT ANY WARRANTY; without even the implied warranty of   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the   GNU General Public License for more details.*******************************************************************/#include <R.h>#include "rf.h"void simpleLinReg(int nsample, double *x, double *y, double *coef,		  double *mse, int *hasPred);void regRF(double *x, double *y, int *xdim, int *sampsize,	   int *nthsize, int *nrnodes, int *nTree, int *mtry, int *imp,	   int *cat, int *maxcat, int *jprint, int *doProx, int *oobprox,           int *biasCorr, double *yptr, double *errimp, double *impmat,           double *impSD, double *prox, int *treeSize, int *nodestatus,           int *lDaughter, int *rDaughter, double *avnode, int *mbest,           double *upper, double *mse, int *keepf, int *replace,           int *testdat, double *xts, int *nts, double *yts, int *labelts,           double *yTestPred, double *proxts, double *msets, double *coef,           int *nout, int *inbag) {    /*************************************************************************   Input:   mdim=number of variables in data set   nsample=number of cases   nthsize=number of cases in a node below which the tree will not split,   setting nthsize=5 generally gives good results.   nTree=number of trees in run.  200-500 gives pretty good results   mtry=number of variables to pick to split on at each node.  mdim/3   seems to give genrally good performance, but it can be   altered up or down   imp=1 turns on variable importance.  This is computed for the   mth variable as the percent rise in the test set mean sum-of-   squared errors when the mth variable is randomly permuted.  *************************************************************************/    double errts = 0.0, averrb, meanY, meanYts, varY, varYts, r, xrand,	errb = 0.0, resid=0.0, ooberr, ooberrperm, delta, *resOOB;    double *yb, *xtmp, *xb, *ytr, *ytree, *tgini;    int k, m, mr, n, nOOB, j, jout, idx, ntest, last, ktmp, nPerm,        nsample, mdim, keepF, keepInbag;    int *oobpair, varImp, localImp, *varUsed;    int *in, *nind, *nodex, *nodexts;    nsample = xdim[0];    mdim = xdim[1];    ntest = *nts;    varImp = imp[0];    localImp = imp[1];    nPerm = imp[2];    keepF = keepf[0];    keepInbag = keepf[1];    if (*jprint == 0) *jprint = *nTree + 1;    yb         = (double *) S_alloc(*sampsize, sizeof(double));    xb         = (double *) S_alloc(mdim * *sampsize, sizeof(double));    ytr        = (double *) S_alloc(nsample, sizeof(double));    xtmp       = (double *) S_alloc(nsample, sizeof(double));    resOOB     = (double *) S_alloc(nsample, sizeof(double));    in        = (int *) S_alloc(nsample, sizeof(int));    nodex      = (int *) S_alloc(nsample, sizeof(int));    varUsed    = (int *) S_alloc(mdim, sizeof(int));    nind = *replace ? NULL : (int *) S_alloc(nsample, sizeof(int));    if (*testdat) {	ytree      = (double *) S_alloc(ntest, sizeof(double));	nodexts    = (int *) S_alloc(ntest, sizeof(int));    }    oobpair = (*doProx && *oobprox) ?	(int *) S_alloc(nsample * nsample, sizeof(int)) : NULL;    /* If variable importance is requested, tgini points to the second       "column" of errimp, otherwise it's just the same as errimp. */    tgini = varImp ? errimp + mdim : errimp;    averrb = 0.0;    meanY = 0.0;    varY = 0.0;    zeroDouble(yptr, nsample);    zeroInt(nout, nsample);    for (n = 0; n < nsample; ++n) {	varY += n * (y[n] - meanY)*(y[n] - meanY) / (n + 1);	meanY = (n * meanY + y[n]) / (n + 1);    }    varY /= nsample;    varYts = 0.0;    meanYts = 0.0;    if (*testdat) {	for (n = 0; n < ntest; ++n) {	    varYts += n * (yts[n] - meanYts)*(yts[n] - meanYts) / (n + 1);	    meanYts = (n * meanYts + yts[n]) / (n + 1);	}	varYts /= ntest;    }    if (*doProx) {        zeroDouble(prox, nsample * nsample);	if (*testdat) zeroDouble(proxts, ntest * (nsample + ntest));    }    if (varImp) {        zeroDouble(errimp, mdim * 2);	if (localImp) zeroDouble(impmat, nsample * mdim);    } else {        zeroDouble(errimp, mdim);    }    if (*labelts) zeroDouble(yTestPred, ntest);    /* print header for running output */    if (*jprint <= *nTree) {	Rprintf("     |      Out-of-bag   ");	if (*testdat) Rprintf("|       Test set    ");	Rprintf("|\n");	Rprintf("Tree |      MSE  %%Var(y) ");	if (*testdat) Rprintf("|      MSE  %%Var(y) ");	Rprintf("|\n");    }    GetRNGstate();    /*************************************     * Start the loop over trees.     *************************************/    for (j = 0; j < *nTree; ++j) {		idx = keepF ? j * *nrnodes : 0;		zeroInt(in, nsample);        zeroInt(varUsed, mdim);        /* Draw a random sample for growing a tree. */		if (*replace) { /* sampling with replacement */			for (n = 0; n < *sampsize; ++n) {				xrand = unif_rand();				k = xrand * nsample;				in[k] = 1;				yb[n] = y[k];				for(m = 0; m < mdim; ++m) {					xb[m + n * mdim] = x[m + k * mdim];				}			}		} else { /* sampling w/o replacement */			for (n = 0; n < nsample; ++n) nind[n] = n;			last = nsample - 1;			for (n = 0; n < *sampsize; ++n) {				ktmp = (int) (unif_rand() * (last+1));                k = nind[ktmp];                swapInt(nind[ktmp], nind[last]);				last--;				in[k] = 1;				yb[n] = y[k];				for(m = 0; m < mdim; ++m) {					xb[m + n * mdim] = x[m + k * mdim];				}			}		}		if (keepInbag) {			for (n = 0; n < nsample; ++n) inbag[n + j * nsample] = in[n];		}        /* grow the regression tree */		regTree(xb, yb, mdim, *sampsize, lDaughter + idx, rDaughter + idx,                upper + idx, avnode + idx, nodestatus + idx, *nrnodes,                treeSize + j, *nthsize, *mtry, mbest + idx, cat, tgini,                varUsed);        /* predict the OOB data with the current tree */		/* ytr is the prediction on OOB data by the current tree */		predictRegTree(x, nsample, mdim, lDaughter + idx,                       rDaughter + idx, nodestatus + idx, ytr, upper + idx,                       avnode + idx, mbest + idx, treeSize[j], cat, *maxcat,                       nodex);		/* yptr is the aggregated prediction by all trees grown so far */		errb = 0.0;		ooberr = 0.0;		jout = 0; /* jout is the number of cases that has been OOB so far */		nOOB = 0; /* nOOB is the number of OOB samples for this tree */		for (n = 0; n < nsample; ++n) {			if (in[n] == 0) {				nout[n]++;                nOOB++;				yptr[n] = ((nout[n]-1) * yptr[n] + ytr[n]) / nout[n];				resOOB[n] = ytr[n] - y[n];                ooberr += resOOB[n] * resOOB[n];			}            if (nout[n]) {				jout++;				errb += (y[n] - yptr[n]) * (y[n] - yptr[n]);			}		}		errb /= jout;		/* Do simple linear regression of y on yhat for bias correction. */		if (*biasCorr) simpleLinReg(nsample, yptr, y, coef, &errb, nout);		/* predict testset data with the current tree */		if (*testdat) {			predictRegTree(xts, ntest, mdim, lDaughter + idx,						   rDaughter + idx, nodestatus + idx, ytree,                           upper + idx, avnode + idx,						   mbest + idx, treeSize[j], cat, *maxcat, nodexts);			/* ytree is the prediction for test data by the current tree */			/* yTestPred is the average prediction by all trees grown so far */			errts = 0.0;			for (n = 0; n < ntest; ++n) {				yTestPred[n] = (j * yTestPred[n] + ytree[n]) / (j + 1);			}            /* compute testset MSE */			if (*labelts) {				for (n = 0; n < ntest; ++n) {					resid = *biasCorr ?                        yts[n] - (coef[0] + coef[1]*yTestPred[n]) :                        yts[n] - yTestPred[n];					errts += resid * resid;				}				errts /= ntest;			}		}        /* Print running output. */		if ((j + 1) % *jprint == 0) {			Rprintf("%4d |", j + 1);			Rprintf(" %8.4g %8.2f ", errb, 100 * errb / varY);			if(*labelts == 1) Rprintf("| %8.4g %8.2f ",									  errts, 100.0 * errts / varYts);			Rprintf("|\n");		}		mse[j] = errb;		if (*labelts) msets[j] = errts;		/*  DO PROXIMITIES */		if (*doProx) {			computeProximity(prox, *oobprox, nodex, in, oobpair, nsample);			/* proximity for test data */			if (*testdat) {                /* In the next call, in and oobpair are not used. */                computeProximity(proxts, 0, nodexts, in, oobpair, ntest);				for (n = 0; n < ntest; ++n) {					for (k = 0; k < nsample; ++k) {						if (nodexts[n] == nodex[k]) {							proxts[n + ntest * (k+ntest)] += 1.0;						}					}				}			}		}		/* Variable importance */		if (varImp) {			for (mr = 0; mr < mdim; ++mr) {                if (varUsed[mr]) { /* Go ahead if the variable is used */                    /* make a copy of the m-th variable into xtmp */                    for (n = 0; n < nsample; ++n)                        xtmp[n] = x[mr + n * mdim];                    ooberrperm = 0.0;                    for (k = 0; k < nPerm; ++k) {                        permuteOOB(mr, x, in, nsample, mdim);                        predictRegTree(x, nsample, mdim, lDaughter + idx,                                       rDaughter + idx, nodestatus + idx, ytr,                                       upper + idx, avnode + idx, mbest + idx,                                       treeSize[j], cat, *maxcat, nodex);                        for (n = 0; n < nsample; ++n) {                            if (in[n] == 0) {                                r = ytr[n] - y[n];                                ooberrperm += r * r;                                if (localImp) {                                    impmat[mr + n * mdim] +=                                        (r*r - resOOB[n]*resOOB[n]) / nPerm;                                }                            }                        }                    }                    delta = (ooberrperm / nPerm - ooberr) / nOOB;                    errimp[mr] += delta;                    impSD[mr] += delta * delta;                    /* copy original data back */                    for (n = 0; n < nsample; ++n)                        x[mr + n * mdim] = xtmp[n];                }            }        }    }    PutRNGstate();    /* end of tree iterations=======================================*/    if (*biasCorr) {  /* bias correction for predicted values */		for (n = 0; n < nsample; ++n) {			if (nout[n]) yptr[n] = coef[0] + coef[1] * yptr[n];		}		if (*testdat) {			for (n = 0; n < ntest; ++n) {				yTestPred[n] = coef[0] + coef[1] * yTestPred[n];			}		}    }    if (*doProx) {		for (n = 0; n < nsample; ++n) {			for (k = n + 1; k < nsample; ++k) {                prox[nsample*k + n] /= oobprox ?                    (oobpair[nsample*k + n] > 0 ? oobpair[nsample*k + n] : 1) :                    *nTree;                prox[nsample * n + k] = prox[nsample * k + n];            }			prox[nsample * n + n] = 1.0;        }		if (*testdat) {			for (n = 0; n < ntest; ++n)				for (k = 0; k < ntest + nsample; ++k)					proxts[ntest*k + n] /= *nTree;		}    }    if (varImp) {		for (m = 0; m < mdim; ++m) {			errimp[m] = errimp[m] / *nTree;			impSD[m] = sqrt( ((impSD[m] / *nTree) -							  (errimp[m] * errimp[m])) / *nTree );			if (localImp) {                for (n = 0; n < nsample; ++n) {                    impmat[m + n * mdim] /= nout[n];                }			}        }    }    for (m = 0; m < mdim; ++m) tgini[m] /= *nTree;}/*----------------------------------------------------------------------*/void regForest(double *x, double *ypred, int *mdim, int *n,               int *ntree, int *lDaughter, int *rDaughter,               int *nodestatus, int *nrnodes, double *xsplit,               double *avnodes, int *mbest, int *treeSize, int *cat,               int *maxcat, int *keepPred, double *allpred, int *doProx,               double *proxMat, int *nodes, int *nodex) {    int i, j, idx1, idx2, *junk;    double *ytree;    junk = NULL;    ytree = (double *) S_alloc(*n, sizeof(double));    if (*nodes) {	zeroInt(nodex, *n * *ntree);    } else {	zeroInt(nodex, *n);    }    if (*doProx) zeroDouble(proxMat, *n * *n);    if (*keepPred) zeroDouble(allpred, *n * *ntree);    idx1 = 0;    idx2 = 0;    for (i = 0; i < *ntree; ++i) {	zeroDouble(ytree, *n);	predictRegTree(x, *n, *mdim, lDaughter + idx1, rDaughter + idx1,                       nodestatus + idx1, ytree, xsplit + idx1,                       avnodes + idx1, mbest + idx1, treeSize[i], cat, *maxcat,                       nodex + idx2);	for (j = 0; j < *n; ++j) ypred[j] += ytree[j];	if (*keepPred) {	    for (j = 0; j < *n; ++j) allpred[j + i * *n] = ytree[j];	}	/* if desired, do proximities for this round */	if (*doProx) computeProximity(proxMat, 0, nodex + idx2, junk,				      junk, *n);	idx1 += *nrnodes; /* increment the offset */	if (*nodes) idx2 += *n;    }    for (i = 0; i < *n; ++i) ypred[i] /= *ntree;    if (*doProx) {	for (i = 0; i < *n; ++i) {	    for (j = i + 1; j < *n; ++j) {                proxMat[i + j * *n] /= *ntree;		proxMat[j + i * *n] = proxMat[i + j * *n];	    }	    proxMat[i + i * *n] = 1.0;	}    }}void simpleLinReg(int nsample, double *x, double *y, double *coef,		  double *mse, int *hasPred) {/* Compute simple linear regression of y on x, returning the coefficients,   the average squared residual, and the predicted values (overwriting y). */    int i, nout = 0;    double sxx=0.0, sxy=0.0, xbar=0.0, ybar=0.0;    double dx = 0.0, dy = 0.0, py=0.0;    for (i = 0; i < nsample; ++i) {	if (hasPred[i]) {	    nout++;	    xbar += x[i];	    ybar += y[i];	}    }    xbar /= nout;    ybar /= nout;    for (i = 0; i < nsample; ++i) {	if (hasPred[i]) {	    dx = x[i] - xbar;	    dy = y[i] - ybar;	    sxx += dx * dx;	    sxy += dx * dy;	}    }    coef[1] = sxy / sxx;    coef[0] = ybar - coef[1] * xbar;    *mse = 0.0;    for (i = 0; i < nsample; ++i) {	if (hasPred[i]) {            py = coef[0] + coef[1] * x[i];	    dy = y[i] - py;	    *mse += dy * dy;            /* y[i] = py; */	}    }    *mse /= nout;    return;}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -