📄 la_svm.cpp
字号:
if(mwrite) for(j=0;j<max_index;j++) // set features for each example lasvm_sparsevector_set(v,j,val[j]); } else // sparse binary file { f.read((char*)sz,2*sizeof(int)); // get label & sparsity of example i if(mwrite) { if(splits.size()>0 && splits[splitpos-1].y!=0) Y.push_back(splits[splitpos-1].y); else Y.push_back(sz[0]); } val.resize(sz[1]); ind.resize(sz[1]); f.read((char*)(&ind[0]),sz[1]*sizeof(int)); f.read((char*)(&val[0]),sz[1]*sizeof(float)); if(mwrite) for(j=0;j<sz[1];j++) // set features for each example { lasvm_sparsevector_set(v,ind[j],val[j]); //printf("%d=%g\n",ind[j],val[j]); if(ind[j]>max_index) max_index=ind[j]; } } } f.close(); msz=X.size()-m; printf("examples: %d features: %d\n",msz,max_index); return msz;}void load_data_file(char *filename){ int msz,i,ft; splits.resize(0); int bin=binary_files; if(bin==0) // if ascii, check if it isn't a split file.. { FILE *f=fopen(filename,"r"); if(f == NULL) { fprintf(stderr,"Can't open input file \"%s\"\n",filename); exit(1); } char c; fscanf(f,"%c",&c); if(c=='f') bin=2; // found split file! } switch(bin) // load diferent file formats { case 0: // libsvm format msz=libsvm_load_data(filename); break; case 1: msz=binary_load_data(filename); break; case 2: ft=split_file_load(filename); if(ft==0) {msz=libsvm_load_data(filename); break;} else {msz=binary_load_data(filename); break;} default: fprintf(stderr,"Illegal file type '-B %d'\n",bin); exit(1); } if(kernel_type==RBF) { x_square.resize(m+msz); for(i=0;i<msz;i++) x_square[i+m]=lasvm_sparsevector_dot_product(X[i+m],X[i+m]); } if(kgamma==-1) kgamma=1.0/ ((double) max_index); // same default as LIBSVM m+=msz;}int sv1,sv2; double max_alpha,alpha_tol;int count_svs(){ int i; max_alpha=0; sv1=0;sv2=0; for(i=0;i<m;i++) // Count svs.. { if(alpha[i]>max_alpha) max_alpha=alpha[i]; if(-alpha[i]>max_alpha) max_alpha=-alpha[i]; } alpha_tol=max_alpha/1000.0; for(i=0;i<m;i++) { if(Y[i]>0) { if(alpha[i] >= alpha_tol) sv1++; } else { if(-alpha[i] >= alpha_tol) sv2++; } } return sv1+sv2;}int libsvm_save_model(const char *model_file_name) // saves the model in the same format as LIBSVM{ FILE *fp = fopen(model_file_name,"w"); if(fp==NULL) return -1; count_svs(); // printf("nSV=%d\n",sv1+sv2); fprintf(fp,"svm_type c_svc\n"); fprintf(fp,"kernel_type %s\n", kernel_type_table[kernel_type]); if(kernel_type == POLY) fprintf(fp,"degree %g\n", degree); if(kernel_type == POLY || kernel_type == RBF || kernel_type == SIGMOID) fprintf(fp,"gamma %g\n", kgamma); if(kernel_type == POLY || kernel_type == SIGMOID) fprintf(fp,"coef0 %g\n", coef0); fprintf(fp, "nr_class %d\n",2); fprintf(fp, "total_sv %d\n",sv1+sv2); { fprintf(fp, "rho %g\n",b0); } fprintf(fp, "label 1 -1\n"); fprintf(fp, "nr_sv"); fprintf(fp," %d %d",sv1,sv2); fprintf(fp, "\n"); fprintf(fp, "SV\n"); for(int j=0;j<2;j++) for(int i=0;i<m;i++) { if (j==0 && Y[i]==-1) continue; if (j==1 && Y[i]==1) continue; if (alpha[i]*Y[i]< alpha_tol) continue; // not an SV fprintf(fp, "%.16g ",alpha[i]); lasvm_sparsevector_pair_t *p1 = X[i]->pairs; while (p1) { fprintf(fp,"%d:%.8g ",p1->index,p1->data); p1 = p1->next; } fprintf(fp, "\n"); } fclose(fp); return 0;}double kernel(int i, int j, void *kparam){ double dot; kcalcs++; dot=lasvm_sparsevector_dot_product(X[i],X[j]); // sparse, linear kernel switch(kernel_type) { case LINEAR: return dot; case POLY: return pow(kgamma*dot+coef0,degree); case RBF: return exp(-kgamma*(x_square[i]+x_square[j]-2*dot)); case SIGMOID: return tanh(kgamma*dot+coef0); } return 0;} void finish(lasvm_t *sv){ int i,l; if (optimizer==ONLINE_WITH_FINISHING) { fprintf(stdout,"..[finishing]"); int iter=0; do { iter += lasvm_finish(sv, epsgr); } while (lasvm_get_delta(sv)>epsgr); } l=(int) lasvm_get_l(sv); int *svind,svs; svind= new int[l]; svs=lasvm_get_sv(sv,svind); alpha.resize(m); for(i=0;i<m;i++) alpha[i]=0; double *svalpha; svalpha=new double[l]; lasvm_get_alpha(sv,svalpha); for(i=0;i<svs;i++) alpha[svind[i]]=svalpha[i]; b0=lasvm_get_b(sv);}void make_old(int val) // move index <val> from new set into old set{ int i,ind=-1; for(i=0;i<(int)inew.size();i++) { if(inew[i]==val) {ind=i; break;} } if (ind>=0) { inew[ind]=inew[inew.size()-1]; inew.pop_back(); iold.push_back(val); }}int select(lasvm_t *sv) // selection strategy{ int s=-1; int t,i,r,j; double tmp,best; int ind=-1; switch(selection_type) { case RANDOM: // pick a random candidate s=rand() % inew.size(); break; case GRADIENT: // pick best gradient from 50 candidates j=candidates; if((int)inew.size()<j) j=inew.size(); r=rand() % inew.size(); s=r; best=1e20; for(i=0;i<j;i++) { r=inew[s]; tmp=lasvm_predict(sv, r); tmp*=Y[r]; //printf("%d: example %d grad=%g\n",i,r,tmp); if(tmp<best) {best=tmp;ind=s;} s=rand() % inew.size(); } s=ind; break; case MARGIN: // pick closest to margin from 50 candidates j=candidates; if((int)inew.size()<j) j=inew.size(); r=rand() % inew.size(); s=r; best=1e20; for(i=0;i<j;i++) { r=inew[s]; tmp=lasvm_predict(sv, r); if (tmp<0) tmp=-tmp; //printf("%d: example %d grad=%g\n",i,r,tmp); if(tmp<best) {best=tmp;ind=s;} s=rand() % inew.size(); } s=ind; break; } t=inew[s]; inew[s]=inew[inew.size()-1]; inew.pop_back(); iold.push_back(t); //printf("(%d %d)\n",iold.size(),inew.size()); return t;}void train_online(char *model_file_name){ int t1,t2=0,i,s,l,j,k; double timer=0; stopwatch *sw; // start measuring time after loading is finished sw=new stopwatch; // save timing information char t[1000]; strcpy(t,model_file_name); strcat(t,".time"); lasvm_kcache_t *kcache=lasvm_kcache_create(kernel, NULL); lasvm_kcache_set_maximum_size(kcache, cache_size*1024*1024); lasvm_t *sv=lasvm_create(kcache,use_b0,C*C_pos,C*C_neg); printf("set cache size %d\n",cache_size); // everything is new when we start for(i=0;i<m;i++) inew.push_back(i); // first add 5 examples of each class, just to balance the initial set int c1=0,c2=0; for(i=0;i<m;i++) { if(Y[i]==1 && c1<5) {lasvm_process(sv,i,(double) Y[i]); c1++; make_old(i);} if(Y[i]==-1 && c2<5){lasvm_process(sv,i,(double) Y[i]); c2++; make_old(i);} if(c1==5 && c2==5) break; } for(j=0;j<epochs;j++) { for(i=0;i<m;i++) { if(inew.size()==0) break; // nothing more to select s=select(sv); // selection strategy, select new point t1=lasvm_process(sv,s,(double) Y[s]); if (deltamax<=1000) // potentially multiple calls to reprocess.. { //printf("%g %g\n",lasvm_get_delta(sv),deltamax); t2=lasvm_reprocess(sv,epsgr);// at least one call to reprocess while (lasvm_get_delta(sv)>deltamax && deltamax<1000) { t2=lasvm_reprocess(sv,epsgr); } } if (verbosity==2) { l=(int) lasvm_get_l(sv); printf("l=%d process=%d reprocess=%d\n",l,t1,t2); } else if(verbosity==1) if( (i%100)==0){ fprintf(stdout, "..%d",i); fflush(stdout); } l=(int) lasvm_get_l(sv); for(k=0;k<(int)select_size.size();k++) { if ( (termination_type==ITERATIONS && i==select_size[k]) || (termination_type==SVS && l>=select_size[k]) || (termination_type==TIME && sw->get_time()>=select_size[k]) ) { if(saves>1) // if there is more than one model to save, give a new name { // save current version before potential finishing step int save_l,*save_sv; double *save_g, *save_alpha; save_l=(int)lasvm_get_l(sv); save_alpha= new double[l];lasvm_get_alpha(sv,save_alpha); save_g= new double[l];lasvm_get_g(sv,save_g); save_sv= new int[l];lasvm_get_sv(sv,save_sv); finish(sv); char tmp[1000]; timer+=sw->get_time(); //f << i << " " << count_svs() << " " << kcalcs << " " << timer << endl; if(termination_type==TIME) { sprintf(tmp,"%s_%dsecs",model_file_name,i); fprintf(stdout,"..[saving model_%d secs]..",i); } else { fprintf(stdout,"..[saving model_%d pts]..",i); sprintf(tmp,"%s_%dpts",model_file_name,i); } libsvm_save_model(tmp); // get back old version //fprintf(stdout, "[restoring before finish]"); fflush(stdout); lasvm_init(sv, save_l, save_sv, save_alpha, save_g); delete save_alpha; delete save_sv; delete save_g; delete sw; sw=new stopwatch; // reset clock } select_size[k]=select_size[select_size.size()-1]; select_size.pop_back(); } } if(select_size.size()==0) break; // early stopping, all intermediate models saved } inew.resize(0);iold.resize(0); // start again for next epoch.. for(i=0;i<m;i++) inew.push_back(i); } if(saves<2) { finish(sv); // if haven't done any intermediate saves, do final save timer+=sw->get_time(); //f << m << " " << count_svs() << " " << kcalcs << " " << timer << endl; } if(verbosity>0) printf("\n"); l=count_svs(); printf("nSVs=%d\n",l); printf("||w||^2=%g\n",lasvm_get_w2(sv)); printf("kcalcs="); cout << kcalcs << endl; //f.close(); lasvm_destroy(sv);}int main(int argc, char **argv) { printf("\n"); printf("la SVM\n"); printf("______\n"); char input_file_name[1024]; char model_file_name[1024]; parse_command_line(argc, argv, input_file_name, model_file_name); load_data_file(input_file_name); train_online(model_file_name); libsvm_save_model(model_file_name); }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -