📄 svm_learn.cpp
字号:
fprintf(stdout,"Actual leave-one-outs computed: %ld (rho=%.2f)\n",
loocomputed,learn_parm->rho);
sprintf(temstr,"Runtime for leave-one-out in cpu-seconds: %.2f\n",
(double)(get_runtime()-runtime_start_loo)/100.0);
printm(temstr);
}
// if(learn_parm->alphafile[0])
// write_alphas(learn_parm->alphafile,a,label,totdoc);
shrink_state_cleanup(&shrink_state);
free(inconsistent);
free(unlabeled);
free(a);
free(a_fullset);
free(xi_fullset);
free(lin);
free(learn_parm->svm_cost);
}
long optimize_to_convergence(
DOC *docs,
long *label,
long totdoc,
long totwords,
LEARN_PARM *learn_parm,
KERNEL_PARM *kernel_parm,
KERNEL_CACHE *kernel_cache,
SHRINK_STATE *shrink_state,
MODEL *model,
long *inconsistent,
long *unlabeled,
double *a,
double *lin,
TIMING *timing_profile,
double *maxdiff,
long heldout,
long retrain)
{
long *chosen,*key,i,j,jj,*last_suboptimal_at,noshrink;
long inconsistentnum,choosenum,already_chosen=0,iteration;
long misclassified,supvecnum=0,*active2dnum,inactivenum;
long *working2dnum,*selexam;
long activenum;
double criterion,eq;
double *a_old;
long t0=0,t1=0,t2=0,t3=0,t4=0,t5=0,t6=0; /* timing */
long transductcycle;
long transduction;
double epsilon_crit_org;
double *selcrit; /* buffer for sorting */
CFLOAT *aicache; /* buffer to keep one row of hessian */
double *weights; /* buffer for weight vector in linear case */
QP qp; /* buffer for one quadratic program */
epsilon_crit_org=learn_parm->epsilon_crit; /* save org */
if(kernel_parm->kernel_type == LINEAR)
{
learn_parm->epsilon_crit=2.0;
kernel_cache=NULL; /* caching makes no sense for linear kernel */
}
learn_parm->epsilon_shrink=2;
(*maxdiff)=1;
learn_parm->totwords=totwords;
chosen = (long *)my_malloc(sizeof(long)*totdoc);
last_suboptimal_at = (long *)my_malloc(sizeof(long)*totdoc);
key = (long *)my_malloc(sizeof(long)*(totdoc+11));
selcrit = (double *)my_malloc(sizeof(double)*totdoc);
selexam = (long *)my_malloc(sizeof(long)*totdoc);
a_old = (double *)my_malloc(sizeof(double)*totdoc);
aicache = (CFLOAT *)my_malloc(sizeof(CFLOAT)*totdoc);
working2dnum = (long *)my_malloc(sizeof(long)*(totdoc+11));
active2dnum = (long *)my_malloc(sizeof(long)*(totdoc+11));
qp.opt_ce = (double *)my_malloc(sizeof(double)*learn_parm->svm_maxqpsize);
qp.opt_ce0 = (double *)my_malloc(sizeof(double));
qp.opt_g = (double *)my_malloc(sizeof(double)*learn_parm->svm_maxqpsize
*learn_parm->svm_maxqpsize);
qp.opt_g0 = (double *)my_malloc(sizeof(double)*learn_parm->svm_maxqpsize);
qp.opt_xinit = (double *)my_malloc(sizeof(double)*learn_parm->svm_maxqpsize);
qp.opt_low=(double *)my_malloc(sizeof(double)*learn_parm->svm_maxqpsize);
qp.opt_up=(double *)my_malloc(sizeof(double)*learn_parm->svm_maxqpsize);
weights=(double *)my_malloc(sizeof(double)*(totwords+1));
choosenum=0;
inconsistentnum=0;
transductcycle=0;
transduction=0;
if(!retrain) retrain=1;
iteration=1;
if(kernel_cache)
{
kernel_cache->time=iteration; /* for lru cache */
kernel_cache_reset_lru(kernel_cache);
}
for(i=0;i<totdoc;i++)
{ /* various inits */
chosen[i]=0;
a_old[i]=a[i];
last_suboptimal_at[i]=1;
if(inconsistent[i])
inconsistentnum++;
if(unlabeled[i])
{
transduction=1;
}
}
activenum=compute_index(shrink_state->active,totdoc,active2dnum);
inactivenum=totdoc-activenum;
clear_index(working2dnum);
/* repeat this loop until we have convergence */
for(;retrain&&com_param.Running;iteration++)
{
if(kernel_cache)
kernel_cache->time=iteration; /* for lru cache */
i=0;
for(jj=0;(j=working2dnum[jj])>=0;jj++)
{ /* clear working set */
if((chosen[j]>=(learn_parm->svm_maxqpsize/
minl(learn_parm->svm_maxqpsize,
learn_parm->svm_newvarsinqp)))
|| (inconsistent[j])
|| (j == heldout))
{
chosen[j]=0;
choosenum--;
}
else
{
chosen[j]++;
working2dnum[i++]=j;
}
}
working2dnum[i]=-1;
if(retrain == 2)
{
choosenum=0;
for(jj=0;(j=working2dnum[jj])>=0;jj++)
{ /* fully clear working set */
chosen[j]=0;
}
clear_index(working2dnum);
for(i=0;i<totdoc;i++)
{ /* set inconsistent examples to zero (-i 1) */
if((inconsistent[i] || (heldout==i)) && (a[i] != 0.0))
{
chosen[i]=99999;
choosenum++;
a[i]=0;
}
}
if(learn_parm->biased_hyperplane)
{
eq=0;
for(i=0;i<totdoc;i++)
{ /* make sure we fulfill equality constraint */
eq+=a[i]*label[i];
}
for(i=0;(i<totdoc) && (fabs(eq) > learn_parm->epsilon_a);i++)
{
if((eq*label[i] > 0) && (a[i] > 0))
{
chosen[i]=88888;
choosenum++;
if((eq*label[i]) > a[i])
{
eq-=(a[i]*label[i]);
a[i]=0;
}
else
{
a[i]-=(eq*label[i]);
eq=0;
}
}
}
}
compute_index(chosen,totdoc,working2dnum);
}
else
{ /* select working set according to steepest gradient */
if((minl(learn_parm->svm_newvarsinqp,learn_parm->svm_maxqpsize)>=4)
&& (kernel_parm->kernel_type != LINEAR))
{
/* select part of the working set from cache */
already_chosen=select_next_qp_subproblem_grad_cache(
label,unlabeled,a,lin,totdoc,
minl((long)(learn_parm->svm_maxqpsize-choosenum),
(long)(learn_parm->svm_newvarsinqp/2)),
learn_parm,inconsistent,active2dnum,
working2dnum,selcrit,selexam,kernel_cache,
key,chosen);
choosenum+=already_chosen;
}
choosenum+=select_next_qp_subproblem_grad(label,unlabeled,a,lin,totdoc,
minl((long)(learn_parm->svm_maxqpsize-choosenum),
(long)(learn_parm->svm_newvarsinqp-already_chosen)),
learn_parm,inconsistent,active2dnum,
working2dnum,selcrit,selexam,kernel_cache,key,
chosen);
}
// sprintf(temstr," %ld vectors chosen\n",choosenum); printm(temstr);
t1=get_runtime();
if(kernel_cache)
cache_multiple_kernel_rows(kernel_cache,docs,working2dnum,choosenum,kernel_parm);
t2=get_runtime();
if(retrain != 2)
{
optimize_svm(docs,label,unlabeled,chosen,active2dnum,model,totdoc,
working2dnum,choosenum,a,lin,learn_parm,aicache,
kernel_parm,&qp,&epsilon_crit_org);
}
t3=get_runtime();
update_linear_component(docs,label,active2dnum,a,a_old,working2dnum,totdoc,
totwords,kernel_parm,kernel_cache,lin,aicache,weights);
t4=get_runtime();
supvecnum=calculate_svm_model(docs,label,unlabeled,lin,a,a_old,learn_parm,
working2dnum,model);
t5=get_runtime();
/* The following computation of the objective function works only */
/* relative to the active variables */
criterion=compute_objective_function(a,lin,label,active2dnum);
// sprintf(temstr,"Objective function (over active variables): %.16f\n",criterion);
// printm(temstr);
for(jj=0;(i=working2dnum[jj])>=0;jj++)
{
a_old[i]=a[i];
}
if(retrain == 2)
{ /* reset inconsistent unlabeled examples */
for(i=0;(i<totdoc);i++)
{
if(inconsistent[i] && unlabeled[i])
{
inconsistent[i]=0;
label[i]=0;
}
}
}
retrain=check_optimality(model,label,unlabeled,a,lin,totdoc,learn_parm,
maxdiff,epsilon_crit_org,&misclassified,
inconsistent,active2dnum,last_suboptimal_at,
iteration,kernel_parm);
t6=get_runtime();
timing_profile->time_select+=t1-t0;
timing_profile->time_kernel+=t2-t1;
timing_profile->time_opti+=t3-t2;
timing_profile->time_update+=t4-t3;
timing_profile->time_model+=t5-t4;
timing_profile->time_check+=t6-t5;
noshrink=0;
if((!retrain) && (inactivenum>0)
&& ((!learn_parm->skip_final_opt_check)
|| (kernel_parm->kernel_type == LINEAR)))
{
if (com_pro.show_other)
{
sprintf(temstr," Checking optimality of inactive variables...");
printm(temstr);
}
t1=get_runtime();
reactivate_inactive_examples(label,unlabeled,a,shrink_state,lin,totdoc,
totwords,iteration,learn_parm,inconsistent,
docs,kernel_parm,kernel_cache,model,aicache,
weights,maxdiff);
/* Update to new active variables. */
activenum=compute_index(shrink_state->active,totdoc,active2dnum);
inactivenum=totdoc-activenum;
/* termination criterion */
noshrink=1;
retrain=0;
if((*maxdiff) > learn_parm->epsilon_crit)
retrain=1;
timing_profile->time_shrink+=get_runtime()-t1;
}
if((!retrain) && (learn_parm->epsilon_crit>(*maxdiff)))
learn_parm->epsilon_crit=(*maxdiff);
if((!retrain) && (learn_parm->epsilon_crit>epsilon_crit_org))
{
learn_parm->epsilon_crit/=2.0;
retrain=1;
noshrink=1;
}
if(learn_parm->epsilon_crit<epsilon_crit_org)
learn_parm->epsilon_crit=epsilon_crit_org;
{
/// sprintf(temstr," => (%ld SV (incl. %ld SV at u-bound), max violation=%.5f)\n",
// supvecnum,model->at_upper_bound,(*maxdiff));
// printm(temstr);
}
if((!retrain) && (transduction))
{
for(i=0;(i<totdoc);i++)
{
shrink_state->active[i]=1;
}
activenum=compute_index(shrink_state->active,totdoc,active2dnum);
inactivenum=0;
retrain=incorporate_unlabeled_examples(model,label,inconsistent,
unlabeled,a,lin,totdoc,
selcrit,selexam,key,
transductcycle,kernel_parm,
learn_parm);
epsilon_crit_org=learn_parm->epsilon_crit;
if(kernel_parm->kernel_type == LINEAR)
learn_parm->epsilon_crit=1;
transductcycle++;
}
else if(((iteration % 10) == 0) && (!noshrink))
{
activenum=shrink_problem(learn_parm,shrink_state,active2dnum,iteration,last_suboptimal_at,
totdoc,maxl((long)(activenum/10),100),
a,inconsistent);
inactivenum=totdoc-activenum;
if((kernel_cache)
&& (supvecnum>kernel_cache->max_elems)
&& ((kernel_cache->activenum-activenum)>maxl((long)(activenum/10),500)))
{
kernel_cache_shrink(kernel_cache,totdoc,maxl((long)(activenum/10),500),
shrink_state->active);
}
}
if((!retrain) && learn_parm->remove_inconsistent)
{
sprintf(temstr," Moving training errors to inconsistent examples...");
printm(temstr);
if(learn_parm->remove_inconsistent == 1)
{
retrain=identify_inconsistent(a,label,unlabeled,totdoc,learn_parm,
&inconsistentnum,inconsistent);
}
else if(learn_parm->remove_inconsistent == 2)
{
retrain=identify_misclassified(lin,label,unlabeled,totdoc,
model,&inconsistentnum,inconsistent);
}
else if(learn_parm->remove_inconsistent == 3)
{
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -