⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 svm_struct_api.c

📁 SVMcfg: Learns a weighted context free grammar from examples. Training examples (e.g. for natural la
💻 C
📖 第 1 页 / 共 4 页
字号:
		 FLOAT prob, int add_mode, si_t si)
{
  if(add_mode == NO_ADD_NEW_SYMBOLS) 
    if(!sihashcc_ref(chart_entry,label))
      return(0);

  chart_cell *cp = sihashcc_valuep(chart_entry, label);
  chart_cell cc = *cp;

  if (cc == NULL) {  /* construct a new chart entry */
    *cp = make_chart_cell(label, left, right, prob, 1, 0, CKY_MAX);
    return(1);
  }  
  
  /* we're dealing with an old chart entry */
  
  /* avoid cyclic chains of unary rules */
  if((!right)                /* unary rule is applied */
     && is_urule_cycle(left,label,si)) { 
    /* printf("found cycle with %ld\n",label); fflush(stdout); */
    return(0); 
  }

  assert(cc->tree.label==label);
  assert(cc->tree2.label==label);
  assert((cc->prob2 <= cc->prob) || (!cc->present2));

  /* check for duplicate rule application and side effects of unary rules */
  if((cc->tree.left == left) && (cc->tree.right == right)) {
    if(cc->prob < prob) { /* update prob of first */
      cc->prob=prob;
      return(1);
    }
    else 
      return(0);
  }

  if ((cc->prob2 >= prob) && (cc->present2))
    return(0);      /* second best chart cell entry is better than current */

  if (cc->prob > prob) { /* better than second best, but worse than best */
    cc->present2=1;
    cc->tree2.left = left;
    cc->tree2.right = right;
    cc->prob2 = prob;
    cc->nalt2 = 1;
    return(1);
    /* WARNING: Ties are not broken at random for second best */
  }

  if (cc->prob < prob) {  /* new entry is better than the best old one */
    /* move former best down to second best */
    cc->present2=1;
    cc->tree2.left = cc->tree.left;
    cc->tree2.right = cc->tree.right;
    cc->prob2 = cc->prob;
    cc->nalt2 = 1;
    /* make new one the best */
    cc->tree.left = left;
    cc->tree.right = right;
    cc->prob = prob;
    cc->nalt = 1;
    return(1);
  }

  /* old best and new entry have same probability */

  assert(cc->nalt<UINT_MAX);

  if (rand() > RAND_MAX/(++(cc->nalt))) {
    /* make new one the second best */
    cc->present2=1;
    cc->tree2.left = left;
    cc->tree2.right = right;
    cc->prob2 = prob;
    cc->nalt2 = 1;
    return(1);
  }

  /* make new one the best */
  /* move former best down to second best */
  cc->present2=1;
  cc->tree2.left = cc->tree.left;
  cc->tree2.right = cc->tree.right;
  cc->prob2 = cc->prob;
  cc->nalt2 = 1;
  /* make new one the best */
  cc->tree.left = left;
  cc->tree.right = right;
  cc->prob = prob;
  return(1);
}

int
add_edge_stratified(chart_cell cc, si_index label, bintree left, 
		    bintree right, FLOAT prob, int correct, int size, si_t si)
{
  long k;
  strat_cell *sp, ss;

  assert(label == cc->tree.label);

  /* avoid cyclic chains of unary rules */
  if((!right)                /* unary rule is applied */
     && is_urule_cycle(left,label,si)) { 
    /* printf("found cycle with %ld\n",label); fflush(stdout); */
    return(0); 
  }

  k=correct*1000+size;
  sp = strat_hash_valuep(cc->strat_hash, k);
  ss = *sp;
  if(ss == NULL) { /* new entry */
    *sp=MALLOC(sizeof(struct strat_cell));
    ss=*sp;
    ss->correct=correct;
    ss->size=size;
    ss->tree.label=label;
    ss->tree.left=left;
    ss->tree.right=right;
    ss->prob=prob;
    ss->next=cc->strat_list;  /* add to beginning of list */
    cc->strat_list=ss;      

    cc->tree.left=left;
    cc->tree.right=right;
    cc->prob=prob;

    return(1);
  }

  if (ss->prob < prob) { /* better than best in this strata so far */
    ss->tree.left=left;
    ss->tree.right=right;
    ss->prob=prob;
    if (cc->prob < prob) { /* better than best so far overall */
      cc->tree.left=left;
      cc->tree.right=right;
      cc->prob=prob;
    }
    return(1);
  }

  return(0);
}

int
add_edges(sihashcc chart_entry, si_index label, chart_cell cl, chart_cell cr,
	  double prob, int correct, int mode, int add_mode, si_t si)
{
  int changed=0;

  if (mode == CKY_MAX) {
    if(cr) {
      /* it is a binary rule */
      changed+=add_edge_zeroone(chart_entry, label, 
				&cl->tree, &cr->tree, 
				cl->prob + cr->prob + prob, add_mode, si);
      if (cr->present2)
	changed+=add_edge_zeroone(chart_entry, label, 
				  &cl->tree, &cr->tree2, 
				  cl->prob + cr->prob2 + prob, add_mode, si);
      if (cl->present2)
	changed+=add_edge_zeroone(chart_entry, label, 
				  &cl->tree2, &cr->tree, 
				  cl->prob2 + cr->prob + prob, add_mode, si);
    }
    else if(cl) {
      /* it is a unary rule */
	changed+=add_edge_zeroone(chart_entry, label, 
				  &cl->tree, NULL, 
				  cl->prob + prob, add_mode, si);
	if(cl->present2) 
	  changed+=add_edge_zeroone(chart_entry, label, 
				    &cl->tree2, NULL, 
				    cl->prob2 + prob, add_mode, si);
    }
    else {
      /* it is a terminal */
      changed+=add_edge_zeroone(chart_entry, label, NULL, NULL, 0.0, add_mode,
				si);
    }
  }
  else if (mode == CKY_STRATIFIED) {

    if(add_mode == NO_ADD_NEW_SYMBOLS) 
      if(!sihashcc_ref(chart_entry,label))
	return(0);

    strat_cell ls,rs;
    chart_cell *cp = sihashcc_valuep(chart_entry, label);
    chart_cell cc = *cp;
    if (cc == NULL) {  /* construct a new chart entry */
      *cp = make_chart_cell(label, NULL, NULL, 0.0, 1, 0, CKY_STRATIFIED);
      cc = *cp;
    }  
  
    if(cr) {
      /* it is a binary rule */
     for(ls=cl->strat_list;ls;ls=ls->next) {
	for(rs=cr->strat_list;rs;rs=rs->next) {
	  /* if((cc->prob-(ls->prob + rs->prob + prob)) < 1.0) * ONLY for mult-loss */
	    changed+=add_edge_stratified(cc, label, &ls->tree,
					 &rs->tree,
					 ls->prob + rs->prob + prob,
					 ls->correct + rs->correct + correct,
					 ls->size + rs->size + is_not_binarized(label, si), 
					 si);
	}
      }
    }
    else if(cl) {
      /* it is a unary rule */
      for(ls=cl->strat_list;ls;ls=ls->next) {
	/* if((cc->prob-(ls->prob + prob)) < 1.0)  ONLY for mult-loss */
	  changed+=add_edge_stratified(cc, label, &ls->tree, NULL,
				       ls->prob + prob,
				       ls->correct + correct, 
				       ls->size + is_not_binarized(label, si), 
				       si);
      }
    }
    else {
      /* it is a terminal */
      changed+=add_edge_stratified(cc, label, NULL, NULL, 0.0, 0, 0, si);
    }
  }

  return(changed);
}

int
apply_unary(sihashcc chart_entry, struct ledges *l, 
	    STRUCTMODEL *sm, si_t si, int mode, int add_mode, 
	    int lpos, int rpos, int start, int end)
{
  sihashursit	ursit;
  size_t	i;
  int           changed=0;
  int           correct=0;

  for (ursit=sihashursit_init(sm->grammar.urs); sihashursit_ok(ursit); 
       ursit = sihashursit_next(ursit)) {
    /* look up the rule's child category */
    chart_cell c = sihashcc_ref(chart_entry, ursit.key);	
    
    if (c) {			/* such categories exist in this cell */
      for (i=0; i<ursit.value.n; i++) {
	/* check if the label of rule is correct */
	if(l)
	  correct=member_ledges(l,remove_parent_from_label(ursit.value.e[i]->parent,si));

	changed+=add_edges(chart_entry, ursit.value.e[i]->parent, 
			   c, NULL, 
			   urule_value(ursit.value.e[i],&c->tree,sm,
				       lpos,rpos,start,end),
			   correct, mode, add_mode, si);
      }
    }
  }
  return(changed);
}

static void
apply_binary(sihashcc parent_entry, struct ledges *l,
	     sihashcc left_entry, sihashcc right_entry,
	     STRUCTMODEL *sm, si_t si, int mode, 
	     int lpos, int mpos, int rpos, int start, int end)
{
  sihashbrsit	brsit;
  size_t	i;
  int           correct=0;

  for (brsit=sihashbrsit_init(sm->grammar.brs); sihashbrsit_ok(brsit); 
       brsit = sihashbrsit_next(brsit)) {
    /* look up the rule's left category */
    chart_cell cl = sihashcc_ref(left_entry, brsit.key);
    if (cl)	/* such categories exist in this cell */
      for (i=0; i<brsit.value.n; i++) {
	chart_cell cr = sihashcc_ref(right_entry, brsit.value.e[i]->right);
	if (cr) {
	  /* check if the label of rule is correct */
	  if(l)
	    correct=member_ledges(l,remove_parent_from_label(brsit.value.e[i]->parent,si));
	  add_edges(parent_entry, brsit.value.e[i]->parent, 
		    cl, cr, 
		    brule_value(brsit.value.e[i],&cl->tree,&cr->tree,sm,
				lpos,mpos,rpos,start,end),
		    correct, mode, ADD_NEW_SYMBOLS, si);
	}
      }
  }
}

chart
cky(struct vindex terms, struct ledges *l, STRUCTMODEL *sm, si_t si, int mode)
{
  int left, right, mid, i;
  chart c;
  struct ledges *lsub=NULL;

  c = chart_make(terms.n);
  
  /* insert lexical items */
  for (left=0; left< (int) terms.n; left++) {
    right=left+1;
    si_index	label = terms.e[left];
    sihashcc	chart_entry = make_sihashcc(NLABELS);
    CHART_ENTRY(c, left, right) = chart_entry;
    add_edges(chart_entry, label, NULL, NULL, 0.0, 0, mode, ADD_NEW_SYMBOLS, 
	      si);
    /* 
       chart_cell  cell = make_chart_cell(label, NULL, NULL, 0.0, 1, 0, mode);
       sihashcc_set(chart_entry, label, cell);
    */

    /* find the correct labels for this ledge */
    if(l) lsub=filter_ledges(l,left,right);

    /* close under unary rules */
    for(i=0;(i<MAXUNARYDEPTH) && apply_unary(chart_entry,lsub,sm,si,mode,ADD_NEW_SYMBOLS,left,left+1,0,terms.n);i++);
    /* make sure there are no values that have not propagated all the way */
    for(i=0;apply_unary(chart_entry,lsub,sm,si,mode,NO_ADD_NEW_SYMBOLS,left,left+1,0,terms.n);i++);

    if(lsub) free_ledges(lsub);
  }

  for (right=2; right<=(int)terms.n; right++)
    for (left=right-2; left>=0; left--) {
      sihashcc chart_entry = make_sihashcc(CHART_CELLS);   
      CHART_ENTRY(c, left, right) = chart_entry;

      /* find the correct labels for this ledge */
      if(l) lsub=filter_ledges(l,left,right);

      for (mid=left+1; mid<right; mid++) 
	apply_binary(chart_entry, lsub, CHART_ENTRY(c,left,mid), 
		     CHART_ENTRY(c,mid,right), sm, si, mode,
		     left,mid,right,0,terms.n);

      for(i=0;(i<MAXUNARYDEPTH) && apply_unary(chart_entry,lsub,sm,si,mode,ADD_NEW_SYMBOLS,left,right,0,terms.n);i++);
      for(i=0;apply_unary(chart_entry,lsub,sm,si,mode,NO_ADD_NEW_SYMBOLS,left,right,0,terms.n);i++);

      if(lsub) free_ledges(lsub);
    }

  return c;
}

chart cky_wrap(struct vindex terms, tree correct_tree, STRUCTMODEL *sm, 
	       si_t si)
{
  int i=0,first=1;
  chart_cell root_cell;
  double bestprob=0;
  struct ledges *l;
  tree t;
  strat_cell ls;

  if(correct_tree) {
    t = remove_parent_annotation(correct_tree, si); /* just in case */
    l = tree_ledges(t);                             /* get correct ledges */
    free_tree(t);
  }
  else {
    l = NULL;
  }
  chart c=cky(terms, l, sm, si, CKY_STRATIFIED);

  root_cell = sihashcc_ref(CHART_ENTRY(c, 0, terms.n),
			   si_string_index(si, ROOT));
  if(root_cell) {
    for(ls=root_cell->strat_list;ls;ls=ls->next) {
      if((ls->prob > bestprob) || first) {
	bestprob=ls->prob;
	root_cell->tree=ls->tree;
	root_cell->prob=ls->prob;
	first=0;
      }
      i++;
    }
    /* printf("strat_n=%d ",i); */
  }

  if(l) free_ledges(l);

  return(c);
} 

chart cky_maxloss(struct vindex terms, PATTERN x, LABEL y,  
		  STRUCTMODEL *sm, STRUCT_LEARN_PARM *sparm, 
		  double *maxloss)
{
  int maxcorrect=0,maxsize=0;
  chart_cell root_cell;
  double maxviol;
  struct ledges *l;
  tree t;
  FILE    *tracefp=NULL;
  strat_cell ls;

  SVECTOR *v=psi(x,y,sm,sparm);
  DOC *ex=create_example(-1,-1,-1,1,v);
  double yprob=classify_example(sm->svm_model,ex);
  free_example(ex,1);

  t = remove_parent_annotation(y.parse, sm->si); /* just in case */
  l = tree_ledges(t);                            /* get correct ledges */
  free_tree(t);

  chart c=cky(terms, l, sm, sm->si, CKY_STRATIFIED);

  root_cell = sihashcc_ref(CHART_ENTRY(c, 0, terms.n),
			   si_string_index(sm->si, ROOT));
  (*maxloss)=0.0;
  maxviol=-1.0;
  if(root_cell) {
    for(ls=root_cell->strat_list;ls;ls=ls->next) {
      double prob=ls->prob;
      int correct=ls->correct;
      int size=ls->size;
      double loss=0;
      double viol=0;
      double chi=(yprob-prob);
      if((sparm->loss_function == 1) 
	 && (sparm->loss_type == SLACK_RESCALING) 
	 && (sparm->slack_norm == 1)) {
	loss=100.0*(1.0-fone(correct,l->n,size));
	viol=(1.0-chi)*loss;
      }
      else if((sparm->loss_function == 1) 
	      && (sparm->loss_type == SLACK_RESCALING) 
	      && (sparm->slack_norm == 2)) {
	loss=sqrt(100.0*(1.0-fone(correct,l->n,size)));
	viol=(1.0-chi)*loss;
      }
      else if((sparm->loss_function == 1) 
	      && (sparm->loss_type == MARGIN_RESCALING) 
	      && (sparm->slack_norm == 1)) {
	loss=100.0*(1.0-fone(correct,l->n,size));
	viol=loss-chi;
      }
      else if((sparm->loss_function == 1) 
	      && (sparm->loss_type == MARGIN_RESCALING) 
	      && (sparm->slack_norm == 2)) {
	loss=sqrt(100.0*(1.0-fone(correct,l->n,size)));
	viol=loss-chi;
      }
      else {
	assert(0);
      }
      if(viol > maxviol) {
	maxviol=viol;
	(*maxloss)=loss;
	maxcorrect=correct;
	maxsize=size;
	root_cell->tree=ls->tree;
	root_cell->prob=prob;
      }
    }
  }

  if(tracefp)
    fprintf(tracefp,"return: maxviol %g with loss %g at %d/%d of %ld\n",maxviol,(*maxloss),maxcorrect,maxsize,(long)l->n); 

  /* assert(maxviol>-0.000001); */ /* assertion assumes loss(y,y)=0 */

  free_ledges(l);

  return(c);
} 

double fone(int correct, int trueledgenum, int predledgenum) 
{
  if((trueledgenum == 0) || (predledgenum == 0))
    return(0.0);
  double precision=(double)correct/predledgenum;
  double recall=(double)correct/trueledgenum;
  return(2.0*precision*recall/(precision+recall));
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -