⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 crf.cpp

📁 pocket_crf_0.45
💻 CPP
📖 第 1 页 / 共 3 页
字号:
#include <stdio.h>#include <stdlib.h>#include <string.h>#include <assert.h>#include <math.h>#include <ctime>#ifdef _WIN32#include <io.h>#endif#include "crf.h"#include "fun.h"#include "lbfgs.h"const int PAGESIZE = 8192;const int MAXSTRLEN = 131072;//16 pagesbool CRF::learn(char* templet_file, char *training_file,char *model_file){	printf("pocket crf\nversion 0.%d\nCopyright(c)2005-2008 Xian Qian, all rights reserved\n",version);	if(!load_templet(templet_file))		return false;	printf("templates loaded\n");	set_chain_type();	if(!check_training(training_file))		return false;	if(!generate_features(training_file))		return false;	printf("training data loaded\n");	shrink_feature();	printf("features shrinked\n");	fmap=&fmap_tmp[0];			//temporary set	fmap_size=fmap_tmp.size();	sequences=&sequences_tmp[0];	//temporary set	sequence_num=sequences_tmp.size();	write_model(model_file,true);	fmap=NULL;	sequences=NULL;	//free memory	tags.clear();	vector<char *>(tags).swap(tags);	tag_str.clear();	xindex.clear();	map<char *,int,str_cmp>(xindex).swap(xindex);	x_freq.clear();	vector<int>(x_freq).swap(x_freq);	x_str.clear();	templets.clear();	vector<templet>(templets).swap(templets);	vector<vector<vector<int> > >(templet_group).swap(templet_group);	compress();			//begin learning	//set global temporary variables	path_num=pow((double)ysize,order+1);	node_anum=pow((double)ysize,order);//alpha(beta) number of each node	head_offset=-log((double)node_anum);	//initialize	int i,j,k;	lambda=new double[lambda_size];	gradient=new double[lambda_size];	threads[0].gradient=gradient;	if(algorithm==CRF_ALGORITHM)	{		if(prior==MEM_PRIOR)			unload();		printf("algorithm: crf\nsequence number: %d\nparameter number: %d\nsigma: %g\nl1: %g\nfreq_thresh: %d\n",sequence_num,lambda_size,sigma,l1,freq_thresh);		for(i=1;i<thread_num;i++)			threads[i].gradient=new double[lambda_size];				int orthant = l1>0?1:0;		for(;orthant>=0;orthant--)		{			int converge=0;			double old_fvalue=1e+37;			clock_t start_time=clock();			LBFGS l;			l.init(lambda_size,depth,prior);			memset(lambda,0,lambda_size*sizeof(double));						for(i=0;i<max_iter;i++)			{				fvalue=0;				if(chain_type==SIMPLE_CHAIN)				{					memset(transit_buff,0,sizeof(double)*(ysize*max_seq_size+path_num*(max_seq_size-1)));					for(j=0;j<path_num;j++)					{						int findex=fmap[transit+j];						if(findex>=0)						{							for(k=0;k<max_seq_size-1;k++)								transit_buff[ysize+(ysize+path_num)*k+ysize+j]=lambda[findex];						}					}				}				for(j=0;j<thread_num;j++)					threads[j].start();				for(j=0;j<thread_num;j++)					threads[j].join();				for(j=0;j<thread_num;j++)				{					fvalue += threads[j].obj;				}				for(j=0;j<lambda_size;j++)					for(k=1;k<thread_num;k++)						gradient[j]+=threads[k].gradient[j];				if(orthant)				{					for(j=0;j<lambda_size;j++)					{						fvalue+=fabs( lambda[j]) / l1;						if(lambda[j]==0.0)						{							if(gradient[j]<-1/l1)							{								gradient[j]+=1/l1;							}else if(gradient[j]>1/l1){								gradient[j]-=1/l1;							}else								gradient[j]=0;						}else{							if(lambda[j]>0)							{								gradient[j]+=1/l1;							}else{								gradient[j]-=1/l1;							}						}					}				}else{					//add sigma||\lambda||^2					for(j=0;j<lambda_size;j++)					{						fvalue+=lambda[j] * lambda[j] / (2*sigma);						gradient[j]+=lambda[j]/sigma;					}				}				double diff=i>0? fabs((old_fvalue-fvalue)/old_fvalue) : 1;				if(orthant)				{					int selected=0;					for(j=0;j<lambda_size;j++)					{						if(lambda[j]!=0)							selected++;					}					printf("iter: %d act: %d diff: %lf obj: %lf\n", i, selected, diff, fvalue);				}else					printf("iter: %d diff: %lf obj: %lf\n", i, diff, fvalue);				old_fvalue=fvalue;				if(diff<eta)					converge++;				else					converge=0;				if(i==max_iter||converge==3) break;//success				double *w0,*w1;				if(prior==SPEED_PRIOR){					w0=NULL;					w1=NULL;				}else{					w0=(double *)work_space;					w1=((double *)work_space)+lambda_size;				}				if(l.optimize(lambda,&fvalue,gradient,orthant,w0,w1)<=0)				{					printf("lbfgs error\n");					break;//error				}				if(prior==MEM_PRIOR)					load();			}			if(orthant)			{				adjust_data();				if(prior==MEM_PRIOR)					unload();//save data structure			}			double elapse = static_cast<double>(clock() - start_time) / CLOCKS_PER_SEC;			if(orthant)				printf("crf feature selection elapse: %g s, %d features selected\n",elapse,lambda_size);			else				printf("crf optimization elapse: %g s\n",elapse);		}	}else if(algorithm==AP_ALGORITHM){		printf("algorithm: ap\nsequence number: %d\nparameter number: %d\nfreq_thresh: %d\n",sequence_num,lambda_size,freq_thresh);		clock_t start_time=clock();		memset(gradient,0,lambda_size*sizeof(double));		memset(lambda,0,lambda_size*sizeof(double));		threads[0].times=max_iter*sequence_num;		for(i=0;i<max_iter;i++)		{			threads[0].start();			threads[0].join();			fvalue=threads[0].obj/total_nodes;			printf("iter: %d err: %lf \n",i,fvalue);		}		for(i=0;i<lambda_size;i++)			lambda[i]=gradient[i]/(max_iter*sequence_num);		adjust_data();		double elapse = static_cast<double>(clock() - start_time) / CLOCKS_PER_SEC;		printf("ap optimization elapse: %g s\n",elapse);	}else if(algorithm==PA_ALGORITHM){		printf("algorithm: pa\nsequence number: %d\nparameter number: %d\nC: %g\nfreq_thresh: %d\n",sequence_num,lambda_size,sigma,freq_thresh);		clock_t start_time=clock();		memset(gradient,0,lambda_size*sizeof(double));		memset(lambda,0,lambda_size*sizeof(double));		threads[0].times=max_iter*sequence_num;		for(i=0;i<max_iter;i++)		{			threads[0].start();			threads[0].join();			fvalue=threads[0].obj/total_nodes;			printf("iter: %d err: %lf \n",i,fvalue);		}		for(i=0;i<lambda_size;i++)			lambda[i]=gradient[i]/(max_iter*sequence_num);		adjust_data();		double elapse = static_cast<double>(clock() - start_time) / CLOCKS_PER_SEC;		printf("pa optimization elapse: %g s\n",elapse);	}	write_model(model_file,false);	//free all members' memory	templet_group.clear();	if(lambda)	{		delete [] lambda;		lambda=NULL;	}	if(gradient)	{		delete [] gradient;		gradient=NULL;	}	for(i=1;i<thread_num;i++)	{		delete [] threads[i].gradient;		threads[i].gradient=NULL;	}	return true;}CRF::CRF(){	depth=5;	algorithm=CRF_ALGORITHM;	transit_buff=NULL;	work_space=NULL;	sequences=NULL;	sequence1s=NULL;	fmap=NULL;	sigma=1;	prior=SPEED_PRIOR;	freq_thresh=0;	margin=false;	nbest=1;	x_str.set_size(PAGESIZE * 16);//each time alloc 16 pages	tag_str.set_size(256);//each time alloc 256 bytes	nodes.set_size(PAGESIZE);	cliques.set_size(PAGESIZE*16);	clique_node.set_size(PAGESIZE*16);	node_clique.set_size(PAGESIZE*16);	clique_feature.set_size(PAGESIZE*16);	order=-1;	max_iter=10000;	eta=0.0001;	lambda=NULL;	gradient=NULL;	l1=0;	version=45;//0.45	thread_num=1;	threads.resize(thread_num);	threads.front().start_i = 0;	threads.front().obj=0;	threads.front().c=this;}CRF::CRF(double c, int freq_t,bool margin_, int candidate_num_){	depth=5;	algorithm=CRF_ALGORITHM;	transit_buff=NULL;	work_space=NULL;	sequences=NULL;	sequence1s=NULL;	fmap=NULL;	sigma=c;	prior=SPEED_PRIOR;	freq_thresh=freq_t;	margin=margin_;	nbest=candidate_num_;	x_str.set_size(PAGESIZE * 16);//each time alloc 16 pages	tag_str.set_size(256);//each time alloc 256 bytes	nodes.set_size(PAGESIZE);	cliques.set_size(PAGESIZE*16);	clique_node.set_size(PAGESIZE*16);	node_clique.set_size(PAGESIZE*16);	clique_feature.set_size(PAGESIZE*16);	order=-1;	max_iter=10000;	eta=0.0001;	lambda=NULL;	gradient=NULL;	l1=0;	version=45;//0.45	thread_num=1;	threads.resize(thread_num);	threads.front().start_i = 0;	threads.front().obj=0;	threads.front().c=this;}CRF::~CRF(){	if(transit_buff)	{		delete [] transit_buff;		transit_buff=NULL;	}	if(lambda)	{		delete [] lambda;		lambda=NULL;	}	if(gradient)	{		delete [] gradient;		gradient=NULL;	}	if(work_space)	{		delete [] work_space;		work_space=NULL;		sequences=NULL;		sequence1s=NULL;		fmap=NULL;	}	if(sequences)	{		delete [] sequences;		sequences=NULL;	}	if(sequence1s)	{		delete [] sequence1s;		sequence1s=NULL;	}	if(fmap)	{		delete [] fmap;		fmap=NULL;	}	if(!access("__data1",0))		unlink("__data1");	if(!access("__data2",0))		unlink("__data2");}void CRF::set_chain_type(){	if(order!=1)	{		chain_type=GENERAL_CHAIN;		return;	}else{		for(int i=0;i<templets.size();i++)		{			bool null_word=true;			for(int j=0;j<templets[i].words.size();j++)				if(templets[i].words[j]!="")				{					null_word=false;					break;				}			if(templets[i].y.size()>1 && (templets[i].x.size() || !null_word))			{				chain_type=FIRST_CHAIN;				return;			}		}	}	chain_type=SIMPLE_CHAIN;}bool CRF::set_order(){	//set groupid,end_of_group,order	if(!templets.size())// no templets		return false;	order=-templets[0].y[0];	templets[0].groupid=0;	templets[0].end_of_group=false;	for(int i=1;i<templets.size();i++)	{		templet &n=templets[i];		templet &last_n=templets[i-1];		int j;		for(j=0;j<last_n.y.size();j++)		{			if(last_n.y[j]!=n.y[j])			{//start of group				last_n.end_of_group=true;				n.groupid=last_n.groupid+1;				if(-n.y[0]>order)		order=-n.y[0];				break;			}		}		if(j==last_n.y.size())//not start of group		{			n.groupid=last_n.groupid;			last_n.end_of_group=false;		}	}	templets.back().end_of_group=true;	gsize=templets.back().groupid+1;	templet_group.resize(gsize);	return true;}bool CRF::set_para(char *para_name, char *para_value){	if(!strcmp(para_name,"sigma"))		sigma=atof(para_value);	else if(!strcmp(para_name,"freq_thresh"))		freq_thresh=atoi(para_value);	else if(!strcmp(para_name,"margin"))		margin=atoi(para_value);	else if(!strcmp(para_name,"nbest"))		nbest=atoi(para_value);	else if(!strcmp(para_name,"max_iter"))		max_iter=atoi(para_value);	else if(!strcmp(para_name,"depth"))		depth=atoi(para_value);	else if(!strcmp(para_name,"eta"))		eta=atof(para_value);	else if(!strcmp(para_name,"l1"))		l1=atof(para_value);	else if(!strcmp(para_name,"prior"))		prior=atoi(para_value);	else if(!strcmp(para_name,"thread_num"))	{		if(algorithm==CRF_ALGORITHM)		{			thread_num=atoi(para_value);			threads.resize(thread_num);			for(int i=0;i<thread_num;i++)			{				threads[i].c=this;				threads[i].start_i = i;				threads[i].obj=0;			}		}	}else if(!strcmp(para_name,"algorithm")){		algorithm=atoi(para_value);		if(algorithm==AP_ALGORITHM || algorithm==PA_ALGORITHM)		{			thread_num=1;			threads.resize(1);			threads[0].start_i = 0;			threads[0].obj=0;			threads[0].c=this;		}	}	return true;}bool CRF::add_templet(char *line){	if(!line[0]||line[0]=='#') 		return false;	templet n;	char *p=line,*q;	char word[1000];		char index_str[1000];	int index1,index2;	while(q=catch_string(p,"%x[",word))	{		p=q;		n.words.push_back(word);		p=catch_string(p,",",index_str);		index1=atoi(index_str);		p=catch_string(p,"]",index_str);		index2=atoi(index_str);		n.x.push_back(make_pair(index1,index2));	}	q=catch_string(p,"%y[",word);	if(!q)	{		printf("templet: %s incorrect\n",line);		return false;	}	n.words.push_back(word);	p=q-3;	while(p=catch_string(p,"%y[","]",index_str))	{		index1=atoi(index_str);		n.y.push_back(index1);	}	int insert_pos;	vector_search(templets,n,index1,insert_pos,templet_cmp());	vector_insert(templets,n,insert_pos);	return true;}bool CRF::load_templet(char *templet_file){	FILE *fp;	//read template	fp=fopen(templet_file,"r");	if (!fp) 	{		printf("template file: %s not found\n",templet_file);		return false;	}	char line[MAXSTRLEN];	while(fgets(line,MAXSTRLEN-1,fp))	{		trim_line(line);		add_templet(line);	}	fclose(fp);	return set_order();}bool CRF::check_training(char *training_file){	FILE *fp;	if((fp=fopen(training_file,"r"))==NULL)	return false;	char line[MAXSTRLEN];	int lines=0;	cols=0;	while(fgets(line,MAXSTRLEN-1,fp))	{		lines++;		if(strlen(line)==1) continue;		trim_line(line);		vector<char *>columns;		if(!split_string(line,"\t",columns))		{			printf("columns should be greater than 1\n");			fclose(fp);			return false;//columns should greater than 1		}		if(cols && cols!=columns.size())		{//incompatible			printf("line %d: columns incompatible\n",lines);			fclose(fp);			return false;		}		cols=columns.size();		if(cols<2){			printf("columns should be greater than 1\n");			fclose(fp);			return false;//columns should greater than 1		}		char *t=columns.back();//tag		int index,insert_pos;		if(!vector_search(tags,t,index,insert_pos,str_cmp()))		{			char *p=tag_str.push_back(t);//copy string			vector_insert(tags,p,insert_pos);		}	}	fclose(fp);	ysize=tags.size();	set_group();	return true;}bool CRF::generate_features(char *filename){	char line[MAXSTRLEN];	max_seq_size=0;	total_nodes=0;	vector<char *>table;// table[i,j] = table[i*cols+j]	charlist table_str;	table_str.set_size(PAGESIZE);//1 page memory	lambda_size=0;//lambda size	int lines=0;	int i;	FILE *fp=fopen(filename,"r");	while(fgets(line,MAXSTRLEN-1,fp))	{		trim_line(line);		if(line[0])		{			vector<char *>columns;			split_string(line,"\t",columns);			for(i=0;i<cols;i++)			{				char *p=table_str.push_back(columns[i]);				table.push_back(p);			}		}else{//get one sequence			if(table.size())//non-empty line			{				if(max_seq_size<table.size())					max_seq_size=table.size();				total_nodes+=table.size();				add_x(table);				table.clear();//prepare for new table				table_str.free();				lines++;				if(!(lines%100))

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -