⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 svm_struct_api.c

📁 New training algorithm for linear classification SVMs that can be much faster than SVMlight for larg
💻 C
📖 第 1 页 / 共 4 页
字号:
    c.lhs=NULL;
    c.rhs=NULL;
    c.m=0;
  }
  else { /* add constraints so that all learned weights are
            positive. WARNING: Currently, they are positive only up to
            precision epsilon set by -e. */
    c.lhs=my_malloc(sizeof(DOC *)*sizePsi);
    c.rhs=my_malloc(sizeof(double)*sizePsi);
    for(i=0; i<sizePsi; i++) {
      words[0].wnum=i+1;
      words[0].weight=1.0;
      words[1].wnum=0;
      /* the following slackid is a hack. we will run into problems,
         if we have move than 1000000 slack sets (i.e. examples) */
      c.lhs[i]=create_example(i,0,1000000+i,1,create_svector(words,"",1.0));
      c.rhs[i]=0.0;
    }
  }
  return(c);
}

LABEL       classify_struct_example(PATTERN x, STRUCTMODEL *sm, 
				    STRUCT_LEARN_PARM *sparm)
{
  /* Finds the label yhat for pattern x that scores the highest
     according to the linear evaluation function in sm, especially the
     weights sm.w. The returned label is taken as the prediction of sm
     for the pattern x. The weights correspond to the features defined
     by psi() and range from index 1 to index sm->sizePsi. If the
     function cannot find a label, it shall return an empty label as
     recognized by the function empty_label(y). */
  LABEL y;
  int i;

  y.totdoc=x.totdoc;
  y.class=(double *)my_malloc(sizeof(double)*y.totdoc);
  /* simply classify by sign of inner product between example vector
     and weight vector */
  for(i=0;i<x.totdoc;i++) {
    y.class[i]=classify_example(sm->svm_model,x.doc[i]);
  }
  return(y);
}

LABEL       find_most_violated_constraint_slackrescaling(PATTERN x, LABEL y, 
						     STRUCTMODEL *sm, 
						     STRUCT_LEARN_PARM *sparm)
{
  /* Finds the label ybar for pattern x that that is responsible for
     the most violated constraint for the slack rescaling
     formulation. It has to take into account the scoring function in
     sm, especially the weights sm.w, as well as the loss
     function. The weights in sm.w correspond to the features defined
     by psi() and range from index 1 to index sm->sizePsi. Most simple
     is the case of the zero/one loss function. For the zero/one loss,
     this function should return the highest scoring label ybar, if
     ybar is unequal y; if it is equal to the correct label y, then
     the function shall return the second highest scoring label. If
     the function cannot find a label, it shall return an empty label
     as recognized by the function empty_label(y). */
  LABEL ybar;
  if((sparm->loss_function == ZEROONE) 
     || (sparm->loss_function == FONE) 
     || (sparm->loss_function == ERRORRATE)
     || (sparm->loss_function == PRBEP) 
     || (sparm->loss_function == PREC_K) 
     || (sparm->loss_function == REC_K)) {
    ybar=find_most_violated_constraint_thresholdmetric(x,y,sm,sparm,
						       sparm->loss_type);
  }
  else if((sparm->loss_function == SWAPPEDPAIRS)) {
    printf("ERROR: Slack-rescaling is not implemented for this loss function!\n");
    exit(1);
  }
  else {
    printf("ERROR: Unknown loss function '%d'.\n",sparm->loss_function);
    exit(1);
  }
  return(ybar);
}

LABEL       find_most_violated_constraint_marginrescaling(PATTERN x, LABEL y, 
						     STRUCTMODEL *sm, 
						     STRUCT_LEARN_PARM *sparm)
{
  /* Finds the label ybar for pattern x that that is responsible for
     the most violated constraint for the margin rescaling
     formulation. It has to take into account the scoring function in
     sm, especially the weights sm.w, as well as the loss
     function. The weights in sm.w correspond to the features defined
     by psi() and range from index 1 to index sm->sizePsi. Most simple
     is the case of the zero/one loss function. For the zero/one loss,
     this function should return the highest scoring label ybar, if
     ybar is unequal y; if it is equal to the correct label y, then
     the function shall return the second highest scoring label. If
     the function cannot find a label, it shall return an empty label
     as recognized by the function empty_label(y). */
  LABEL ybar;
  if(sparm->loss_function == ERRORRATE) 
    ybar=find_most_violated_constraint_errorrate(x,y,sm,sparm,
	  sparm->loss_type); 
  /* ybar=find_most_violated_constraint_thresholdmetric(x,y,sm,sparm,
     sparm->loss_type); */
  else if((sparm->loss_function == ZEROONE) 
     || (sparm->loss_function == FONE) 
     || (sparm->loss_function == PRBEP) 
     || (sparm->loss_function == PREC_K) 
     || (sparm->loss_function == REC_K)) 
    ybar=find_most_violated_constraint_thresholdmetric(x,y,sm,sparm,
						       sparm->loss_type);
  else if((sparm->loss_function == SWAPPEDPAIRS))
    ybar=find_most_violated_constraint_rankmetric(x,y,sm,sparm,
						  sparm->loss_type);
  else if((sparm->loss_function == AVGPREC))
    ybar=find_most_violated_constraint_avgprec(x,y,sm,sparm,
					       sparm->loss_type);
  else {
    printf("ERROR: Unknown loss function '%d'.\n",sparm->loss_function);
    exit(1);
  }
  return(ybar);
}

LABEL       find_most_violated_constraint_errorrate(PATTERN x, LABEL y, 
						    STRUCTMODEL *sm, 
						    STRUCT_LEARN_PARM *sparm,
						    int loss_type)
{
  /* Finds the most violated constraint for errorrate under
     margin-rescaling!!! */
  LABEL ybar;
  int i,totwords;
  double *score,valmax=0,diff,loss_step,*ortho_weights;
  MODEL svm_model;

  ybar.totdoc=x.totdoc;
  ybar.class=(double *)my_malloc(sizeof(double)*x.totdoc);
  score=(double *)my_malloc(sizeof(double)*(ybar.totdoc+1));

  totwords=sm->svm_model->totwords;
  svm_model=(*sm->svm_model);  

  /* For sparse kernel, replace weight vector with beta=gamma^T*L^-1 */
  if(sm->sparse_kernel_type>0) {
    svm_model.lin_weights=(double *)my_malloc(sizeof(double)*(totwords+1));
    ortho_weights=prod_nvector_ltmatrix(sm->svm_model->lin_weights+1,sm->invL);
    for(i=0;i<sm->invL->m;i++)
      svm_model.lin_weights[i+1]=ortho_weights[i];
    svm_model.lin_weights[0]=0;
    free_nvector(ortho_weights);
  }

  loss_step=100.0/x.totdoc;
  for(i=0;i<x.totdoc;i++) {
    score[i]=classify_example(&svm_model,x.doc[i]);
    diff=loss_step-2.0*y.class[i]*score[i];
    if(diff > 0) {
      ybar.class[i]=-y.class[i];
      valmax+=loss_step;
    }
    else {
      ybar.class[i]=y.class[i];
    }
    valmax+=ybar.class[i]*score[i];
  }

  /* Restore weight vector that was modified above */
  if(sm->sparse_kernel_type>0) {
    free(svm_model.lin_weights);
  }

  if(struct_verbosity >= 2) {
    printf("\n max_ybar {loss(y_i,ybar)+w*Psi(x,ybar)}=%f\n",valmax);
    SVECTOR *fy=psi(x,y,sm,sparm);
    SVECTOR *fybar=psi(x,ybar,sm,sparm);
    DOC *exy=create_example(0,0,1,1,fy);
    DOC *exybar=create_example(0,0,1,1,fybar);
    printf(" -> w*Psi(x,y_i)=%f, w*Psi(x,ybar)=%f\n",
	   classify_example(sm->svm_model,exy),
	   classify_example(sm->svm_model,exybar));
    free_example(exy,1);
    free_example(exybar,1);
  }

  free(score);

  return(ybar);
}
   


LABEL       find_most_violated_constraint_thresholdmetric(PATTERN x, LABEL y, 
						     STRUCTMODEL *sm, 
						     STRUCT_LEARN_PARM *sparm,
						     int loss_type)
{
  /* Finds the most violated constraint for metrics that are based on
     a threshold. */
  LABEL ybar;
  int i,nump,numn,start,prec_rec_k,totwords;
  double *score,*sump,*sumn;
  STRUCT_ID_SCORE *scorep,*scoren;
  int threshp=0,threshn=0;
  int a,d;
  double val,valmax,loss,score_y,*ortho_weights;
  MODEL svm_model;

  ybar.totdoc=x.totdoc;
  ybar.class=(double *)my_malloc(sizeof(double)*x.totdoc);
  score=(double *)my_malloc(sizeof(double)*(ybar.totdoc+1));
  scorep=(STRUCT_ID_SCORE *)my_malloc(sizeof(STRUCT_ID_SCORE)*(ybar.totdoc+1));
  scoren=(STRUCT_ID_SCORE *)my_malloc(sizeof(STRUCT_ID_SCORE)*(ybar.totdoc+1));
  sump=(double *)my_malloc(sizeof(double)*(ybar.totdoc+1));
  sumn=(double *)my_malloc(sizeof(double)*(ybar.totdoc+1));

  totwords=sm->svm_model->totwords;
  svm_model=(*sm->svm_model);  

  /* For sparse kernel, replace weight vector with beta=gamma^T*L^-1 */
  if(sm->sparse_kernel_type>0) {
    svm_model.lin_weights=(double *)my_malloc(sizeof(double)*(totwords+1));
    ortho_weights=prod_nvector_ltmatrix(sm->svm_model->lin_weights+1,sm->invL);
    for(i=0;i<sm->invL->m;i++)
      svm_model.lin_weights[i+1]=ortho_weights[i];
    svm_model.lin_weights[0]=0;
    free_nvector(ortho_weights);
  }

  nump=0;
  numn=0;
  for(i=0;i<x.totdoc;i++) {
    score[i]=fabs(y.class[i])*classify_example(&svm_model,x.doc[i]);
    if(y.class[i] > 0) {
      scorep[nump].score=score[i];
      scorep[nump].tiebreak=0;
      scorep[nump].id=i;
      nump++;
    }
    else {
      scoren[numn].score=score[i];
      scoren[numn].tiebreak=0;
      scoren[numn].id=i;
      numn++;
    }
  }

  /* Restore weight vector that was modified above */
  if(sm->sparse_kernel_type>0) {
    free(svm_model.lin_weights);
  }

  /* compute score of target label */
  score_y=0;
  if(loss_type==SLACK_RESCALING) {
    for(i=0;i<x.totdoc;i++) 
      /*      score_y+=y.class[i]*score[i]; */
      score_y+=score[i]; 
 }
 
  if(nump)
    qsort(scorep,nump,sizeof(STRUCT_ID_SCORE),comparedown);
  sump[0]=0;
  for(i=0;i<nump;i++) {
    sump[i+1]=sump[i]+scorep[i].score;
  }
  if(numn)
    qsort(scoren,numn,sizeof(STRUCT_ID_SCORE),compareup);
  sumn[0]=0;
  for(i=0;i<numn;i++) {
    sumn[i+1]=sumn[i]+scoren[i].score;
  }

  /* find max of loss(ybar,y)+score(ybar) for margin rescaling or max
     of loss(ybar,y)+loss*(score(ybar)-score(y)) for slack
     rescaling */
  valmax=0;
  start=1;
  prec_rec_k=(int)(nump*sparm->prec_rec_k_frac);
  if(prec_rec_k<1) prec_rec_k=1;
  for(a=0;a<=nump;a++) {
    for(d=0;d<=numn;d++) {
      if(sparm->loss_function == ZEROONE)
	loss=zeroone_loss(a,numn-d,nump-a,d);
      else if(sparm->loss_function == FONE)
	loss=fone_loss(a,numn-d,nump-a,d);
      else if(sparm->loss_function == ERRORRATE)
	loss=errorrate_loss(a,numn-d,nump-a,d);
      else if((sparm->loss_function == PRBEP) && (a+numn-d == nump))
	loss=prbep_loss(a,numn-d,nump-a,d);
      else if((sparm->loss_function == PREC_K) && (a+numn-d >= prec_rec_k))
	loss=prec_k_loss(a,numn-d,nump-a,d);
      else if((sparm->loss_function == REC_K) && (a+numn-d <= prec_rec_k)) 
	loss=rec_k_loss(a,numn-d,nump-a,d);
      else {
	loss=0;
      }
      if(loss > 0) {
	if(loss_type==SLACK_RESCALING) {
	  val=loss+loss*(sump[a]-(sump[nump]-sump[a])-sumn[d]+(sumn[numn]-sumn[d] - score_y));
	}
	else if(loss_type==MARGIN_RESCALING) {
	  val=loss+sump[a]-(sump[nump]-sump[a])-sumn[d]+(sumn[numn]-sumn[d]);
	}
	else {
	  printf("ERROR: Unknown loss type '%d'.\n",loss_type);
	  exit(1);
	}
	if((val > valmax) || (start)) {
	  start=0;
	  valmax=val;
	  threshp=a;
	  threshn=d;
	}
      }
    }
  }

  /* assign labels that maximize score */
  /*  for(i=0;i<nump;i++) {
    if(i<threshp) 
      ybar.class[scorep[i].id]=1;
    else 
      ybar.class[scorep[i].id]=-1;
  }
  for(i=0;i<numn;i++) {
    if(i<threshn) 
      ybar.class[scoren[i].id]=-1;
    else 
      ybar.class[scoren[i].id]=1;
      } */
  for(i=0;i<nump;i++) {
    if(i<threshp) 
      ybar.class[scorep[i].id]=y.class[scorep[i].id];
    else 
      ybar.class[scorep[i].id]=-y.class[scorep[i].id];
  }
  for(i=0;i<numn;i++) {
    if(i<threshn) 
      ybar.class[scoren[i].id]=y.class[scoren[i].id];
    else 
      ybar.class[scoren[i].id]=-y.class[scoren[i].id];
  }

  if(struct_verbosity >= 2) {
    if(loss_type==SLACK_RESCALING) 
      printf("\n max_ybar {loss(y_i,ybar)+loss(y_i,ybar)[w*Psi(x,ybar)-w*Psi(x,y)]}=%f\n",valmax);
    else
      printf("\n max_ybar {loss(y_i,ybar)+w*Psi(x,ybar)}=%f\n",valmax);
    SVECTOR *fy=psi(x,y,sm,sparm);
    SVECTOR *fybar=psi(x,ybar,sm,sparm);
    DOC *exy=create_example(0,0,1,1,fy);
    DOC *exybar=create_example(0,0,1,1,fybar);
    printf(" -> w*Psi(x,y_i)=%f, w*Psi(x,ybar)=%f\n",
	   classify_example(sm->svm_model,exy),
	   classify_example(sm->svm_model,exybar));
    free_example(exy,1);
    free_example(exybar,1);
  }
  free(score);
  free(scorep);
  free(scoren);
  free(sump);
  free(sumn);

  return(ybar);
}

LABEL       find_most_violated_constraint_rankmetric(PATTERN x, LABEL y, 
						     STRUCTMODEL *sm, 
						     STRUCT_LEARN_PARM *sparm,
						     int loss_type)
{
  /* Finds the most violated constraint for metrics that are based on
     a threshold. 
     WARNING: Currently only implemented for margin-rescaling!!! */
  LABEL ybar;
  long i,nump,numn,sump,sumn;
  double *score,*ortho_weights;
  STRUCT_ID_SCORE *scorep,*scoren,*predset;
  int totwords;
  MODEL svm_model;

  ybar.totdoc=x.totdoc;
  ybar.class=(double *)my_malloc(sizeof(double)*x.totdoc);
  score=(double *)my_malloc(sizeof(double)*(ybar.totdoc+1));
  scorep=(STRUCT_ID_SCORE *)my_malloc(sizeof(STRUCT_ID_SCORE)*(ybar.totdoc+1));
  scoren=(STRUCT_ID_SCORE *)my_malloc(sizeof(STRUCT_ID_SCORE)*(ybar.totdoc+1));
  predset=(STRUCT_ID_SCORE *)my_malloc(sizeof(STRUCT_ID_SCORE)*(ybar.totdoc+1));
  totwords=sm->svm_model->totwords;
  svm_model=(*sm->svm_model);  

  /* For sparse kernel, replace weight vector with beta=gamma^T*L^-1 */
  if(sm->sparse_kernel_type>0) {
    svm_model.lin_weights=(double *)my_malloc(sizeof(double)*(totwords+1));
    ortho_weights=prod_nvector_ltmatrix(sm->svm_model->lin_weights+1,sm->invL);
    for(i=0;i<sm->invL->m;i++)
      svm_model.lin_weights[i+1]=ortho_weights[i];
    svm_model.lin_weights[0]=0;
    free_nvector(ortho_weights);
  }

  nump=0;
  numn=0;
  for(i=0;i<x.totdoc;i++) {
    score[i]=classify_example(&svm_model,x.doc[i]);
    if(y.class[i] > 0) {
      scorep[nump].score=score[i];
      scorep[nump].tiebreak=0;
      scorep[nump].id=i;
      nump++;
    }
    else {
      scoren[numn].score=score[i];
      scorep[numn].tiebreak=0;
      scoren[numn].id=i;
      numn++;
    }
  }
  if(nump)
    qsort(scorep,nump,sizeof(STRUCT_ID_SCORE),comparedown);
  if(numn)
    qsort(scoren,numn,sizeof(STRUCT_ID_SCORE),comparedown);

  /* Restore weight vector that was modified above */
  if(sm->sparse_kernel_type>0) {
    free(svm_model.lin_weights);
  }

  /* find max of loss(ybar,y)+score(ybar) */
  if(sparm->loss_function == SWAPPEDPAIRS) { /* number of swapped pairs (ie. ROC Area) */
    for(i=0;i<nump;i++) {
      predset[i]=scorep[i];
      predset[i].score-=(0.5); 
    }
    for(i=0;i<numn;i++) {
      predset[nump+i]=scoren[i];
      predset[nump+i].score+=(0.5); 
    }
    qsort(predset,nump+numn,sizeof(STRUCT_ID_SCORE),comparedown);
    sump=0;
    sumn=0;
    for(i=0;i<numn+nump;i++) {
      if(y.class[predset[i].id] > 0) {

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -