📄 test_lwr_2d.m
字号:
function test_lwr_2D% test for 1D data set; due to globally optimized distance metric,% one can observe overfitting in the flat parts of the functionn = 500;% a random training set using the CROSS functionX = (rand(n,2)-.5)*2;Y = max([exp(-X(:,1).^2 * 10),exp(-X(:,2).^2 * 50),1.25*exp(-(X(:,1).^2+X(:,2).^2)*5)]');Y = Y' + randn(n,1)*0.05;% a systematic test setXt = [];for i=-1:0.05:1, for j=-1:0.05:1, Xt = [Xt; i j]; endendYt = max([exp(-Xt(:,1).^2 * 10),exp(-Xt(:,2).^2 * 50),1.25*exp(-(Xt(:,1).^2+Xt(:,2).^2)*5)]');Yt = Yt';% find the optimal distance metric by cross validationDmin = 10;Dmax = 1000;n_iter = 10;for j=0:n_iter, D = Dmin-1 + exp(log(Dmax-(Dmin-1))/n_iter*j); DD = diag([ D D ]); mse_cv = 0; for i=1:n, XX=X; YY=Y; XX(i,:)=[]; YY(i,:)=[]; [beta,yq]=lwr(XX,YY,DD,X(i,:)'); mse_cv = mse_cv+(Y(i)-yq)^2; end mse_cv = mse_cv/n; R(j+1,:)=[D,mse_cv]; disp(sprintf('%3d: D=%f mse_cv=%f',j,D,mse_cv));end[val,ind] = min(R(:,2));D = R(ind,1);DD = diag([ D D ]);% create the final LWR fitYp = zeros(size(Yt));for i=1:length(Xt), [beta,yq]=lwr(X,Y,DD,Xt(i,:)'); Yp(i,1) = yq;endfigure(1);clf;subplot(2,2,1);plot3(X(:,1),X(:,2),Y,'*');axis([-1 1 -1 1 -.5 1.5]);subplot(2,2,2);[x,y,z]=makesurf([Xt,Yp],sqrt(length(Xt)));surfl(x,y,z);axis([-1 1 -1 1 -.5 1.5]);title(sprintf('Optimial D=%f',D));subplot(2,2,3);[x,y,z]=makesurf([Xt,Yt],sqrt(length(Xt)));surfl(x,y,z);axis([-1 1 -1 1 -.5 1.5]);subplot(2,2,4);plot(R(:,1),R(:,2));% --------------------------------------------------------------------------------function [X,Y,Z]=makesurf(data,nx)% [X,Y,Z]=makesurf(data,nx) converts the 3D data file data into% three matices as need by surf(). nx tells how long the row of the% output matrices are[m,n]=size(data);n=0;for i=1:nx:m, n = n+1; X(:,n) = data(i:i+nx-1,1); Y(:,n) = data(i:i+nx-1,2); Z(:,n) = data(i:i+nx-1,3);end;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -