⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 la_test.cpp

📁 用C语言实现的最新且最快的SVM源码,可用于解决多类分类问题
💻 CPP
字号:
// -*- mode: c++; c-file-style: "stroustrup"; -*-using namespace std;#include <stdio.h>#include <vector>#include <math.h>#include <ctype.h>#include <iostream>#include <fstream>#include "vector.h"#define LINEAR  0#define POLY    1#define RBF     2#define SIGMOID 3 char *kernel_type_table[] = {"linear","polynomial","rbf","sigmoid"};class ID // class to hold split file indices and labels{public:	int x;	int y;    ID() : x(0), y(0) {}    ID(int x1,int y1) : x(x1), y(y1) {}};// IDs will be sorted by index, not by label.bool operator<(const ID& x, const ID& y){    return x.x < y.x;}int m,msv;                         // training and test set sizesvector <lasvm_sparsevector_t*> X; // feature vectors for test setvector <lasvm_sparsevector_t*> Xsv;// feature vectors for SVsvector <int> Y;                   // labelsvector <double> alpha;            // alpha_i, SV weightsdouble b0;                        // thresholdint use_b0=1;                     // use threshold via constraint \sum a_i y_i =0int kernel_type=RBF;              // LINEAR, POLY, RBF or SIGMOID kernelsdouble degree=3,kgamma=-1,coef0=0;// kernel paramsvector <double> x_square;         // norms of input vectors, used for RBFvector <double> xsv_square;        // norms of test vectors, used for RBFchar split_file_name[1024]="\0";         // filename for the splitsint binary_files=0;vector <ID> splits;             int max_index=0;void exit_with_help(){    fprintf(stdout,	    "\nUsage: la_test [options] test_set_file model_file output_file\n"	    "options:\n"            "-B file format : files are stored in the following format:\n"            "	0 -- libsvm ascii format (default)\n"            "	1 -- binary format\n"            "	2 -- split file format\n");     exit(1);}int split_file_load(char *f){    int binary_file=0,labs=0,inds=0;    FILE *fp;    fp=fopen(f,"r");     	if(fp==NULL) {printf("[couldn't load split file: %s]\n",f); exit(1);}    char dummy[100],dummy2[100];	unsigned int i,j=0; for(i=0;i<strlen(f);i++) if(f[i]=='/') j=i+1;    fscanf(fp,"%s %s",dummy,dummy2);    strcpy(&(f[j]),dummy2);        fscanf(fp,"%s %d",dummy,&binary_file);    fscanf(fp,"%s %d",dummy,&inds);    fscanf(fp,"%s %d",dummy,&labs);    printf("[split file: load:%s binary:%d new_indices:%d new_labels:%d]\n",dummy2,binary_file,inds,labs);    //printf("[split file:%s binary=%d]\n",dummy2,binary_file);    if(!inds) return binary_file;    while(1)    {        int i,j;        int c=fscanf(fp,"%d",&i);        if(labs) c=fscanf(fp,"%d",&j);        if(c==-1) break;        if (labs) 			splits.push_back(ID(i-1,j)); 		else 			splits.push_back(ID(i-1,0));    }    sort(splits.begin(),splits.end());		return binary_file;}int libsvm_load_data(char *filename)// loads the same format as LIBSVM{    int index; double value;    int elements, i;    FILE *fp = fopen(filename,"r");    lasvm_sparsevector_t* v;    if(fp == NULL)    {        fprintf(stderr,"Can't open input file \"%s\"\n",filename);        exit(1);    }    else        printf("loading \"%s\"..  \n",filename);    int splitpos=0;    int msz = 0;     elements = 0;    while(1)    {        int c = fgetc(fp);        switch(c)        {        case '\n':            if(splits.size()>0)             {                if(splitpos<(int)splits.size() && splits[splitpos].x==msz)                {                    v=lasvm_sparsevector_create();                    X.push_back(v);	splitpos++;                }            }            else            {                v=lasvm_sparsevector_create();                X.push_back(v);             }            ++msz;            //printf("%d\n",m);            elements=0;            break;        case ':':            ++elements;            break;        case EOF:            goto out;        default:            ;        }    } out:    rewind(fp);    max_index = 0;splitpos=0;    for(i=0;i<msz;i++)    {        int write=0;        if(splits.size()>0)        {            if(splitpos<(int)splits.size() && splits[splitpos].x==i)            {                write=2;splitpos++;            }        }        else            write=1;         int label;        fscanf(fp,"%d",&label);        //	printf("%d %d\n",i,label);        if(write)         {            if(splits.size()>0)            {                  if(splits[splitpos-1].y!=0)                    Y.push_back(splits[splitpos-1].y);                else                    Y.push_back(label);            }            else                Y.push_back(label);        }			        while(1)        {            int c;            do {                c = getc(fp);                if(c=='\n') goto out2;            } while(isspace(c));            ungetc(c,fp);            fscanf(fp,"%d:%lf",&index,&value);			            if (write==1) lasvm_sparsevector_set(X[m+i],index,value);            if (write==2) lasvm_sparsevector_set(X[splitpos-1],index,value);            if (index>max_index) max_index=index;        }    out2:        label=1; // dummy    }    fclose(fp);    msz=X.size()-m;    printf("examples: %d   features: %d\n",msz,max_index);    return msz;}int binary_load_data(char *filename){    int msz,i=0,j;    lasvm_sparsevector_t* v;    int nonsparse=0;    ifstream f;    f.open(filename,ios::in|ios::binary);        // read number of examples and number of features    int sz[2];     f.read((char*)sz,2*sizeof(int));    if (!f) { printf("File writing error in line %d.\n",i); exit(1);}    msz=sz[0]; max_index=sz[1];    vector <float> val;    vector <int>   ind;    val.resize(max_index);    if(max_index>0) nonsparse=1;    int splitpos=0;    for(i=0;i<msz;i++)     {        int mwrite=0;        if(splits.size()>0)        {            if(splitpos<(int)splits.size() && splits[splitpos].x==i)             {                 mwrite=1;splitpos++;                v=lasvm_sparsevector_create(); X.push_back(v);            }        }        else        {            mwrite=1;            v=lasvm_sparsevector_create(); X.push_back(v);        }		        if(nonsparse) // non-sparse binary file        {            f.read((char*)sz,1*sizeof(int)); // get label            if(mwrite)             {                if(splits.size()>0 && splits[splitpos-1].y!=0)                    Y.push_back(splits[splitpos-1].y);                else                    Y.push_back(sz[0]);            }            f.read((char*)(&val[0]),max_index*sizeof(float));            if(mwrite)                for(j=0;j<max_index;j++) // set features for each example                    lasvm_sparsevector_set(v,j,val[j]);        }        else			// sparse binary file        {            f.read((char*)sz,2*sizeof(int)); // get label & sparsity of example i            if(mwrite)             {                if(splits.size()>0 && splits[splitpos-1].y!=0)                    Y.push_back(splits[splitpos-1].y);                else                    Y.push_back(sz[0]);            }            val.resize(sz[1]); ind.resize(sz[1]);            f.read((char*)(&ind[0]),sz[1]*sizeof(int));            f.read((char*)(&val[0]),sz[1]*sizeof(float));            if(mwrite)                for(j=0;j<sz[1];j++) // set features for each example                {                    lasvm_sparsevector_set(v,ind[j],val[j]);                    //printf("%d=%g\n",ind[j],val[j]);                    if(ind[j]>max_index) max_index=ind[j];                }        }		    }    f.close();    msz=X.size()-m;    printf("examples: %d   features: %d\n",msz,max_index);    return msz;}void load_data_file(char *filename){    int msz,i,ft;    splits.resize(0);     int bin=binary_files;    if(bin==0) // if ascii, check if it isn't a split file..    {        FILE *f=fopen(filename,"r");        if(f == NULL)        {            fprintf(stderr,"Can't open input file \"%s\"\n",filename);            exit(1);        }        char c; fscanf(f,"%c",&c);         if(c=='f') bin=2; // found split file!    }    switch(bin)  // load diferent file formats    {    case 0: // libsvm format        msz=libsvm_load_data(filename); break;    case 1:         msz=binary_load_data(filename); break;    case 2:        ft=split_file_load(filename);        if(ft==0) 	         {msz=libsvm_load_data(filename); break;}         else        {msz=binary_load_data(filename); break;}    default:        fprintf(stderr,"Illegal file type '-B %d'\n",bin);        exit(1);    }    if(kernel_type==RBF)    {        x_square.resize(m+msz);        for(i=0;i<msz;i++)            x_square[i+m]=lasvm_sparsevector_dot_product(X[i+m],X[i+m]);    }    if(kgamma==-1)        kgamma=1.0/ ((double) max_index); // same default as LIBSVM    m+=msz;}  void libsvm_load_sv_data(FILE *fp)// loads the same format as LIBSVM{     int max_index; int oldindex=0;    int index; double value; int i;    lasvm_sparsevector_t* v;        alpha.resize(msv);    for(i=0;i<msv;i++)    {	v=lasvm_sparsevector_create();	Xsv.push_back(v);     }        max_index = 0;    for(i=0;i<msv;i++)    {	double label;	fscanf(fp,"%lf",&label);	//printf("%d:%g\n",i,label);	alpha[i] = label;	while(1)	{	    int c;	    do {		c = getc(fp);		if(c=='\n') goto out2;	    } while(isspace(c));	    ungetc(c,fp);	    fscanf(fp,"%d:%lf",&index,&value);	    if(index!=oldindex)	    {		lasvm_sparsevector_set(Xsv[i],index,value);	    }	    oldindex=index;	    if (index>max_index) max_index=index;	}	    out2:	label=1; // dummy    }        printf("loading model: %d svs\n",msv);        if(kernel_type==RBF)    {    	xsv_square.resize(msv);    	for(i=0;i<msv;i++)    	    xsv_square[i]=lasvm_sparsevector_dot_product(Xsv[i],Xsv[i]);    }    }int libsvm_load_model(const char *model_file_name)// saves the model in the same format as LIBSVM{    int i;    FILE *fp = fopen(model_file_name,"r");	    if(fp == NULL)    {	fprintf(stderr,"Can't open input file \"%s\"\n",model_file_name);	exit(1);    }    static char tmp[1001];    fscanf(fp,"%1000s",tmp); //svm_type    fscanf(fp,"%1000s",tmp); //c_svc    fscanf(fp,"%1000s",tmp); //kernel_type    fscanf(fp,"%1000s",tmp); //rbf,poly,..    kernel_type=LINEAR;    for(i=0;i<4;i++)	if (strcmp(tmp,kernel_type_table[i])==0) kernel_type=i;    if(kernel_type == POLY)    {	fscanf(fp,"%1000s",tmp); 	fscanf(fp,"%lf", &degree);    }    if(kernel_type == POLY || kernel_type == RBF || kernel_type == SIGMOID)    {	fscanf(fp,"%1000s",tmp); 	fscanf(fp,"%lf",&kgamma);    }    if(kernel_type == POLY || kernel_type == SIGMOID)    {	fscanf(fp,"%1000s",tmp); 	fscanf(fp,"%lf", &coef0);    }    fscanf(fp,"%1000s",tmp); // nr_class    fscanf(fp,"%1000s",tmp); // 2    fscanf(fp,"%1000s",tmp); // total_sv    fscanf(fp,"%d",&msv);     fscanf(fp,"%1000s",tmp); //rho    fscanf(fp,"%lf\n",&b0);    fscanf(fp,"%1000s",tmp); // label    fscanf(fp,"%1000s",tmp); // 1    fscanf(fp,"%1000s",tmp); // -1    fscanf(fp,"%1000s",tmp); // nr_sv    fscanf(fp,"%1000s",tmp); // num    fscanf(fp,"%1000s",tmp); // num    fscanf(fp,"%1000s",tmp); // SV	    // now load SV data...        libsvm_load_sv_data(fp);	    // finished!    fclose(fp);    return 0;}double kernel(int i, int j, void *kparam){    double dot;    dot=lasvm_sparsevector_dot_product(X[i],Xsv[j]);        // sparse, linear kernel    switch(kernel_type)    {    case LINEAR:        return dot;     case POLY:	return pow(kgamma*dot+coef0,degree);    case RBF:	return exp(-kgamma*(x_square[i]+xsv_square[j]-2*dot));        case SIGMOID:	return tanh(kgamma*dot+coef0);        }    return 0;}    void test(char *output_name){	    FILE *fp=fopen(output_name,"w");    int i,j; double y; double acc=0;    for(i=0;i<m;i++)    {	y=-b0;	for(j=0;j<msv;j++)	{	    y+=alpha[j]*kernel(i,j,NULL);	}	if(y>=0) y=1; else y=-1; 	if(((int)y)==Y[i]) acc++;     }    printf("accuracy= %g (%d/%d)\n",(acc/m)*100,((int)acc),m);    fclose(fp);}void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name, char *output_file_name){    int i;         // parse options    for(i=1;i<argc;i++)    {	if(argv[i][0] != '-') break;	++i;	switch(argv[i-1][1])	{	case 'B':	    binary_files=atoi(argv[i]);	    break;	default:	    fprintf(stderr,"unknown option\n");	    exit_with_help();	}    }    // determine filenames    if(i>=argc)	exit_with_help();    strcpy(input_file_name, argv[i]);    if(i<argc-1)	strcpy(model_file_name,argv[i+1]);    else    {	char *p = strrchr(argv[i],'/');	if(p==NULL)	    p = argv[i];	else	    ++p;	sprintf(model_file_name,"%s.model",p);    }    if(argc<i+3) exit_with_help();    strcpy(input_file_name, argv[i]);    strcpy(model_file_name, argv[i+1]);    strcpy(output_file_name, argv[i+2]);}int main(int argc, char **argv)  {    printf("\n");    printf("la test\n");    printf("_______\n");        char input_file_name[1024];    char model_file_name[1024];    char output_file_name[1024];    parse_command_line(argc, argv, input_file_name, model_file_name, output_file_name);         libsvm_load_model(model_file_name);// load model    load_data_file(input_file_name); // load test data        test(output_file_name);}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -