📄 fghd_lagr.m
字号:
function [f,grd,hess_diag,h]=fghd_lagr(w,X,fghpar) mu= 1e0; penpar=fghpar.penpar; pentype= fghpar.pentype; [Nsn,T]=size(X); W=reshape(w,Nsn,Nsn)'; S=W*X; A=inv(W); if strcmp(pentype,'boxconstr'), % Quadr-log penalty with multipliers lagrmult=fghpar.lagrmult; [h1,dh1,d2h1]=penfun(-1-S,penpar,lagrmult(1:Nsn,:)); [h2,dh2,d2h2]=penfun(S-1,penpar,lagrmult((Nsn+1):end,:)); h=h1+h2; dh=dh2-dh1; d2h=d2h1+d2h2; elseif strcmp(pentype,'smooth_abs') [h,dh,d2h]=psi(S,penpar); % Smoothing abs value function elseif strcmp(pentype,'smooth_lq') lagrmult=fghpar.lagrmult; [h,dh,d2h]=phimax_lq(S,penpar,lagrmult,-1,1); % Smoothing abs value function else error('unknown penalty function'); end f= -log(abs(det(W))) + (mu/T)*sum(h(:)); if nargout > 1 %%%%%%%%%%%%%%% compute gradient grd= -A' + (mu/T)*dh*X'; %natgrd=grd*W'*W; %natgrd=natgrd(:); %gnat=grd*W'*W; gnat=gnat'; gnat= gnat(:); %Natural gradient grd=grd'; grd= grd(:); end %hess_diag=[]; if nargout > 2 %%%%%%% compute hessian diagonal for W=I XX=X.^2; %for i=1:Nsn % hess_diag=[hess_diag;(mu/T)*XX*d2h(i,:)'+1]; %end hess_diagM=(mu/T)*XX*d2h'; hess_diag=hess_diagM(:); end if nargout > 3 %%%%%%% compute hessian A=eye(Nsn); %h = diag(hess_diag); h = zeros(Nsn^2); k=1; for j=1:Nsn for i=1:Nsn tmp=A(:,j)*A(i,:); h(k,:)=h(k,:)+tmp(:)'; k=k+1; end end k=1; tt=1:k:T; T1=length(tt); ind=[1:Nsn]; for i=1:Nsn Xd2h_i=mulmd(X(:,tt),(mu/T1)*d2h(i,tt)); %hess(:,:,i)=Xd2h_i*X(:,tt)' + I; hh =Xd2h_i*X(:,tt)'; h(ind,ind)=h(ind,ind)+hh; %hess(ind,ind)=hess(ind,ind)+diag(diag(h)); %%%%%% for testing only!!! ind=ind+Nsn; end end return
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -