📄 regtree.c
字号:
/******************************************************************* Copyright (C) 2001-7 Leo Breiman, Adele Cutler and Merck & Co., Inc. This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. *******************************************************************//****************************************************************** * buildtree and findbestsplit routines translated from Leo's * original Fortran code. * * copyright 1999 by leo Breiman * this is free software and can be used for any purpose. * It comes with no guarantee. * ******************************************************************/#include <Rmath.h>#include <R.h>#include "rf.h"void regTree(double *x, double *y, int mdim, int nsample, int *lDaughter, int *rDaughter, double *upper, double *avnode, int *nodestatus, int nrnodes, int *treeSize, int nthsize, int mtry, int *mbest, int *cat, double *tgini, int *varUsed) { int i, j, k, m, ncur, *jdex, *nodestart, *nodepop; int ndstart, ndend, ndendl, nodecnt, jstat, msplit; double d, ss, av, decsplit, ubest, sumnode; nodestart = (int *) Calloc(nrnodes, int); nodepop = (int *) Calloc(nrnodes, int); /* initialize some arrays for the tree */ zeroInt(nodestatus, nrnodes); zeroInt(nodestart, nrnodes); zeroInt(nodepop, nrnodes); zeroDouble(avnode, nrnodes); jdex = (int *) Calloc(nsample, int); for (i = 1; i <= nsample; ++i) jdex[i-1] = i; ncur = 0; nodestart[0] = 0; nodepop[0] = nsample; nodestatus[0] = NODE_TOSPLIT; /* compute mean and sum of squares for Y */ av = 0.0; ss = 0.0; for (i = 0; i < nsample; ++i) { d = y[jdex[i] - 1]; ss += i * (av - d) * (av - d) / (i + 1); av = (i * av + d) / (i + 1); } avnode[0] = av; /* start main loop */ for (k = 0; k < nrnodes - 2; ++k) { if (k > ncur || ncur >= nrnodes - 2) break; /* skip if the node is not to be split */ if (nodestatus[k] != NODE_TOSPLIT) continue; /* initialize for next call to findbestsplit */ ndstart = nodestart[k]; ndend = ndstart + nodepop[k] - 1; nodecnt = nodepop[k]; sumnode = nodecnt * avnode[k]; jstat = 0; decsplit = 0.0; findBestSplit(x, jdex, y, mdim, nsample, ndstart, ndend, &msplit, &decsplit, &ubest, &ndendl, &jstat, mtry, sumnode, nodecnt, cat); if (jstat == 1) { /* Node is terminal: Mark it as such and move on to the next. */ nodestatus[k] = NODE_TERMINAL; continue; } /* Found the best split. */ mbest[k] = msplit; varUsed[msplit - 1] = 1; upper[k] = ubest; tgini[msplit - 1] += decsplit; nodestatus[k] = NODE_INTERIOR; /* leftnode no.= ncur+1, rightnode no. = ncur+2. */ nodepop[ncur + 1] = ndendl - ndstart + 1; nodepop[ncur + 2] = ndend - ndendl; nodestart[ncur + 1] = ndstart; nodestart[ncur + 2] = ndendl + 1; /* compute mean and sum of squares for the left daughter node */ av = 0.0; ss = 0.0; for (j = ndstart; j <= ndendl; ++j) { d = y[jdex[j]-1]; m = j - ndstart; ss += m * (av - d) * (av - d) / (m + 1); av = (m * av + d) / (m+1); } avnode[ncur+1] = av; nodestatus[ncur+1] = NODE_TOSPLIT; if (nodepop[ncur + 1] <= nthsize) { nodestatus[ncur + 1] = NODE_TERMINAL; } /* compute mean and sum of squares for the right daughter node */ av = 0.0; ss = 0.0; for (j = ndendl + 1; j <= ndend; ++j) { d = y[jdex[j]-1]; m = j - (ndendl + 1); ss += m * (av - d) * (av - d) / (m + 1); av = (m * av + d) / (m + 1); } avnode[ncur + 2] = av; nodestatus[ncur + 2] = NODE_TOSPLIT; if (nodepop[ncur + 2] <= nthsize) { nodestatus[ncur + 2] = NODE_TERMINAL; } /* map the daughter nodes */ lDaughter[k] = ncur + 1 + 1; rDaughter[k] = ncur + 2 + 1; /* Augment the tree by two nodes. */ ncur += 2; } *treeSize = nrnodes; for (k = nrnodes - 1; k >= 0; --k) { if (nodestatus[k] == 0) (*treeSize)--; if (nodestatus[k] == NODE_TOSPLIT) { nodestatus[k] = NODE_TERMINAL; } } Free(nodestart); Free(jdex); Free(nodepop);}/*--------------------------------------------------------------*/void findBestSplit(double *x, int *jdex, double *y, int mdim, int nsample, int ndstart, int ndend, int *msplit, double *decsplit, double *ubest, int *ndendl, int *jstat, int mtry, double sumnode, int nodecnt, int *cat) { int last, ncat[32], icat[32], lc, nl, nr, npopl, npopr; int i, j, kv, l, *mind, *ncase; double *xt, *ut, *v, *yl, sumcat[32], avcat[32], tavcat[32], ubestt; double crit, critmax, critvar, suml, sumr, d, critParent; ut = (double *) Calloc(nsample, double); xt = (double *) Calloc(nsample, double); v = (double *) Calloc(nsample, double); yl = (double *) Calloc(nsample, double); mind = (int *) Calloc(mdim, int); ncase = (int *) Calloc(nsample, int); zeroDouble(avcat, 32); zeroDouble(tavcat, 32); /* START BIG LOOP */ *msplit = -1; *decsplit = 0.0; critmax = 0.0; ubestt = 0.0; for (i=0; i < mdim; ++i) mind[i] = i; last = mdim - 1; for (i = 0; i < mtry; ++i) { critvar = 0.0; j = (int) (unif_rand() * (last+1)); kv = mind[j]; swapInt(mind[j], mind[last]); /* mind[j] = mind[last]; mind[last] = kv; */ last--; lc = cat[kv]; if (lc == 1) { /* numeric variable */ for (j = ndstart; j <= ndend; ++j) { xt[j] = x[kv + (jdex[j] - 1) * mdim]; yl[j] = y[jdex[j] - 1]; } } else { /* categorical variable */ zeroInt(ncat, 32); zeroDouble(sumcat, 32); for (j = ndstart; j <= ndend; ++j) { l = (int) x[kv + (jdex[j] - 1) * mdim]; sumcat[l - 1] += y[jdex[j] - 1]; ncat[l - 1] ++; } /* Compute means of Y by category. */ for (j = 0; j < lc; ++j) { avcat[j] = ncat[j] ? sumcat[j] / ncat[j] : 0.0; } /* Make the category mean the `pseudo' X data. */ for (j = 0; j < nsample; ++j) { xt[j] = avcat[(int) x[kv + (jdex[j] - 1) * mdim] - 1]; yl[j] = y[jdex[j] - 1]; } } /* copy the x data in this node. */ for (j = ndstart; j <= ndend; ++j) v[j] = xt[j]; for (j = 1; j <= nsample; ++j) ncase[j - 1] = j; R_qsort_I(v, ncase, ndstart + 1, ndend + 1); if (v[ndstart] >= v[ndend]) continue; /* ncase(n)=case number of v nth from bottom */ /* Start from the right and search to the left. */ critParent = sumnode * sumnode / nodecnt; suml = 0.0; sumr = sumnode; npopl = 0; npopr = nodecnt; crit = 0.0; /* Search through the "gaps" in the x-variable. */ for (j = ndstart; j <= ndend - 1; ++j) { d = yl[ncase[j] - 1]; suml += d; sumr -= d; npopl++; npopr--; if (v[j] < v[j+1]) { crit = (suml * suml / npopl) + (sumr * sumr / npopr) - critParent; if (crit > critvar) { ubestt = (v[j] + v[j+1]) / 2.0; critvar = crit; } } } if (critvar > critmax) { *ubest = ubestt; *msplit = kv + 1; critmax = critvar; for (j = ndstart; j <= ndend; ++j) { ut[j] = xt[j]; } if (cat[kv] > 1) { for (j = 0; j < cat[kv]; ++j) tavcat[j] = avcat[j]; } } } *decsplit = critmax; /* If best split can not be found, set to terminal node and return. */ if (*msplit != -1) { nl = ndstart; for (j = ndstart; j <= ndend; ++j) { if (ut[j] <= *ubest) { nl++; ncase[nl-1] = jdex[j]; } } *ndendl = imax2(nl - 1, ndstart); nr = *ndendl + 1; for (j = ndstart; j <= ndend; ++j) { if (ut[j] > *ubest) { if (nr >= nsample) break; nr++; ncase[nr - 1] = jdex[j]; } } if (*ndendl >= ndend) *ndendl = ndend - 1; for (j = ndstart; j <= ndend; ++j) jdex[j] = ncase[j]; lc = cat[*msplit - 1]; if (lc > 1) { for (j = 0; j < lc; ++j) { icat[j] = (tavcat[j] < *ubest) ? 1 : 0; } *ubest = pack(lc, icat); } } else *jstat = 1; Free(ncase); Free(mind); Free(v); Free(yl); Free(xt); Free(ut); }/*====================================================================*/void predictRegTree(double *x, int nsample, int mdim, int *lDaughter, int *rDaughter, int *nodestatus, double *ypred, double *split, double *nodepred, int *splitVar, int treeSize, int *cat, int maxcat, int *nodex) { int i, j, k, m, npack, *cbestsplit; /* decode the categorical splits */ if (maxcat > 1) { cbestsplit = (int *) Calloc(maxcat * treeSize, int); zeroInt(cbestsplit, maxcat * treeSize); for (i = 0; i < treeSize; ++i) { if (nodestatus[i] != NODE_TERMINAL && cat[splitVar[i] - 1] > 1) { npack = (int) split[i]; /* unpack `npack' into bits */ for (j = 0; npack; npack >>= 1, ++j) { cbestsplit[j + i*maxcat] = npack & 1; } } } } for (i = 0; i < nsample; ++i) { k = 0; while (nodestatus[k] != NODE_TERMINAL) { /* go down the tree */ m = splitVar[k] - 1; if (cat[m] == 1) { k = (x[m + i*mdim] <= split[k]) ? lDaughter[k] - 1 : rDaughter[k] - 1; } else { /* Split by a categorical predictor */ k = cbestsplit[(int) x[m + i * mdim] - 1 + k * maxcat] ? lDaughter[k] - 1 : rDaughter[k] - 1; } } /* terminal node: assign prediction and move on to next */ ypred[i] = nodepred[k]; nodex[i] = k + 1; } if (maxcat > 1) Free(cbestsplit);}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -