⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 svm_learn.c

📁 这是一个采用c++编写的用于机器学习文本分类的SVM算法的实现代码。
💻 C
📖 第 1 页 / 共 5 页
字号:
    }

    noshrink=0;
    if((!retrain) && (inactivenum>0) 
       && ((!learn_parm->skip_final_opt_check) 
	   || (kernel_parm->kernel_type == LINEAR))) { 
      if(((verbosity>=1) && (kernel_parm->kernel_type != LINEAR)) 
	 || (verbosity>=2)) {
	if(verbosity==1) {
	  printf("\n");
	}
	printf(" Checking optimality of inactive variables..."); 
	fflush(stdout);
      }
      t1=get_runtime();
      reactivate_inactive_examples(label,unlabeled,a,shrink_state,lin,c,totdoc,
				   totwords,iteration,learn_parm,inconsistent,
				   docs,kernel_parm,kernel_cache,model,aicache,
				   weights,maxdiff);
      /* Update to new active variables. */
      activenum=compute_index(shrink_state->active,totdoc,active2dnum);
      inactivenum=totdoc-activenum;
      /* reset watchdog */
      bestmaxdiff=(*maxdiff);
      bestmaxdiffiter=iteration;
      /* termination criterion */
      noshrink=1;
      retrain=0;
      if((*maxdiff) > learn_parm->epsilon_crit) 
	retrain=1;
      timing_profile->time_shrink+=get_runtime()-t1;
      if(((verbosity>=1) && (kernel_parm->kernel_type != LINEAR)) 
	 || (verbosity>=2)) {
	printf("done.\n");  fflush(stdout);
        printf(" Number of inactive variables = %ld\n",inactivenum);
      }		  
    }

    if((!retrain) && (learn_parm->epsilon_crit>(*maxdiff))) 
      learn_parm->epsilon_crit=(*maxdiff);
    if((!retrain) && (learn_parm->epsilon_crit>epsilon_crit_org)) {
      learn_parm->epsilon_crit/=2.0;
      retrain=1;
      noshrink=1;
    }
    if(learn_parm->epsilon_crit<epsilon_crit_org) 
      learn_parm->epsilon_crit=epsilon_crit_org;
    
    if(verbosity>=2) {
      printf(" => (%ld SV (incl. %ld SV at u-bound), max violation=%.5f)\n",
	     supvecnum,model->at_upper_bound,(*maxdiff)); 
      fflush(stdout);
    }
    if(verbosity>=3) {
      printf("\n");
    }

    if((!retrain) && (transduction)) {
      for(i=0;(i<totdoc);i++) {
	shrink_state->active[i]=1;
      }
      activenum=compute_index(shrink_state->active,totdoc,active2dnum);
      inactivenum=0;
      if(verbosity==1) printf("done\n");
      retrain=incorporate_unlabeled_examples(model,label,inconsistent,
					     unlabeled,a,lin,totdoc,
					     selcrit,selexam,key,
					     transductcycle,kernel_parm,
					     learn_parm);
      epsilon_crit_org=learn_parm->epsilon_crit;
      if(kernel_parm->kernel_type == LINEAR)
	learn_parm->epsilon_crit=1; 
      transductcycle++;
      /* reset watchdog */
      bestmaxdiff=(*maxdiff);
      bestmaxdiffiter=iteration;
    } 
    else if(((iteration % 10) == 0) && (!noshrink)) {
      activenum=shrink_problem(docs,learn_parm,shrink_state,kernel_parm,
			       active2dnum,last_suboptimal_at,iteration,totdoc,
			       maxl((long)(activenum/10),
				    maxl((long)(totdoc/500),100)),
			       a,inconsistent);
      inactivenum=totdoc-activenum;
      if((kernel_cache)
	 && (supvecnum>kernel_cache->max_elems)
	 && ((kernel_cache->activenum-activenum)>maxl((long)(activenum/10),500))) {
	kernel_cache_shrink(kernel_cache,totdoc,
			    minl((kernel_cache->activenum-activenum),
				 (kernel_cache->activenum-supvecnum)),
			    shrink_state->active); 
      }
    }

    if((!retrain) && learn_parm->remove_inconsistent) {
      if(verbosity>=1) {
	printf(" Moving training errors to inconsistent examples...");
	fflush(stdout);
      }
      if(learn_parm->remove_inconsistent == 1) {
	retrain=identify_inconsistent(a,label,unlabeled,totdoc,learn_parm,
				      &inconsistentnum,inconsistent); 
      }
      else if(learn_parm->remove_inconsistent == 2) {
	retrain=identify_misclassified(lin,label,unlabeled,totdoc,
				       model,&inconsistentnum,inconsistent); 
      }
      else if(learn_parm->remove_inconsistent == 3) {
	retrain=identify_one_misclassified(lin,label,unlabeled,totdoc,
				   model,&inconsistentnum,inconsistent);
      }
      if(retrain) {
	if(kernel_parm->kernel_type == LINEAR) { /* reinit shrinking */
	  learn_parm->epsilon_crit=2.0;
	} 
      }
      if(verbosity>=1) {
	printf("done.\n");
	if(retrain) {
	  printf(" Now %ld inconsistent examples.\n",inconsistentnum);
	}
      }
    }
  } /* end of loop */

  free(chosen);
  free(last_suboptimal_at);
  free(key);
  free(selcrit);
  free(selexam);
  free(a_old);
  free(aicache);
  free(working2dnum);
  free(active2dnum);
  free(qp.opt_ce);
  free(qp.opt_ce0);
  free(qp.opt_g);
  free(qp.opt_g0);
  free(qp.opt_xinit);
  free(qp.opt_low);
  free(qp.opt_up);
  free(weights);

  learn_parm->epsilon_crit=epsilon_crit_org; /* restore org */
  model->maxdiff=(*maxdiff);

  return(iteration);
}

long optimize_to_convergence_sharedslack(DOC **docs, long int *label, 
			     long int totdoc, 
			     long int totwords, LEARN_PARM *learn_parm, 
			     KERNEL_PARM *kernel_parm, 
			     KERNEL_CACHE *kernel_cache, 
			     SHRINK_STATE *shrink_state, MODEL *model, 
			     double *a, double *lin, double *c, 
			     TIMING *timing_profile, double *maxdiff)
     /* docs: Training vectors (x-part) */
     /* label: Training labels/value (y-part, zero if test example for
			      transduction) */
     /* totdoc: Number of examples in docs/label */
     /* totwords: Number of features (i.e. highest feature index) */
     /* learn_parm: Learning paramenters */
     /* kernel_parm: Kernel paramenters */
     /* kernel_cache: Initialized/partly filled Cache, if using a kernel. 
                      NULL if linear. */
     /* shrink_state: State of active variables */
     /* model: Returns learning result */
     /* a: alphas */
     /* lin: linear component of gradient */
     /* c: right hand side of inequalities (margin) */
     /* maxdiff: returns maximum violation of KT-conditions */
{
  long *chosen,*key,i,j,jj,*last_suboptimal_at,noshrink,*unlabeled;
  long *inconsistent,choosenum,already_chosen=0,iteration;
  long misclassified,supvecnum=0,*active2dnum,inactivenum;
  long *working2dnum,*selexam,*ignore;
  long activenum,retrain,maxslackid,slackset,jointstep;
  double criterion,eq_target;
  double *a_old,*alphaslack;
  long t0=0,t1=0,t2=0,t3=0,t4=0,t5=0,t6=0; /* timing */
  double epsilon_crit_org,maxsharedviol; 
  double bestmaxdiff;
  long   bestmaxdiffiter,terminate;

  double *selcrit;  /* buffer for sorting */        
  CFLOAT *aicache;  /* buffer to keep one row of hessian */
  double *weights;  /* buffer for weight vector in linear case */
  QP qp;            /* buffer for one quadratic program */
  double *slack;    /* vector of slack variables for optimization with
		       shared slacks */

  epsilon_crit_org=learn_parm->epsilon_crit; /* save org */
  if(kernel_parm->kernel_type == LINEAR) {
    learn_parm->epsilon_crit=2.0;
    kernel_cache=NULL;   /* caching makes no sense for linear kernel */
  } 
  learn_parm->epsilon_shrink=2;
  (*maxdiff)=1;

  learn_parm->totwords=totwords;

  chosen = (long *)my_malloc(sizeof(long)*totdoc);
  unlabeled = (long *)my_malloc(sizeof(long)*totdoc);
  inconsistent = (long *)my_malloc(sizeof(long)*totdoc);
  ignore = (long *)my_malloc(sizeof(long)*totdoc);
  key = (long *)my_malloc(sizeof(long)*(totdoc+11)); 
  selcrit = (double *)my_malloc(sizeof(double)*totdoc);
  selexam = (long *)my_malloc(sizeof(long)*totdoc);
  a_old = (double *)my_malloc(sizeof(double)*totdoc);
  aicache = (CFLOAT *)my_malloc(sizeof(CFLOAT)*totdoc);
  working2dnum = (long *)my_malloc(sizeof(long)*(totdoc+11));
  active2dnum = (long *)my_malloc(sizeof(long)*(totdoc+11));
  qp.opt_ce = (double *)my_malloc(sizeof(double)*learn_parm->svm_maxqpsize);
  qp.opt_ce0 = (double *)my_malloc(sizeof(double));
  qp.opt_g = (double *)my_malloc(sizeof(double)*learn_parm->svm_maxqpsize
				 *learn_parm->svm_maxqpsize);
  qp.opt_g0 = (double *)my_malloc(sizeof(double)*learn_parm->svm_maxqpsize);
  qp.opt_xinit = (double *)my_malloc(sizeof(double)*learn_parm->svm_maxqpsize);
  qp.opt_low=(double *)my_malloc(sizeof(double)*learn_parm->svm_maxqpsize);
  qp.opt_up=(double *)my_malloc(sizeof(double)*learn_parm->svm_maxqpsize);
  weights=(double *)my_malloc(sizeof(double)*(totwords+1));
  maxslackid=0;
  for(i=0;i<totdoc;i++) {    /* determine size of slack array */
    if(maxslackid<docs[i]->slackid)
      maxslackid=docs[i]->slackid;
  }
  slack=(double *)my_malloc(sizeof(double)*(maxslackid+1));
  alphaslack=(double *)my_malloc(sizeof(double)*(maxslackid+1));
  last_suboptimal_at = (long *)my_malloc(sizeof(long)*(maxslackid+1));
  for(i=0;i<=maxslackid;i++) {    /* init shared slacks */
    slack[i]=0;
    alphaslack[i]=0;
    last_suboptimal_at[i]=1;
  }

  choosenum=0;
  retrain=1;
  iteration=1;
  bestmaxdiffiter=1;
  bestmaxdiff=999999999;
  terminate=0;

  if(kernel_cache) {
    kernel_cache->time=iteration;  /* for lru cache */
    kernel_cache_reset_lru(kernel_cache);
  }

  for(i=0;i<totdoc;i++) {    /* various inits */
    chosen[i]=0;
    unlabeled[i]=0;
    inconsistent[i]=0;
    ignore[i]=0;
    a_old[i]=a[i];
  }
  activenum=compute_index(shrink_state->active,totdoc,active2dnum);
  inactivenum=totdoc-activenum;
  clear_index(working2dnum);

  /* call to init slack and alphaslack */
  compute_shared_slacks(docs,label,a,lin,c,active2dnum,learn_parm,
			slack,alphaslack);

                            /* repeat this loop until we have convergence */
  for(;retrain && (!terminate);iteration++) {

    if(kernel_cache)
      kernel_cache->time=iteration;  /* for lru cache */
    if(verbosity>=2) {
      printf(
	"Iteration %ld: ",iteration); fflush(stdout);
    }
    else if(verbosity==1) {
      printf("."); fflush(stdout);
    }

    if(verbosity>=2) t0=get_runtime();
    if(verbosity>=3) {
      printf("\nSelecting working set... "); fflush(stdout); 
    }

    if(learn_parm->svm_newvarsinqp>learn_parm->svm_maxqpsize) 
      learn_parm->svm_newvarsinqp=learn_parm->svm_maxqpsize;

    /* select working set according to steepest gradient */
    jointstep=0;
    eq_target=0;
    if(iteration % 101) {
      slackset=select_next_qp_slackset(docs,label,a,lin,slack,alphaslack,c,
				       learn_parm,active2dnum,&maxsharedviol);
      if((iteration % 2) 
	 || (!slackset) || (maxsharedviol<learn_parm->epsilon_crit)){
	/* do a step with examples from different slack sets */
	if(verbosity >= 2) {
	  printf("(i-step)"); fflush(stdout);
	}
	i=0;
	for(jj=0;(j=working2dnum[jj])>=0;jj++) { /* clear old part of working set */
	  if((chosen[j]>=(learn_parm->svm_maxqpsize/
			  minl(learn_parm->svm_maxqpsize,
			       learn_parm->svm_newvarsinqp)))) {
	    chosen[j]=0; 
	    choosenum--; 
	  }
	  else {
	    chosen[j]++;
	    working2dnum[i++]=j;
	  }
	}
	working2dnum[i]=-1;
	
	already_chosen=0;
	if((minl(learn_parm->svm_newvarsinqp,
		 learn_parm->svm_maxqpsize-choosenum)>=4) 
	   && (kernel_parm->kernel_type != LINEAR)) {
	  /* select part of the working set from cache */
	  already_chosen=select_next_qp_subproblem_grad(
			      label,unlabeled,a,lin,c,totdoc,
			      (long)(minl(learn_parm->svm_maxqpsize-choosenum,
					  learn_parm->svm_newvarsinqp)
				     /2),
			      learn_parm,inconsistent,active2dnum,
			      working2dnum,selcrit,selexam,kernel_cache,
			      (long)1,key,chosen);
	  choosenum+=already_chosen;
	}
	choosenum+=select_next_qp_subproblem_grad(
                              label,unlabeled,a,lin,c,totdoc,
                              minl(learn_parm->svm_maxqpsize-choosenum,
				   learn_parm->svm_newvarsinqp-already_chosen),
                              learn_parm,inconsistent,active2dnum,
			      working2dnum,selcrit,selexam,kernel_cache,
			      (long)0,key,chosen);
      }
      else { /* do a step with all examples from same slack set */
	if(verbosity >= 2) {
	  printf("(j-step on %ld)",slackset); fflush(stdout);
	}
	jointstep=1;
	for(jj=0;(j=working2dnum[jj])>=0;jj++) { /* clear working set */
	    chosen[j]=0; 
	}
	working2dnum[0]=-1;
	eq_target=alphaslack[slackset];
	for(j=0;j<totdoc;j++) {                  /* mask all but slackset */
	  /* for(jj=0;(j=active2dnum[jj])>=0;jj++) { */
	  if(docs[j]->slackid != slackset)
	    ignore[j]=1; 
	  else {
	    ignore[j]=0; 
	    learn_parm->svm_cost[j]=learn_parm->svm_c;
	    /* printf("Inslackset(%ld,%ld)",j,shrink_state->active[j]); */
	  }
	}
	learn_parm->biased_hyperplane=1;
	choosenum=select_next_qp_subproblem_grad(
                              label,unlabeled,a,lin,c,totdoc,
                              learn_parm->svm_maxqpsize,
                              learn_parm,ignore,active2dnum,
			      working2dnum,selcrit,selexam,kernel_cache,
			      (long)0,key,chosen);
	learn_parm->biased_hyperplane=0;
      }
    }
    else { /

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -