⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 la_svm.cpp

📁 用C语言实现的最新且最快的SVM源码,可用于解决多类分类问题
💻 CPP
📖 第 1 页 / 共 2 页
字号:
            if(mwrite)                for(j=0;j<max_index;j++) // set features for each example                    lasvm_sparsevector_set(v,j,val[j]);        }        else			// sparse binary file        {            f.read((char*)sz,2*sizeof(int)); // get label & sparsity of example i            if(mwrite)             {                if(splits.size()>0 && splits[splitpos-1].y!=0)                    Y.push_back(splits[splitpos-1].y);                else                    Y.push_back(sz[0]);            }            val.resize(sz[1]); ind.resize(sz[1]);            f.read((char*)(&ind[0]),sz[1]*sizeof(int));            f.read((char*)(&val[0]),sz[1]*sizeof(float));            if(mwrite)                for(j=0;j<sz[1];j++) // set features for each example                {                    lasvm_sparsevector_set(v,ind[j],val[j]);                    //printf("%d=%g\n",ind[j],val[j]);                    if(ind[j]>max_index) max_index=ind[j];                }        }		    }    f.close();    msz=X.size()-m;    printf("examples: %d   features: %d\n",msz,max_index);    return msz;}void load_data_file(char *filename){    int msz,i,ft;    splits.resize(0);     int bin=binary_files;    if(bin==0) // if ascii, check if it isn't a split file..    {        FILE *f=fopen(filename,"r");        if(f == NULL)        {            fprintf(stderr,"Can't open input file \"%s\"\n",filename);            exit(1);        }        char c; fscanf(f,"%c",&c);         if(c=='f') bin=2; // found split file!    }    switch(bin)  // load diferent file formats    {    case 0: // libsvm format        msz=libsvm_load_data(filename); break;    case 1:         msz=binary_load_data(filename); break;    case 2:        ft=split_file_load(filename);        if(ft==0) 	         {msz=libsvm_load_data(filename); break;}         else        {msz=binary_load_data(filename); break;}    default:        fprintf(stderr,"Illegal file type '-B %d'\n",bin);        exit(1);    }    if(kernel_type==RBF)    {        x_square.resize(m+msz);        for(i=0;i<msz;i++)            x_square[i+m]=lasvm_sparsevector_dot_product(X[i+m],X[i+m]);    }    if(kgamma==-1)        kgamma=1.0/ ((double) max_index); // same default as LIBSVM    m+=msz;}int sv1,sv2; double max_alpha,alpha_tol;int count_svs(){    int i;     max_alpha=0;     sv1=0;sv2=0;        for(i=0;i<m;i++) 	// Count svs..       {        if(alpha[i]>max_alpha) max_alpha=alpha[i];        if(-alpha[i]>max_alpha) max_alpha=-alpha[i];    }           alpha_tol=max_alpha/1000.0;        for(i=0;i<m;i++)     {        if(Y[i]>0)         {            if(alpha[i] >= alpha_tol) sv1++;         }        else            {            if(-alpha[i] >= alpha_tol) sv2++;         }                }    return sv1+sv2;}int libsvm_save_model(const char *model_file_name)    // saves the model in the same format as LIBSVM{    FILE *fp = fopen(model_file_name,"w");    if(fp==NULL) return -1;	    count_svs();    // printf("nSV=%d\n",sv1+sv2);    fprintf(fp,"svm_type c_svc\n");    fprintf(fp,"kernel_type %s\n", kernel_type_table[kernel_type]);    if(kernel_type == POLY)        fprintf(fp,"degree %g\n", degree);    if(kernel_type == POLY || kernel_type == RBF || kernel_type == SIGMOID)        fprintf(fp,"gamma %g\n", kgamma);    if(kernel_type == POLY || kernel_type == SIGMOID)        fprintf(fp,"coef0 %g\n", coef0);    fprintf(fp, "nr_class %d\n",2);    fprintf(fp, "total_sv %d\n",sv1+sv2);	    {        fprintf(fp, "rho %g\n",b0);    }	    fprintf(fp, "label 1 -1\n");    fprintf(fp, "nr_sv");    fprintf(fp," %d %d",sv1,sv2);    fprintf(fp, "\n");    fprintf(fp, "SV\n");    for(int j=0;j<2;j++)        for(int i=0;i<m;i++)        {            if (j==0 && Y[i]==-1) continue;            if (j==1 && Y[i]==1) continue;            if (alpha[i]*Y[i]< alpha_tol) continue; // not an SV	                fprintf(fp, "%.16g ",alpha[i]);            lasvm_sparsevector_pair_t *p1 = X[i]->pairs;            while (p1)            {                          fprintf(fp,"%d:%.8g ",p1->index,p1->data);                p1 = p1->next;            }            fprintf(fp, "\n");        }    fclose(fp);    return 0;}double kernel(int i, int j, void *kparam){    double dot;    kcalcs++;    dot=lasvm_sparsevector_dot_product(X[i],X[j]);        // sparse, linear kernel    switch(kernel_type)    {    case LINEAR:        return dot;    case POLY:        return pow(kgamma*dot+coef0,degree);    case RBF:        return exp(-kgamma*(x_square[i]+x_square[j]-2*dot));        case SIGMOID:        return tanh(kgamma*dot+coef0);        }    return 0;}   void finish(lasvm_t *sv){    int i,l;     if (optimizer==ONLINE_WITH_FINISHING)    {        fprintf(stdout,"..[finishing]");             int iter=0;        do {             iter += lasvm_finish(sv, epsgr);         } while (lasvm_get_delta(sv)>epsgr);    }    l=(int) lasvm_get_l(sv);    int *svind,svs; svind= new int[l];    svs=lasvm_get_sv(sv,svind);     alpha.resize(m);    for(i=0;i<m;i++) alpha[i]=0;    double *svalpha; svalpha=new double[l];    lasvm_get_alpha(sv,svalpha);     for(i=0;i<svs;i++) alpha[svind[i]]=svalpha[i];    b0=lasvm_get_b(sv);}void make_old(int val)    // move index <val> from new set into old set{    int i,ind=-1;    for(i=0;i<(int)inew.size();i++)    {        if(inew[i]==val) {ind=i; break;}    }    if (ind>=0)    {        inew[ind]=inew[inew.size()-1];        inew.pop_back();        iold.push_back(val);    }}int select(lasvm_t *sv) // selection strategy{    int s=-1;    int t,i,r,j;    double tmp,best; int ind=-1;    switch(selection_type)    {    case RANDOM:   // pick a random candidate        s=rand() % inew.size();        break;    case GRADIENT: // pick best gradient from 50 candidates        j=candidates; if((int)inew.size()<j) j=inew.size();        r=rand() % inew.size();        s=r;        best=1e20;        for(i=0;i<j;i++)        {            r=inew[s];            tmp=lasvm_predict(sv, r);              tmp*=Y[r];            //printf("%d: example %d   grad=%g\n",i,r,tmp);            if(tmp<best) {best=tmp;ind=s;}            s=rand() % inew.size();        }          s=ind;        break;    case MARGIN:  // pick closest to margin from 50 candidates        j=candidates; if((int)inew.size()<j) j=inew.size();        r=rand() % inew.size();        s=r;        best=1e20;        for(i=0;i<j;i++)        {            r=inew[s];            tmp=lasvm_predict(sv, r);              if (tmp<0) tmp=-tmp;             //printf("%d: example %d   grad=%g\n",i,r,tmp);            if(tmp<best) {best=tmp;ind=s;}            s=rand() % inew.size();        }          s=ind;        break;    }	    t=inew[s];     inew[s]=inew[inew.size()-1];    inew.pop_back();    iold.push_back(t);	    //printf("(%d %d)\n",iold.size(),inew.size());    return t;}void train_online(char *model_file_name){    int t1,t2=0,i,s,l,j,k;    double timer=0;    stopwatch *sw; // start measuring time after loading is finished    sw=new stopwatch;    // save timing information    char t[1000];    strcpy(t,model_file_name);    strcat(t,".time");        lasvm_kcache_t *kcache=lasvm_kcache_create(kernel, NULL);    lasvm_kcache_set_maximum_size(kcache, cache_size*1024*1024);    lasvm_t *sv=lasvm_create(kcache,use_b0,C*C_pos,C*C_neg);    printf("set cache size %d\n",cache_size);    // everything is new when we start    for(i=0;i<m;i++) inew.push_back(i);        // first add 5 examples of each class, just to balance the initial set    int c1=0,c2=0;    for(i=0;i<m;i++)    {        if(Y[i]==1 && c1<5) {lasvm_process(sv,i,(double) Y[i]); c1++; make_old(i);}        if(Y[i]==-1 && c2<5){lasvm_process(sv,i,(double) Y[i]); c2++; make_old(i);}        if(c1==5 && c2==5) break;    }        for(j=0;j<epochs;j++)    {        for(i=0;i<m;i++)        {            if(inew.size()==0) break; // nothing more to select            s=select(sv);            // selection strategy, select new point                        t1=lasvm_process(sv,s,(double) Y[s]);                        if (deltamax<=1000) // potentially multiple calls to reprocess..            {                //printf("%g %g\n",lasvm_get_delta(sv),deltamax);                t2=lasvm_reprocess(sv,epsgr);// at least one call to reprocess                while (lasvm_get_delta(sv)>deltamax && deltamax<1000)                {                    t2=lasvm_reprocess(sv,epsgr);                }            }                        if (verbosity==2)             {                l=(int) lasvm_get_l(sv);                printf("l=%d process=%d reprocess=%d\n",l,t1,t2);            }            else                if(verbosity==1)                    if( (i%100)==0){ fprintf(stdout, "..%d",i); fflush(stdout); }                        l=(int) lasvm_get_l(sv);            for(k=0;k<(int)select_size.size();k++)            {                 if   ( (termination_type==ITERATIONS && i==select_size[k])                        || (termination_type==SVS && l>=select_size[k])                       || (termination_type==TIME && sw->get_time()>=select_size[k])                    )			                 {                      if(saves>1) // if there is more than one model to save, give a new name                    {                        // save current version before potential finishing step                        int save_l,*save_sv; double *save_g, *save_alpha;                        save_l=(int)lasvm_get_l(sv);                        save_alpha= new double[l];lasvm_get_alpha(sv,save_alpha);				                        save_g= new double[l];lasvm_get_g(sv,save_g);                        save_sv= new int[l];lasvm_get_sv(sv,save_sv);				                        finish(sv);                         char tmp[1000];                        timer+=sw->get_time();                        //f << i << " " << count_svs() << " " << kcalcs << " " << timer << endl;                        if(termination_type==TIME)                          {                            sprintf(tmp,"%s_%dsecs",model_file_name,i);                             fprintf(stdout,"..[saving model_%d secs]..",i);                        }                        else                        {	                            fprintf(stdout,"..[saving model_%d pts]..",i);                            sprintf(tmp,"%s_%dpts",model_file_name,i);                          }                        libsvm_save_model(tmp);                                                  // get back old version                        //fprintf(stdout, "[restoring before finish]"); fflush(stdout);                         lasvm_init(sv, save_l, save_sv, save_alpha, save_g);                         delete save_alpha; delete save_sv; delete save_g;                        delete sw; sw=new stopwatch;    // reset clock                    }                      select_size[k]=select_size[select_size.size()-1];                    select_size.pop_back();                }            }            if(select_size.size()==0) break; // early stopping, all intermediate models saved        }        inew.resize(0);iold.resize(0); // start again for next epoch..        for(i=0;i<m;i++) inew.push_back(i);    }    if(saves<2)     {        finish(sv); // if haven't done any intermediate saves, do final save        timer+=sw->get_time();        //f << m << " " << count_svs() << " " << kcalcs << " " << timer << endl;    }    if(verbosity>0) printf("\n");    l=count_svs();     printf("nSVs=%d\n",l);    printf("||w||^2=%g\n",lasvm_get_w2(sv));    printf("kcalcs="); cout << kcalcs << endl;    //f.close();    lasvm_destroy(sv);}int main(int argc, char **argv)  {    printf("\n");    printf("la SVM\n");    printf("______\n");        char input_file_name[1024];    char model_file_name[1024];    parse_command_line(argc, argv, input_file_name, model_file_name);    load_data_file(input_file_name);    train_online(model_file_name);        libsvm_save_model(model_file_name);   }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -