⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 svm_struct_api.c

📁 SVMcfg: Learns a weighted context free grammar from examples. Training examples (e.g. for natural la
💻 C
📖 第 1 页 / 共 4 页
字号:
/***********************************************************************/
/*                                                                     */
/*   svm_struct_api.c       (modified for PCFG parsing)                */
/*                                                                     */
/*   Definition of API for attaching implementing SVM learning of      */
/*   structures (e.g. parsing, multi-label classification, HMM)        */ 
/*                                                                     */
/*   Based on CKY parser by Mark Johnson                               */
/*                                                                     */
/*   Author: Thorsten Joachims                                         */
/*   Date: 12.07.04                                                    */
/*                                                                     */
/*   Copyright (c) 2004  Thorsten Joachims - All rights reserved       */
/*                                                                     */
/*   This software is available for non-commercial use only. It must   */
/*   not be modified and distributed without prior permission of the   */
/*   author. The author is not responsible for implications from the   */
/*   use of this software.                                             */
/*                                                                     */
/***********************************************************************/

#include <stdio.h>
#include <string.h>
#include "svm_struct/svm_struct_common.h"
#include "svm_struct_api.h"
#include "svm_struct/svm_struct_learn.h"

#include <assert.h>
#include <math.h>
#include <stdlib.h>
#include <time.h>
#include <sys/time.h>
#include <unistd.h>
#include "local-trees.h"	/* local tree count format */
#include "mmm.h"		/* memory debugger */
#include "hash-string.h" 	/* hash tables and string-index tables */
#include "hash-templates.h"
#include "vindex.h"
/* #include "ftree.h" use a tree that can have a feature vector at each node */
#include "tree.h"
#include "ledge.h"
#include "grammar.h"
#include "hash.h"
#include "hash-templates.h"

#define RAND_SEED	time(0)

#define CKY_MAX         0
#define CKY_STRATIFIED  1

#define NO_ADD_NEW_SYMBOLS 0
#define ADD_NEW_SYMBOLS 1

#define CKY_INITSTRAT   200  /* initial size of hash table for strat chells */

#define MAXUNARYDEPTH   2    /* maximum number of unary rules to apply
				in chain */

#define CHART_SIZE(n)			(n)*((n)+1)/2
#define CHART_ENTRY(chart, i, j)	chart[(j)*((j)-1)/2+(i)]

typedef struct strat_cell {
  unsigned int   correct;
  unsigned int   size;
  struct bintree tree;
  FLOAT		 prob;
  struct strat_cell     *next;
} *strat_cell;

HASH_HEADER(strat_hash, long, strat_cell)
HASH_CODE(strat_hash, long, strat_cell, IDENTITY, NEQ, IDENTITY, NO_OP, NULL, NO_OP)

typedef struct chart_cell {
  struct bintree tree;
  FLOAT		 prob;
  unsigned int   nalt;
  int            present2;
  struct bintree tree2;
  FLOAT		 prob2;
  unsigned int   nalt2;
  strat_cell     strat_list;
  strat_hash     strat_hash;
} *chart_cell;


chart_cell
make_chart_cell(si_index label, bintree left, bintree right,
		FLOAT prob, unsigned int nalt, int correct, int mode)
{
  chart_cell c = MALLOC(sizeof(struct chart_cell));
  c->tree.label = label;
  c->tree.left = left;
  c->tree.right = right;
  c->prob = prob;
  c->nalt = nalt;
  c->present2 = 0;
  c->tree2.label = label;
  c->tree2.left = NULL;
  c->tree2.right = NULL;
  c->prob2 = 0;
  c->nalt2 = nalt;

  c->strat_list = NULL;
  c->strat_hash = NULL;
  if(mode == CKY_STRATIFIED) {
    c->strat_hash = make_strat_hash(CKY_INITSTRAT);
  }
  return c;
}


/* chart_cell_free() frees the memory associated with this chart cell.
 * A chart cell has a tree associated with it, but since every tree
 * node is associated with exactly one chart cell, only free the
 * top-most node of each tree.
 */
void chart_cell_free(chart_cell c)
{
  strat_cell s,snext;
  /* printf("cprt=%ld %ld %ld %ld %ld %ld\n",c,c->strat_index,c->strat_correct,c->strat_size,c->strat_tree,c->strat_prob);fflush(stdout); */
  if(c->strat_hash) free_strat_hash(c->strat_hash);
  s=c->strat_list;
  while(s) {
    snext=s->next;
    FREE(s);
    s=snext;
  }
  FREE(c);
}

HASH_HEADER(sihashcc, si_index, chart_cell)
HASH_CODE(sihashcc, si_index, chart_cell, IDENTITY, NEQ, IDENTITY, NO_OP, 
	  NULL, chart_cell_free)

typedef sihashcc *chart;

void count_local_trees(const tree tree, vihashl localtree_ht);
	/* adds local tree counts from local trees in tree to localtree_ht */

void write_local_trees(FILE *fh, const vihashl localtree_ht, si_t si);
	/* writes local tree hash table to stdout */

int tree_eq(const tree t1, const tree t2);

grammar create_grammar(vihashl localtree_ht, vihashl weightid, si_t si);

void chart_free(chart c, size_t n);

SVECTOR *collect_phi(bintree parse, STRUCTMODEL *sm, 
		     size_t lpos, size_t *rpos, size_t start, size_t end);
SVECTOR *phi_urule(long weightid, bintree child, STRUCTMODEL *sm, 
		   int lpos, int rpos, int start, int end);
SVECTOR *phi_brule(long weightid, bintree left, bintree right, STRUCTMODEL *sm,
		   int lpos, int mpos, int rpos, int start, int end);
int add_feature(WORD *feat, int pos, long fnum, long weight);
int encode_number(WORD *feat,int pos,long basefnum,long number,double weight,
		  long a, long b, long c, long d, long e, long f);
double urule_value(urule rule, bintree child, STRUCTMODEL *sm, 
		   int lpos, int rpos, int start, int end);
double brule_value(brule rule, bintree left, bintree right, STRUCTMODEL *sm, 
		   int lpos, int mpos, int rpos, int start, int end);

chart cky(struct vindex terms, struct ledges *l, STRUCTMODEL *sm, si_t si, 
	  int mode);
chart cky_wrap(struct vindex terms, tree correct_tree, STRUCTMODEL *sm, 
	       si_t si);
chart cky_maxloss(struct vindex terms, PATTERN x, LABEL y, STRUCTMODEL *sm, 
		  STRUCT_LEARN_PARM *sparm, double *loss);

double fone(int correct, int trueledgenum, int predledgenum);


void        svm_struct_learn_api_init(int argc, char* argv[])
{
  /* Called in learning part before anything else is done to allow
     any initializations that might be necessary. */
}

void        svm_struct_learn_api_exit()
{
  /* Called in learning part at the very end to allow any clean-up
     that might be necessary. */
}

void        svm_struct_classify_api_init(int argc, char* argv[])
{
  /* Called in prediction part before anything else is done to allow
     any initializations that might be necessary. */
}

void        svm_struct_classify_api_exit()
{
  /* Called in prediction part at the very end to allow any clean-up
     that might be necessary. */
}

SAMPLE      read_struct_examples(char *file, STRUCT_LEARN_PARM *sparm)
{
  /* Reads struct examples and returns them in sample. The number of
     examples must be written into sample.n */
  SAMPLE   sample;  /* sample */
  EXAMPLE  *examples;
  long     totsen;
  long     n;       /* number of examples */

  tree     t;
  si_t     si;
  FILE     *fp;

  if(!sparm->si)
    si = make_si(100);
  else
    si = sparm->si;

  n=0;
  totsen=0;
  examples=(EXAMPLE *)MALLOC(sizeof(EXAMPLE)*100000);   /* hack */
  if ((fp = fopen (file, "r")) == NULL)
    { perror (file); exit (1); } 
  while ((t = readtree_root(fp, si))) {
    tree p = collapse_identical_unary(t); free_tree(t);
    if(sparm->parent_annotation) {
      t = annotate_with_parent(p, si); free_tree(p);
      p=t;
    }
    examples[n].y.parse = p;
    examples[n].y.si = si;
    examples[n].x.sentence = tree_terms(examples[n].y.parse);
    examples[n].x.si = si;
    /*
      int i;
      display_tree(stdout,examples[n].y.parse, si, 0); 
      printf("\n");
      for(i=0;i<examples[n].x.sentence.n;i++) {
        printf(" %s",si_index_string(si, examples[n].x.sentence.e[i]));
      }
      printf("\n");
    */
    if(examples[n].x.sentence.n <= sparm->maxsentlen) {
      n++;
      if((struct_verbosity>=1) && ((n % 100) == 0)) { 
	printf("%ld..",n); 
	fflush(stdout);
      }
    }
    else {
      free_label(examples[n].y);
      free(examples[n].x.sentence.e);
    }
    totsen++;
  }
  fclose(fp);
  if(struct_verbosity>=1) { 
      printf("(read %ld out of %ld)..",n,totsen); 
      fflush(stdout);
  }
  sample.n=n;
  sample.examples=examples;
  return(sample);
}

void        init_struct_model(SAMPLE sample, STRUCTMODEL *sm, 
			      STRUCT_LEARN_PARM *sparm, LEARN_PARM *lparm, 
			      KERNEL_PARM *kparm)
{
  /* Initialize structmodel sm. The weight vector w does not need to be
     initialized, but you need to provide the maximum size of the
     feature space in sizePsi. This is the maximum number of different
     weights that can be learned. Later, the weight vector w will
     contain the learned weights for the model. */

  si_t     si;
  long     i;
  vihashl  localtree_ht = make_vihashl(1000);
  vihashl  weightid_ht = make_vihashl(1000);
  grammar  g;
  
  si=sample.examples[0].y.si;

  printf("Extracting grammar rules from training examples..."); fflush(stdout);
  for(i=0;i<sample.n;i++) {
    count_local_trees(sample.examples[i].y.parse, localtree_ht);
  }
  printf("done\n"); fflush(stdout);

  printf("Creating grammar..."); fflush(stdout);
  g = create_grammar(localtree_ht, weightid_ht, si);
  printf("done\n"); fflush(stdout);

  free_vihashl(localtree_ht);
  sm->grammar=g;
  sm->si=si;
  sm->weightid_ht=weightid_ht;
  sm->sizePsi=g.idMax+1;
  sm->sparm=sparm;
}

CONSTSET    init_struct_constraints(SAMPLE sample, STRUCTMODEL *sm, 
				    STRUCT_LEARN_PARM *sparm)
{
  /* Initializes the optimization problem. Typically, you do not need
     to change this function, since you want to start with an empty
     set of constraints. However, if for example you have constraints
     that certain weights need to be positive, you might put that in
     here. The constraints are represented as lhs[i]*w >= rhs[i]. lhs
     is an array of feature vectors, rhs is an array of doubles. m is
     the number of constraints. The function returns the initial
     set of constraints. */
  CONSTSET c;
  long     sizePsi=sm->sizePsi;
  long     i;
  WORD     words[2];

  if(1) { /* normal case: start with empty set of constraints */
    c.lhs=NULL;
    c.rhs=NULL;
    c.m=0;
  }
  else { /* add constraints so that all learned weights are
            positive. WARNING: Currently, they are positive only up to
            precision epsilon set by -e. */
    c.lhs=my_malloc(sizeof(DOC *)*sizePsi);
    c.rhs=my_malloc(sizeof(double)*sizePsi);
    for(i=0; i<sizePsi; i++) {
      words[0].wnum=i+1;
      words[0].weight=1.0;
      words[1].wnum=0;
      /* the following slackid is a hack. we will run into problems,
         if we have move than 1000000 slack sets (ie examples) */
      c.lhs[i]=create_example(i,0,1000000+i,1,create_svector(words,"",1.0));
      c.rhs[i]=0.0;
    }
  }
  return(c);
}

LABEL       classify_struct_example(PATTERN x, STRUCTMODEL *sm, 
				    STRUCT_LEARN_PARM *sparm)
{
  /* Finds the label yhat for pattern x that scores the highest
     according to the linear evaluation function in sm, especially the
     weights sm.w. The returned label is taken as the prediction of sm
     for the pattern x. The weights correspond to the features defined
     by psi() and range from index 1 to index sm->sizePsi. If the
     function cannot find a label, it shall return an empty label as
     recognized by the function empty_label(y). */
  LABEL   ybar;
  vindex  terms;
  grammar g;
  si_t    si;
  chart   c,c2=NULL;
  chart_cell	root_cell,root_cell2;
  FILE    *tracefp=NULL;
  FILE    *parsefp=NULL;
  double  *w;

  terms=&(x.sentence);
  w=sm->w;
  g=sm->grammar;
  si=sm->si;

  if (tracefp) {
    int i;
    fprintf(tracefp, "\nSentence:\n");
    for (i=0; i<terms->n; i++)
      fprintf(tracefp, " %s", si_index_string(si, terms->e[i]));
    fprintf(tracefp, "\n"); fflush(tracefp);
  }
     
  if (tracefp) { printf("Parsing sentence..."); fflush(stdout); }
  c = cky(*terms, NULL, sm, si, CKY_MAX);
  if (tracefp) /* cross-check parse against results from other parser*/
    c2 = cky_wrap(*terms, NULL, sm, si); 
  if (tracefp) { printf("done\n"); fflush(stdout); }

  /* fetch best root node */

  if (tracefp) { printf("Fetching root..."); fflush(stdout); }
  root_cell = sihashcc_ref(CHART_ENTRY(c, 0, terms->n),
			   si_string_index(si, ROOT));
  if (c2) { 
    root_cell2 = sihashcc_ref(CHART_ENTRY(c2, 0, terms->n),
			      si_string_index(si, ROOT));
  }
  if (tracefp) { printf("done\n"); fflush(stdout); }

  ybar.si=si;
  if (root_cell) {
    if (c2) { 
      if(fabs(root_cell->prob - root_cell2->prob) > 0.0000001) {
	printf("ERROR:\n");
	printf("prob=%g \t prob2=%g\n",root_cell->prob,root_cell2->prob);
	printf("Parse1: "); 
	write_tree(tracefp, bintree_tree(&root_cell->tree, si), si);
	printf("\nParse2: "); 
	write_tree(tracefp, bintree_tree(&root_cell2->tree, si), si);
	printf("\n"); 
	fflush(stdout); 
      }
      assert(fabs(root_cell->prob - root_cell2->prob) < 0.0000001); 
    }

    if (tracefp) { printf("Getting parse tree..."); fflush(stdout); }
    ybar.parse = bintree_tree(&root_cell->tree, si);
    if (tracefp) { printf("done\n"); fflush(stdout); }
    
    double logprob = (double) root_cell->prob;
    if (parsefp) { fprintf(parsefp, "Prob = %g ", logprob); fflush(stdout); }
    ybar.prob=logprob;
  }
  else {
    ybar.parse=NULL;
    fprintf(stdout, "Failed to parse\n");
    if (parsefp)
      fprintf(parsefp, "parse_failure.\n");
  }

  chart_free(c, terms->n);			/* free the chart */
  if(c2) chart_free(c2, terms->n); 		/* free the chart */

  return(ybar);    
}

LABEL       find_most_violated_constraint_slackrescaling(PATTERN x, LABEL y, 
						     STRUCTMODEL *sm, 
						     STRUCT_LEARN_PARM *sparm)
{
  /* Finds the label ybar for pattern x that that is responsible for
     the most violated constraint for the slack rescaling
     formulation. For linear slack variables, this is that label ybar
     that maximizes

            argmax_{ybar} loss(y,ybar)*(1-psi(x,y)+psi(x,ybar)) 

     Note that ybar may be equal to y (i.e. the max is 0), which is
     different from the algorithms described in
     [Tschantaridis/05]. Note that this argmax has to take into
     account the scoring function in sm, especially the weights sm.w,
     as well as the loss function, and whether linear or quadratic
     slacks are used. The weights in sm.w correspond to the features
     defined by psi() and range from index 1 to index
     sm->sizePsi. Most simple is the case of the zero/one loss
     function. For the zero/one loss, this function should return the
     highest scoring label ybar (which may be equal to the correct
     label y), or the second highest scoring label ybar, if
     Psi(x,ybar)>Psi(x,y)-1. If the function cannot find a label, it
     shall return an empty label as recognized by the function
     empty_label(y). */
  LABEL   ybar,ybar2;
  vindex  terms;
  grammar g;
  si_t    si;
  chart   c,c2;
  chart_cell	root_cell,root_cell2;
  FILE    *tracefp=NULL;
  FILE    *parsefp=NULL;
  double  *w;
  double  lossval;

  terms=&(x.sentence);
  w=sm->w;
  g=sm->grammar;
  si=sm->si;

  if (tracefp) {
    int i;
    fprintf(tracefp, "\nSentence:\n");
    for (i=0; i<terms->n; i++)
      fprintf(tracefp, " %s", si_index_string(si, terms->e[i]));
    fprintf(tracefp, "\n"); fflush(tracefp);
  }
     
  if (tracefp) { printf("Parsing sentence..."); fflush(stdout); }

  if(sparm->loss_function == 0) { 
    c = cky(*terms, NULL, sm, si, CKY_MAX); 
    /* c2 = cky_wrap(*terms, y.parse, sm, si); */
    c2=c;
    lossval=100.0; 
  }
  else  if(sparm->loss_function >= 1) { /* type 3: stratified losses */
    c = cky_maxloss(*terms, x, y, sm, sparm, &lossval); 
    c2=c;
  }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -