📄 svm_struct_api.c
字号:
else{
c=NULL;
c2=NULL;
assert(NULL);
}
if (tracefp) { printf("done\n"); fflush(stdout); }
/* fetch best root node */
if (tracefp) { printf("Fetching root..."); fflush(stdout); }
root_cell = sihashcc_ref(CHART_ENTRY(c, 0, terms->n),
si_string_index(si, ROOT));
root_cell2 = sihashcc_ref(CHART_ENTRY(c2, 0, terms->n),
si_string_index(si, ROOT));
if (tracefp) { printf("done\n"); fflush(stdout); }
ybar.si=si;
if (root_cell) {
assert(fabs(root_cell->prob - root_cell2->prob) < 0.0000001);
if (tracefp) { printf("Getting parse tree..."); fflush(stdout); }
ybar.parse = bintree_tree(&root_cell->tree, si);
if (tracefp) { printf("done\n"); fflush(stdout); }
double logprob = (double) root_cell->prob;
if (parsefp) { fprintf(parsefp, "Prob = %g ", logprob); fflush(stdout); }
ybar.prob=logprob;
if(sparm->loss_function == 0)
ybar.loss=loss(y,ybar,sparm);
else
ybar.loss = lossval;
{ /* sanity check */
SVECTOR *v=psi(x,ybar,sm,sparm);
DOC *ex=create_example(-1,-1,-1,1,v);
double logprobpsi=classify_example(sm->svm_model,ex);
free_example(ex,0);
if (parsefp) { fprintf(parsefp, "ProbPsi = %g ", logprobpsi); fflush(stdout); }
if(fabs(logprob - logprobpsi) > 0.000001)
printf(" psi-cky mismatch\n");
assert(fabs(logprob - logprobpsi) < 0.000001);
free_svector(v);
if(root_cell->present2) {
LABEL ybar2;
ybar2.parse = bintree_tree(&root_cell->tree2, si);
assert(root_cell->prob2 <= root_cell->prob);
double logprob = (double) root_cell->prob2;
v=psi(x,ybar2,sm,sparm);
ex=create_example(-1,-1,-1,1,v);
logprobpsi=classify_example(sm->svm_model,ex);
free_example(ex,0);
if (parsefp) { fprintf(parsefp, " > %g ", logprobpsi); fflush(stdout); }
if(fabs(logprob - logprobpsi) > 0.0000001)
printf(" psi-cky2 mismatch\n");
if(tree_eq(ybar.parse,ybar2.parse) || (fabs(logprob - logprobpsi) > 0.0000001)) {
printf("\n prob1=%f prob2=%f\n",root_cell->prob,root_cell->prob2);
printf("\npsiprob1=%f psiprob2=%f(%f)\n",classify_example(sm->svm_model,create_example(-1,-1,-1,1,psi(x,ybar,sm,sparm))),classify_example(sm->svm_model,create_example(-1,-1,-1,1,psi(x,ybar2,sm,sparm))),logprobpsi);
write_bintree(stdout,&root_cell->tree, si);printf("\n");
write_bintree(stdout,&root_cell->tree2, si);printf("\n");
write_tree(stdout,ybar.parse, si);printf("\n");
write_tree(stdout,ybar2.parse, si);printf("\n");
}
assert(!tree_eq(ybar.parse,ybar2.parse));
/* assert(fabs(logprob - logprobpsi) < 0.0000001); */
free_svector(v);
free_tree(ybar2.parse);
}
}
/*
if(((sparm->loss_function >= 1)) && (ybar.loss == 0)) {
free_label(ybar);
ybar.parse=NULL;
}
*/
/* check whether to return second highest scoring parse for 0/1-loss */
if(tree_eq(ybar.parse,y.parse) && (sparm->loss_function == 0)) {
if (tracefp) {
printf("Best parse is correct:\n");
printf("Correct: ");
write_tree(tracefp, y.parse, si);
printf("\nBest : ");
write_tree(tracefp, ybar.parse, si);
printf("\n");
fflush(stdout);
}
if(root_cell->present2) {
ybar2.parse = bintree_tree(&root_cell->tree2, si);
ybar2.si=si;
ybar2.prob=root_cell->prob2;
ybar2.loss=loss(y,ybar2,sparm);
if (tracefp && ybar2.parse) {
printf("2ndBest: ");
write_tree(tracefp, ybar.parse, si);
printf("\n");
fflush(stdout);
}
if(ybar.prob > (ybar2.prob + ybar2.loss)) {
free_label(ybar2); /* return y */
}
else { /* return second highest scoring parse */
free_label(ybar);
ybar=ybar2;
}
}
else {
if (tracefp) {
printf("No 2nd best parse found!\n");
fflush(stdout);
}
}
}
if (parsefp && ybar.parse)
write_tree(parsefp, ybar.parse, si);
}
else {
ybar.parse=NULL;
fprintf(stderr, "Failed to parse\n");
if (parsefp)
fprintf(parsefp, "parse_failure.\n");
}
chart_free(c, terms->n); /* free the chart */
if(c != c2) chart_free(c2, terms->n); /* free the chart */
return(ybar);
}
LABEL find_most_violated_constraint_marginrescaling(PATTERN x, LABEL y,
STRUCTMODEL *sm,
STRUCT_LEARN_PARM *sparm)
{
/* Finds the label ybar for pattern x that that is responsible for
the most violated constraint for the margin rescaling
formulation. For linear slack variables, this is that label ybar
that maximizes
argmax_{ybar} loss(y,ybar)+psi(x,ybar)
Note that ybar may be equal to y (i.e. the max is 0), which is
different from the algorithms described in
[Tschantaridis/05]. Note that this argmax has to take into
account the scoring function in sm, especially the weights sm.w,
as well as the loss function, and whether linear or quadratic
slacks are used. The weights in sm.w correspond to the features
defined by psi() and range from index 1 to index
sm->sizePsi. Most simple is the case of the zero/one loss
function. For the zero/one loss, this function should return the
highest scoring label ybar (which may be equal to the correct
label y), or the second highest scoring label ybar, if
Psi(x,ybar)>Psi(x,y)-1. If the function cannot find a label, it
shall return an empty label as recognized by the function
empty_label(y). */
LABEL ybar;
ybar=find_most_violated_constraint_slackrescaling(x,y,sm,sparm);
return(ybar);
}
int empty_label(LABEL y)
{
/* Returns true, if y is an empty label. An empty label might be
returned by find_most_violated_constraint_???(x, y, sm) if there
is no incorrect label that can be found for x, or if it is unable
to label x at all */
return(y.parse == NULL);
}
SVECTOR *psi(PATTERN x, LABEL y, STRUCTMODEL *sm,
STRUCT_LEARN_PARM *sparm)
{
/* Returns a feature vector describing the match between pattern x and
label y. The feature vector is returned as an SVECTOR
(i.e. pairs <featurenumber:featurevalue>), where the last pair has
featurenumber 0 as a terminator. Featurenumbers start with 1 and end with
sizePsi. This feature vector determines the linear evaluation
function that is used to score labels. There will be one weight in
sm.w for each feature. Note that psi has to match
find_most_violated_constraint_???(x, y, sm) and vice versa. In
particular, find_most_violated_constraint_???(x, y, sm) finds that
ybar!=y that maximizes psi(x,ybar,sm)*sm.w (where * is the inner
vector product) and the appropriate function of the loss. */
SVECTOR *psi;
bintree btree;
struct vindex terms;
size_t rpos;
terms=tree_terms(y.parse);
btree=right_binarize(y.parse,sm->si);
psi=collect_phi(btree,sm,0,&rpos,0,terms.n);
free(terms.e);
free_bintree(btree);
return(psi);
}
double loss(LABEL y, LABEL ybar, STRUCT_LEARN_PARM *sparm)
{
/* loss for correct label y and predicted label ybar. The loss for
y==ybar has to be zero. sparm->loss_function is set with the -l option. */
if(sparm->loss_function == 0) { /* type 0 loss: 0/1 loss */
/* return 0, if y==ybar. return 1 else */
if(tree_eq(ybar.parse,y.parse))
return(0);
else
return(100);
}
else if((sparm->loss_function == 1)) {
/* type 1 loss: 1-fone */
tree t = remove_parent_annotation(y.parse, y.si);
struct ledges *test_ledges = tree_ledges(t);
free_tree(t);
t = remove_parent_annotation(ybar.parse, ybar.si);
struct ledges *parse_ledges = tree_ledges(t);
free_tree(t);
int common_bracket_count = common_ledge_count(test_ledges, parse_ledges);
double loss=1.0-fone(common_bracket_count,test_ledges->n,parse_ledges->n);
free_ledges(test_ledges);
free_ledges(parse_ledges);
/* printf("loss: %g ybar.loss=%g\n",loss,ybar.loss); */
/* assert(fabs(loss - ybar.loss) < 0.00000001); */
return(100.0*loss);
}
else {
/* Put your code for different loss functions here. But then
find_most_violated_constraint_???(x, y, sm) has to return the
highest scoring label with the largest loss. */
assert(0);
return(0);
}
}
int finalize_iteration(double ceps, int cached_constraint,
SAMPLE sample, STRUCTMODEL *sm,
CONSTSET cset, double *alpha,
STRUCT_LEARN_PARM *sparm)
{
/* This function is called just before the end of each cutting plane iteration. ceps is the amount by which the most violated constraint found in the current iteration was violated. cached_constraint is true if the added constraint was constructed from the cache. If the return value is FALSE, then the algorithm is allowed to terminate. If it is TRUE, the algorithm will keep iterating even if the desired precision sparm->epsilon is already reached. */
return(0);
}
void print_struct_learning_stats(SAMPLE sample, STRUCTMODEL *sm,
CONSTSET cset, double *alpha,
STRUCT_LEARN_PARM *sparm)
{
/* This function is called after training and allows final touches to
the model sm. But primarly it allows computing and printing any
kind of statistic (e.g. training error) you might want. */
long i,correct=0;
LABEL ybar;
EXAMPLE *ex;
long n;
double cumloss=0;
/* double valy,valybar; */
/* DOC *extemp; */
/* SVECTOR *fy,*fybar; */
n=sample.n;
ex=sample.examples;
if(struct_verbosity >= 1)
printf("Classifying training examples");
/* classify the training examples */
for(i=0; i<n;i++) {
/* find best parse */
ybar=classify_struct_example(ex[i].x,sm,sparm);
/* assert(ybar.parse); */
/* compute scores for sanity check */
/*
fy=psi(ex[i].x,ex[i].y,sm,sparm);
extemp=create_example(-1,-1,-1,1,fy);
valy=classify_example(sm->svm_model,extemp);
free_example(extemp,1);
fybar=psi(ex[i].x,ybar,sm,sparm);
extemp=create_example(-1,-1,-1,1,fybar);
valybar=classify_example(sm->svm_model,extemp);
free_example(extemp,1);
*/
if(tree_eq(ybar.parse,ex[i].y.parse)) {
correct++;
printf("+");fflush(stdout);
/* assert(valy >= valybar); */
}
else {
printf("-");fflush(stdout);
cumloss+=loss(ex[i].y,ybar,sparm);
/* assert(valy <= valybar); */
}
/*
fprintf(stdout,"Prob = %g ", ybar.prob); fflush(stdout);
fprintf(stdout,"Xi = %g = %g ", xi, sm->w[sm->sizePsi+1+i]/sqrt(2*sparm->C)); fflush(stdout);
if(sparm->parent_annotation) {
tree t = remove_parent_annotation(ybar.parse, sm->si);
free_tree(ybar.parse);
ybar.parse=t;
}
write_prolog_tree(stdout, ybar.parse, sm->si);
*/
free_label(ybar);
}
printf("done\n");
if(struct_verbosity>=1) {
printf("Number of correct parses: %i out of %i (%.2f%%)\n",
(int)correct,(int)n, 100.0*correct/n);
printf("Average training loss: %.4f\n", cumloss/n);
}
}
void print_struct_testing_stats(SAMPLE sample, STRUCTMODEL *sm,
STRUCT_LEARN_PARM *sparm,
STRUCT_TEST_STATS *teststats)
{
/* This function is called after making all test predictions in
svm_struct_classify and allows computing and printing any kind of
evaluation (e.g. precision/recall) you might want. You can use
the function eval_prediction to accumulate the necessary
statistics for each prediction. */
if(verbosity >= 1) {
printf("Fraction of parsed sentences: %.2f%% (%ld/%d)\n",
100.0*teststats->parsed_sentences/sample.n,
teststats->parsed_sentences,sample.n);
printf("Labelled bracket precision: %.2f%% (%ld/%ld)\n",
100.0*teststats->common_bracket_sum/teststats->parse_bracket_sum,
teststats->common_bracket_sum, teststats->parse_bracket_sum);
printf("Labelled bracket recall: %.2f%% (%ld/%ld)\n",
100.0*teststats->common_bracket_sum/teststats->test_bracket_sum,
teststats->common_bracket_sum, teststats->test_bracket_sum);
printf("Labelled bracket F1: %.2f%%\n",
100.0*fone(teststats->common_bracket_sum,
teststats->test_bracket_sum,
teststats->parse_bracket_sum));
}
}
void eval_prediction(long exnum, EXAMPLE ex, LABEL ypred,
STRUCTMODEL *sm, STRUCT_LEARN_PARM *sparm,
STRUCT_TEST_STATS *teststats)
{
/* This function allows you to accumlate statistic for how well the
predicition matches the labeled example. It is called from
svm_struct_classify. See also the function
print_struct_testing_stats. */
if(exnum == 0) { /* this is the first time the function is
called. So initialize the teststats */
teststats->parsed_sentences=0;
teststats->test_bracket_sum=0;
teststats->parse_bracket_sum=0;
teststats->common_bracket_sum=0;
}
if(!empty_label(ypred)) {
teststats->parsed_sentences++;
tree t = remove_parent_annotation(ex.y.parse, ex.y.si);
struct ledges *test_ledges=tree_ledges(t);
free_tree(t);
t = remove_parent_annotation(ypred.parse, ypred.si);
struct ledges *parse_ledges=tree_ledges(t);
free_tree(t);
teststats->common_bracket_sum+=common_ledge_count(test_ledges, parse_ledges);
teststats->test_bracket_sum+=test_ledges->n;
teststats->parse_bracket_sum+=parse_ledges->n;
free_ledges(test_ledges);
free_ledges(parse_ledges);
}
}
void write_struct_model(char *file, STRUCTMODEL *sm,
STRUCT_LEARN_PARM *sparm)
{
vihashlit hit;
vindex vi;
long weightid,i;
FILE *fp;
char file_svm[500],file_grammar[500];
/* First write the SVM model */
strcpy(file_svm,file);
strcat(file_svm,".svm");
write_model(file_svm,sm->svm_model);
/* Then write the grammar */
strcpy(file_grammar,file);
strcat(file_grammar,".grammar");
if ((fp = fopen (file_grammar, "w")) == NULL)
{ perror (file_grammar); exit (1); }
fprintf(fp,"%ld # number of attributes in weight vector\n",sm->sizePsi);
fprintf(fp,"%d # loss function\n",sparm->loss_function);
fprintf(fp,"%d # use parent annotation\n",sparm->parent_annotation);
fprintf(fp,"%d # maximum sentence length\n",sparm->maxsentlen);
fprintf(fp,"%d # use border features\n",sparm->feat_borders);
fprintf(fp,"%d # use span length as feature\n",
sparm->feat_parent_span_length);
fprintf(fp,"%d # use children span length as features\n",
sparm->feat_children_span_length);
fprintf(fp,"%d # use difference in children span lengths as feature\n",
sparm->feat_diff_children_length);
for (hit = vihashlit_init(sm->weightid_ht); vihashlit_ok(hit); hit = vihashlit_next(hit)) {
vi=hit.key;
weightid=vihashl_ref(sm->weightid_ht,vi);
assert(vi->n);
fprintf(fp,"%ld %s " REWRITES "",weightid,si_index_string(sm->si, vi->e[0]));
for(i=1;i<vi->n;i++) {
fprintf(fp," %s",si_index_string(sm->si, vi->e[i]));
}
fprintf(fp,"\n");
}
fclose(fp);
}
STRUCTMODEL read_struct_model(char *file, STRUCT_LEARN_PARM *sparm)
{
/* Reads structural model sm from file file. This function is used
only in the prediction module, not in the learning module. */
STRUCTMODEL sm;
vindex vi;
long i,id;
FILE *fp;
char file_svm[500],file_grammar[500];
sihashbrsit bhit;
sihashursit uhit;
/* First write the SVM model */
strcpy(file_svm,file);
strcat(file_svm,".svm");
sm.svm_model=read_model(file_svm);
/* Then read the grammar */
strcpy(file_grammar,file);
strcat(file_grammar,".grammar");
if ((fp = fopen (file_grammar, "r")) == NULL)
{ perror (file_grammar); exit (1); }
fscanf(fp,"%ld%*[^\n]\n",&sm.sizePsi);
fscanf(fp,"%d%*[^\n]\n", &sparm->loss_function);
fscanf(fp,"%d%*[^\n]\n", &sparm->parent_annotation);
fscanf(fp,"%d%*[^\n]\n", &sparm->maxsentlen);
fscanf(fp,"%d%*[^\n]\n", &sparm->feat_borders);
fscanf(fp,"%d%*[^\n]\n", &sparm->feat_parent_span_length);
fscanf(fp,"%d%*[^\n]\n", &sparm->feat_children_span_length);
fscanf(fp,"%d%*[^\n]\n", &sparm->feat_diff_children_length);
sm.si=make_si(1024);
sm.grammar=read_grammar(fp,sm.si);
sm.grammar.idMax=sm.sizePsi;
sm.sparm=sparm;
fclose(fp);
/* Reconstruct the hashtable for mapping rules to feature numbers */
sm.weightid_ht = make_vihashl(1000);
for (bhit=sihashbrsit_init(sm.grammar.brs); sihashbrsit_ok(bhit); bhit=sihashbrsit_next(bhit))
for (i=0; i<bhit.value.n; i++) {
vi=make_vindex(4);
vi->n=3;
vi->e[0]=bhit.value.e[i]->parent;
vi->e[1]=bhit.value.e[i]->left;
vi->e[2]=bhit.value.e[i]->right;
id=(long)(bhit.value.e[i]->prob+0.1);
vihashl_set(sm.weightid_ht,vi,id);
bhit.value.e[i]->weightid=id;
vindex_free(vi);
}
for (uhit=sihashursit_init(sm.grammar.urs); sihashursit_ok(uhit); uhit=sihashursit_next(uhit))
for (i=0; i<uhit.value.n; i++) {
vi=make_vindex(4);
vi->n=2;
vi->e[0]=uhit.value.e[i]->parent;
vi->e[1]=uhit.value.e[i]->child;
id=(long)(uhit.value.e[i]->prob+0.1);
vihashl_set(sm.weightid_ht,vi,id);
uhit.value.e[i]->weightid=id;
vindex_free(vi);
}
sparm->si=sm.si;
return(sm);
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -