📄 svm_learn.cpp
字号:
//////////////////////////////////////////////////////////////////////
#include "stdafx.h"
#include "svm.h"
#include "svm_learn.h"
#include "svm_hideo.h"
#ifdef _DEBUG
#undef THIS_FILE
static char THIS_FILE[]=__FILE__;
#define new DEBUG_NEW
#endif
//////////////////////////////////////////////////////////////////////
// Construction/Destruction
//////////////////////////////////////////////////////////////////////
/* interface to QP-solver */
double *optimize_qp(QP *, double *, long, double *, LEARN_PARM *);
/*---------------------------------------------------------------------------*/
/* Learns an SVM model based on the training data in docs/label. The resulting
model is returned in the structure model. */
void svm_learn(
DOC *docs,
long *label,
long totdoc,
long totwords,
LEARN_PARM *learn_parm,
KERNEL_PARM *kernel_parm,
KERNEL_CACHE *kernel_cache,
MODEL *model
)
{
long *inconsistent,i;
long inconsistentnum;
long misclassified,upsupvecnum;
double loss,model_length,example_length;
double maxdiff,*lin,*a;
long runtime_start,runtime_end;
long iterations;
long *unlabeled,transduction;
long heldout;
long loo_count=0,loo_count_pos=0,loo_count_neg=0,trainpos=0,trainneg=0;
long loocomputed=0,runtime_start_loo=0,runtime_start_xa=0;
double heldout_c=0,r_delta_sq=0,r_delta,r_delta_avg;
double *xi_fullset; /* buffer for storing xi on full sample in loo */
double *a_fullset; /* buffer for storing alpha on full sample in loo */
TIMING timing_profile;
SHRINK_STATE shrink_state;
runtime_start=get_runtime();
timing_profile.time_kernel=0;
timing_profile.time_opti=0;
timing_profile.time_shrink=0;
timing_profile.time_update=0;
timing_profile.time_model=0;
timing_profile.time_check=0;
timing_profile.time_select=0;
com_result.kernel_cache_statistic=0;
learn_parm->totwords=totwords;
/* make sure -n value is reasonable */
if((learn_parm->svm_newvarsinqp < 2) || (learn_parm->svm_newvarsinqp > learn_parm->svm_maxqpsize))
{
learn_parm->svm_newvarsinqp=learn_parm->svm_maxqpsize;
}
init_shrink_state(&shrink_state,totdoc,(long)10000);
inconsistent = (long *)my_malloc(sizeof(long)*totdoc);
unlabeled = (long *)my_malloc(sizeof(long)*totdoc);
a = (double *)my_malloc(sizeof(double)*totdoc);
a_fullset = (double *)my_malloc(sizeof(double)*totdoc);
xi_fullset = (double *)my_malloc(sizeof(double)*totdoc);
lin = (double *)my_malloc(sizeof(double)*totdoc);
learn_parm->svm_cost = (double *)my_malloc(sizeof(double)*totdoc);
model->supvec = (DOC **)my_malloc(sizeof(DOC *)*(totdoc+2));
model->alpha = (double *)my_malloc(sizeof(double)*(totdoc+2));
model->index = (long *)my_malloc(sizeof(long)*(totdoc+2));
model->at_upper_bound=0;
model->b=0;
model->supvec[0]=0; /* element 0 reserved and empty for now */
model->alpha[0]=0;
model->lin_weights=NULL;
model->totwords=totwords;
model->totdoc=totdoc;
model->kernel_parm=(*kernel_parm);
model->sv_num=1;
model->loo_error=-1;
model->loo_recall=-1;
model->loo_precision=-1;
model->xa_error=-1;
model->xa_recall=-1;
model->xa_precision=-1;
inconsistentnum=0;
transduction=0;
r_delta=estimate_r_delta(docs,totdoc,kernel_parm);
r_delta_sq=r_delta*r_delta;
r_delta_avg=estimate_r_delta_average(docs,totdoc,kernel_parm);
if(learn_parm->svm_c == 0.0)
{ /* default value for C */
learn_parm->svm_c=1.0/(r_delta_avg*r_delta_avg);
if (com_pro.show_compute_1)
{
sprintf(temstr,"Setting default regularization parameter C=%.4f\n",learn_parm->svm_c);
printm(temstr);
}
}
for(i=0;i<totdoc;i++)
{ /* various inits */
inconsistent[i]=0;
a[i]=0;
lin[i]=0;
unlabeled[i]=0;
if(label[i] == 0)
{
unlabeled[i]=1;
transduction=1;
}
if(label[i] > 0)
{
learn_parm->svm_cost[i]=learn_parm->svm_c*learn_parm->svm_costratio*
fabs((double)label[i]);
label[i]=1;
trainpos++;
}
else if(label[i] < 0)
{
learn_parm->svm_cost[i]=learn_parm->svm_c*fabs((double)label[i]);
label[i]=-1;
trainneg++;
}
else
{
learn_parm->svm_cost[i]=0;
}
}
/* caching makes no sense for linear kernel */
if(kernel_parm->kernel_type == LINEAR)
{
kernel_cache = NULL;
}
if(transduction)
{
learn_parm->svm_iter_to_shrink=99999999;
sprintf(temstr,"\nDeactivating Shrinking due to an incompatibility with the transductive \nlearner in the current version.\n\n");
printm(temstr);
}
if(transduction && learn_parm->compute_loo)
{
learn_parm->compute_loo=0;
sprintf(temstr,"\nCannot compute leave-one-out estimates for transductive learner.\n\n");
printm(temstr);
}
if(learn_parm->remove_inconsistent && learn_parm->compute_loo)
{
learn_parm->compute_loo=0;
sprintf(temstr,"\nCannot compute leave-one-out estimates when removing inconsistent examples.\n\n");
printm(temstr);
}
if((trainpos == 1) || (trainneg == 1))
{
learn_parm->compute_loo=0;
sprintf(temstr,"\nCannot compute leave-one-out with only one example in one class.\n\n");
printm(temstr);
}
if (com_pro.show_action)
{
sprintf(temstr,"Optimizing...");
printm(temstr);
}
/* train the svm */
iterations=optimize_to_convergence(docs,label,totdoc,totwords,learn_parm,
kernel_parm,kernel_cache,&shrink_state,model,inconsistent,unlabeled,a,lin,&timing_profile, &maxdiff,(long)-1,(long)1);
if (com_pro.show_action)
{
sprintf(temstr,"done. (%ld iterations) ",iterations);
printm(temstr);
}
misclassified=0;
for(i=0;(i<totdoc);i++)
{ /* get final statistic */
if((lin[i]-model->b)*(double)label[i] <= 0.0)
misclassified++;
}
if (com_pro.show_action)
{
printm("optimization finished");
}
if (com_pro.show_trainresult)
{
sprintf(temstr," (%ld misclassified, maxdiff=%.5f).\n", misclassified,maxdiff);
printm(temstr);
}
com_result.train_misclassify=misclassified;
com_result.max_difference=maxdiff;
runtime_end=get_runtime();
if (learn_parm->remove_inconsistent)
{
inconsistentnum=0;
for(i=0;i<totdoc;i++)
if(inconsistent[i])
inconsistentnum++;
sprintf(temstr,"Number of SV: %ld (plus %ld inconsistent examples)\n", model->sv_num-1,inconsistentnum);
printm(temstr);
}
else
{
upsupvecnum=0;
for(i=1;i<model->sv_num;i++)
{
if(fabs(model->alpha[i]) >= (learn_parm->svm_cost[(model->supvec[i])->docnum]-learn_parm->epsilon_a))
upsupvecnum++;
}
if (com_pro.show_trainresult)
{
sprintf(temstr,"Number of SV: %ld (including %ld at upper bound)\n", model->sv_num-1,upsupvecnum);
printm(temstr);
}
}
if( (!learn_parm->skip_final_opt_check))
{
loss=0;
model_length=0;
for(i=0;i<totdoc;i++)
{
if((lin[i]-model->b)*(double)label[i] < 1.0-learn_parm->epsilon_crit)
loss+=1.0-(lin[i]-model->b)*(double)label[i];
model_length+=a[i]*label[i]*lin[i];
}
model_length=sqrt(model_length);
sprintf(temstr,"L1 loss: loss=%.5f\n",loss); printm(temstr);
sprintf(temstr,"Norm of weight vector: |w|=%.5f\n",model_length);printm(temstr);
example_length=estimate_sphere(model,kernel_parm);
sprintf(temstr,"Norm of longest example vector: |x|=%.5f\n", length_of_longest_document_vector(docs,totdoc,kernel_parm));
printm(temstr);
sprintf(temstr,"Estimated VCdim of classifier: VCdim<=%.5f\n", estimate_margin_vcdim(model,model_length,example_length, kernel_parm));
printm(temstr);
if((!learn_parm->remove_inconsistent) && (!transduction))
{
runtime_start_xa=get_runtime();
sprintf(temstr,"Computing XiAlpha-estimates...");
printm(temstr);
compute_xa_estimates(model,label,unlabeled,totdoc,docs,lin,a,
kernel_parm,learn_parm,&(model->xa_error),
&(model->xa_recall),&(model->xa_precision));
sprintf(temstr,"Runtime for XiAlpha-estimates in cpu-seconds: %.2f\n",
(get_runtime()-runtime_start_xa)/100.0);
printm(temstr);
fprintf(stdout,"XiAlpha-estimate of the error: error<=%.2f%% (rho=%.2f,depth=%ld)\n",
model->xa_error,learn_parm->rho,learn_parm->xa_depth);
fprintf(stdout,"XiAlpha-estimate of the recall: recall=>%.2f%% (rho=%.2f,depth=%ld)\n",
model->xa_recall,learn_parm->rho,learn_parm->xa_depth);
fprintf(stdout,"XiAlpha-estimate of the precision: precision=>%.2f%% (rho=%.2f,depth=%ld)\n",
model->xa_precision,learn_parm->rho,learn_parm->xa_depth);
}
else if(!learn_parm->remove_inconsistent)
{
estimate_transduction_quality(model,label,unlabeled,totdoc,docs,lin);
}
}
if (com_pro.show_trainresult)
{
sprintf(temstr,"Number of kernel evaluations: %ld\n",com_result.kernel_cache_statistic);
printm(temstr);
}
/* leave-one-out testing starts now */
if(learn_parm->compute_loo)
{
/* save results of training on full dataset for leave-one-out */
runtime_start_loo=get_runtime();
for(i=0;i<totdoc;i++)
{
xi_fullset[i]=1.0-((lin[i]-model->b)*(double)label[i]);
a_fullset[i]=a[i];
}
sprintf(temstr,"Computing leave-one-out");
printm(temstr);
/* repeat this loop for every held-out example */
for(heldout=0;(heldout<totdoc);heldout++)
{
if(learn_parm->rho*a_fullset[heldout]*r_delta_sq+xi_fullset[heldout]
< 1.0)
{
/* guaranteed to not produce a leave-one-out error */
sprintf(temstr,"+");
printm(temstr);
}
else if(xi_fullset[heldout] > 1.0)
{
/* guaranteed to produce a leave-one-out error */
loo_count++;
if(label[heldout] > 0) loo_count_pos++; else loo_count_neg++;
sprintf(temstr,"-"); printm(temstr);
}
else
{
loocomputed++;
heldout_c=learn_parm->svm_cost[heldout]; /* set upper bound to zero */
learn_parm->svm_cost[heldout]=0;
/* make sure heldout example is not currently */
/* shrunk away. Assumes that lin is up to date! */
shrink_state.active[heldout]=1;
optimize_to_convergence(docs,label,totdoc,totwords,learn_parm,
kernel_parm,
kernel_cache,&shrink_state,model,inconsistent,unlabeled,
a,lin,&timing_profile,
&maxdiff,heldout,(long)2);
/* printf("%f\n",(lin[heldout]-model->b)*(double)label[heldout]); */
if(((lin[heldout]-model->b)*(double)label[heldout]) < 0.0)
{
loo_count++; /* there was a loo-error */
if(label[heldout] > 0) loo_count_pos++; else loo_count_neg++;
}
else
{
}
/* now we need to restore the original data set*/
learn_parm->svm_cost[heldout]=heldout_c; /* restore upper bound */
}
} /* end of leave-one-out loop */
sprintf(temstr,"\nRetrain on full problem"); printm(temstr);
optimize_to_convergence(docs,label,totdoc,totwords,learn_parm,
kernel_parm,
kernel_cache,&shrink_state,model,inconsistent,unlabeled,
a,lin,&timing_profile,
&maxdiff,(long)-1,(long)1);
/* after all leave-one-out computed */
model->loo_error=100.0*loo_count/(double)totdoc;
model->loo_recall=(1.0-(double)loo_count_pos/(double)trainpos)*100.0;
model->loo_precision=(trainpos-loo_count_pos)/
(double)(trainpos-loo_count_pos+loo_count_neg)*100.0;
fprintf(stdout,"Leave-one-out estimate of the error: error=%.2f%%\n",
model->loo_error);
fprintf(stdout,"Leave-one-out estimate of the recall: recall=%.2f%%\n",
model->loo_recall);
fprintf(stdout,"Leave-one-out estimate of the precision: precision=%.2f%%\n",
model->loo_precision);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -