⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 svm.cpp

📁 良好的代码实现
💻 CPP
📖 第 1 页 / 共 5 页
字号:
        {
            learn_parm->svm_cost[i]=0;
        }
    }
    
    /* caching makes no sense for linear kernel */
    if(kernel_parm->kernel_type == LINEAR)
    {
        kernel_cache = NULL;   
    } 
    
    if(transduction) 
    {
        learn_parm->svm_iter_to_shrink=99999999;
        sprintf(temstr,"\nDeactivating Shrinking due to an incompatibility with the transductive \nlearner in the current version.\n\n");
        printm(temstr);
    }
    
    if(transduction && learn_parm->compute_loo) 
    {
        learn_parm->compute_loo=0;
        sprintf(temstr,"\nCannot compute leave-one-out estimates for transductive learner.\n\n");
        printm(temstr);
    }
    
    if(learn_parm->remove_inconsistent && learn_parm->compute_loo) 
    {
        learn_parm->compute_loo=0;
        sprintf(temstr,"\nCannot compute leave-one-out estimates when removing inconsistent examples.\n\n");
        printm(temstr);
    }    
    
    if((trainpos == 1) || (trainneg == 1)) 
    {
        learn_parm->compute_loo=0;
        sprintf(temstr,"\nCannot compute leave-one-out with only one example in one class.\n\n");
        printm(temstr);
    }    
    
    if (com_pro.show_action)
	{
		sprintf(temstr,"Optimizing..."); 
		printm(temstr);
	}
    
    /* train the svm */
    iterations=optimize_to_convergence(docs,label,totdoc,totwords,learn_parm,
        kernel_parm,kernel_cache,&shrink_state,model,inconsistent,unlabeled,a,lin,&timing_profile,  &maxdiff,(long)-1,(long)1);
    if (com_pro.show_action)
	{
		sprintf(temstr,"done. (%ld iterations) ",iterations);
		printm(temstr);
	}
    
    misclassified=0;
    for(i=0;(i<totdoc);i++)
    { /* get final statistic */
        if((lin[i]-model->b)*(double)label[i] <= 0.0) 
            misclassified++;
    }
	if (com_pro.show_action)
	{
		printm("optimization finished");
	}
	if (com_pro.show_trainresult)
	{
		sprintf(temstr," (%ld misclassified, maxdiff=%.5f).\n", misclassified,maxdiff); 
		printm(temstr);
	}
	com_result.train_misclassify=misclassified;
	com_result.max_difference=maxdiff;
             
    runtime_end=get_runtime();
             
    if (learn_parm->remove_inconsistent)
    {     
        inconsistentnum=0;
        for(i=0;i<totdoc;i++) 
            if(inconsistent[i]) 
               inconsistentnum++;
        sprintf(temstr,"Number of SV: %ld (plus %ld inconsistent examples)\n", model->sv_num-1,inconsistentnum);
        printm(temstr);
    }
    
    else
    {
     upsupvecnum=0;
     for(i=1;i<model->sv_num;i++) 
     {
         if(fabs(model->alpha[i]) >= (learn_parm->svm_cost[(model->supvec[i])->docnum]-learn_parm->epsilon_a)) 
             upsupvecnum++;
     }
	 if (com_pro.show_trainresult)
	 {
	 sprintf(temstr,"Number of SV: %ld (including %ld at upper bound)\n", model->sv_num-1,upsupvecnum);
     printm(temstr);
	 }
    }
	
	if( (!learn_parm->skip_final_opt_check)) 
	{
		loss=0;
		model_length=0; 
		for(i=0;i<totdoc;i++)
		{
			if((lin[i]-model->b)*(double)label[i] < 1.0-learn_parm->epsilon_crit)
				loss+=1.0-(lin[i]-model->b)*(double)label[i];
			model_length+=a[i]*label[i]*lin[i];
		}
		model_length=sqrt(model_length);
		sprintf(temstr,"L1 loss: loss=%.5f\n",loss);   printm(temstr);
		sprintf(temstr,"Norm of weight vector: |w|=%.5f\n",model_length);printm(temstr);
		example_length=estimate_sphere(model,kernel_parm); 
		sprintf(temstr,"Norm of longest example vector: |x|=%.5f\n",  length_of_longest_document_vector(docs,totdoc,kernel_parm));
		printm(temstr);
		sprintf(temstr,"Estimated VCdim of classifier: VCdim<=%.5f\n",       estimate_margin_vcdim(model,model_length,example_length,    kernel_parm));
		printm(temstr);
		if((!learn_parm->remove_inconsistent) && (!transduction)) 
		{
			runtime_start_xa=get_runtime();
                     sprintf(temstr,"Computing XiAlpha-estimates..."); 
                     printm(temstr);
                     compute_xa_estimates(model,label,unlabeled,totdoc,docs,lin,a,
                         kernel_parm,learn_parm,&(model->xa_error),
                         &(model->xa_recall),&(model->xa_precision));
                     
                     
                     sprintf(temstr,"Runtime for XiAlpha-estimates in cpu-seconds: %.2f\n",
                         (get_runtime()-runtime_start_xa)/100.0);
                     printm(temstr);
                     
                     fprintf(stdout,"XiAlpha-estimate of the error: error<=%.2f%% (rho=%.2f,depth=%ld)\n",
                         model->xa_error,learn_parm->rho,learn_parm->xa_depth);
                     fprintf(stdout,"XiAlpha-estimate of the recall: recall=>%.2f%% (rho=%.2f,depth=%ld)\n",
                         model->xa_recall,learn_parm->rho,learn_parm->xa_depth);
                     fprintf(stdout,"XiAlpha-estimate of the precision: precision=>%.2f%% (rho=%.2f,depth=%ld)\n",
                         model->xa_precision,learn_parm->rho,learn_parm->xa_depth);
                 }
                 else if(!learn_parm->remove_inconsistent)
                 {
                     estimate_transduction_quality(model,label,unlabeled,totdoc,docs,lin);
                 }
             }
	if (com_pro.show_trainresult)
	{
		sprintf(temstr,"Number of kernel evaluations: %ld\n",com_result.kernel_cache_statistic);
        printm(temstr);
	}
             /* leave-one-out testing starts now */
             if(learn_parm->compute_loo)
             {
                 /* save results of training on full dataset for leave-one-out */
                 runtime_start_loo=get_runtime();
                 for(i=0;i<totdoc;i++) 
                 {
                     xi_fullset[i]=1.0-((lin[i]-model->b)*(double)label[i]);
                     a_fullset[i]=a[i];
                 }
                 sprintf(temstr,"Computing leave-one-out");
                 printm(temstr);
                 
                 /* repeat this loop for every held-out example */
                 for(heldout=0;(heldout<totdoc);heldout++)
                 {
                     if(learn_parm->rho*a_fullset[heldout]*r_delta_sq+xi_fullset[heldout]
                         < 1.0) 
                     { 
                         /* guaranteed to not produce a leave-one-out error */
                         sprintf(temstr,"+"); 
                         printm(temstr);
                     }
                     else if(xi_fullset[heldout] > 1.0) 
                     {
                         /* guaranteed to produce a leave-one-out error */
                         loo_count++;
                         if(label[heldout] > 0) loo_count_pos++; else loo_count_neg++;
                         sprintf(temstr,"-");  printm(temstr);
                     }
                     else
                     {
                         loocomputed++;
                         heldout_c=learn_parm->svm_cost[heldout]; /* set upper bound to zero */
                         learn_parm->svm_cost[heldout]=0;
                         /* make sure heldout example is not currently  */
                         /* shrunk away. Assumes that lin is up to date! */
                         shrink_state.active[heldout]=1;  
                         
                         
                         optimize_to_convergence(docs,label,totdoc,totwords,learn_parm,
                             kernel_parm,
                             kernel_cache,&shrink_state,model,inconsistent,unlabeled,
                             a,lin,&timing_profile,
                             &maxdiff,heldout,(long)2);
                         
                         /* printf("%f\n",(lin[heldout]-model->b)*(double)label[heldout]); */
                         
                         if(((lin[heldout]-model->b)*(double)label[heldout]) < 0.0)
                         { 
                             loo_count++;                           /* there was a loo-error */
                             if(label[heldout] > 0) loo_count_pos++; else loo_count_neg++;
                         }
                         else
                         {
                             
                         }
                         /* now we need to restore the original data set*/
                         learn_parm->svm_cost[heldout]=heldout_c; /* restore upper bound */
                     }
                 } /* end of leave-one-out loop */
                 
                 
                 sprintf(temstr,"\nRetrain on full problem");  printm(temstr);
                 optimize_to_convergence(docs,label,totdoc,totwords,learn_parm,
                     kernel_parm,
                     kernel_cache,&shrink_state,model,inconsistent,unlabeled,
                     a,lin,&timing_profile,
                     &maxdiff,(long)-1,(long)1);
                 
                 
                 /* after all leave-one-out computed */
                 model->loo_error=100.0*loo_count/(double)totdoc;
                 model->loo_recall=(1.0-(double)loo_count_pos/(double)trainpos)*100.0;
                 model->loo_precision=(trainpos-loo_count_pos)/
                     (double)(trainpos-loo_count_pos+loo_count_neg)*100.0;
                 fprintf(stdout,"Leave-one-out estimate of the error: error=%.2f%%\n",
                     model->loo_error);
                 fprintf(stdout,"Leave-one-out estimate of the recall: recall=%.2f%%\n",
                     model->loo_recall);
                 fprintf(stdout,"Leave-one-out estimate of the precision: precision=%.2f%%\n",
                     model->loo_precision);
                 fprintf(stdout,"Actual leave-one-outs computed:  %ld (rho=%.2f)\n",
                     loocomputed,learn_parm->rho);
                 sprintf(temstr,"Runtime for leave-one-out in cpu-seconds: %.2f\n",
                     (double)(get_runtime()-runtime_start_loo)/100.0);
                 printm(temstr);
             }
             
             // if(learn_parm->alphafile[0])
             //     write_alphas(learn_parm->alphafile,a,label,totdoc);
             
             shrink_state_cleanup(&shrink_state);
             
             free(inconsistent);
             free(unlabeled);
             free(a);
             free(a_fullset);
             free(xi_fullset);
             free(lin);
             free(learn_parm->svm_cost);
}



long CSVM::optimize_to_convergence(
                             DOC *docs,
                             long *label,
                             long totdoc,
                             long totwords,           
                             LEARN_PARM *learn_parm,    
                             KERNEL_PARM *kernel_parm,  
                             KERNEL_CACHE *kernel_cache,
                             SHRINK_STATE *shrink_state,
                             MODEL *model,              
                             long *inconsistent,
                             long *unlabeled,
                             double *a,
                             double *lin,
                             TIMING *timing_profile,
                             double *maxdiff,
                             long heldout,
                             long retrain)
{
    long *chosen,*key,i,j,jj,*last_suboptimal_at,noshrink;
    long inconsistentnum,choosenum,already_chosen=0,iteration;
    long misclassified,supvecnum=0,*active2dnum,inactivenum;
    long *working2dnum,*selexam;
    long activenum;
    double criterion,eq;
    double *a_old;
    long t0=0,t1=0,t2=0,t3=0,t4=0,t5=0,t6=0; /* timing */
    long transductcycle;
    long transduction;
    double epsilon_crit_org; 
    
    double *selcrit;  /* buffer for sorting */        
    CFLOAT *aicache;  /* buffer to keep one row of hessian */
    double *weights;  /* buffer for weight vector in linear case */
    QP qp;            /* buffer for one quadratic program */
    
    epsilon_crit_org=learn_parm->epsilon_crit; /* save org */
    if(kernel_parm->kernel_type == LINEAR)
    {
        learn_parm->epsilon_crit=2.0;
        kernel_cache=NULL;   /* caching makes no sense for linear kernel */
    } 
    learn_parm->epsilon_shrink=2;
    (*maxdiff)=1;
    
    learn_parm->totwords=totwords;
    
    chosen = (long *)my_malloc(sizeof(long)*totdoc);
    last_suboptimal_at = (long *)my_malloc(sizeof(long)*totdoc);
    key = (long *)my_malloc(sizeof(long)*(totdoc+11)); 
    selcrit = (double *)my_malloc(sizeof(double)*totdoc);
    selexam = (long *)my_malloc(sizeof(long)*totdoc);
    a_old = (double *)my_malloc(sizeof(double)*totdoc);
    aicache = (CFLOAT *)my_malloc(sizeof(CFLOAT)*totdoc);
    working2dnum = (long *)my_malloc(sizeof(long)*(totdoc+11));
    active2dnum = (long *)my_malloc(sizeof(long)*(totdoc+11));
    qp.opt_ce = (double *)my_malloc(sizeof(double)*learn_parm->svm_maxqpsize);
    qp.opt_ce0 = (double *)my_malloc(sizeof(double));
    qp.opt_g = (double *)my_malloc(sizeof(double)*learn_parm->svm_maxqpsize
        *learn_parm->svm_maxqpsize);
    qp.opt_g0 = (double *)my_malloc(sizeof(double)*learn_parm->svm_maxqpsize);
    qp.opt_xinit = (double *)my_malloc(sizeof(double)*learn_parm->svm_maxqpsize);
    qp.opt_low=(double *)my_malloc(sizeof(double)*learn_parm->svm_maxqpsize);
    qp.opt_up=(double *)my_malloc(sizeof(double)*learn_parm->svm_maxqpsize);
    weights=(double *)my_malloc(sizeof(double)*(totwords+1));
    
    choosenum=0;
    inconsistentnum=0;
    transductcycle=0;
    transduction=0;
    if(!retrain) retrain=1;
    iteration=1;
    
    if(kernel_cache)
    {
        kernel_cache->time=iteration;  /* for lru cache */
        kernel_cache_reset_lru(kernel_cache);
    }
    
    for(i=0;i<totdoc;i++)
    {    /* various inits */
        chosen[i]=0;
        a_old[i]=a[i];
        last_suboptimal_at[i]=1;
        if(inconsistent[i]) 
            inconsistentnum++;
        if(unlabeled[i]) 
        {
            transduction=1;
        }
    }
    activenum=compute_index(shrink_state->active,totdoc,active2dnum);
    inactivenum=totdoc-activenum;
    clear_index(working2dnum);
    
    /* repeat this loop until we have convergence */
    for(;retrain;iteration++) 
    {
        
        if(kernel_cache)
            kernel_cache->time=iteration;  /* for lru cache */
        i=0;
        for(jj=0;(j=working2dnum[jj])>=0;jj++)
        { /* clear working set */
            if((chosen[j]>=(learn_parm->svm_maxqpsize/
                minl(learn_parm->svm_maxqpsize,
                learn_parm->svm_newvarsinqp))) 
                || (inconsistent[j])
                || (j == heldout)) 
            {
                chosen[j]=0; 
                choosenum--; 
            }
            else 
            {
                chosen[j]++;
                working2dnum[i++]=j;
            }
        }
        working2dnum[i]=-1;
        
        if(retrain == 2)
        {
            choosenum=0;
            for(jj=0;(j=working2dnum[jj])>=0;jj++) 
            { /* fully clear working set */
                chosen[j]=0; 
            }
            clear_index(working2dnum);
            for(i=0;i<totdoc;i++) 
            { /* set inconsistent examples to zero (-i 1) */
                if((inconsistent[i] || (heldout==i)) && (a[i] != 0.0)) 
                {
                    chosen[i]=99999;
                    choosenum++;
                    a[i]=0;
                }
            }
            if(learn_parm->biased_hyperplane)
            {
                eq=0;
                for(i=0;i<totdoc;i++) 
                { /* make sure we fulfill equality constraint */
                    eq+=a[i]*label[i];
                }
                for(i=0;(i<totdoc) && (fabs(eq) > learn_parm->epsilon_a);i++) 
                {
                    if((eq*label[i] > 0) && (a[i] > 0)) 
                    {
                        chosen[i]=88888;
                        choosenum++;
                        if((eq*label[i]) > a[i]) 
                        {
                            eq-=(a[i]*label[i]);
                            a[i]=0;
                        }
                        else 
                        {
                            a[i]-=(eq*label[i]);
                            eq=0;
                        }
                    }
                }
            }
            compute_index(chosen,totdoc,working2dnum);
        }
        else
        {      /* select working set according to steepest gradient */ 
            if((minl(learn_parm->svm_newvarsinqp,learn_parm->svm_maxqpsize)>=4) 
                && (kernel_parm->kernel_type != LINEAR)) 
            {
                /* select part of the working set from cache */
                already_chosen=select_next_qp_subproblem_grad_cache(
                    label,unlabeled,a,lin,totdoc,
                    minl((long)(learn_parm->svm_maxqpsize-choosenum),
                    (long)(learn_parm->svm_newvarsinqp/2)),
                    learn_parm,inconsistent,active2dnum,
                    working2dnum,selcrit,selexam,kernel_cache,
                    key,chosen);
                choosenum+=already_chosen;
            }
            choosenum+=select_next_qp_subproblem_grad(label,unlabeled,a,lin,totdoc,
                minl((long)(learn_parm->svm_maxqpsize-choosenum),
                (long)(learn_parm->svm_newvarsinqp-already_chosen)),
                learn_parm,inconsistent,active2dnum,
                working2dnum,selcrit,selexam,kernel_cache,key,
                chosen);
        }
        
        //      sprintf(temstr," %ld vectors chosen\n",choosenum);  printm(temstr);
        
        
        t1=get_runtime();
        
        if(kernel_cache) 
            cache_multiple_

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -