⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 svm_struct_api.c

📁 SVMcfg: Learns a weighted context free grammar from examples. Training examples (e.g. for natural la
💻 C
📖 第 1 页 / 共 4 页
字号:
  else{
    c=NULL;
    c2=NULL;
    assert(NULL);
  }
  if (tracefp) { printf("done\n"); fflush(stdout); }

  /* fetch best root node */

  if (tracefp) { printf("Fetching root..."); fflush(stdout); }
  root_cell = sihashcc_ref(CHART_ENTRY(c, 0, terms->n),
			   si_string_index(si, ROOT));
  root_cell2 = sihashcc_ref(CHART_ENTRY(c2, 0, terms->n),
			    si_string_index(si, ROOT));
  if (tracefp) { printf("done\n"); fflush(stdout); }

  ybar.si=si;
  if (root_cell) {
    assert(fabs(root_cell->prob - root_cell2->prob) < 0.0000001); 

    if (tracefp) { printf("Getting parse tree..."); fflush(stdout); }
    ybar.parse = bintree_tree(&root_cell->tree, si);
    if (tracefp) { printf("done\n"); fflush(stdout); }
    
    double logprob = (double) root_cell->prob;
    if (parsefp) { fprintf(parsefp, "Prob = %g ", logprob); fflush(stdout); }
    ybar.prob=logprob;
    
    if(sparm->loss_function == 0) 
      ybar.loss=loss(y,ybar,sparm);
    else
      ybar.loss = lossval;

    { /* sanity check */
      SVECTOR *v=psi(x,ybar,sm,sparm);
      DOC *ex=create_example(-1,-1,-1,1,v);
      double logprobpsi=classify_example(sm->svm_model,ex);
      free_example(ex,0);
      if (parsefp) { fprintf(parsefp, "ProbPsi = %g ", logprobpsi); fflush(stdout); }
      if(fabs(logprob - logprobpsi) > 0.000001) 
	printf(" psi-cky mismatch\n"); 
      assert(fabs(logprob - logprobpsi) < 0.000001); 
      free_svector(v);
      if(root_cell->present2) {
	LABEL ybar2;
	ybar2.parse = bintree_tree(&root_cell->tree2, si);
	assert(root_cell->prob2 <= root_cell->prob);
	double logprob = (double) root_cell->prob2;
	v=psi(x,ybar2,sm,sparm);
	ex=create_example(-1,-1,-1,1,v);
	logprobpsi=classify_example(sm->svm_model,ex);
	free_example(ex,0);
	if (parsefp) { fprintf(parsefp, " > %g ", logprobpsi); fflush(stdout); }
	if(fabs(logprob - logprobpsi) > 0.0000001)
	  printf(" psi-cky2 mismatch\n");
	if(tree_eq(ybar.parse,ybar2.parse) || (fabs(logprob - logprobpsi) > 0.0000001)) {
	  printf("\n   prob1=%f      prob2=%f\n",root_cell->prob,root_cell->prob2);
	  printf("\npsiprob1=%f   psiprob2=%f(%f)\n",classify_example(sm->svm_model,create_example(-1,-1,-1,1,psi(x,ybar,sm,sparm))),classify_example(sm->svm_model,create_example(-1,-1,-1,1,psi(x,ybar2,sm,sparm))),logprobpsi);
	  write_bintree(stdout,&root_cell->tree, si);printf("\n");
	  write_bintree(stdout,&root_cell->tree2, si);printf("\n");
	  write_tree(stdout,ybar.parse, si);printf("\n");
	  write_tree(stdout,ybar2.parse, si);printf("\n");
	}
	assert(!tree_eq(ybar.parse,ybar2.parse)); 
	/* assert(fabs(logprob - logprobpsi) < 0.0000001);  */
	free_svector(v);
	free_tree(ybar2.parse);
      }
    }

    /*
    if(((sparm->loss_function >= 1)) && (ybar.loss == 0)) {
      free_label(ybar);
      ybar.parse=NULL;
    }
    */

    /* check whether to return second highest scoring parse for 0/1-loss */
    if(tree_eq(ybar.parse,y.parse) && (sparm->loss_function == 0)) { 
      if (tracefp) { 
	printf("Best parse is correct:\n"); 
	printf("Correct: "); 
	write_tree(tracefp, y.parse, si);
	printf("\nBest   : "); 
	write_tree(tracefp, ybar.parse, si);
	printf("\n"); 
	fflush(stdout); 
      }

      if(root_cell->present2) {
	ybar2.parse = bintree_tree(&root_cell->tree2, si);
	ybar2.si=si;
	ybar2.prob=root_cell->prob2;
	ybar2.loss=loss(y,ybar2,sparm);
	if (tracefp && ybar2.parse) { 
	  printf("2ndBest: "); 
	  write_tree(tracefp, ybar.parse, si);
	  printf("\n"); 
	  fflush(stdout); 
	}
	if(ybar.prob > (ybar2.prob + ybar2.loss)) {
	  free_label(ybar2); /* return y */
	}
	else { /* return second highest scoring parse */
	  free_label(ybar);
	  ybar=ybar2;
	}
      }
      else {
	if (tracefp) { 
	  printf("No 2nd best parse found!\n"); 
	  fflush(stdout); 
	}
      }
    }
    if (parsefp && ybar.parse)
      write_tree(parsefp, ybar.parse, si);
  }
  else {
    ybar.parse=NULL;
    fprintf(stderr, "Failed to parse\n");
    if (parsefp)
      fprintf(parsefp, "parse_failure.\n");
  }

  chart_free(c, terms->n);			/* free the chart */
  if(c != c2) chart_free(c2, terms->n);		/* free the chart */

  return(ybar);
}

LABEL       find_most_violated_constraint_marginrescaling(PATTERN x, LABEL y, 
						     STRUCTMODEL *sm, 
						     STRUCT_LEARN_PARM *sparm)
{
  /* Finds the label ybar for pattern x that that is responsible for
     the most violated constraint for the margin rescaling
     formulation. For linear slack variables, this is that label ybar
     that maximizes

            argmax_{ybar} loss(y,ybar)+psi(x,ybar)

     Note that ybar may be equal to y (i.e. the max is 0), which is
     different from the algorithms described in
     [Tschantaridis/05]. Note that this argmax has to take into
     account the scoring function in sm, especially the weights sm.w,
     as well as the loss function, and whether linear or quadratic
     slacks are used. The weights in sm.w correspond to the features
     defined by psi() and range from index 1 to index
     sm->sizePsi. Most simple is the case of the zero/one loss
     function. For the zero/one loss, this function should return the
     highest scoring label ybar (which may be equal to the correct
     label y), or the second highest scoring label ybar, if
     Psi(x,ybar)>Psi(x,y)-1. If the function cannot find a label, it
     shall return an empty label as recognized by the function
     empty_label(y). */
  LABEL ybar;

  ybar=find_most_violated_constraint_slackrescaling(x,y,sm,sparm);

  return(ybar);
}

int         empty_label(LABEL y)
{
  /* Returns true, if y is an empty label. An empty label might be
     returned by find_most_violated_constraint_???(x, y, sm) if there
     is no incorrect label that can be found for x, or if it is unable
     to label x at all */
  return(y.parse == NULL);
}

SVECTOR     *psi(PATTERN x, LABEL y, STRUCTMODEL *sm,
		 STRUCT_LEARN_PARM *sparm)
{
  /* Returns a feature vector describing the match between pattern x and
     label y. The feature vector is returned as an SVECTOR
     (i.e. pairs <featurenumber:featurevalue>), where the last pair has
     featurenumber 0 as a terminator. Featurenumbers start with 1 and end with
     sizePsi. This feature vector determines the linear evaluation
     function that is used to score labels. There will be one weight in
     sm.w for each feature. Note that psi has to match
     find_most_violated_constraint_???(x, y, sm) and vice versa. In
     particular, find_most_violated_constraint_???(x, y, sm) finds that
     ybar!=y that maximizes psi(x,ybar,sm)*sm.w (where * is the inner
     vector product) and the appropriate function of the loss.  */
  SVECTOR *psi;
  bintree btree;
  struct vindex  terms;
  size_t  rpos;

  terms=tree_terms(y.parse);
  btree=right_binarize(y.parse,sm->si);
  psi=collect_phi(btree,sm,0,&rpos,0,terms.n);
  free(terms.e);
  free_bintree(btree);

  return(psi);
}

double      loss(LABEL y, LABEL ybar, STRUCT_LEARN_PARM *sparm)
{
  /* loss for correct label y and predicted label ybar. The loss for
     y==ybar has to be zero. sparm->loss_function is set with the -l option. */
  if(sparm->loss_function == 0) { /* type 0 loss: 0/1 loss */
                                  /* return 0, if y==ybar. return 1 else */
    if(tree_eq(ybar.parse,y.parse))
      return(0);
    else 
      return(100);
  }
  else if((sparm->loss_function == 1)) {  
    /* type 1 loss: 1-fone */
    tree t = remove_parent_annotation(y.parse, y.si); 
    struct ledges *test_ledges = tree_ledges(t);   
    free_tree(t);
    t = remove_parent_annotation(ybar.parse, ybar.si); 
    struct ledges *parse_ledges = tree_ledges(t);   
    free_tree(t);
    int common_bracket_count = common_ledge_count(test_ledges, parse_ledges);
    double loss=1.0-fone(common_bracket_count,test_ledges->n,parse_ledges->n);
    free_ledges(test_ledges);
    free_ledges(parse_ledges);

    /* printf("loss: %g ybar.loss=%g\n",loss,ybar.loss); */
    /* assert(fabs(loss - ybar.loss) < 0.00000001); */
    return(100.0*loss);  
  }
  else {
    /* Put your code for different loss functions here. But then
       find_most_violated_constraint_???(x, y, sm) has to return the
       highest scoring label with the largest loss. */
    assert(0);
    return(0);
  }
}

int         finalize_iteration(double ceps, int cached_constraint,
			       SAMPLE sample, STRUCTMODEL *sm,
			       CONSTSET cset, double *alpha, 
			       STRUCT_LEARN_PARM *sparm)
{
  /* This function is called just before the end of each cutting plane iteration. ceps is the amount by which the most violated constraint found in the current iteration was violated. cached_constraint is true if the added constraint was constructed from the cache. If the return value is FALSE, then the algorithm is allowed to terminate. If it is TRUE, the algorithm will keep iterating even if the desired precision sparm->epsilon is already reached. */
  return(0);
}

void        print_struct_learning_stats(SAMPLE sample, STRUCTMODEL *sm,
					CONSTSET cset, double *alpha, 
					STRUCT_LEARN_PARM *sparm)
{
  /* This function is called after training and allows final touches to
     the model sm. But primarly it allows computing and printing any
     kind of statistic (e.g. training error) you might want. */
  long i,correct=0;
  LABEL ybar;
  EXAMPLE *ex;
  long n;
  double cumloss=0;
  /* double valy,valybar; */
  /* DOC *extemp; */
  /* SVECTOR *fy,*fybar; */

  n=sample.n;
  ex=sample.examples;

  if(struct_verbosity >= 1)
    printf("Classifying training examples");

  /* classify the training examples */
  for(i=0; i<n;i++) { 
    /* find best parse */
    ybar=classify_struct_example(ex[i].x,sm,sparm);
    /* assert(ybar.parse); */

    /* compute scores for sanity check */ 
    /*
    fy=psi(ex[i].x,ex[i].y,sm,sparm);
    extemp=create_example(-1,-1,-1,1,fy);
    valy=classify_example(sm->svm_model,extemp);
    free_example(extemp,1);
    fybar=psi(ex[i].x,ybar,sm,sparm);
    extemp=create_example(-1,-1,-1,1,fybar);
    valybar=classify_example(sm->svm_model,extemp);
    free_example(extemp,1);
    */

    if(tree_eq(ybar.parse,ex[i].y.parse)) {
      correct++;
      printf("+");fflush(stdout);
      /* assert(valy >= valybar); */
    }
    else {
      printf("-");fflush(stdout);
      cumloss+=loss(ex[i].y,ybar,sparm);
      /* assert(valy <= valybar); */
    }

    /*
    fprintf(stdout,"Prob = %g ", ybar.prob); fflush(stdout);
    fprintf(stdout,"Xi = %g = %g ", xi, sm->w[sm->sizePsi+1+i]/sqrt(2*sparm->C)); fflush(stdout);
    if(sparm->parent_annotation) {
      tree t = remove_parent_annotation(ybar.parse, sm->si);
      free_tree(ybar.parse);
      ybar.parse=t;
    }
    write_prolog_tree(stdout, ybar.parse, sm->si);
    */

    free_label(ybar);
  }
  printf("done\n");
  if(struct_verbosity>=1) {
    printf("Number of correct parses: %i out of %i (%.2f%%)\n",
	   (int)correct,(int)n, 100.0*correct/n);
    printf("Average training loss: %.4f\n", cumloss/n);
  }
}

void        print_struct_testing_stats(SAMPLE sample, STRUCTMODEL *sm,
				       STRUCT_LEARN_PARM *sparm, 
				       STRUCT_TEST_STATS *teststats)
{
  /* This function is called after making all test predictions in
     svm_struct_classify and allows computing and printing any kind of
     evaluation (e.g. precision/recall) you might want. You can use
     the function eval_prediction to accumulate the necessary
     statistics for each prediction. */
  if(verbosity >= 1) {
    printf("Fraction of parsed sentences: %.2f%% (%ld/%d)\n",
	   100.0*teststats->parsed_sentences/sample.n,
	   teststats->parsed_sentences,sample.n);
    printf("Labelled bracket precision: %.2f%% (%ld/%ld)\n", 
	    100.0*teststats->common_bracket_sum/teststats->parse_bracket_sum,
	    teststats->common_bracket_sum, teststats->parse_bracket_sum);
    printf("Labelled bracket recall: %.2f%% (%ld/%ld)\n", 
	    100.0*teststats->common_bracket_sum/teststats->test_bracket_sum,
	    teststats->common_bracket_sum, teststats->test_bracket_sum);
    printf("Labelled bracket F1: %.2f%%\n", 
	   100.0*fone(teststats->common_bracket_sum,
		      teststats->test_bracket_sum, 
		      teststats->parse_bracket_sum));
   }
}

void        eval_prediction(long exnum, EXAMPLE ex, LABEL ypred, 
			    STRUCTMODEL *sm, STRUCT_LEARN_PARM *sparm, 
			    STRUCT_TEST_STATS *teststats)
{
  /* This function allows you to accumlate statistic for how well the
     predicition matches the labeled example. It is called from
     svm_struct_classify. See also the function
     print_struct_testing_stats. */
  if(exnum == 0) { /* this is the first time the function is
		      called. So initialize the teststats */
    teststats->parsed_sentences=0;
    teststats->test_bracket_sum=0;
    teststats->parse_bracket_sum=0;
    teststats->common_bracket_sum=0;
  }

  if(!empty_label(ypred)) {
    teststats->parsed_sentences++;
    tree t = remove_parent_annotation(ex.y.parse, ex.y.si); 
    struct ledges *test_ledges=tree_ledges(t);
    free_tree(t);
    t = remove_parent_annotation(ypred.parse, ypred.si); 
    struct ledges *parse_ledges=tree_ledges(t);
    free_tree(t);
    teststats->common_bracket_sum+=common_ledge_count(test_ledges, parse_ledges);
    teststats->test_bracket_sum+=test_ledges->n;
    teststats->parse_bracket_sum+=parse_ledges->n;
    free_ledges(test_ledges);
    free_ledges(parse_ledges);
  }
}

void        write_struct_model(char *file, STRUCTMODEL *sm, 
			       STRUCT_LEARN_PARM *sparm)
{
  vihashlit hit;
  vindex    vi;
  long      weightid,i;
  FILE      *fp;
  char      file_svm[500],file_grammar[500];

  /* First write the SVM model */
  strcpy(file_svm,file);
  strcat(file_svm,".svm");
  write_model(file_svm,sm->svm_model);

  /* Then write the grammar */
  strcpy(file_grammar,file);
  strcat(file_grammar,".grammar");
  if ((fp = fopen (file_grammar, "w")) == NULL)
    { perror (file_grammar); exit (1); }
  fprintf(fp,"%ld # number of attributes in weight vector\n",sm->sizePsi); 
  fprintf(fp,"%d # loss function\n",sparm->loss_function);
  fprintf(fp,"%d # use parent annotation\n",sparm->parent_annotation);
  fprintf(fp,"%d # maximum sentence length\n",sparm->maxsentlen);
  fprintf(fp,"%d # use border features\n",sparm->feat_borders);
  fprintf(fp,"%d # use span length as feature\n",
	  sparm->feat_parent_span_length);
  fprintf(fp,"%d # use children span length as features\n",
	  sparm->feat_children_span_length);
  fprintf(fp,"%d # use difference in children span lengths as feature\n",
	  sparm->feat_diff_children_length);
  for (hit = vihashlit_init(sm->weightid_ht); vihashlit_ok(hit); hit = vihashlit_next(hit)) {
    vi=hit.key;
    weightid=vihashl_ref(sm->weightid_ht,vi);
    assert(vi->n);
    fprintf(fp,"%ld %s " REWRITES "",weightid,si_index_string(sm->si, vi->e[0]));
    for(i=1;i<vi->n;i++) {
      fprintf(fp," %s",si_index_string(sm->si, vi->e[i]));
    }
    fprintf(fp,"\n");
  }
  fclose(fp);
}

STRUCTMODEL read_struct_model(char *file, STRUCT_LEARN_PARM *sparm)
{
  /* Reads structural model sm from file file. This function is used
     only in the prediction module, not in the learning module. */
  STRUCTMODEL sm;
  vindex    vi;
  long      i,id;
  FILE      *fp;
  char      file_svm[500],file_grammar[500];
  sihashbrsit	bhit;
  sihashursit	uhit;

  /* First write the SVM model */
  strcpy(file_svm,file);
  strcat(file_svm,".svm");
  sm.svm_model=read_model(file_svm);

  /* Then read the grammar */
  strcpy(file_grammar,file);
  strcat(file_grammar,".grammar");
  if ((fp = fopen (file_grammar, "r")) == NULL)
    { perror (file_grammar); exit (1); }
  fscanf(fp,"%ld%*[^\n]\n",&sm.sizePsi); 
  fscanf(fp,"%d%*[^\n]\n", &sparm->loss_function);  
  fscanf(fp,"%d%*[^\n]\n", &sparm->parent_annotation);  
  fscanf(fp,"%d%*[^\n]\n", &sparm->maxsentlen);  
  fscanf(fp,"%d%*[^\n]\n", &sparm->feat_borders);  
  fscanf(fp,"%d%*[^\n]\n", &sparm->feat_parent_span_length);  
  fscanf(fp,"%d%*[^\n]\n", &sparm->feat_children_span_length);  
  fscanf(fp,"%d%*[^\n]\n", &sparm->feat_diff_children_length);  
  sm.si=make_si(1024);
  sm.grammar=read_grammar(fp,sm.si);
  sm.grammar.idMax=sm.sizePsi;
  sm.sparm=sparm;
  fclose(fp);

  /* Reconstruct the hashtable for mapping rules to feature numbers */
  sm.weightid_ht = make_vihashl(1000);
  for (bhit=sihashbrsit_init(sm.grammar.brs); sihashbrsit_ok(bhit); bhit=sihashbrsit_next(bhit)) 
    for (i=0; i<bhit.value.n; i++) {
      vi=make_vindex(4); 
      vi->n=3;
      vi->e[0]=bhit.value.e[i]->parent;
      vi->e[1]=bhit.value.e[i]->left;
      vi->e[2]=bhit.value.e[i]->right;
      id=(long)(bhit.value.e[i]->prob+0.1);
      vihashl_set(sm.weightid_ht,vi,id);	
      bhit.value.e[i]->weightid=id;
      vindex_free(vi);
    }
  for (uhit=sihashursit_init(sm.grammar.urs); sihashursit_ok(uhit); uhit=sihashursit_next(uhit)) 
    for (i=0; i<uhit.value.n; i++) {
      vi=make_vindex(4); 
      vi->n=2;
      vi->e[0]=uhit.value.e[i]->parent;
      vi->e[1]=uhit.value.e[i]->child;
      id=(long)(uhit.value.e[i]->prob+0.1);
      vihashl_set(sm.weightid_ht,vi,id);	
      uhit.value.e[i]->weightid=id;
      vindex_free(vi);
    }

  sparm->si=sm.si;
  return(sm);
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -