📄 svm_struct_api.c
字号:
c.lhs=NULL;
c.rhs=NULL;
c.m=0;
}
else { /* add constraints so that all learned weights are
positive. WARNING: Currently, they are positive only up to
precision epsilon set by -e. */
c.lhs=my_malloc(sizeof(DOC *)*sizePsi);
c.rhs=my_malloc(sizeof(double)*sizePsi);
for(i=0; i<sizePsi; i++) {
words[0].wnum=i+1;
words[0].weight=1.0;
words[1].wnum=0;
/* the following slackid is a hack. we will run into problems,
if we have move than 1000000 slack sets (i.e. examples) */
c.lhs[i]=create_example(i,0,1000000+i,1,create_svector(words,"",1.0));
c.rhs[i]=0.0;
}
}
return(c);
}
LABEL classify_struct_example(PATTERN x, STRUCTMODEL *sm,
STRUCT_LEARN_PARM *sparm)
{
/* Finds the label yhat for pattern x that scores the highest
according to the linear evaluation function in sm, especially the
weights sm.w. The returned label is taken as the prediction of sm
for the pattern x. The weights correspond to the features defined
by psi() and range from index 1 to index sm->sizePsi. If the
function cannot find a label, it shall return an empty label as
recognized by the function empty_label(y). */
LABEL y;
int i;
y.totdoc=x.totdoc;
y.class=(double *)my_malloc(sizeof(double)*y.totdoc);
/* simply classify by sign of inner product between example vector
and weight vector */
for(i=0;i<x.totdoc;i++) {
y.class[i]=classify_example(sm->svm_model,x.doc[i]);
}
return(y);
}
LABEL find_most_violated_constraint_slackrescaling(PATTERN x, LABEL y,
STRUCTMODEL *sm,
STRUCT_LEARN_PARM *sparm)
{
/* Finds the label ybar for pattern x that that is responsible for
the most violated constraint for the slack rescaling
formulation. It has to take into account the scoring function in
sm, especially the weights sm.w, as well as the loss
function. The weights in sm.w correspond to the features defined
by psi() and range from index 1 to index sm->sizePsi. Most simple
is the case of the zero/one loss function. For the zero/one loss,
this function should return the highest scoring label ybar, if
ybar is unequal y; if it is equal to the correct label y, then
the function shall return the second highest scoring label. If
the function cannot find a label, it shall return an empty label
as recognized by the function empty_label(y). */
LABEL ybar;
if((sparm->loss_function == ZEROONE)
|| (sparm->loss_function == FONE)
|| (sparm->loss_function == ERRORRATE)
|| (sparm->loss_function == PRBEP)
|| (sparm->loss_function == PREC_K)
|| (sparm->loss_function == REC_K)) {
ybar=find_most_violated_constraint_thresholdmetric(x,y,sm,sparm,
sparm->loss_type);
}
else if((sparm->loss_function == SWAPPEDPAIRS)) {
printf("ERROR: Slack-rescaling is not implemented for this loss function!\n");
exit(1);
}
else {
printf("ERROR: Unknown loss function '%d'.\n",sparm->loss_function);
exit(1);
}
return(ybar);
}
LABEL find_most_violated_constraint_marginrescaling(PATTERN x, LABEL y,
STRUCTMODEL *sm,
STRUCT_LEARN_PARM *sparm)
{
/* Finds the label ybar for pattern x that that is responsible for
the most violated constraint for the margin rescaling
formulation. It has to take into account the scoring function in
sm, especially the weights sm.w, as well as the loss
function. The weights in sm.w correspond to the features defined
by psi() and range from index 1 to index sm->sizePsi. Most simple
is the case of the zero/one loss function. For the zero/one loss,
this function should return the highest scoring label ybar, if
ybar is unequal y; if it is equal to the correct label y, then
the function shall return the second highest scoring label. If
the function cannot find a label, it shall return an empty label
as recognized by the function empty_label(y). */
LABEL ybar;
if(sparm->loss_function == ERRORRATE)
ybar=find_most_violated_constraint_errorrate(x,y,sm,sparm,
sparm->loss_type);
/* ybar=find_most_violated_constraint_thresholdmetric(x,y,sm,sparm,
sparm->loss_type); */
else if((sparm->loss_function == ZEROONE)
|| (sparm->loss_function == FONE)
|| (sparm->loss_function == PRBEP)
|| (sparm->loss_function == PREC_K)
|| (sparm->loss_function == REC_K))
ybar=find_most_violated_constraint_thresholdmetric(x,y,sm,sparm,
sparm->loss_type);
else if((sparm->loss_function == SWAPPEDPAIRS))
ybar=find_most_violated_constraint_rankmetric(x,y,sm,sparm,
sparm->loss_type);
else if((sparm->loss_function == AVGPREC))
ybar=find_most_violated_constraint_avgprec(x,y,sm,sparm,
sparm->loss_type);
else {
printf("ERROR: Unknown loss function '%d'.\n",sparm->loss_function);
exit(1);
}
return(ybar);
}
LABEL find_most_violated_constraint_errorrate(PATTERN x, LABEL y,
STRUCTMODEL *sm,
STRUCT_LEARN_PARM *sparm,
int loss_type)
{
/* Finds the most violated constraint for errorrate under
margin-rescaling!!! */
LABEL ybar;
int i,totwords;
double *score,valmax=0,diff,loss_step,*ortho_weights;
MODEL svm_model;
ybar.totdoc=x.totdoc;
ybar.class=(double *)my_malloc(sizeof(double)*x.totdoc);
score=(double *)my_malloc(sizeof(double)*(ybar.totdoc+1));
totwords=sm->svm_model->totwords;
svm_model=(*sm->svm_model);
/* For sparse kernel, replace weight vector with beta=gamma^T*L^-1 */
if(sm->sparse_kernel_type>0) {
svm_model.lin_weights=(double *)my_malloc(sizeof(double)*(totwords+1));
ortho_weights=prod_nvector_ltmatrix(sm->svm_model->lin_weights+1,sm->invL);
for(i=0;i<sm->invL->m;i++)
svm_model.lin_weights[i+1]=ortho_weights[i];
svm_model.lin_weights[0]=0;
free_nvector(ortho_weights);
}
loss_step=100.0/x.totdoc;
for(i=0;i<x.totdoc;i++) {
score[i]=classify_example(&svm_model,x.doc[i]);
diff=loss_step-2.0*y.class[i]*score[i];
if(diff > 0) {
ybar.class[i]=-y.class[i];
valmax+=loss_step;
}
else {
ybar.class[i]=y.class[i];
}
valmax+=ybar.class[i]*score[i];
}
/* Restore weight vector that was modified above */
if(sm->sparse_kernel_type>0) {
free(svm_model.lin_weights);
}
if(struct_verbosity >= 2) {
printf("\n max_ybar {loss(y_i,ybar)+w*Psi(x,ybar)}=%f\n",valmax);
SVECTOR *fy=psi(x,y,sm,sparm);
SVECTOR *fybar=psi(x,ybar,sm,sparm);
DOC *exy=create_example(0,0,1,1,fy);
DOC *exybar=create_example(0,0,1,1,fybar);
printf(" -> w*Psi(x,y_i)=%f, w*Psi(x,ybar)=%f\n",
classify_example(sm->svm_model,exy),
classify_example(sm->svm_model,exybar));
free_example(exy,1);
free_example(exybar,1);
}
free(score);
return(ybar);
}
LABEL find_most_violated_constraint_thresholdmetric(PATTERN x, LABEL y,
STRUCTMODEL *sm,
STRUCT_LEARN_PARM *sparm,
int loss_type)
{
/* Finds the most violated constraint for metrics that are based on
a threshold. */
LABEL ybar;
int i,nump,numn,start,prec_rec_k,totwords;
double *score,*sump,*sumn;
STRUCT_ID_SCORE *scorep,*scoren;
int threshp=0,threshn=0;
int a,d;
double val,valmax,loss,score_y,*ortho_weights;
MODEL svm_model;
ybar.totdoc=x.totdoc;
ybar.class=(double *)my_malloc(sizeof(double)*x.totdoc);
score=(double *)my_malloc(sizeof(double)*(ybar.totdoc+1));
scorep=(STRUCT_ID_SCORE *)my_malloc(sizeof(STRUCT_ID_SCORE)*(ybar.totdoc+1));
scoren=(STRUCT_ID_SCORE *)my_malloc(sizeof(STRUCT_ID_SCORE)*(ybar.totdoc+1));
sump=(double *)my_malloc(sizeof(double)*(ybar.totdoc+1));
sumn=(double *)my_malloc(sizeof(double)*(ybar.totdoc+1));
totwords=sm->svm_model->totwords;
svm_model=(*sm->svm_model);
/* For sparse kernel, replace weight vector with beta=gamma^T*L^-1 */
if(sm->sparse_kernel_type>0) {
svm_model.lin_weights=(double *)my_malloc(sizeof(double)*(totwords+1));
ortho_weights=prod_nvector_ltmatrix(sm->svm_model->lin_weights+1,sm->invL);
for(i=0;i<sm->invL->m;i++)
svm_model.lin_weights[i+1]=ortho_weights[i];
svm_model.lin_weights[0]=0;
free_nvector(ortho_weights);
}
nump=0;
numn=0;
for(i=0;i<x.totdoc;i++) {
score[i]=fabs(y.class[i])*classify_example(&svm_model,x.doc[i]);
if(y.class[i] > 0) {
scorep[nump].score=score[i];
scorep[nump].tiebreak=0;
scorep[nump].id=i;
nump++;
}
else {
scoren[numn].score=score[i];
scoren[numn].tiebreak=0;
scoren[numn].id=i;
numn++;
}
}
/* Restore weight vector that was modified above */
if(sm->sparse_kernel_type>0) {
free(svm_model.lin_weights);
}
/* compute score of target label */
score_y=0;
if(loss_type==SLACK_RESCALING) {
for(i=0;i<x.totdoc;i++)
/* score_y+=y.class[i]*score[i]; */
score_y+=score[i];
}
if(nump)
qsort(scorep,nump,sizeof(STRUCT_ID_SCORE),comparedown);
sump[0]=0;
for(i=0;i<nump;i++) {
sump[i+1]=sump[i]+scorep[i].score;
}
if(numn)
qsort(scoren,numn,sizeof(STRUCT_ID_SCORE),compareup);
sumn[0]=0;
for(i=0;i<numn;i++) {
sumn[i+1]=sumn[i]+scoren[i].score;
}
/* find max of loss(ybar,y)+score(ybar) for margin rescaling or max
of loss(ybar,y)+loss*(score(ybar)-score(y)) for slack
rescaling */
valmax=0;
start=1;
prec_rec_k=(int)(nump*sparm->prec_rec_k_frac);
if(prec_rec_k<1) prec_rec_k=1;
for(a=0;a<=nump;a++) {
for(d=0;d<=numn;d++) {
if(sparm->loss_function == ZEROONE)
loss=zeroone_loss(a,numn-d,nump-a,d);
else if(sparm->loss_function == FONE)
loss=fone_loss(a,numn-d,nump-a,d);
else if(sparm->loss_function == ERRORRATE)
loss=errorrate_loss(a,numn-d,nump-a,d);
else if((sparm->loss_function == PRBEP) && (a+numn-d == nump))
loss=prbep_loss(a,numn-d,nump-a,d);
else if((sparm->loss_function == PREC_K) && (a+numn-d >= prec_rec_k))
loss=prec_k_loss(a,numn-d,nump-a,d);
else if((sparm->loss_function == REC_K) && (a+numn-d <= prec_rec_k))
loss=rec_k_loss(a,numn-d,nump-a,d);
else {
loss=0;
}
if(loss > 0) {
if(loss_type==SLACK_RESCALING) {
val=loss+loss*(sump[a]-(sump[nump]-sump[a])-sumn[d]+(sumn[numn]-sumn[d] - score_y));
}
else if(loss_type==MARGIN_RESCALING) {
val=loss+sump[a]-(sump[nump]-sump[a])-sumn[d]+(sumn[numn]-sumn[d]);
}
else {
printf("ERROR: Unknown loss type '%d'.\n",loss_type);
exit(1);
}
if((val > valmax) || (start)) {
start=0;
valmax=val;
threshp=a;
threshn=d;
}
}
}
}
/* assign labels that maximize score */
/* for(i=0;i<nump;i++) {
if(i<threshp)
ybar.class[scorep[i].id]=1;
else
ybar.class[scorep[i].id]=-1;
}
for(i=0;i<numn;i++) {
if(i<threshn)
ybar.class[scoren[i].id]=-1;
else
ybar.class[scoren[i].id]=1;
} */
for(i=0;i<nump;i++) {
if(i<threshp)
ybar.class[scorep[i].id]=y.class[scorep[i].id];
else
ybar.class[scorep[i].id]=-y.class[scorep[i].id];
}
for(i=0;i<numn;i++) {
if(i<threshn)
ybar.class[scoren[i].id]=y.class[scoren[i].id];
else
ybar.class[scoren[i].id]=-y.class[scoren[i].id];
}
if(struct_verbosity >= 2) {
if(loss_type==SLACK_RESCALING)
printf("\n max_ybar {loss(y_i,ybar)+loss(y_i,ybar)[w*Psi(x,ybar)-w*Psi(x,y)]}=%f\n",valmax);
else
printf("\n max_ybar {loss(y_i,ybar)+w*Psi(x,ybar)}=%f\n",valmax);
SVECTOR *fy=psi(x,y,sm,sparm);
SVECTOR *fybar=psi(x,ybar,sm,sparm);
DOC *exy=create_example(0,0,1,1,fy);
DOC *exybar=create_example(0,0,1,1,fybar);
printf(" -> w*Psi(x,y_i)=%f, w*Psi(x,ybar)=%f\n",
classify_example(sm->svm_model,exy),
classify_example(sm->svm_model,exybar));
free_example(exy,1);
free_example(exybar,1);
}
free(score);
free(scorep);
free(scoren);
free(sump);
free(sumn);
return(ybar);
}
LABEL find_most_violated_constraint_rankmetric(PATTERN x, LABEL y,
STRUCTMODEL *sm,
STRUCT_LEARN_PARM *sparm,
int loss_type)
{
/* Finds the most violated constraint for metrics that are based on
a threshold.
WARNING: Currently only implemented for margin-rescaling!!! */
LABEL ybar;
long i,nump,numn,sump,sumn;
double *score,*ortho_weights;
STRUCT_ID_SCORE *scorep,*scoren,*predset;
int totwords;
MODEL svm_model;
ybar.totdoc=x.totdoc;
ybar.class=(double *)my_malloc(sizeof(double)*x.totdoc);
score=(double *)my_malloc(sizeof(double)*(ybar.totdoc+1));
scorep=(STRUCT_ID_SCORE *)my_malloc(sizeof(STRUCT_ID_SCORE)*(ybar.totdoc+1));
scoren=(STRUCT_ID_SCORE *)my_malloc(sizeof(STRUCT_ID_SCORE)*(ybar.totdoc+1));
predset=(STRUCT_ID_SCORE *)my_malloc(sizeof(STRUCT_ID_SCORE)*(ybar.totdoc+1));
totwords=sm->svm_model->totwords;
svm_model=(*sm->svm_model);
/* For sparse kernel, replace weight vector with beta=gamma^T*L^-1 */
if(sm->sparse_kernel_type>0) {
svm_model.lin_weights=(double *)my_malloc(sizeof(double)*(totwords+1));
ortho_weights=prod_nvector_ltmatrix(sm->svm_model->lin_weights+1,sm->invL);
for(i=0;i<sm->invL->m;i++)
svm_model.lin_weights[i+1]=ortho_weights[i];
svm_model.lin_weights[0]=0;
free_nvector(ortho_weights);
}
nump=0;
numn=0;
for(i=0;i<x.totdoc;i++) {
score[i]=classify_example(&svm_model,x.doc[i]);
if(y.class[i] > 0) {
scorep[nump].score=score[i];
scorep[nump].tiebreak=0;
scorep[nump].id=i;
nump++;
}
else {
scoren[numn].score=score[i];
scorep[numn].tiebreak=0;
scoren[numn].id=i;
numn++;
}
}
if(nump)
qsort(scorep,nump,sizeof(STRUCT_ID_SCORE),comparedown);
if(numn)
qsort(scoren,numn,sizeof(STRUCT_ID_SCORE),comparedown);
/* Restore weight vector that was modified above */
if(sm->sparse_kernel_type>0) {
free(svm_model.lin_weights);
}
/* find max of loss(ybar,y)+score(ybar) */
if(sparm->loss_function == SWAPPEDPAIRS) { /* number of swapped pairs (ie. ROC Area) */
for(i=0;i<nump;i++) {
predset[i]=scorep[i];
predset[i].score-=(0.5);
}
for(i=0;i<numn;i++) {
predset[nump+i]=scoren[i];
predset[nump+i].score+=(0.5);
}
qsort(predset,nump+numn,sizeof(STRUCT_ID_SCORE),comparedown);
sump=0;
sumn=0;
for(i=0;i<numn+nump;i++) {
if(y.class[predset[i].id] > 0) {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -