📄 svm_struct_learn.c
字号:
/***********************************************************************/
/* */
/* svm_struct_learn.c */
/* */
/* Basic algorithm for learning structured outputs (e.g. parses, */
/* sequences, multi-label classification) with a Support Vector */
/* Machine. */
/* */
/* Author: Thorsten Joachims */
/* Date: 26.06.06 */
/* */
/* Copyright (c) 2006 Thorsten Joachims - All rights reserved */
/* */
/* This software is available for non-commercial use only. It must */
/* not be modified and distributed without prior permission of the */
/* author. The author is not responsible for implications from the */
/* use of this software. */
/* */
/***********************************************************************/
#include "svm_struct_learn.h"
#include "svm_struct_common.h"
#include "../svm_struct_api.h"
#include <assert.h>
#define MAX(x,y) ((x) < (y) ? (y) : (x))
#define MIN(x,y) ((x) > (y) ? (y) : (x))
void svm_learn_struct(SAMPLE sample, STRUCT_LEARN_PARM *sparm,
LEARN_PARM *lparm, KERNEL_PARM *kparm,
STRUCTMODEL *sm)
{
int i,j;
int numIt=0;
long argmax_count=0;
long newconstraints=0, totconstraints=0, activenum=0;
int opti_round, *opti, fullround;
long old_numConst=0;
double epsilon,svmCnorm;
long tolerance,new_precision=1;
double lossval,factor,dist;
double margin=0;
double slack, *slacks, slacksum, ceps;
double dualitygap,modellength,alphasum;
long sizePsi;
double *alpha=NULL;
long *alphahist=NULL,optcount=0,lastoptcount=0;
CONSTSET cset;
SVECTOR *diff=NULL;
SVECTOR *fy, *fybar, *f, **fycache=NULL;
SVECTOR *slackvec;
WORD slackv[2];
MODEL *svmModel=NULL;
KERNEL_CACHE *kcache=NULL;
LABEL ybar;
DOC *doc;
long n=sample.n;
EXAMPLE *ex=sample.examples;
double rt_total=0, rt_opt=0, rt_init=0, rt_psi=0, rt_viol=0;
double rt1,rt2;
rt1=get_runtime();
init_struct_model(sample,sm,sparm,lparm,kparm);
sizePsi=sm->sizePsi+1; /* sm must contain size of psi on return */
/* initialize example selection heuristic */
opti=(int*)my_malloc(n*sizeof(int));
for(i=0;i<n;i++) {
opti[i]=0;
}
opti_round=0;
/* normalize regularization parameter C by the number of training examples */
svmCnorm=sparm->C/n;
if(sparm->slack_norm == 1) {
lparm->svm_c=svmCnorm; /* set upper bound C */
lparm->sharedslack=1;
}
else if(sparm->slack_norm == 2) {
lparm->svm_c=999999999999999.0; /* upper bound C must never be reached */
lparm->sharedslack=0;
if(kparm->kernel_type != LINEAR) {
printf("ERROR: Kernels are not implemented for L2 slack norm!");
fflush(stdout);
exit(0);
}
}
else {
printf("ERROR: Slack norm must be L1 or L2!"); fflush(stdout);
exit(0);
}
epsilon=100.0; /* start with low precision and
increase later */
tolerance=MIN(n/3,MAX(n/100,5));/* increase precision, whenever less
than that number of constraints
is not fulfilled */
lparm->biased_hyperplane=0; /* set threshold to zero */
cset=init_struct_constraints(sample, sm, sparm);
if(cset.m > 0) {
alpha=(double *)realloc(alpha,sizeof(double)*cset.m);
alphahist=(long *)realloc(alphahist,sizeof(long)*cset.m);
for(i=0; i<cset.m; i++) {
alpha[i]=0;
alphahist[i]=-1; /* -1 makes sure these constraints are never removed */
}
}
/* set initial model and slack variables*/
svmModel=(MODEL *)my_malloc(sizeof(MODEL));
lparm->epsilon_crit=epsilon;
if(kparm->kernel_type != LINEAR)
kcache=kernel_cache_init(MAX(cset.m,1),lparm->kernel_cache_size);
svm_learn_optimization(cset.lhs,cset.rhs,cset.m,sizePsi+n,
lparm,kparm,kcache,svmModel,alpha);
if(kcache)
kernel_cache_cleanup(kcache);
add_weight_vector_to_linear_model(svmModel);
sm->svm_model=svmModel;
sm->w=svmModel->lin_weights; /* short cut to weight vector */
/* create a cache of the feature vectors for the correct labels */
if(USE_FYCACHE) {
fycache=(SVECTOR **)malloc(n*sizeof(SVECTOR *));
for(i=0;i<n;i++) {
fy=psi(ex[i].x,ex[i].y,sm,sparm);
if(kparm->kernel_type == LINEAR) {
diff=add_list_ss(fy); /* store difference vector directly */
free_svector(fy);
fy=diff;
}
fycache[i]=fy;
}
}
rt_init+=MAX(get_runtime()-rt1,0);
rt_total+=MAX(get_runtime()-rt1,0);
/*****************/
/*** main loop ***/
/*****************/
do { /* iteratively increase precision */
epsilon=MAX(epsilon*0.49999999999,sparm->epsilon);
new_precision=1;
if(epsilon == sparm->epsilon) /* for final precision, find all SV */
tolerance=0;
lparm->epsilon_crit=epsilon/2; /* svm precision must be higher than eps */
if(struct_verbosity>=1)
printf("Setting current working precision to %g.\n",epsilon);
do { /* iteration until (approx) all SV are found for current
precision and tolerance */
opti_round++;
activenum=n;
do { /* go through examples that keep producing new constraints */
if(struct_verbosity>=1) {
printf("Iter %i (%ld active): ",++numIt,activenum);
fflush(stdout);
}
old_numConst=cset.m;
ceps=0;
fullround=(activenum == n);
for(i=0; i<n; i++) { /*** example loop ***/
rt1=get_runtime();
if(opti[i] != opti_round) {/* if the example is not shrunk
away, then see if it is necessary to
add a new constraint */
rt2=get_runtime();
argmax_count++;
if(sparm->loss_type == SLACK_RESCALING)
ybar=find_most_violated_constraint_slackrescaling(ex[i].x,
ex[i].y,sm,
sparm);
else
ybar=find_most_violated_constraint_marginrescaling(ex[i].x,
ex[i].y,sm,
sparm);
rt_viol+=MAX(get_runtime()-rt2,0);
if(empty_label(ybar)) {
if(opti[i] != opti_round) {
activenum--;
opti[i]=opti_round;
}
if(struct_verbosity>=2)
printf("no-incorrect-found(%i) ",i);
continue;
}
/**** get psi(y)-psi(ybar) ****/
rt2=get_runtime();
if(fycache)
fy=copy_svector(fycache[i]);
else
fy=psi(ex[i].x,ex[i].y,sm,sparm);
fybar=psi(ex[i].x,ybar,sm,sparm);
rt_psi+=MAX(get_runtime()-rt2,0);
/**** scale feature vector and margin by loss ****/
lossval=loss(ex[i].y,ybar,sparm);
if(sparm->slack_norm == 2)
lossval=sqrt(lossval);
if(sparm->loss_type == SLACK_RESCALING)
factor=lossval;
else /* do not rescale vector for */
factor=1.0; /* margin rescaling loss type */
for(f=fy;f;f=f->next)
f->factor*=factor;
for(f=fybar;f;f=f->next)
f->factor*=-factor;
margin=lossval;
/**** create constraint for current ybar ****/
append_svector_list(fy,fybar);/* append the two vector lists */
doc=create_example(cset.m,0,i+1,1,fy);
/**** compute slack for this example ****/
slack=0;
for(j=0;j<cset.m;j++)
if(cset.lhs[j]->slackid == i+1) {
if(sparm->slack_norm == 2) /* works only for linear kernel */
slack=MAX(slack,cset.rhs[j]
-(classify_example(svmModel,cset.lhs[j])
-sm->w[sizePsi+i]/(sqrt(2*svmCnorm))));
else
slack=MAX(slack,
cset.rhs[j]-classify_example(svmModel,cset.lhs[j]));
}
/**** if `error' add constraint and recompute ****/
dist=classify_example(svmModel,doc);
ceps=MAX(ceps,margin-dist-slack);
if(slack > (margin-dist+0.0001)) {
printf("\nWARNING: Slack of most violated constraint is smaller than slack of working\n");
printf(" set! There is probably a bug in 'find_most_violated_constraint_*'.\n");
printf("Ex %d: slack=%f, newslack=%f\n",i,slack,margin-dist);
/* exit(1); */
}
if((dist+slack)<(margin-epsilon)) {
if(struct_verbosity>=2)
{printf("(%i,eps=%.2f) ",i,margin-dist-slack); fflush(stdout);}
if(struct_verbosity==1)
{printf("."); fflush(stdout);}
/**** resize constraint matrix and add new constraint ****/
cset.m++;
cset.lhs=(DOC **)realloc(cset.lhs,sizeof(DOC *)*cset.m);
if(kparm->kernel_type == LINEAR) {
diff=add_list_ss(fy); /* store difference vector directly */
if(sparm->slack_norm == 1)
cset.lhs[cset.m-1]=create_example(cset.m-1,0,i+1,1,
copy_svector(diff));
else if(sparm->slack_norm == 2) {
/**** add squared slack variable to feature vector ****/
slackv[0].wnum=sizePsi+i;
slackv[0].weight=1/(sqrt(2*svmCnorm));
slackv[1].wnum=0; /*terminator*/
slackvec=create_svector(slackv,"",1.0);
cset.lhs[cset.m-1]=create_example(cset.m-1,0,i+1,1,
add_ss(diff,slackvec));
free_svector(slackvec);
}
free_svector(diff);
}
else { /* kernel is used */
if(sparm->slack_norm == 1)
cset.lhs[cset.m-1]=create_example(cset.m-1,0,i+1,1,
copy_svector(fy));
else if(sparm->slack_norm == 2)
exit(1);
}
cset.rhs=(double *)realloc(cset.rhs,sizeof(double)*cset.m);
cset.rhs[cset.m-1]=margin;
alpha=(double *)realloc(alpha,sizeof(double)*cset.m);
alpha[cset.m-1]=0;
alphahist=(long *)realloc(alphahist,sizeof(long)*cset.m);
alphahist[cset.m-1]=optcount;
newconstraints++;
totconstraints++;
}
else {
printf("+"); fflush(stdout);
if(opti[i] != opti_round) {
activenum--;
opti[i]=opti_round;
}
}
free_example(doc,0);
free_svector(fy); /* this also free's fybar */
free_label(ybar);
}
/**** get new QP solution ****/
if((newconstraints >= sparm->newconstretrain)
|| ((newconstraints > 0) && (i == n-1))
|| (new_precision && (i == n-1))) {
if(struct_verbosity>=1) {
printf("*");fflush(stdout);
}
rt2=get_runtime();
free_model(svmModel,0);
svmModel=(MODEL *)my_malloc(sizeof(MODEL));
/* Always get a new kernel cache. It is not possible to use the
same cache for two different training runs */
if(kparm->kernel_type != LINEAR)
kcache=kernel_cache_init(MAX(cset.m,1),lparm->kernel_cache_size);
/* Run the QP solver on cset. */
svm_learn_optimization(cset.lhs,cset.rhs,cset.m,sizePsi+n,
lparm,kparm,kcache,svmModel,alpha);
if(kcache)
kernel_cache_cleanup(kcache);
/* Always add weight vector, in case part of the kernel is
linear. If not, ignore the weight vector since its
content is bogus. */
add_weight_vector_to_linear_model(svmModel);
sm->svm_model=svmModel;
sm->w=svmModel->lin_weights; /* short cut to weight vector */
optcount++;
/* keep track of when each constraint was last
active. constraints marked with -1 are not updated */
for(j=0;j<cset.m;j++)
if((alphahist[j]>-1) && (alpha[j] != 0))
alphahist[j]=optcount;
rt_opt+=MAX(get_runtime()-rt2,0);
new_precision=0;
newconstraints=0;
}
rt_total+=MAX(get_runtime()-rt1,0);
} /* end of example loop */
rt1=get_runtime();
if(struct_verbosity>=1)
printf("(NumConst=%d, SV=%ld, CEps=%.4f, QPEps=%.4f)\n",cset.m,
svmModel->sv_num-1,ceps,svmModel->maxdiff);
/* Check if some of the linear constraints have not been
active in a while. Those constraints are then removed to
avoid bloating the working set beyond necessity. */
if(struct_verbosity>=2)
printf("Reducing working set...");fflush(stdout);
remove_inactive_constraints(&cset,alpha,optcount,alphahist,
MAX(50,optcount-lastoptcount));
lastoptcount=optcount;
if(struct_verbosity>=2)
printf("done. (NumConst=%d)\n",cset.m);
rt_total+=MAX(get_runtime()-rt1,0);
} while(activenum > 0); /* repeat until all examples produced no
constraint at least once */
} while(((cset.m - old_numConst) > tolerance) || (!fullround));
} while((epsilon > sparm->epsilon)
|| finalize_iteration(ceps,0,sample,sm,cset,alpha,sparm));
if(struct_verbosity>=1) {
/**** compute sum of slacks ****/
/**** WARNING: If positivity constraints are used, then the
maximum slack id is larger than what is allocated
below ****/
slacks=(double *)my_malloc(sizeof(double)*(n+1));
for(i=0; i<=n; i++) {
slacks[i]=0;
}
if(sparm->slack_norm == 1) {
for(j=0;j<cset.m;j++)
slacks[cset.lhs[j]->slackid]=MAX(slacks[cset.lhs[j]->slackid],
cset.rhs[j]-classify_example(svmModel,cset.lhs[j]));
}
else if(sparm->slack_norm == 2) {
for(j=0;j<cset.m;j++)
slacks[cset.lhs[j]->slackid]=MAX(slacks[cset.lhs[j]->slackid],
cset.rhs[j]
-(classify_example(svmModel,cset.lhs[j])
-sm->w[sizePsi+cset.lhs[j]->slackid-1]/(sqrt(2*svmCnorm))));
}
slacksum=0;
for(i=1; i<=n; i++)
slacksum+=slacks[i];
free(slacks);
alphasum=0;
for(i=0; i<cset.m; i++)
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -