⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 fit_multinom_logistic.m

📁 The BNL toolbox is a set of Matlab functions for defining and estimating the parameters of a Bayesi
💻 M
字号:
function [parms,restparms,fischer,se]=fit_multinom_logistic( parms,restparms,design, freq,total,prec,baselinecategory)
%fits logistic regression model (multinomial)


%frequencies: row vector or matrix n by q (q=number of categories)
%total= total counts: column vector of length n
%design is design matrix n*(q-1) by #pars, first all rows for first observation
%(row of freq)
%returns MLE parms (column vector) and optional: fischer info matrix fischer, and standard errors se

if nargin==6
    baselinecategory=1;
end
%take out parms that are restricted
f=find(restparms~=1);

if ~isempty(f)
    [n,q]=size(freq);

    rel_freq=freq./(total*ones(1,q));
    rel_freq(:,baselinecategory)=[];  %discard baseline freqs
    rel_freq=rel_freq';
    rel_freq=rel_freq(:);% one vector: first all categories for first observation, etc.
    v=Inf;
    nr_it=1;
    
    while v>prec
    %fit logistic model using iteratively reweighted least squares
    %compute working observations and weigths
        nr_it;
        lin_pred=design*parms;
        lin_pred=reshape(lin_pred,q-1,n);
        mu=multinom_logistic([zeros(n,1) lin_pred']); %add column of zeros for multinom_logistic
        mu(:,1)=[];%remove first column again
        %compute blockdiag weightmatrix (for natural link)
        %too slow!!!!
        % varfnctn=[];
        % omega=[];
        % for i=1:n
        %     varfc=diag(mu(i,:))-mu(i,:)'*mu(i,:);
        %     
        %    varfnctn=blkdiag(varfnctn,varfc);
        %     omega=blkdiag(omega,total(i)*eye((q-1)));
        %end
        %faster??
        m=mu';
        m=m(:);% one vector: first all categories for first observation, etc.
        s=numel(m);
        %varfnctn2=diag(m(:))
        %varfnctn=diag(m(:))-m(:)*m(:)'.*kron(eye(n),ones(q-1));
    %    find(varfnctn-varfnctn2~=0)
        %pause
        %omega=kron(diag(total),eye((q-1)));
        %still faster
        %%%%%%%%%%%%%
        omega=kron(sparse(1:n,1:n,total),speye((q-1)));
        aa=kron(speye(n),ones(q-1));
        [r,c]=find(aa==1);
       % aa(sub2ind(size(aa),r,c))=m(r).*m(c);
        aa=spconvert([r c m(r).*m(c)]);
        varfnctn=sparse(1:s,1:s,m)-aa;
        
        %varfnctn=sparse(1:s,1:s,m)-m*m'.*kron(speye(n),ones(q-1));

    


        %fischer scoring instead of iwls
        weights=varfnctn*omega;
        fischer=design'*weights*design;
        score=design'*omega*(rel_freq-m);
        parms_old=parms;

        f=find(restparms~=1);
        parms(f)=parms_old(f)+fischer(f,f)\score(f);
        
        %weighted least squares
        %lin_pred=lin_pred(:);% one vector: first all categories for first observation, etc.
        %work_y=lin_pred+ varfnctn\(rel_freq-m);
        %weights=varfnctn*omega;
        %parms_old=parms;
        %parms=(design'*weights*design)\design'*weights*work_y;
        %converged?
        %%%%%%%%%%%%%
        
        
        
        %nr_it=nr_it+1;
        a=find(abs(parms)>15);
        parms(a)=sign(parms(a))*15;
        restparms(a)=1;
       v=max(abs(parms_old-parms));
       %if ~isempty(a) 
       %     v=1;
       % end
    end
    
end
%parms
if nargout>2
se=sqrt(diag(inv(fischer)));
end

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -