📄 crf.cpp
字号:
#include <stdio.h>#include <stdlib.h>#include <string.h>#include <assert.h>#include <math.h>#include <ctime>#ifdef _WIN32#include <io.h>#endif#include "crf.h"#include "fun.h"#include "lbfgs.h"const int PAGESIZE = 8192;const int MAXSTRLEN = 131072;//16 pagesbool CRF::learn(char* templet_file, char *training_file,char *model_file){ printf("pocket crf\nversion 0.%d\nCopyright(c)2005-2008 Xian Qian, all rights reserved\n",version); if(!load_templet(templet_file)) return false; printf("templates loaded\n"); set_chain_type(); if(!check_training(training_file)) return false; if(!generate_features(training_file)) return false; printf("training data loaded\n"); shrink_feature(); printf("features shrinked\n"); fmap=&fmap_tmp[0]; //temporary set fmap_size=fmap_tmp.size(); sequences=&sequences_tmp[0]; //temporary set sequence_num=sequences_tmp.size(); write_model(model_file,true); fmap=NULL; sequences=NULL; //free memory tags.clear(); vector<char *>(tags).swap(tags); tag_str.clear(); xindex.clear(); map<char *,int,str_cmp>(xindex).swap(xindex); x_freq.clear(); vector<int>(x_freq).swap(x_freq); x_str.clear(); templets.clear(); vector<templet>(templets).swap(templets); vector<vector<vector<int> > >(templet_group).swap(templet_group); compress(); //begin learning //set global temporary variables path_num=pow((double)ysize,order+1); node_anum=pow((double)ysize,order);//alpha(beta) number of each node head_offset=-log((double)node_anum); //initialize int i,j,k; lambda=new double[lambda_size]; gradient=new double[lambda_size]; threads[0].gradient=gradient; if(algorithm==CRF_ALGORITHM) { if(prior==MEM_PRIOR) unload(); printf("algorithm: crf\nsequence number: %d\nparameter number: %d\nsigma: %g\nl1: %g\nfreq_thresh: %d\n",sequence_num,lambda_size,sigma,l1,freq_thresh); for(i=1;i<thread_num;i++) threads[i].gradient=new double[lambda_size]; int orthant = l1>0?1:0; for(;orthant>=0;orthant--) { int converge=0; double old_fvalue=1e+37; clock_t start_time=clock(); LBFGS l; l.init(lambda_size,depth,prior); memset(lambda,0,lambda_size*sizeof(double)); for(i=0;i<max_iter;i++) { fvalue=0; if(chain_type==SIMPLE_CHAIN) { memset(transit_buff,0,sizeof(double)*(ysize*max_seq_size+path_num*(max_seq_size-1))); for(j=0;j<path_num;j++) { int findex=fmap[transit+j]; if(findex>=0) { for(k=0;k<max_seq_size-1;k++) transit_buff[ysize+(ysize+path_num)*k+ysize+j]=lambda[findex]; } } } for(j=0;j<thread_num;j++) threads[j].start(); for(j=0;j<thread_num;j++) threads[j].join(); for(j=0;j<thread_num;j++) { fvalue += threads[j].obj; } for(j=0;j<lambda_size;j++) for(k=1;k<thread_num;k++) gradient[j]+=threads[k].gradient[j]; if(orthant) { for(j=0;j<lambda_size;j++) { fvalue+=fabs( lambda[j]) / l1; if(lambda[j]==0.0) { if(gradient[j]<-1/l1) { gradient[j]+=1/l1; }else if(gradient[j]>1/l1){ gradient[j]-=1/l1; }else gradient[j]=0; }else{ if(lambda[j]>0) { gradient[j]+=1/l1; }else{ gradient[j]-=1/l1; } } } }else{ //add sigma||\lambda||^2 for(j=0;j<lambda_size;j++) { fvalue+=lambda[j] * lambda[j] / (2*sigma); gradient[j]+=lambda[j]/sigma; } } double diff=i>0? fabs((old_fvalue-fvalue)/old_fvalue) : 1; if(orthant) { int selected=0; for(j=0;j<lambda_size;j++) { if(lambda[j]!=0) selected++; } printf("iter: %d act: %d diff: %lf obj: %lf\n", i, selected, diff, fvalue); }else printf("iter: %d diff: %lf obj: %lf\n", i, diff, fvalue); old_fvalue=fvalue; if(diff<eta) converge++; else converge=0; if(i==max_iter||converge==3) break;//success double *w0,*w1; if(prior==SPEED_PRIOR){ w0=NULL; w1=NULL; }else{ w0=(double *)work_space; w1=((double *)work_space)+lambda_size; } if(l.optimize(lambda,&fvalue,gradient,orthant,w0,w1)<=0) { printf("lbfgs error\n"); break;//error } if(prior==MEM_PRIOR) load(); } if(orthant) { adjust_data(); if(prior==MEM_PRIOR) unload();//save data structure } double elapse = static_cast<double>(clock() - start_time) / CLOCKS_PER_SEC; if(orthant) printf("crf feature selection elapse: %g s, %d features selected\n",elapse,lambda_size); else printf("crf optimization elapse: %g s\n",elapse); } }else if(algorithm==AP_ALGORITHM){ printf("algorithm: ap\nsequence number: %d\nparameter number: %d\nfreq_thresh: %d\n",sequence_num,lambda_size,freq_thresh); clock_t start_time=clock(); memset(gradient,0,lambda_size*sizeof(double)); memset(lambda,0,lambda_size*sizeof(double)); threads[0].times=max_iter*sequence_num; for(i=0;i<max_iter;i++) { threads[0].start(); threads[0].join(); fvalue=threads[0].obj/total_nodes; printf("iter: %d err: %lf \n",i,fvalue); } for(i=0;i<lambda_size;i++) lambda[i]=gradient[i]/(max_iter*sequence_num); adjust_data(); double elapse = static_cast<double>(clock() - start_time) / CLOCKS_PER_SEC; printf("ap optimization elapse: %g s\n",elapse); }else if(algorithm==PA_ALGORITHM){ printf("algorithm: pa\nsequence number: %d\nparameter number: %d\nC: %g\nfreq_thresh: %d\n",sequence_num,lambda_size,sigma,freq_thresh); clock_t start_time=clock(); memset(gradient,0,lambda_size*sizeof(double)); memset(lambda,0,lambda_size*sizeof(double)); threads[0].times=max_iter*sequence_num; for(i=0;i<max_iter;i++) { threads[0].start(); threads[0].join(); fvalue=threads[0].obj/total_nodes; printf("iter: %d err: %lf \n",i,fvalue); } for(i=0;i<lambda_size;i++) lambda[i]=gradient[i]/(max_iter*sequence_num); adjust_data(); double elapse = static_cast<double>(clock() - start_time) / CLOCKS_PER_SEC; printf("pa optimization elapse: %g s\n",elapse); } write_model(model_file,false); //free all members' memory templet_group.clear(); if(lambda) { delete [] lambda; lambda=NULL; } if(gradient) { delete [] gradient; gradient=NULL; } for(i=1;i<thread_num;i++) { delete [] threads[i].gradient; threads[i].gradient=NULL; } return true;}CRF::CRF(){ depth=5; algorithm=CRF_ALGORITHM; transit_buff=NULL; work_space=NULL; sequences=NULL; sequence1s=NULL; fmap=NULL; sigma=1; prior=SPEED_PRIOR; freq_thresh=0; margin=false; nbest=1; x_str.set_size(PAGESIZE * 16);//each time alloc 16 pages tag_str.set_size(256);//each time alloc 256 bytes nodes.set_size(PAGESIZE); cliques.set_size(PAGESIZE*16); clique_node.set_size(PAGESIZE*16); node_clique.set_size(PAGESIZE*16); clique_feature.set_size(PAGESIZE*16); order=-1; max_iter=10000; eta=0.0001; lambda=NULL; gradient=NULL; l1=0; version=45;//0.45 thread_num=1; threads.resize(thread_num); threads.front().start_i = 0; threads.front().obj=0; threads.front().c=this;}CRF::CRF(double c, int freq_t,bool margin_, int candidate_num_){ depth=5; algorithm=CRF_ALGORITHM; transit_buff=NULL; work_space=NULL; sequences=NULL; sequence1s=NULL; fmap=NULL; sigma=c; prior=SPEED_PRIOR; freq_thresh=freq_t; margin=margin_; nbest=candidate_num_; x_str.set_size(PAGESIZE * 16);//each time alloc 16 pages tag_str.set_size(256);//each time alloc 256 bytes nodes.set_size(PAGESIZE); cliques.set_size(PAGESIZE*16); clique_node.set_size(PAGESIZE*16); node_clique.set_size(PAGESIZE*16); clique_feature.set_size(PAGESIZE*16); order=-1; max_iter=10000; eta=0.0001; lambda=NULL; gradient=NULL; l1=0; version=45;//0.45 thread_num=1; threads.resize(thread_num); threads.front().start_i = 0; threads.front().obj=0; threads.front().c=this;}CRF::~CRF(){ if(transit_buff) { delete [] transit_buff; transit_buff=NULL; } if(lambda) { delete [] lambda; lambda=NULL; } if(gradient) { delete [] gradient; gradient=NULL; } if(work_space) { delete [] work_space; work_space=NULL; sequences=NULL; sequence1s=NULL; fmap=NULL; } if(sequences) { delete [] sequences; sequences=NULL; } if(sequence1s) { delete [] sequence1s; sequence1s=NULL; } if(fmap) { delete [] fmap; fmap=NULL; } if(!access("__data1",0)) unlink("__data1"); if(!access("__data2",0)) unlink("__data2");}void CRF::set_chain_type(){ if(order!=1) { chain_type=GENERAL_CHAIN; return; }else{ for(int i=0;i<templets.size();i++) { bool null_word=true; for(int j=0;j<templets[i].words.size();j++) if(templets[i].words[j]!="") { null_word=false; break; } if(templets[i].y.size()>1 && (templets[i].x.size() || !null_word)) { chain_type=FIRST_CHAIN; return; } } } chain_type=SIMPLE_CHAIN;}bool CRF::set_order(){ //set groupid,end_of_group,order if(!templets.size())// no templets return false; order=-templets[0].y[0]; templets[0].groupid=0; templets[0].end_of_group=false; for(int i=1;i<templets.size();i++) { templet &n=templets[i]; templet &last_n=templets[i-1]; int j; for(j=0;j<last_n.y.size();j++) { if(last_n.y[j]!=n.y[j]) {//start of group last_n.end_of_group=true; n.groupid=last_n.groupid+1; if(-n.y[0]>order) order=-n.y[0]; break; } } if(j==last_n.y.size())//not start of group { n.groupid=last_n.groupid; last_n.end_of_group=false; } } templets.back().end_of_group=true; gsize=templets.back().groupid+1; templet_group.resize(gsize); return true;}bool CRF::set_para(char *para_name, char *para_value){ if(!strcmp(para_name,"sigma")) sigma=atof(para_value); else if(!strcmp(para_name,"freq_thresh")) freq_thresh=atoi(para_value); else if(!strcmp(para_name,"margin")) margin=atoi(para_value); else if(!strcmp(para_name,"nbest")) nbest=atoi(para_value); else if(!strcmp(para_name,"max_iter")) max_iter=atoi(para_value); else if(!strcmp(para_name,"depth")) depth=atoi(para_value); else if(!strcmp(para_name,"eta")) eta=atof(para_value); else if(!strcmp(para_name,"l1")) l1=atof(para_value); else if(!strcmp(para_name,"prior")) prior=atoi(para_value); else if(!strcmp(para_name,"thread_num")) { if(algorithm==CRF_ALGORITHM) { thread_num=atoi(para_value); threads.resize(thread_num); for(int i=0;i<thread_num;i++) { threads[i].c=this; threads[i].start_i = i; threads[i].obj=0; } } }else if(!strcmp(para_name,"algorithm")){ algorithm=atoi(para_value); if(algorithm==AP_ALGORITHM || algorithm==PA_ALGORITHM) { thread_num=1; threads.resize(1); threads[0].start_i = 0; threads[0].obj=0; threads[0].c=this; } } return true;}bool CRF::add_templet(char *line){ if(!line[0]||line[0]=='#') return false; templet n; char *p=line,*q; char word[1000]; char index_str[1000]; int index1,index2; while(q=catch_string(p,"%x[",word)) { p=q; n.words.push_back(word); p=catch_string(p,",",index_str); index1=atoi(index_str); p=catch_string(p,"]",index_str); index2=atoi(index_str); n.x.push_back(make_pair(index1,index2)); } q=catch_string(p,"%y[",word); if(!q) { printf("templet: %s incorrect\n",line); return false; } n.words.push_back(word); p=q-3; while(p=catch_string(p,"%y[","]",index_str)) { index1=atoi(index_str); n.y.push_back(index1); } int insert_pos; vector_search(templets,n,index1,insert_pos,templet_cmp()); vector_insert(templets,n,insert_pos); return true;}bool CRF::load_templet(char *templet_file){ FILE *fp; //read template fp=fopen(templet_file,"r"); if (!fp) { printf("template file: %s not found\n",templet_file); return false; } char line[MAXSTRLEN]; while(fgets(line,MAXSTRLEN-1,fp)) { trim_line(line); add_templet(line); } fclose(fp); return set_order();}bool CRF::check_training(char *training_file){ FILE *fp; if((fp=fopen(training_file,"r"))==NULL) return false; char line[MAXSTRLEN]; int lines=0; cols=0; while(fgets(line,MAXSTRLEN-1,fp)) { lines++; if(strlen(line)==1) continue; trim_line(line); vector<char *>columns; if(!split_string(line,"\t",columns)) { printf("columns should be greater than 1\n"); fclose(fp); return false;//columns should greater than 1 } if(cols && cols!=columns.size()) {//incompatible printf("line %d: columns incompatible\n",lines); fclose(fp); return false; } cols=columns.size(); if(cols<2){ printf("columns should be greater than 1\n"); fclose(fp); return false;//columns should greater than 1 } char *t=columns.back();//tag int index,insert_pos; if(!vector_search(tags,t,index,insert_pos,str_cmp())) { char *p=tag_str.push_back(t);//copy string vector_insert(tags,p,insert_pos); } } fclose(fp); ysize=tags.size(); set_group(); return true;}bool CRF::generate_features(char *filename){ char line[MAXSTRLEN]; max_seq_size=0; total_nodes=0; vector<char *>table;// table[i,j] = table[i*cols+j] charlist table_str; table_str.set_size(PAGESIZE);//1 page memory lambda_size=0;//lambda size int lines=0; int i; FILE *fp=fopen(filename,"r"); while(fgets(line,MAXSTRLEN-1,fp)) { trim_line(line); if(line[0]) { vector<char *>columns; split_string(line,"\t",columns); for(i=0;i<cols;i++) { char *p=table_str.push_back(columns[i]); table.push_back(p); } }else{//get one sequence if(table.size())//non-empty line { if(max_seq_size<table.size()) max_seq_size=table.size(); total_nodes+=table.size(); add_x(table); table.clear();//prepare for new table table_str.free(); lines++; if(!(lines%100))
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -