📄 crf.cpp
字号:
printf("%d.. ",lines); } } } if(table.size())//non-empty line { if(max_seq_size<table.size()) max_seq_size=table.size(); total_nodes+=table.size(); add_x(table); table.clear();//prepare for new table table_str.free(); lines++; if(!(lines%100)) printf("%d.. ",lines); } fclose(fp); //set transit_buff if(chain_type==SIMPLE_CHAIN && max_seq_size && algorithm==CRF_ALGORITHM) transit_buff=new double[ysize*max_seq_size+ysize*ysize*(max_seq_size-1)]; return true;}void CRF::set_group(){ //calculate templet_group //set the size of each group //order and tags.size() of CRF must be known int i,j; for(i=0,j=0;i<templets.size();i++) { if(templets[i].end_of_group) { int n=pow((double)ysize,(int)templets[i].y.size()); templet_group[j++].resize(n);//group j has n offsets } } vector<int> path_index(order+1,0); int path_size=pow((double)ysize,order+1); for(i=0;i<path_size;i++)//assosiate path i with templet_group { int cur_group=0; for(j=0;j<templets.size();j++) { if(templets[j].end_of_group) { vector<int> &ytemp=templets[j].y; int k,offset; vector<int> temp; for(k=0,offset=0;k<ytemp.size();k++) offset=offset*ysize+path_index[-ytemp[k]]; templet_group[cur_group++][offset].push_back(i);//path i added to current group's offset } } for(j=0;j<order+1 && path_index[j]==ysize-1;j++); if(j==order+1) break; path_index[j]++; for(j--;j>=0;j--) path_index[j]=0; } for(i=0;i<templet_group.size();i++) for(j=0;j<templet_group[i].size();j++) { ((vector<int>)(templet_group[i][j])).swap(templet_group[i][j]); }}bool CRF::add_x(vector<char *> &table){ int i,j,k,c; int rows=table.size()/cols; char s[1024]; char s1[1024]; char s2[1024]; sequence seq; vector<int> y; node* nod=nodes.alloc(rows); seq.node_num=rows; seq.nodes=nod; sequences_tmp.push_back(seq); vector<vector<char *> > ext_table(table.size());//split each unit by " " for(i=0;i<table.size();i++){ vector<char *> units; split_string(table[i]," ",units); for(j=0;j<units.size();) { if(!units[j][0]) units.erase(units.begin()+j); else j++; } ext_table[i]=units; } //below, using ext_table for(i=0;i<rows;i++) { y.resize(y.size()+1); vector_search(tags,ext_table[(i+1)*cols-1][0],y.back(),j,str_cmp());//get the tag of current node,j is invalid nod[i].key=0; for(j=i-order;j<=i;j++) if(j>=0) nod[i].key=nod[i].key*ysize+y[j]; vector<clique*> clisp;//features that affect on current nodes vector<int> feature_vector; for(j=0;j<templets.size();j++) { //get first y's offset templet &pat=templets[j]; if(pat.y[0]+i<0) continue; if(pat.x.size()>0){//if has x int index1,index2; bool has_xstring=true;//false, if unit="" vector<int> xid(pat.x.size(),0);//xid=(0,0): 0 th units + 0 th units vector<int> xtop(pat.x.size(),0); //get xtop for(k=0;k<pat.x.size();k++) { index1=pat.x[k].first+i; index2=pat.x[k].second; if(index1<0) { xtop[k]=1; }else if(index1>=rows){ xtop[k]=1; }else if(!ext_table[index1*cols+index2].size()){//no string here xtop[k]=0; has_xstring=false; break; }else{ xtop[k]=ext_table[index1*cols+index2].size(); } } if(has_xstring) { xid.back()=-1; while(1) { //increase and check whether stop for(k=xid.size()-1;k>=0 && xid[k]+1 == xtop[k];k--) xid[k]=0; if(k<0) break;//stop xid[k]++; sprintf(s, "%d", j); strcat(s,":"); //get x for(k=0;k<pat.x.size();k++) { strcat(s,pat.words[k].c_str()); strcat(s,"//"); index1=pat.x[k].first+i; index2=pat.x[k].second; assert(index2>=0 && index2<cols-1); if(index1<0) { index1=-index1-1; strcpy(s1,"B_"); sprintf(s2,"%d",index1); strcat(s1,s2);//B_0 for example }else if(index1>=rows){ index1-=rows; strcpy(s1,"E_"); sprintf(s2,"%d",index1); strcat(s1,s2);//E_0 for example }else{ assert(ext_table[index1*cols+index2].size()); strcpy(s1,ext_table[index1*cols+index2][xid[k]]); } strcat(s,s1); strcat(s,"//"); } strcat(s,pat.words[k].c_str()); //x obtained, insert x int index;//index of feature s if(insert_x(s,index)) { c=pow((double)ysize,(int)pat.y.size()); lambda_size+=c; fmap_tmp.resize(lambda_size,0); } //get clique feature_vector.push_back(index); } } }else{//else , no x sprintf(s, "%d", j); strcat(s,":"); strcat(s,pat.words[0].c_str()); //x obtained, insert x int index;//index of feature s if(insert_x(s,index)) { c=pow((double)ysize,(int)pat.y.size()); lambda_size+=c; fmap_tmp.resize(lambda_size,0); } //get clique feature_vector.push_back(index); } if(pat.end_of_group) {//creat new clique clique cli; vector<node*> ns; int key=0; for(k=0;k<pat.y.size();k++) { ns.push_back(nod+i+pat.y[k]); key=key*ysize+ y[i+pat.y[k]]; } //set feature count for(k=0;k<feature_vector.size();k++) { fmap_tmp[feature_vector[k]+key]++; } node ** np=clique_node.push_back(&ns[0],ns.size()); cli.nodes=np; cli.node_num=ns.size(); cli.key=key; int *f=NULL; if(feature_vector.size()) f=clique_feature.push_back(&feature_vector[0],feature_vector.size()); cli.fvector=f; cli.feature_num=feature_vector.size(); cli.groupid=templets[j].groupid; clique *new_clique=cliques.push_back(&cli,1); clisp.push_back(new_clique); feature_vector.clear(); } } //set node -> clique if(clisp.size()) nod[i].cliques = node_clique.push_back(&clisp[0],clisp.size()); else nod[i].cliques = NULL; nod[i].clique_num =clisp.size(); } return true;}bool CRF::insert_x(char *target, int &index){ map<char *, int , str_cmp>::iterator p; p=xindex.find(target); if(p!=xindex.end()) { index=p->second; x_freq[index]++; return false; }else{ char *q=x_str.push_back(target); xindex.insert(make_pair(q,lambda_size)); index=lambda_size; x_freq.resize(index+1,0); x_freq[index]=1; return true; }}void CRF::compress(){ //count bytes unsigned long bytes=0; int i,j,k,ii; if(chain_type==GENERAL_CHAIN){ //sequence data for(i=0;i<sequences_tmp.size();i++) { sequence &seq=sequences_tmp[i]; for(j=0;j<seq.node_num;j++) { node &nod=seq.nodes[j]; for(k=0;k<nod.clique_num;k++) { if(!nod.cliques[k]) continue; clique &cli=*nod.cliques[k]; bytes+=cli.feature_num*sizeof(int); bytes+=cli.node_num*sizeof(node*); bytes+=sizeof(clique); } bytes+=nod.clique_num*sizeof(clique*); } bytes+=seq.node_num*sizeof(node); bytes+=sizeof(sequence); } }else if(chain_type==FIRST_CHAIN){ //sequence data for(i=0;i<sequences_tmp.size();i++) { sequence &seq=sequences_tmp[i]; for(j=0;j<seq.node_num;j++) { node &nod=seq.nodes[j]; for(k=0;k<nod.clique_num;k++) { if(!nod.cliques[k]) continue; clique &cli=*nod.cliques[k]; if(templet_group[cli.groupid].size()==ysize)//vertex feature bytes+=cli.feature_num*sizeof(int); else bytes+=cli.feature_num*sizeof(int);//edge feature } } bytes+=seq.node_num*sizeof(vertex); bytes+=(seq.node_num-1)*sizeof(edge); bytes+=sizeof(sequence1); } }else if(chain_type==SIMPLE_CHAIN){ //sequence data for(i=0;i<sequences_tmp.size();i++) { sequence &seq=sequences_tmp[i]; for(j=0;j<seq.node_num;j++) { node &nod=seq.nodes[j]; for(k=0;k<nod.clique_num;k++) { if(!nod.cliques[k]) continue; clique &cli=*nod.cliques[k]; if(templet_group[cli.groupid].size()==ysize)//vertex feature bytes+=cli.feature_num*sizeof(int); } } bytes+=seq.node_num*sizeof(vertex); bytes+=sizeof(sequence1); } } //fmap bytes+=fmap_tmp.size()*sizeof(int); if(prior== MEM_PRIOR && algorithm==CRF_ALGORITHM && bytes<lambda_size*sizeof(double)*2) bytes=lambda_size*sizeof(double)*2; work_size=bytes; //allocate work_space=new char[work_size]; char *p=work_space; //copy //sequence data if(chain_type==GENERAL_CHAIN) { memcpy(p,&sequences_tmp[0],sequences_tmp.size()*sizeof(sequence)); sequence_num=sequences_tmp.size(); sequences_tmp.clear(); vector<sequence>(sequences_tmp).swap(sequences_tmp); sequences=(sequence *)p; p+=sequence_num*sizeof(sequence); for(i=0;i<sequence_num;i++) { sequence &seq=sequences[i]; memcpy(p,seq.nodes,sizeof(node)*seq.node_num); node *tmp_nodes=seq.nodes; seq.nodes=(node *)p; p+=sizeof(node)*seq.node_num; for(j=0;j<seq.node_num;j++) { node &nod=seq.nodes[j]; vector<clique*> clisp(nod.clique_num); for(k=0;k<nod.clique_num;k++) { if(!nod.cliques[k]) { clisp[k]=NULL; continue; } clique &cli=*nod.cliques[k]; if(cli.feature_num) { memcpy(p,cli.fvector,cli.feature_num*sizeof(int)); cli.fvector=(int*) p; p+=cli.feature_num*sizeof(int); }//else cli.fvector=NULL; vector<node*> ns(cli.node_num); for(ii=0;ii<cli.node_num;ii++) ns[ii]=cli.nodes[ii]-tmp_nodes+seq.nodes; memcpy(p,&ns[0],cli.node_num*sizeof(node*)); cli.nodes=(node**) p; p+=cli.node_num*sizeof(node*); memcpy(p,&cli,sizeof(clique)); clisp[k]=(clique*)p; p+=sizeof(clique); } memcpy(p,&clisp[0],sizeof(clique*)*nod.clique_num); nod.cliques=(clique**)p; p+=sizeof(clique*)*nod.clique_num; } } }else if(chain_type==FIRST_CHAIN){ sequence_num=sequences_tmp.size(); vector<sequence1> seq1s(sequence_num); for(i=0;i<sequence_num;i++) { sequence &seq=sequences_tmp[i]; seq1s[i].vertex_num=seq.node_num; vector<vertex> vertexes(seq.node_num); vector<edge> edges(seq.node_num-1); for(j=0;j<seq.node_num;j++) { node &nod=seq.nodes[j]; for(k=0;k<nod.clique_num;k++) { if(!nod.cliques[k]) continue; clique &cli=*nod.cliques[k]; if(templet_group[cli.groupid].size()==ysize)//vertex feature { memcpy(p,cli.fvector,cli.feature_num*sizeof(int)); vertexes[j].fvector=(int *)p; p+=cli.feature_num*sizeof(int); vertexes[j].feature_num=cli.feature_num; vertexes[j].key=nod.key; }else{//edge feature memcpy(p,cli.fvector,cli.feature_num*sizeof(int)); edges[j-1].fvector=(int *)p; p+=cli.feature_num*sizeof(int); edges[j-1].feature_num=cli.feature_num; } } } memcpy(p,&vertexes[0],seq.node_num*sizeof(vertex)); seq1s[i].vertexes=(vertex*)p; p+=seq.node_num*sizeof(vertex); if(seq.node_num>1){ memcpy(p,&edges[0],edges.size()*sizeof(edge)); seq1s[i].edges=(edge*)p; p+=edges.size()*sizeof(edge); }else{ seq1s[i].edges=NULL; } } memcpy(p,&seq1s[0],seq1s.size()*sizeof(sequence1)); sequence1s=(sequence1*)p; p+=seq1s.size()*sizeof(sequence1); sequences_tmp.clear(); vector<sequence>(sequences_tmp).swap(sequences_tmp); }else if(chain_type==SIMPLE_CHAIN){ transit=-1; sequence_num=sequences_tmp.size(); vector<sequence1> seq1s(sequence_num); for(i=0;i<sequence_num;i++) { sequence &seq=sequences_tmp[i]; seq1s[i].vertex_num=seq.node_num; vector<vertex> vertexes(seq.node_num); for(j=0;j<seq.node_num;j++) { node &nod=seq.nodes[j]; for(k=0;k<nod.clique_num;k++) { if(!nod.cliques[k]) continue; clique &cli=*nod.cliques[k]; if(templet_group[cli.groupid].size()==ysize)//vertex feature { memcpy(p,cli.fvector,cli.feature_num*sizeof(int)); vertexes[j].fvector=(int *)p; p+=cli.feature_num*sizeof(int); vertexes[j].feature_num=cli.feature_num; vertexes[j].key=nod.key; }else if(transit<0){//edge feature transit=cli.fvector[0]; } } } memcpy(p,&vertexes[0],seq.node_num*sizeof(vertex)); seq1s[i].vertexes=(vertex*)p; p+=seq.node_num*sizeof(vertex); seq1s[i].edges=NULL; } memcpy(p,&seq1s[0],seq1s.size()*sizeof(sequence1)); sequence1s=(sequence1*)p; p+=seq1s.size()*sizeof(sequence1); sequences_tmp.clear(); vector<sequence>(sequences_tmp).swap(sequences_tmp); } nodes.clear(); cliques.clear(); clique_node.clear();//clique affected nodes node_clique.clear();//node->clique clique_feature.clear();//clique->feature //fmap memcpy(p,&fmap_tmp[0],sizeof(int)*fmap_tmp.size()); fmap_size=fmap_tmp.size(); fmap_tmp.clear(); vector<int>(fmap_tmp).swap(fmap_tmp); fmap=(int *)p; p+=sizeof(int)*fmap_size;}void CRF::write_model(char *model_file, bool first_part){ FILE *fout; if(first_part) { fout=fopen(model_file,"w"); if(!fout) { printf("can not open model file: %s\n",model_file); return; } int i,j; //write version fprintf(fout,"version\t%d\n",version); //write templets for(i=0;i<templets.size();i++) { templet &cur_templet=templets[i]; for(j=0;j<cur_templet.x.size();j++) fprintf(fout,"%s%%x[%d,%d]",cur_templet.words[j].c_str(),cur_templet.x[j].first,cur_templet.x[j].second); fprintf(fout,"%s",cur_templet.words[j].c_str()); for(j=0;j<cur_templet.y.size();j++) fprintf(fout,"%%y[%d]",cur_templet.y[j]); fprintf(fout,"\n"); } fprintf(fout,"\n"); //write y fprintf(fout,"%d\n",ysize); for(i=0;i<tags.size();i++) fprintf(fout,"%s\n",tags[i]); fprintf(fout,"\n"); //write x fprintf(fout,"%d\n\n",cols); fprintf(fout,"%d\n",xindex.size()); map<char*, int, str_cmp>::iterator it; for(it = xindex.begin(); it != xindex.end(); it++) fprintf(fout,"%s\t%d\n",it->first,it->second); fprintf(fout,"\n"); fclose(fout); }else{ int i,j; if(l1>0 || algorithm==AP_ALGORITHM || algorithm==PA_ALGORITHM){ //compress fmap // index fmap ffmap[index]=new_index new_fmap[new_index]=fmap[index] // 0 0 0 0 // 1 -1 1 -1 // 2 -2 -1 // 3 -2 -1 // 4 1 2 1 // 5 2 3 2 // // vector<int> ffmap(fmap_size); vector<int> new_fmap; int compress_fmap_index=0; for(i=0;i<fmap_size;i++){ if(fmap[i]!=-2){ ffmap[i]=compress_fmap_index++; new_fmap.push_back(fmap[i]); }else{ ffmap[i]=-1; } } char line[MAXSTRLEN]; FILE *fp; fp=fopen(model_file,"r"); fgets(line,MAXSTRLEN-1,fp);//skip version line //load templates while(fgets(line,MAXSTRLEN-1,fp)) { trim_line(line); if(!add_templet(line)) break; } set_order(); //get ysize fgets(line,MAXSTRLEN-1,fp); trim_line(line); ysize=atoi(line); tags.resize(ysize); for(i=0;i<ysize;i++) { fgets(line,MAXSTRLEN-1,fp); trim_line(line); char *q=tag_str.push_back(line); tags[i]=q; } set_group(); //get cols fgets(line,MAXSTRLEN-1,fp); fgets(line,MAXSTRLEN-1,fp); trim_line(line); cols=atoi(line); //load x int index; int x_num; fgets(line,MAXSTRLEN-1,fp); fgets(line,MAXSTRLEN-1,fp);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -