📄 crf_thread.cpp
字号:
#include "crf_thread.h"#include "fun.h"void crf_thread::build_lattice(sequence &seq){ if(path.size()<c->path_num*seq.node_num) path.resize(c->path_num*seq.node_num);//r*f for all path ended with current node fill(path.begin(),path.end(),0); int i,j,k,ii,jj; for(i=0;i<seq.node_num;i++) { double *cur_path=&path[c->path_num*i]; node &nod=seq.nodes[i]; //calculate r*f for paths in cliques order for(j=0;j<nod.clique_num;j++) { if(!nod.cliques[j]) continue; clique &cli=*nod.cliques[j];//get j th clique vector<vector<int> > &group=c->templet_group[cli.groupid]; for(k=0;k<cli.feature_num;k++) { for(ii=0;ii<group.size();ii++) { for(jj=0;jj<group[ii].size();jj++) { int findex=c->fmap[cli.fvector[k]+ii]; if(findex>=0) cur_path[group[ii][jj]]+=c->lambda[findex]; } } } } }}void crf_thread::build_lattice(sequence1 &seq1){ int i,j,k; if(path.size()<c->path_num*(seq1.vertex_num-1) + c->ysize*seq1.vertex_num) path.resize(c->path_num*(seq1.vertex_num-1) + c->ysize*seq1.vertex_num); if(c->chain_type==FIRST_CHAIN) fill(path.begin(),path.end(),0); else if(c->chain_type==SIMPLE_CHAIN && c->algorithm==CRF_ALGORITHM) memcpy((double*)&path[0],c->transit_buff,sizeof(double)*path.size()); else if(c->chain_type==SIMPLE_CHAIN && (c->algorithm==AP_ALGORITHM || c->algorithm==PA_ALGORITHM)) { fill(path.begin(),path.end(),0); for(i=0;i<c->path_num;i++) { int findex=c->fmap[c->transit+i]; if(findex>=0) { for(j=0;j<seq1.vertex_num-1;j++) path[2*c->ysize+(c->ysize+c->path_num)*j+i]=c->lambda[findex]; } } } for(i=0;i<seq1.vertex_num;i++) { int cur_path_offset=i?(c->path_num+ c->ysize)*i-c->path_num:0; vertex &vtx=seq1.vertexes[i]; for(j=0;j<vtx.feature_num;j++) { for(k=0;k<c->ysize;k++) { int findex=c->fmap[vtx.fvector[j]+k]; if(findex>=0) path[cur_path_offset+k]+=c->lambda[findex]; } } if(i && c->chain_type==FIRST_CHAIN) { edge &e=seq1.edges[i-1]; for(j=0;j<e.feature_num;j++) { for(k=0;k<c->path_num;k++) { int findex=c->fmap[e.fvector[j]+k]; if(findex>=0) path[cur_path_offset+c->ysize+k]+=c->lambda[findex]; } } } }}double crf_thread::path_cost(sequence &seq){ int i; double cost=0; for(i=0;i<seq.node_num;i++) cost+=path[c->path_num*i+seq.nodes[i].key]; return cost;}double crf_thread::path_cost(sequence1 &seq1){ int i; double cost=path[seq1.vertexes[0].key]; for(i=1;i<seq1.vertex_num;i++) { int offset=(c->path_num+c->ysize)*(i-1)+c->ysize; cost+=path[offset+seq1.vertexes[i].key%c->ysize]; cost+=path[offset+c->ysize+seq1.vertexes[i].key]; } return cost;}double crf_thread::seq_fx_gx(sequence &seq){ double s; double z; build_lattice(seq); s=path_cost(seq); forward_backward(seq,z); calculate_gradient(seq,z); return z-s;}double crf_thread::seq_fx_gx(sequence1 &seq1){ double s; double z; build_lattice(seq1); s=path_cost(seq1); forward_backward(seq1,z); calculate_gradient(seq1,z); return z-s;}void crf_thread::forward_backward(sequence &seq, double &z){ int i,j,k,ii; //forward int alpha_size=c->node_anum*seq.node_num; if(alpha.size()<alpha_size) { alpha.resize(alpha_size); beta.resize(alpha_size); first_cal.resize(alpha_size); } fill(alpha.begin(),alpha.end(),0); fill(first_cal.begin(),first_cal.end(),1); for(i=0;i<seq.node_num;i++) { double *cur_path=&path[c->path_num*i]; //cal alpha of current node if(i>0) { double *cur_alpha=&alpha[i*c->node_anum]; const double *last_alpha=&alpha[(i-1)*c->node_anum]; int *cur_first=&first_cal[i*c->node_anum]; for(j=0;j<c->path_num;j++) { ii=j % c->node_anum; k=j / c->ysize; if(!cur_first[ii]) { cur_alpha[ii]=log_sum_exp(last_alpha[k]+cur_path[j],cur_alpha[ii]); }else{ cur_alpha[ii]=last_alpha[k]+cur_path[j]; cur_first[ii]=0; } } }else{ double *cur_alpha=&alpha[i*c->node_anum]; int *cur_first=&first_cal[i*c->node_anum]; for(j=0;j<c->path_num;j++) { ii=j % c->node_anum; if(!cur_first[ii]) { cur_alpha[ii]=log_sum_exp(cur_path[j]+ c->head_offset ,cur_alpha[ii]); }else{ cur_alpha[ii]=cur_path[j]+ c->head_offset ; cur_first[ii]=0; } } } } //backward fill(beta.begin(),beta.end(),0); fill(first_cal.begin(),first_cal.end(),true); vector<double> last_path(c->path_num,0); for(i=seq.node_num-1;i>=0;i--) { //calculate beta of last node if(i<seq.node_num-1) { double *cur_beta=&beta[i*c->node_anum]; double *last_beta=&beta[(i+1)*c->node_anum]; int *cur_first=&first_cal[i*c->node_anum]; double *last_path=&path[c->path_num*(i+1)]; for(j=0;j<c->path_num;j++) { k=j % c->node_anum; ii=j / c->ysize; if(!cur_first[ii]) { cur_beta[ii]=log_sum_exp(last_beta[k]+last_path[j],cur_beta[ii]); }else{ cur_beta[ii]=last_beta[k]+last_path[j]; cur_first[ii]=0; } } }else{ double *cur_beta=&beta[i*c->node_anum]; for(j=0;j<c->node_anum;j++) cur_beta[j]=0; } } //calculate z(x) z=alpha[c->node_anum*(seq.node_num-1)]; for(i=1;i<c->node_anum;i++) z=log_sum_exp(z, alpha[c->node_anum*(seq.node_num-1)+i]);}void crf_thread::forward_backward(sequence1 &seq1, double &z){ int i,j,k,ii; //forward int alpha_size=c->ysize*seq1.vertex_num; if(alpha.size()<alpha_size) { alpha.resize(alpha_size); beta.resize(alpha_size); first_cal.resize(alpha_size); } fill(alpha.begin(),alpha.end(),0); fill(first_cal.begin(),first_cal.end(),1); int cur_path_offset=-1; int last_path_offset; for(i=0;i<seq1.vertex_num;i++) { last_path_offset=cur_path_offset; cur_path_offset=i?(c->path_num+ c->ysize)*i-c->path_num:0; double *cur_path=&path[cur_path_offset]; //cal alpha of current node if(i>0) { double *cur_alpha=&alpha[i*c->ysize]; const double *last_alpha=&alpha[(i-1)*c->ysize]; int *cur_first=&first_cal[i*c->ysize]; for(j=0;j<c->path_num;j++) { ii=j % c->ysize; k=j / c->ysize; if(!cur_first[ii]) { cur_alpha[ii]=log_sum_exp(last_alpha[k]+cur_path[j+c->ysize],cur_alpha[ii]); }else{ cur_alpha[ii]=last_alpha[k]+cur_path[j+c->ysize]; cur_first[ii]=0; } } for(j=0;j<c->ysize;j++) cur_alpha[j]+=cur_path[j]; }else{ double *cur_alpha=&alpha[0]; for(j=0;j<c->ysize;j++) cur_alpha[j]=cur_path[j]; } } //backward fill(beta.begin(),beta.end(),0); fill(first_cal.begin(),first_cal.end(),1); vector<double> last_path(c->path_num,0); last_path_offset=-1; for(i=seq1.vertex_num-1;i>=0;i--) { cur_path_offset=i?(c->path_num+ c->ysize)*i-c->path_num:0; //calculate beta of last node if(i<seq1.vertex_num-1) { double *cur_beta=&beta[i*c->ysize]; double *last_beta=&beta[(i+1)*c->ysize]; int *cur_first=&first_cal[i*c->ysize]; double *last_path=&path[last_path_offset]; for(j=0;j<c->path_num;j++) { k=j % c->ysize; ii=j / c->ysize; if(!cur_first[ii]){ cur_beta[ii]=log_sum_exp(last_beta[k]+last_path[j+c->ysize]+last_path[k],cur_beta[ii]); }else{ cur_beta[ii]=last_beta[k]+last_path[j+c->ysize]+last_path[k]; cur_first[ii]=0; } } }else{ double *cur_beta=&beta[i*c->ysize]; for(j=0;j<c->ysize;j++) cur_beta[j]=0; } last_path_offset=cur_path_offset; } //calculate z(x) z=alpha[c->ysize*(seq1.vertex_num-1)]; for(i=1;i<c->ysize;i++) z=log_sum_exp(z, alpha[c->ysize*(seq1.vertex_num-1)+i]);}void crf_thread::calculate_gradient(sequence &seq, double &z){ int i,j,k,ii,jj; if(!margin.size()) margin.resize(c->path_num); for(i=0;i<seq.node_num;i++) { double *cur_path=&path[c->path_num*i]; double *cur_beta=&beta[c->node_anum*i]; node &nod=seq.nodes[i]; fill(margin.begin(),margin.end(),0); if(i>0) { double *last_alpha=&alpha[c->node_anum*(i-1)]; for(j=0;j<c->path_num;j++) margin[j]=exp(cur_path[j] + last_alpha[j / c->ysize] + cur_beta[j % c->node_anum] - z); for(j=0;j<nod.clique_num;j++) { if(!nod.cliques[j]) continue; clique &cli=*nod.cliques[j];//get j th clique vector<vector<int> > &group=c->templet_group[cli.groupid]; for(k=0;k<group.size();k++) { for(ii=0;ii<group[k].size();ii++) { for(jj=0;jj<cli.feature_num;jj++) { int findex=c->fmap[cli.fvector[jj]+k]; if(findex>=0) gradient[findex]+=margin[group[k][ii]]; } } } for(k=0;k<cli.feature_num;k++) { int findex=c->fmap[cli.fvector[k]+cli.key]; if(findex>=0) gradient[findex]--; } } }else{//first node for(j=0;j<c->path_num;j++) margin[j]=exp(cur_path[j] + c->head_offset + cur_beta[j % c->node_anum] - z); for(j=0;j<nod.clique_num;j++) { if(!nod.cliques[j]) continue; clique &cli=*nod.cliques[j];//get j th clique vector<vector<int> > &group=c->templet_group[cli.groupid]; for(k=0;k<group.size();k++) { for(ii=0;ii<group[k].size();ii++) { for(jj=0;jj<cli.feature_num;jj++) { int findex=c->fmap[cli.fvector[jj]+k]; if(findex>=0) gradient[findex]+=margin[group[k][ii]]; } } } for(k=0;k<cli.feature_num;k++) { int findex=c->fmap[cli.fvector[k]+cli.key]; if(findex>=0) gradient[findex]--; } } } }}void crf_thread::calculate_gradient(sequence1 &seq1, double &z){ int i,j,k; double margin1; int cur_path_offset; for(i=0;i<seq1.vertex_num;i++) { cur_path_offset=i?(c->path_num+ c->ysize)*i-c->path_num:0; double *cur_path=&path[cur_path_offset]; double *cur_beta=&beta[c->ysize*i]; double *cur_alpha=&alpha[c->ysize*i]; vertex &vtx=seq1.vertexes[i]; for(j=0;j<c->ysize;j++) { margin1=exp(cur_alpha[j]+cur_beta[j]-z); for(k=0;k<vtx.feature_num;k++) { int findex=c->fmap[vtx.fvector[k]+j]; if(findex>=0) gradient[findex]+=margin1; } } for(j=0;j<vtx.feature_num;j++) { int findex=c->fmap[vtx.fvector[j] + vtx.key%c->ysize]; if(findex>=0) gradient[findex]--; } if(i>0) { double *last_alpha=&alpha[c->ysize*(i-1)]; if(c->chain_type==FIRST_CHAIN) { edge &e=seq1.edges[i-1]; for(j=0;j<c->path_num;j++) { margin1=exp(cur_path[j+c->ysize] + cur_path[j%c->ysize]+ last_alpha[j / c->ysize] + cur_beta[j % c->ysize] - z); for(k=0;k<e.feature_num;k++) { int findex=c->fmap[e.fvector[k]+j]; if(findex>=0) gradient[findex]+=margin1; } } for(j=0;j<e.feature_num;j++) { int findex=c->fmap[e.fvector[j]+vtx.key]; if(findex>=0) gradient[findex]--; } }else if(c->chain_type==SIMPLE_CHAIN){ int findex; for(j=0;j<c->path_num;j++) { margin1=exp(cur_path[j+c->ysize] + cur_path[j%c->ysize]+ last_alpha[j / c->ysize] + cur_beta[j % c->ysize] - z); findex=c->fmap[c->transit+j]; if(findex>=0) gradient[findex]+=margin1; } findex=c->fmap[c->transit+vtx.key]; if(findex>=0) gradient[findex]--; } } }}void crf_thread::run(){ if(c->algorithm==CRF_ALGORITHM) { memset(gradient,0,sizeof(double)*c->lambda_size); obj=0; if(c->chain_type==GENERAL_CHAIN) { for(int i = start_i; i < c->sequence_num; i += c->thread_num) { obj += seq_fx_gx (c->sequences[i]); } }else if(c->chain_type==FIRST_CHAIN || c->chain_type==SIMPLE_CHAIN){ for(int i = start_i; i < c->sequence_num; i += c->thread_num) { obj += seq_fx_gx (c->sequence1s[i]); } } }else if(c->algorithm==AP_ALGORITHM){ obj=0; for(int i = 0; i < c->sequence_num; i ++) { if(c->chain_type==GENERAL_CHAIN) { build_lattice(c->sequences[i]); viterbi(c->sequences[i]); ap_update(c->sequences[i]); }else if(c->chain_type==FIRST_CHAIN || c->chain_type==SIMPLE_CHAIN){ build_lattice(c->sequence1s[i]); viterbi(c->sequence1s[i]); ap_update(c->sequence1s[i]); } times--; } }else if(c->algorithm==PA_ALGORITHM){ obj=0; for(int i = 0; i < c->sequence_num; i ++) { if(c->chain_type==GENERAL_CHAIN) { build_lattice(c->sequences[i]); viterbi(c->sequences[i]); pa_update(c->sequences[i]); }else if(c->chain_type==FIRST_CHAIN || c->chain_type==SIMPLE_CHAIN){ build_lattice(c->sequence1s[i]); viterbi(c->sequence1s[i]); pa_update(c->sequence1s[i]); } times--; } }}void crf_thread::assign_tag(sequence &seq, vector<int> &node_tag){ int i,j,k; for(i=0;i<seq.node_num;i++){ seq.nodes[i].key=0; for(j=0;j<=c->order;j++){ if(i+j>=c->order) seq.nodes[i].key=seq.nodes[i].key*c->ysize+node_tag[i+j-c->order]; } } for(i=0;i<seq.node_num;i++) { node &nod=seq.nodes[i]; for(j=0;j<nod.clique_num;j++) { if(!nod.cliques[j]) continue; clique &cli=*(nod.cliques[j]); int key=0; for(k=0;k<cli.node_num;k++) key= key*c->ysize +cli.nodes[k]->key%c->ysize; cli.key=key; } }}void crf_thread::ap_update(sequence &seq){ int i,j,k;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -