⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 score.m

📁 The BNL toolbox is a set of Matlab functions for defining and estimating the parameters of a Bayesi
💻 M
字号:
function  score=score(link,parms,restparms,joint_prob_tabs,design,equiv_class,equiv_class_time,terminal_merged_nodes,N)
%computes score function for multinomial regression in point parms


nrparset=length(parms); %#parameter sets
nrdes=length(design);%#CPTs
%determine for which nodes we have to compute suff stats for each separate case
%(ungrouped data)

grouped=zeros(nrdes,1);
n=ones(nrdes,1)*N;
for i=1:nrdes
    nd=ndims(design{i});
    siz=size(design{i},nd);
    if siz(end)~=N 
        n(i)=1;
        grouped(i)=1;
    end
end
score=[];
for i=1:nrparset
    rest=find(restparms{i}~=1);

    if ~isempty(rest)
        eclass_time=find(equiv_class_time==i);
        eclasses=equiv_class(eclass_time);
        t=[];
        for j=1:length(equiv_class);                   %find the numbers of the equivalent classes
            if ~isempty(find(eclasses==j));
             t=[t j];
            end
        end
        final_freq=[];
        final_tot=[];
        final_des=[];
        for ii=t
            eclass=find(equiv_class==ii);
            post=joint_prob_tabs(eclass);
            if ~mysubset(eclass, terminal_merged_nodes.nodenrs)
                %hidden and partially observed nodes
                %partially observed nodes can be treated as hidden nodes
                %by entering hard evidence, entries become zero, and stay zero
                %during propagation 
            
                if grouped(ii)
                   freq=compute_suff_stats(post);
                else
                    freq=compute_suff_stats_ind(post);
                end
            else
                merged=[];
                for j=1:length(terminal_merged_nodes.nodenrs); %find the numbers of the merged nodes
                    if ~isempty(find(eclass==terminal_merged_nodes.nodenrs(j)));
                        merged=[ merged j];
                    end
                end
                nrvars=terminal_merged_nodes.nrvars( merged(1));
                data=terminal_merged_nodes.data(merged);
                f=find(terminal_merged_nodes.nrvars(merged)~=nrvars);
                if ~isempty(f) error('equiv merged nodes should have same number of variables'), end
                S_item=terminal_merged_nodes.respcat{merged(1)}(1);
                if grouped(ii)
                    freq=compute_suff_stats(post,data,nrvars,S_item,N);
                else
                    freq=compute_suff_stats_ind(post,data,nrvars,S_item,N);
                end
            end
            %freq: dimensions are ordered from highest to lowest nodenr
            %that is: first parents (then items) then respcategories
            %order in parms and design is reversed!
            %change order in freqs
            nd=ndims(freq);
            if~isvector(freq) 
                freq=permute(freq,[nd:-1:1]);
            end
    
        
            if grouped(ii)
                siz=size(freq);
                freq=reshape(freq,siz(1),prod(siz(2:end)));
                freq=freq'+1;
            else
                freq=permute(freq,[2:nd  1]);%cases are the first dimension
                %we change it so that it becomes the last one
                siz=size(freq);
                freq=reshape(freq,siz(1),prod(siz(2:end)));
                freq=freq';
            end
            final_freq=[final_freq;freq];
            total=sum(freq,2);
        
            final_tot=[final_tot;total];
            if grouped(ii)
            des=design{ii};
            else
                siz=size(design{ii});
                des=permute(design{ii},[ 1 3 2]);
                des=reshape(des,siz(1)*siz(3),siz(2));
            end
            final_des=[final_des;des];
        end
        %remove final_tot==0 (when no data are observed)
        f=find(final_tot==0);
        final_des(f,:)=[];
        final_freq(f,:)=[];
        final_tot(f,:)=[];
    
        %parms{i}=full(pars);
        [n,q]=size(final_freq);
        baselinecategory=1;
        rel_freq=final_freq./(final_tot*ones(1,q));
        rel_freq(:,baselinecategory)=[];  %discard baseline freqs
        rel_freq=rel_freq';
        rel_freq=rel_freq(:);% one vector: first all categories for first observation, etc.
        lin_pred=final_des*parms{i};
        lin_pred=reshape(lin_pred,q-1,n);
        li=link{eclasses(1)};
        if strmatch(li,'multinomial')
        mu=multinom_logistic([zeros(n,1) lin_pred']); %add column of zeros for multinom_logistic
        elseif  strmatch(li,'cumulative')
            mu=cum_logistic([zeros(n,1) lin_pred']);
             D=deriv_cum_logist(lin_pred');
        elseif  strmatch(li,'adjacent')
            mu=adj_logistic([zeros(n,1) lin_pred']);
            D=deriv_adj_logist(lin_pred');
        else error('no valid link function');
        end
            
        mu(:,1)=[];%remove first column again
        m=mu';
        m=m(:);% one vector: first all categories for first observation, etc.
        s=numel(m);
        if strmatch(li,'multinomial')
            omega=kron(sparse(1:n,1:n,final_tot),speye((q-1)));
            sc=final_des'*omega*(rel_freq-m);
            score2=full(sc);
        else 
            DD=[];
            V=[];
            for i=1:size(D,3)
                varfnctn=(diag(mu(i,:))-mu(i,:)'*mu(i,:))/total(i);
                DD=blkdiag(DD,sparse(D(:,:,i)));
                V=blkdiag(V,inv(varfnctn));
            end
            
            score2=final_des'*DD*V*(rel_freq-m);
            end
        
        score=[score;score2(rest)];
    end
end

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -