⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 svm_learn.c

📁 SVM-light Version llf_dqy_hhu
💻 C
📖 第 1 页 / 共 5 页
字号:
/***********************************************************************//*                                                                     *//*   svm_learn.c                                                       *//*                                                                     *//*   Learning module of Support Vector Machine.                        *//*                                                                     *//*   Author: Thorsten Joachims                                         *//*   Date: 02.07.02                                                    *//*                                                                     *//*   Copyright (c) 2002  Thorsten Joachims - All rights reserved       *//*                                                                     *//*   This software is available for non-commercial use only. It must   *//*   not be modified and distributed without prior permission of the   *//*   author. The author is not responsible for implications from the   *//*   use of this software.                                             *//*                                                                     *//***********************************************************************/# include "svm_common.h"# include "svm_learn.h"/* interface to QP-solver */double *optimize_qp(QP *, double *, long, double *, LEARN_PARM *);/*---------------------------------------------------------------------------*//* Learns an SVM classification model based on the training data in   docs/label. The resulting model is returned in the structure   model. *//*void svm_learn_classification(DOC *docs, double *class, long int			      totdoc, long int totwords, 			      LEARN_PARM *learn_parm, 			      KERNEL_PARM *kernel_parm, 			      KERNEL_CACHE *kernel_cache, MODEL *model)     // docs:        Training vectors (x-part) //     // class:       Training labels (y-part, zero if test example for transduction) //     // totdoc:      Number of examples in docs/label //     // totwords:    Number of features (i.e. highest feature index) //     // learn_parm:  Learning paramenters //     // kernel_parm: Kernel paramenters //     // kernel_cache:Initialized Cache of size 1*totdoc //     // model:       Returns learning result (assumed empty before called) //{  long *inconsistent,i,*label;  long inconsistentnum;  long misclassified,upsupvecnum;  double loss,model_length,example_length;  double maxdiff,*lin,*a,*c;  long runtime_start,runtime_end;  long iterations;  long *unlabeled,transduction;  long heldout;  long loo_count=0,loo_count_pos=0,loo_count_neg=0,trainpos=0,trainneg=0;  long loocomputed=0,runtime_start_loo=0,runtime_start_xa=0;  double heldout_c=0,r_delta_sq=0,r_delta,r_delta_avg;  double *xi_fullset; // buffer for storing xi on full sample in loo //  double *a_fullset;  // buffer for storing alpha on full sample in loo //  TIMING timing_profile;  SHRINK_STATE shrink_state;  runtime_start=get_runtime();  timing_profile.time_kernel=0;  timing_profile.time_opti=0;  timing_profile.time_shrink=0;  timing_profile.time_update=0;  timing_profile.time_model=0;  timing_profile.time_check=0;  timing_profile.time_select=0;  kernel_cache_statistic=0;  learn_parm->totwords=totwords;  // make sure -n value is reasonable //  if((learn_parm->svm_newvarsinqp < 2)      || (learn_parm->svm_newvarsinqp > learn_parm->svm_maxqpsize)) {    learn_parm->svm_newvarsinqp=learn_parm->svm_maxqpsize;  }  init_shrink_state(&shrink_state,totdoc,(long)MAXSHRINK);  label = (long *)my_malloc(sizeof(long)*totdoc);  inconsistent = (long *)my_malloc(sizeof(long)*totdoc);  unlabeled = (long *)my_malloc(sizeof(long)*totdoc);  c = (double *)my_malloc(sizeof(double)*totdoc);  a = (double *)my_malloc(sizeof(double)*totdoc);  a_fullset = (double *)my_malloc(sizeof(double)*totdoc);  xi_fullset = (double *)my_malloc(sizeof(double)*totdoc);  lin = (double *)my_malloc(sizeof(double)*totdoc);  learn_parm->svm_cost = (double *)my_malloc(sizeof(double)*totdoc);  model->supvec = (DOC **)my_malloc(sizeof(DOC *)*(totdoc+2));  model->alpha = (double *)my_malloc(sizeof(double)*(totdoc+2));  model->index = (long *)my_malloc(sizeof(long)*(totdoc+2));  model->at_upper_bound=0;  model->b=0;	         model->supvec[0]=0;  // element 0 reserved and empty for now //  model->alpha[0]=0;  model->lin_weights=NULL;  model->totwords=totwords;  model->totdoc=totdoc;  model->kernel_parm=(*kernel_parm);  model->sv_num=1;  model->loo_error=-1;  model->loo_recall=-1;  model->loo_precision=-1;  model->xa_error=-1;  model->xa_recall=-1;  model->xa_precision=-1;  inconsistentnum=0;  transduction=0;  r_delta=estimate_r_delta(docs,totdoc,kernel_parm);  r_delta_sq=r_delta*r_delta;  r_delta_avg=estimate_r_delta_average(docs,totdoc,kernel_parm);  if(learn_parm->svm_c == 0.0) {  // default value for C //    learn_parm->svm_c=1.0/(r_delta_avg*r_delta_avg);    if(verbosity>=1)       printf("Setting default regularization parameter C=%.4f\n",	     learn_parm->svm_c);  }  learn_parm->eps=-1.0;      // equivalent regression epsilon for classification //  for(i=0;i<totdoc;i++) {    // various inits //    docs[i].docnum=i;    inconsistent[i]=0;    a[i]=0;    lin[i]=0;    c[i]=0.0;    unlabeled[i]=0;    if(class[i] == 0) {      unlabeled[i]=1;      label[i]=0;      transduction=1;    }    if(class[i] > 0) {      learn_parm->svm_cost[i]=learn_parm->svm_c*learn_parm->svm_costratio*	docs[i].costfactor;      label[i]=1;      trainpos++;    }    else if(class[i] < 0) {      learn_parm->svm_cost[i]=learn_parm->svm_c*docs[i].costfactor;      label[i]=-1;      trainneg++;    }    else {      learn_parm->svm_cost[i]=0;    }  }  if(verbosity>=2) {    printf("%ld positive, %ld negative, and %ld unlabeled examples.\n",trainpos,trainneg,totdoc-trainpos-trainneg); fflush(stdout);  }  // caching makes no sense for linear kernel //  if(kernel_parm->kernel_type == LINEAR) {    kernel_cache = NULL;     }   if(transduction) {    learn_parm->svm_iter_to_shrink=99999999;    if(verbosity >= 1)      printf("\nDeactivating Shrinking due to an incompatibility with the transductive \nlearner in the current version.\n\n");  }  if(transduction && learn_parm->compute_loo) {    learn_parm->compute_loo=0;    if(verbosity >= 1)      printf("\nCannot compute leave-one-out estimates for transductive learner.\n\n");  }      if(learn_parm->remove_inconsistent && learn_parm->compute_loo) {    learn_parm->compute_loo=0;    printf("\nCannot compute leave-one-out estimates when removing inconsistent examples.\n\n");  }      if(learn_parm->compute_loo && ((trainpos == 1) || (trainneg == 1))) {    learn_parm->compute_loo=0;    printf("\nCannot compute leave-one-out with only one example in one class.\n\n");  }      if(verbosity==1) {    printf("Optimizing"); fflush(stdout);  }  // train the svm //  /iterations=optimize_to_convergence(docs,label,totdoc,totwords,learn_parm,				     kernel_parm,kernel_cache,&shrink_state,model,				     inconsistent,unlabeled,a,lin,				     c,&timing_profile,				     &maxdiff,(long)-1,				     (long)1);    if(verbosity>=1) {    if(verbosity==1) printf("done. (%ld iterations)\n",iterations);    misclassified=0;    for(i=0;(i<totdoc);i++) { // get final statistic //      if((lin[i]-model->b)*(double)label[i] <= 0.0) 	misclassified++;    }    printf("Optimization finished (%ld misclassified, maxdiff=%.5f).\n",	   misclassified,maxdiff);     runtime_end=get_runtime();    if(verbosity>=2) {      printf("Runtime in cpu-seconds: %.2f (%.2f%% for kernel/%.2f%% for optimizer/%.2f%% for final/%.2f%% for update/%.2f%% for model/%.2f%% for check/%.2f%% for select)\n",        ((float)runtime_end-(float)runtime_start)/100.0,        (100.0*timing_profile.time_kernel)/(float)(runtime_end-runtime_start),	(100.0*timing_profile.time_opti)/(float)(runtime_end-runtime_start),	(100.0*timing_profile.time_shrink)/(float)(runtime_end-runtime_start),        (100.0*timing_profile.time_update)/(float)(runtime_end-runtime_start),        (100.0*timing_profile.time_model)/(float)(runtime_end-runtime_start),        (100.0*timing_profile.time_check)/(float)(runtime_end-runtime_start),        (100.0*timing_profile.time_select)/(float)(runtime_end-runtime_start));    }    else {      printf("Runtime in cpu-seconds: %.2f\n",	     (runtime_end-runtime_start)/100.0);    }    if(learn_parm->remove_inconsistent) {	        inconsistentnum=0;      for(i=0;i<totdoc;i++) 	if(inconsistent[i]) 	  inconsistentnum++;      printf("Number of SV: %ld (plus %ld inconsistent examples)\n",	     model->sv_num-1,inconsistentnum);    }    else {      upsupvecnum=0;      for(i=1;i<model->sv_num;i++) {	if(fabs(model->alpha[i]) >= 	   (learn_parm->svm_cost[(model->supvec[i])->docnum]-	    learn_parm->epsilon_a)) 	  upsupvecnum++;      }      printf("Number of SV: %ld (including %ld at upper bound)\n",	     model->sv_num-1,upsupvecnum);    }        if((verbosity>=1) && (!learn_parm->skip_final_opt_check)) {      loss=0;      model_length=0;       for(i=0;i<totdoc;i++) {	if((lin[i]-model->b)*(double)label[i] < 1.0-learn_parm->epsilon_crit)	  loss+=1.0-(lin[i]-model->b)*(double)label[i];	model_length+=a[i]*label[i]*lin[i];      }      model_length=sqrt(model_length);      fprintf(stdout,"L1 loss: loss=%.5f\n",loss);      fprintf(stdout,"Norm of weight vector: |w|=%.5f\n",model_length);      example_length=estimate_sphere(model,kernel_parm);       fprintf(stdout,"Norm of longest example vector: |x|=%.5f\n",	      length_of_longest_document_vector(docs,totdoc,kernel_parm));      fprintf(stdout,"Estimated VCdim of classifier: VCdim<=%.5f\n",	      estimate_margin_vcdim(model,model_length,example_length,				    kernel_parm));      if((!learn_parm->remove_inconsistent) && (!transduction)) {	runtime_start_xa=get_runtime();	if(verbosity>=1) {	  printf("Computing XiAlpha-estimates..."); fflush(stdout);	}	compute_xa_estimates(model,label,unlabeled,totdoc,docs,lin,a,			     kernel_parm,learn_parm,&(model->xa_error),			     &(model->xa_recall),&(model->xa_precision));	if(verbosity>=1) {	  printf("done\n");	}	printf("Runtime for XiAlpha-estimates in cpu-seconds: %.2f\n",	       (get_runtime()-runtime_start_xa)/100.0);		fprintf(stdout,"XiAlpha-estimate of the error: error<=%.2f%% (rho=%.2f,depth=%ld)\n",		model->xa_error,learn_parm->rho,learn_parm->xa_depth);	fprintf(stdout,"XiAlpha-estimate of the recall: recall=>%.2f%% (rho=%.2f,depth=%ld)\n",		model->xa_recall,learn_parm->rho,learn_parm->xa_depth);	fprintf(stdout,"XiAlpha-estimate of the precision: precision=>%.2f%% (rho=%.2f,depth=%ld)\n",		model->xa_precision,learn_parm->rho,learn_parm->xa_depth);      }      else if(!learn_parm->remove_inconsistent) {	estimate_transduction_quality(model,label,unlabeled,totdoc,docs,lin);      }    }    if(verbosity>=1) {      printf("Number of kernel evaluations: %ld\n",kernel_cache_statistic);    }  }  //leave-one-out testing starts now //  if(learn_parm->compute_loo) {    // save results of training on full dataset for leave-one-out //    runtime_start_loo=get_runtime();    for(i=0;i<totdoc;i++) {      xi_fullset[i]=1.0-((lin[i]-model->b)*(double)label[i]);      if(xi_fullset[i]<0) xi_fullset[i]=0;      a_fullset[i]=a[i];    }    if(verbosity>=1) {      printf("Computing leave-one-out");    }        // repeat this loop for every held-out example //    for(heldout=0;(heldout<totdoc);heldout++) {      if(learn_parm->rho*a_fullset[heldout]*r_delta_sq+xi_fullset[heldout]	 < 1.0) { 	//guaranteed to not produce a leave-one-out error //	if(verbosity==1) {	  printf("+"); fflush(stdout); 	}      }      else if(xi_fullset[heldout] > 1.0) {	// guaranteed to produce a leave-one-out error //	loo_count++;	if(label[heldout] > 0)  loo_count_pos++; else loo_count_neg++;	if(verbosity==1) {	  printf("-"); fflush(stdout); 	}      }      else {	loocomputed++;	heldout_c=learn_parm->svm_cost[heldout]; //set upper bound to zero //	learn_parm->svm_cost[heldout]=0;	// make sure heldout example is not currently  //	// shrunk away. Assumes that lin is up to date! //	shrink_state.active[heldout]=1;  	if(verbosity>=2) 	  printf("\nLeave-One-Out test on example %ld\n",heldout);	if(verbosity>=1) {	  printf("(?[%ld]",heldout); fflush(stdout); 	}		optimize_to_convergence(docs,label,totdoc,totwords,learn_parm,				kernel_parm,				kernel_cache,&shrink_state,model,inconsistent,unlabeled,				a,lin,c,&timing_profile,				&maxdiff,heldout,(long)2);	// printf("%.20f\n",(lin[heldout]-model->b)*(double)label[heldout]); //	if(((lin[heldout]-model->b)*(double)label[heldout]) <= 0.0) { 	  loo_count++;                            // there was a loo-error //	  if(label[heldout] > 0)  loo_count_pos++; else loo_count_neg++;	  if(verbosity>=1) {	    printf("-)"); fflush(stdout); 	  }	}	else {	  if(verbosity>=1) {	    printf("+)"); fflush(stdout); 	  }	}	// now we need to restore the original data set//	learn_parm->svm_cost[heldout]=heldout_c; // restore upper bound //      }    } // end of leave-one-out loop //    if(verbosity>=1) {      printf("\nRetrain on full problem"); fflush(stdout);     }    optimize_to_convergence(docs,label,totdoc,totwords,learn_parm,			    kernel_parm,			    kernel_cache,&shrink_state,model,inconsistent,unlabeled,			    a,lin,c,&timing_profile,			    &maxdiff,(long)-1,(long)1);    if(verbosity >= 1)       printf("done.\n");            // after all leave-one-out computed //    model->loo_error=100.0*loo_count/(double)totdoc;    model->loo_recall=(1.0-(double)loo_count_pos/(double)trainpos)*100.0;    model->loo_precision=(trainpos-loo_count_pos)/      (double)(trainpos-loo_count_pos+loo_count_neg)*100.0;    if(verbosity >= 1) {      fprintf(stdout,"Leave-one-out estimate of the error: error=%.2f%%\n",	      model->loo_error);      fprintf(stdout,"Leave-one-out estimate of the recall: recall=%.2f%%\n",	      model->loo_recall);      fprintf(stdout,"Leave-one-out estimate of the precision: precision=%.2f%%\n",	      model->loo_precision);      fprintf(stdout,"Actual leave-one-outs computed:  %ld (rho=%.2f)\n",	      loocomputed,learn_parm->rho);      printf("Runtime for leave-one-out in cpu-seconds: %.2f\n",	     (double)(get_runtime()-runtime_start_loo)/100.0);    }  }      if(learn_parm->alphafile[0])    write_alphas(learn_parm->alphafile,a,label,totdoc);    shrink_state_cleanup(&shrink_state);  free(label);  free(inconsistent);  free(unlabeled);  free(c);  free(a);  free(a_fullset);  free(xi_fullset);  free(lin);  free(learn_parm->svm_cost);}*/

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -