📄 fit_multinom_logistic.m
字号:
function [parms,restparms,fischer,se]=fit_multinom_logistic( parms,restparms,design, freq,total,prec,baselinecategory)
%fits logistic regression model (multinomial)
%frequencies: row vector or matrix n by q (q=number of categories)
%total= total counts: column vector of length n
%design is design matrix n*(q-1) by #pars, first all rows for first observation
%(row of freq)
%returns MLE parms (column vector) and optional: fischer info matrix fischer, and standard errors se
if nargin==6
baselinecategory=1;
end
%take out parms that are restricted
f=find(restparms~=1);
if ~isempty(f)
[n,q]=size(freq);
rel_freq=freq./(total*ones(1,q));
rel_freq(:,baselinecategory)=[]; %discard baseline freqs
rel_freq=rel_freq';
rel_freq=rel_freq(:);% one vector: first all categories for first observation, etc.
v=Inf;
nr_it=1;
while v>prec
%fit logistic model using iteratively reweighted least squares
%compute working observations and weigths
nr_it;
lin_pred=design*parms;
lin_pred=reshape(lin_pred,q-1,n);
mu=multinom_logistic([zeros(n,1) lin_pred']); %add column of zeros for multinom_logistic
mu(:,1)=[];%remove first column again
%compute blockdiag weightmatrix (for natural link)
%too slow!!!!
% varfnctn=[];
% omega=[];
% for i=1:n
% varfc=diag(mu(i,:))-mu(i,:)'*mu(i,:);
%
% varfnctn=blkdiag(varfnctn,varfc);
% omega=blkdiag(omega,total(i)*eye((q-1)));
%end
%faster??
m=mu';
m=m(:);% one vector: first all categories for first observation, etc.
s=numel(m);
%varfnctn2=diag(m(:))
%varfnctn=diag(m(:))-m(:)*m(:)'.*kron(eye(n),ones(q-1));
% find(varfnctn-varfnctn2~=0)
%pause
%omega=kron(diag(total),eye((q-1)));
%still faster
%%%%%%%%%%%%%
omega=kron(sparse(1:n,1:n,total),speye((q-1)));
aa=kron(speye(n),ones(q-1));
[r,c]=find(aa==1);
% aa(sub2ind(size(aa),r,c))=m(r).*m(c);
aa=spconvert([r c m(r).*m(c)]);
varfnctn=sparse(1:s,1:s,m)-aa;
%varfnctn=sparse(1:s,1:s,m)-m*m'.*kron(speye(n),ones(q-1));
%fischer scoring instead of iwls
weights=varfnctn*omega;
fischer=design'*weights*design;
score=design'*omega*(rel_freq-m);
parms_old=parms;
f=find(restparms~=1);
parms(f)=parms_old(f)+fischer(f,f)\score(f);
%weighted least squares
%lin_pred=lin_pred(:);% one vector: first all categories for first observation, etc.
%work_y=lin_pred+ varfnctn\(rel_freq-m);
%weights=varfnctn*omega;
%parms_old=parms;
%parms=(design'*weights*design)\design'*weights*work_y;
%converged?
%%%%%%%%%%%%%
%nr_it=nr_it+1;
a=find(abs(parms)>15);
parms(a)=sign(parms(a))*15;
restparms(a)=1;
v=max(abs(parms_old-parms));
%if ~isempty(a)
% v=1;
% end
end
end
%parms
if nargout>2
se=sqrt(diag(inv(fischer)));
end
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -