📄 bayestree
字号:
.packageName <- "BayesTree"
bart = function( x.train, y.train, x.test=matrix(0.0,0,0), sigest=NA, sigdf=3, sigquant=.90, k=2.0, power=2.0, base=.95, ntree=200, ndpost=1000, nskip=100, printevery=100, keepevery=1, keeptrainfits=TRUE, usequants=FALSE, numcut=100, printcutoffs=0, verbose=TRUE){ #check input arguments: if((!is.matrix(x.train)) || (typeof(x.train)!="double")) stop("argument x.train must be a double matrix") if((!is.matrix(x.test)) || (typeof(x.test)!="double")) stop("argument x.test must be a double matrix") if((!is.vector(y.train)) || (typeof(y.train)!="double")) stop("argument y.train must be a double vector") if(nrow(x.train) != length(y.train)) stop("number of rows in x.train must equal length of y.train") if((nrow(x.test) >0) && (ncol(x.test)!=ncol(x.train))) stop("input x.test must have the same number of columns as x.train") if((!is.na(sigest)) && (typeof(sigest)!="double")) stop("input sigest must be double") if((!is.na(sigest)) && (sigest<0.0)) stop("input sigest must be positive") if((mode(sigdf)!="numeric") || (sigdf<0)) stop("input sigdf must be a positive number") if((mode(printevery)!="numeric") || (printevery<0)) stop("input printevery must be a positive number") if((mode(keepevery)!="numeric") || (keepevery<0)) stop("input keepevery must be a positive number") if((mode(sigquant)!="numeric") || (sigquant<0)) stop("input sigquant must be a positive number") if((mode(ntree)!="numeric") || (ntree<0)) stop("input ntree must be a positive number") if((mode(ndpost)!="numeric") || (ndpost<0)) stop("input ndpost must be a positive number") if((mode(nskip)!="numeric") || (nskip<0)) stop("input nskip must be a positive number") if((mode(k)!="numeric") || (k<0)) stop("input k must be a positive number") if(mode(numcut)!="numeric") stop("input numcut must be a numeric vector") if(length(numcut)==1) numcut = rep(numcut,ncol(x.train)) if(length(numcut) != ncol(x.train)) stop("length of numcut must equal number of columns of x.train") numcut = as.integer(numcut) if(min(numcut)<1) stop("numcut must be >= 1") if(typeof(usequants) != "logical") stop("input usequants must a logical variable") if(typeof(keeptrainfits) != "logical") stop("input keeptrainfits must a logical variable") if(typeof(verbose) != "logical") stop("input verbose must a logical variable") if(mode(printcutoffs) != "numeric") stop("input printcutoffs must be numeric") printcutoffs = as.integer(printcutoffs) if(printcutoffs <0) stop("input printcutoffs must be >=0") if(power <= 0) stop("power must be positive") if(base <= 0) stop("base must be positive") rgy = range(y.train) y = -.5 + (y.train-rgy[1])/(rgy[2]-rgy[1]) # if sigest=NA, fit a lm to training data to get the value of sigest... # sigest is on the scale of the transformed y, so we do the lm after the scaling above... if (is.na(sigest)) { templm = lm(y~x.train) sigest = summary(templm)$sigma } else { sigest = sigest/(rgy[2]-rgy[1]) #put input sigma estimate on transformed scale } ncskip = floor(nskip/keepevery) ncpost = floor(ndpost/keepevery) nctot = ncskip + ncpost totnd = keepevery*nctot cres = .C('mbart',as.integer(nrow(x.train)), as.integer(ncol(x.train)), as.integer(nrow(x.test)), as.double(x.train), as.double(y), as.double(x.test), as.double(sigest), as.integer(sigdf), as.double(sigquant), as.double(k), as.double(power), as.double(base), as.integer(ntree), as.integer(totnd), as.integer(printevery), as.integer(keepevery), as.integer(keeptrainfits), as.integer(numcut), as.integer(usequants), as.integer(printcutoffs), as.integer(verbose), sdraw=double(nctot), trdraw=double(nrow(x.train)*nctot), tedraw=double(nrow(x.test)*nctot), vcdraw=integer(ncol(x.train)*nctot)) # now read in the results... sigma = cres$sdraw*(rgy[2]-rgy[1]) first.sigma = sigma[1:ncskip] # we often want the sigma draws sigma = sigma[ncskip+(1:ncpost)] # put sigest on the original y scale for output purposes sigest = sigest*(rgy[2]-rgy[1]) yhat.train = yhat.test = yhat.train.mean = yhat.test.mean = NULL varcount = NULL if (keeptrainfits) { yhat.train = matrix(cres$trdraw,nrow=nctot,byrow=T)[(ncskip+1):nctot,] yhat.train = (rgy[2]-rgy[1])*(yhat.train+.5) + rgy[1] yhat.train.mean = apply(yhat.train,2,mean) } if (nrow(x.test)) { yhat.test = matrix(cres$tedraw,nrow=nctot,byrow=T)[(ncskip+1):nctot,] yhat.test = (rgy[2]-rgy[1])*(yhat.test+.5) + rgy[1] yhat.test.mean = apply(yhat.test,2,mean) } varcount = matrix(cres$vcdraw,nrow=nctot,byrow=T)[(ncskip+1):nctot,] retval = list( call=match.call(), first.sigma=first.sigma, sigma=sigma, sigest=sigest, yhat.train=yhat.train, yhat.train.mean=yhat.train.mean, yhat.test=yhat.test, yhat.test.mean=yhat.test.mean, varcount=varcount, y = y.train ) class(retval) = 'bart' return(invisible(retval))}pd2bart = function ( x.train, y.train, xind=1:2, levs=NULL, levquants=c(.05,(1:9)/10,.95), pl=TRUE, plquants=c(.05,.95), ...){ n = nrow(x.train) nlevels = rep(0,2) if(is.null(levs)) { levs = list() for(i in 1:2) { ux = unique(x.train[,xind[i]]) if(length(ux) <= length(levquants)) levs[[i]] = sort(ux) else levs[[i]] = unique(quantile(x.train[,xind[i]],probs=levquants)) } } nlevels = unlist(lapply(levs,length)) xvals <- as.matrix(expand.grid(levs[[1]],levs[[2]])) nxvals <- nrow(xvals) if (ncol(x.train)==2){ cat('special case: only 2 xs\n') x.test = xvals } else { x.test=NULL for(v in 1:nxvals) { temp = x.train temp[,xind[1]] = xvals[v,1] temp[,xind[2]] = xvals[v,2] x.test = rbind(x.test,temp) } } pdbrt = bart(x.train,y.train,x.test,...) if (ncol(x.train)==2) { fdr = pdbrt$yhat.test } else { fdr = NULL for(i in 1:nxvals) { cind = ((i-1)*n+1):(i*n) fdr = cbind(fdr,(apply(pdbrt$yhat.test[,cind],1,mean))) } } if(is.null(colnames(x.train))) xlbs = paste('x',xind,sep='') else xlbs = colnames(x.train)[xind] retval = list(fd = fdr,levs = levs,xlbs=xlbs, bartcall=pdbrt$call,yhat.train=pdbrt$yhat.train, first.sigma=pdbrt$first.sigma,sigma=pdbrt$sigma, yhat.train.mean=pdbrt$yhat.train.mean,sigest=pdbrt$sigest,y=pdbrt$y) class(retval) = 'pd2bart' if(pl) plot(retval,plquants=plquants) return(retval)}pdbart = function ( x.train, y.train, xind=1:ncol(x.train), levs=NULL, levquants=c(.05,(1:9)/10,.95), pl=TRUE, plquants=c(.05,.95), ...){ n = nrow(x.train) nvar = length(xind) nlevels = rep(0,nvar) if(is.null(levs)) { levs = list() for(i in 1:nvar) { ux = unique(x.train[,xind[i]]) if(length(ux) < length(levquants)) levs[[i]] = sort(ux) else levs[[i]] = unique(quantile(x.train[,xind[i]],probs=levquants)) } } nlevels = unlist(lapply(levs,length)) x.test=NULL for(i in 1:nvar) { for(v in levs[[i]]) { temp = x.train temp[,xind[i]] = v x.test = rbind(x.test,temp) } } pdbrt = bart(x.train,y.train,x.test,...) fdr = list() cnt=0 for(j in 1:nvar) { fdrtemp=NULL for(i in 1:nlevels[j]) { cind = cnt + ((i-1)*n+1):(i*n) fdrtemp = cbind(fdrtemp,(apply(pdbrt$yhat.test[,cind],1,mean))) } fdr[[j]] = fdrtemp cnt = cnt + n*nlevels[j] } if(is.null(colnames(x.train))) xlbs = paste('x',xind,sep='') else xlbs = colnames(x.train)[xind] retval = list(fd = fdr,levs = levs,xlbs=xlbs, bartcall=pdbrt$call,yhat.train=pdbrt$yhat.train, first.sigma=pdbrt$first.sigma,sigma=pdbrt$sigma, yhat.train.mean=pdbrt$yhat.train.mean,sigest=pdbrt$sigest,y=pdbrt$y) class(retval) = 'pdbart' if(pl) plot(retval,plquants=plquants) return(retval)}plot.bart = function( x, plquants=c(.05,.95), cols =c('blue','black'), ...){ par(mfrow=c(1,2)) plot(c(x$first.sigma,x$sigma),col=rep(c('red','black'), c(length(x$first.sigma),length(x$sigma))),ylab='sigma',...) ql <- apply(x$yhat.train,2,quantile,probs=plquants[1]) qm <- apply(x$yhat.train,2,quantile,probs=.5) qu <- apply(x$yhat.train,2,quantile,probs=plquants[2]) plot(x$y,qm,ylim=range(ql,qu),xlab='y',ylab= 'posterior interval for E(Y|x)',...) for (i in 1:length(qm)) lines(rep(x$y[i],2),c(ql[i],qu[i]),col=cols[1]) abline(0,1,lty=2,col=cols[2])}plot.pd2bart = function( x, plquants =c(.05,.95), contour.color='white', justmedian=TRUE, ...){ pdquants = apply(x$fd,2,quantile,probs=c(plquants[1],.5,plquants[2])) qq <- vector('list',3) for (i in 1:3) qq[[i]] <- matrix(pdquants[i,],nrow=length(x$levs[[1]])) if(justmedian) { zlim = range(qq[[2]]) vind = c(2) } else { par(mfrow=c(1,3)) zlim = range(qq) vind = 1:3 } for (i in vind) { image(x=x$levs[[1]],y=x$levs[[2]],qq[[i]],zlim=zlim, xlab=x$xlbs[1],ylab=x$xlbs[2],...) contour(x=x$levs[[1]],y=x$levs[[2]],qq[[i]],zlim=zlim, ,add=TRUE,method='edge',col=contour.color) title(main=c('Lower quantile','Median','Upper quantile')[i]) }}plot.pdbart = function( x, xind = 1:length(x$fd), plquants =c(.05,.95),cols=c('black','blue'), ...){ rgy = range(x$fd) for(i in xind) { tsum = apply(x$fd[[i]],2,quantile,probs=c(plquants[1],.5,plquants[2])) plot(range(x$levs[[i]]),rgy,type='n',xlab=x$xlbs[i],ylab='partial-dependence',...) lines(x$levs[[i]],tsum[2,],col=cols[1],type='b') lines(x$levs[[i]],tsum[1,],col=cols[2],type='b') lines(x$levs[[i]],tsum[3,],col=cols[2],type='b') }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -