⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 svm_learn.c

📁 这是一个采用c++编写的用于机器学习文本分类的SVM算法的实现代码。
💻 C
📖 第 1 页 / 共 5 页
字号:
    (void)compute_index(index,totdoc,index2dnum);
    slack=(double *)my_malloc(sizeof(double)*(maxslackid+1));
    alphaslack=(double *)my_malloc(sizeof(double)*(maxslackid+1));
    for(i=0;i<=maxslackid;i++) {    /* init shared slacks */
      slack[i]=0;
      alphaslack[i]=0;
    }
    compute_shared_slacks(docs,label,a,lin,c,index2dnum,learn_parm,
			  slack,alphaslack);
    loss=0;
    model->at_upper_bound=0;
    svsetnum=0;
    for(i=0;i<=maxslackid;i++) {    /* create full index */
      loss+=slack[i];
      if(alphaslack[i] > (learn_parm->svm_c - learn_parm->epsilon_a)) 
	model->at_upper_bound++;
      if(alphaslack[i] > learn_parm->epsilon_a)
	svsetnum++;
    }
    free(index);
    free(index2dnum);
    free(slack);
    free(alphaslack);
  }
  
  if((verbosity>=1) && (!learn_parm->skip_final_opt_check)) {
    if(learn_parm->sharedslack) {
      printf("Number of SV: %ld\n",
	     model->sv_num-1);
      printf("Number of non-zero slack variables: %ld (out of %ld)\n",
	     model->at_upper_bound,svsetnum);
      fprintf(stdout,"L1 loss: loss=%.5f\n",loss);
    }
    else {
      upsupvecnum=0;
      for(i=1;i<model->sv_num;i++) {
	if(fabs(model->alpha[i]) >= 
	   (learn_parm->svm_cost[(model->supvec[i])->docnum]-
	    learn_parm->epsilon_a)) 
	  upsupvecnum++;
      }
      printf("Number of SV: %ld (including %ld at upper bound)\n",
	     model->sv_num-1,upsupvecnum);
      fprintf(stdout,"L1 loss: loss=%.5f\n",loss);
    }
    example_length=estimate_sphere(model,kernel_parm); 
    fprintf(stdout,"Norm of longest example vector: |x|=%.5f\n",
	    length_of_longest_document_vector(docs,totdoc,kernel_parm));
  }
  if(verbosity>=1) {
    printf("Number of kernel evaluations: %ld\n",kernel_cache_statistic);
  }
    
  if(alpha) {
    for(i=0;i<totdoc;i++) {    /* copy final alphas */
      alpha[i]=a[i];
    }
  }
 
  if(learn_parm->alphafile[0])
    write_alphas(learn_parm->alphafile,a,label,totdoc);
  
  shrink_state_cleanup(&shrink_state);
  free(label);
  free(unlabeled);
  free(inconsistent);
  free(c);
  free(a);
  free(lin);
  free(learn_parm->svm_cost);
}


long optimize_to_convergence(DOC **docs, long int *label, long int totdoc, 
			     long int totwords, LEARN_PARM *learn_parm, 
			     KERNEL_PARM *kernel_parm, 
			     KERNEL_CACHE *kernel_cache, 
			     SHRINK_STATE *shrink_state, MODEL *model, 
			     long int *inconsistent, long int *unlabeled, 
			     double *a, double *lin, double *c, 
			     TIMING *timing_profile, double *maxdiff, 
			     long int heldout, long int retrain)
     /* docs: Training vectors (x-part) */
     /* label: Training labels/value (y-part, zero if test example for
			      transduction) */
     /* totdoc: Number of examples in docs/label */
     /* totwords: Number of features (i.e. highest feature index) */
     /* laern_parm: Learning paramenters */
     /* kernel_parm: Kernel paramenters */
     /* kernel_cache: Initialized/partly filled Cache, if using a kernel. 
                      NULL if linear. */
     /* shrink_state: State of active variables */
     /* model: Returns learning result */
     /* inconsistent: examples thrown out as inconstistent */
     /* unlabeled: test examples for transduction */
     /* a: alphas */
     /* lin: linear component of gradient */
     /* c: right hand side of inequalities (margin) */
     /* maxdiff: returns maximum violation of KT-conditions */
     /* heldout: marks held-out example for leave-one-out (or -1) */
     /* retrain: selects training mode (1=regular / 2=holdout) */
{
  long *chosen,*key,i,j,jj,*last_suboptimal_at,noshrink;
  long inconsistentnum,choosenum,already_chosen=0,iteration;
  long misclassified,supvecnum=0,*active2dnum,inactivenum;
  long *working2dnum,*selexam;
  long activenum;
  double criterion,eq;
  double *a_old;
  long t0=0,t1=0,t2=0,t3=0,t4=0,t5=0,t6=0; /* timing */
  long transductcycle;
  long transduction;
  double epsilon_crit_org; 
  double bestmaxdiff;
  long   bestmaxdiffiter,terminate;

  double *selcrit;  /* buffer for sorting */        
  CFLOAT *aicache;  /* buffer to keep one row of hessian */
  double *weights;  /* buffer for weight vector in linear case */
  QP qp;            /* buffer for one quadratic program */

  epsilon_crit_org=learn_parm->epsilon_crit; /* save org */
  if(kernel_parm->kernel_type == LINEAR) {
    learn_parm->epsilon_crit=2.0;
    kernel_cache=NULL;   /* caching makes no sense for linear kernel */
  } 
  learn_parm->epsilon_shrink=2;
  (*maxdiff)=1;

  learn_parm->totwords=totwords;

  chosen = (long *)my_malloc(sizeof(long)*totdoc);
  last_suboptimal_at = (long *)my_malloc(sizeof(long)*totdoc);
  key = (long *)my_malloc(sizeof(long)*(totdoc+11)); 
  selcrit = (double *)my_malloc(sizeof(double)*totdoc);
  selexam = (long *)my_malloc(sizeof(long)*totdoc);
  a_old = (double *)my_malloc(sizeof(double)*totdoc);
  aicache = (CFLOAT *)my_malloc(sizeof(CFLOAT)*totdoc);
  working2dnum = (long *)my_malloc(sizeof(long)*(totdoc+11));
  active2dnum = (long *)my_malloc(sizeof(long)*(totdoc+11));
  qp.opt_ce = (double *)my_malloc(sizeof(double)*learn_parm->svm_maxqpsize);
  qp.opt_ce0 = (double *)my_malloc(sizeof(double));
  qp.opt_g = (double *)my_malloc(sizeof(double)*learn_parm->svm_maxqpsize
				 *learn_parm->svm_maxqpsize);
  qp.opt_g0 = (double *)my_malloc(sizeof(double)*learn_parm->svm_maxqpsize);
  qp.opt_xinit = (double *)my_malloc(sizeof(double)*learn_parm->svm_maxqpsize);
  qp.opt_low=(double *)my_malloc(sizeof(double)*learn_parm->svm_maxqpsize);
  qp.opt_up=(double *)my_malloc(sizeof(double)*learn_parm->svm_maxqpsize);
  weights=(double *)my_malloc(sizeof(double)*(totwords+1));

  choosenum=0;
  inconsistentnum=0;
  transductcycle=0;
  transduction=0;
  if(!retrain) retrain=1;
  iteration=1;
  bestmaxdiffiter=1;
  bestmaxdiff=999999999;
  terminate=0;

  if(kernel_cache) {
    kernel_cache->time=iteration;  /* for lru cache */
    kernel_cache_reset_lru(kernel_cache);
  }

  for(i=0;i<totdoc;i++) {    /* various inits */
    chosen[i]=0;
    a_old[i]=a[i];
    last_suboptimal_at[i]=1;
    if(inconsistent[i]) 
      inconsistentnum++;
    if(unlabeled[i]) {
      transduction=1;
    }
  }
  activenum=compute_index(shrink_state->active,totdoc,active2dnum);
  inactivenum=totdoc-activenum;
  clear_index(working2dnum);

                            /* repeat this loop until we have convergence */
  for(;retrain && (!terminate);iteration++) {

    if(kernel_cache)
      kernel_cache->time=iteration;  /* for lru cache */
    if(verbosity>=2) {
      printf(
	"Iteration %ld: ",iteration); fflush(stdout);
    }
    else if(verbosity==1) {
      printf("."); fflush(stdout);
    }

    if(verbosity>=2) t0=get_runtime();
    if(verbosity>=3) {
      printf("\nSelecting working set... "); fflush(stdout); 
    }

    if(learn_parm->svm_newvarsinqp>learn_parm->svm_maxqpsize) 
      learn_parm->svm_newvarsinqp=learn_parm->svm_maxqpsize;

    i=0;
    for(jj=0;(j=working2dnum[jj])>=0;jj++) { /* clear working set */
      if((chosen[j]>=(learn_parm->svm_maxqpsize/
		      minl(learn_parm->svm_maxqpsize,
			   learn_parm->svm_newvarsinqp))) 
	 || (inconsistent[j])
	 || (j == heldout)) {
	chosen[j]=0; 
	choosenum--; 
      }
      else {
	chosen[j]++;
	working2dnum[i++]=j;
      }
    }
    working2dnum[i]=-1;

    if(retrain == 2) {
      choosenum=0;
      for(jj=0;(j=working2dnum[jj])>=0;jj++) { /* fully clear working set */
	chosen[j]=0; 
      }
      clear_index(working2dnum);
      for(i=0;i<totdoc;i++) { /* set inconsistent examples to zero (-i 1) */
	if((inconsistent[i] || (heldout==i)) && (a[i] != 0.0)) {
	  chosen[i]=99999;
	  choosenum++;
	  a[i]=0;
	}
      }
      if(learn_parm->biased_hyperplane) {
	eq=0;
	for(i=0;i<totdoc;i++) { /* make sure we fulfill equality constraint */
	  eq+=a[i]*label[i];
	}
	for(i=0;(i<totdoc) && (fabs(eq) > learn_parm->epsilon_a);i++) {
	  if((eq*label[i] > 0) && (a[i] > 0)) {
	    chosen[i]=88888;
	    choosenum++;
	    if((eq*label[i]) > a[i]) {
	      eq-=(a[i]*label[i]);
	      a[i]=0;
	    }
	    else {
	      a[i]-=(eq*label[i]);
	      eq=0;
	    }
	  }
	}
      }
      compute_index(chosen,totdoc,working2dnum);
    }
    else {      /* select working set according to steepest gradient */
      if(iteration % 101) {
        already_chosen=0;
	if((minl(learn_parm->svm_newvarsinqp,
		 learn_parm->svm_maxqpsize-choosenum)>=4) 
	   && (kernel_parm->kernel_type != LINEAR)) {
	  /* select part of the working set from cache */
	  already_chosen=select_next_qp_subproblem_grad(
			      label,unlabeled,a,lin,c,totdoc,
			      (long)(minl(learn_parm->svm_maxqpsize-choosenum,
					  learn_parm->svm_newvarsinqp)
				     /2),
			      learn_parm,inconsistent,active2dnum,
			      working2dnum,selcrit,selexam,kernel_cache,1,
			      key,chosen);
	  choosenum+=already_chosen;
	}
	choosenum+=select_next_qp_subproblem_grad(
                              label,unlabeled,a,lin,c,totdoc,
                              minl(learn_parm->svm_maxqpsize-choosenum,
				   learn_parm->svm_newvarsinqp-already_chosen),
                              learn_parm,inconsistent,active2dnum,
			      working2dnum,selcrit,selexam,kernel_cache,0,key,
			      chosen);
      }
      else { /* once in a while, select a somewhat random working set
		to get unlocked of infinite loops due to numerical
		inaccuracies in the core qp-solver */
	choosenum+=select_next_qp_subproblem_rand(
                              label,unlabeled,a,lin,c,totdoc,
                              minl(learn_parm->svm_maxqpsize-choosenum,
				   learn_parm->svm_newvarsinqp),
                              learn_parm,inconsistent,active2dnum,
			      working2dnum,selcrit,selexam,kernel_cache,key,
			      chosen,iteration);
      }
    }

    if(verbosity>=2) {
      printf(" %ld vectors chosen\n",choosenum); fflush(stdout); 
    }

    if(verbosity>=2) t1=get_runtime();

    if(kernel_cache) 
      cache_multiple_kernel_rows(kernel_cache,docs,working2dnum,
				 choosenum,kernel_parm); 
    
    if(verbosity>=2) t2=get_runtime();
    if(retrain != 2) {
      optimize_svm(docs,label,unlabeled,inconsistent,0.0,chosen,active2dnum,
		   model,totdoc,working2dnum,choosenum,a,lin,c,learn_parm,
		   aicache,kernel_parm,&qp,&epsilon_crit_org);
    }

    if(verbosity>=2) t3=get_runtime();
    update_linear_component(docs,label,active2dnum,a,a_old,working2dnum,totdoc,
			    totwords,kernel_parm,kernel_cache,lin,aicache,
			    weights);

    if(verbosity>=2) t4=get_runtime();
    supvecnum=calculate_svm_model(docs,label,unlabeled,lin,a,a_old,c,
		                  learn_parm,working2dnum,active2dnum,model);

    if(verbosity>=2) t5=get_runtime();

    /* The following computation of the objective function works only */
    /* relative to the active variables */
    if(verbosity>=3) {
      criterion=compute_objective_function(a,lin,c,learn_parm->eps,label,
		                           active2dnum);
      printf("Objective function (over active variables): %.16f\n",criterion);
      fflush(stdout); 
    }

    for(jj=0;(i=working2dnum[jj])>=0;jj++) {
      a_old[i]=a[i];
    }

    if(retrain == 2) {  /* reset inconsistent unlabeled examples */
      for(i=0;(i<totdoc);i++) {
	if(inconsistent[i] && unlabeled[i]) {
	  inconsistent[i]=0;
	  label[i]=0;
	}
      }
    }

    retrain=check_optimality(model,label,unlabeled,a,lin,c,totdoc,learn_parm,
			     maxdiff,epsilon_crit_org,&misclassified,
			     inconsistent,active2dnum,last_suboptimal_at,
			     iteration,kernel_parm);

    if(verbosity>=2) {
      t6=get_runtime();
      timing_profile->time_select+=t1-t0;
      timing_profile->time_kernel+=t2-t1;
      timing_profile->time_opti+=t3-t2;
      timing_profile->time_update+=t4-t3;
      timing_profile->time_model+=t5-t4;
      timing_profile->time_check+=t6-t5;
    }

    /* checking whether optimizer got stuck */
    if((*maxdiff) < bestmaxdiff) {
      bestmaxdiff=(*maxdiff);
      bestmaxdiffiter=iteration;
    }
    if(iteration > (bestmaxdiffiter+learn_parm->maxiter)) { 
      /* long time no progress? */
      terminate=1;
      retrain=0;
      if(verbosity>=1) 
	printf("\nWARNING: Relaxing KT-Conditions due to slow progress! Terminating!\n");

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -