⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 svm_learn.cpp

📁 支持向量机分类器(可分类文本
💻 CPP
📖 第 1 页 / 共 5 页
字号:
                 fprintf(stdout,"Actual leave-one-outs computed:  %ld (rho=%.2f)\n",
                     loocomputed,learn_parm->rho);
                 sprintf(temstr,"Runtime for leave-one-out in cpu-seconds: %.2f\n",
                     (double)(get_runtime()-runtime_start_loo)/100.0);
                 printm(temstr);
             }
             
             // if(learn_parm->alphafile[0])
             //     write_alphas(learn_parm->alphafile,a,label,totdoc);
             
             shrink_state_cleanup(&shrink_state);
             
             free(inconsistent);
             free(unlabeled);
             free(a);
             free(a_fullset);
             free(xi_fullset);
             free(lin);
             free(learn_parm->svm_cost);
}



long optimize_to_convergence(
                             DOC *docs,
                             long *label,
                             long totdoc,
                             long totwords,           
                             LEARN_PARM *learn_parm,    
                             KERNEL_PARM *kernel_parm,  
                             KERNEL_CACHE *kernel_cache,
                             SHRINK_STATE *shrink_state,
                             MODEL *model,              
                             long *inconsistent,
                             long *unlabeled,
                             double *a,
                             double *lin,
                             TIMING *timing_profile,
                             double *maxdiff,
                             long heldout,
                             long retrain)
{
    long *chosen,*key,i,j,jj,*last_suboptimal_at,noshrink;
    long inconsistentnum,choosenum,already_chosen=0,iteration;
    long misclassified,supvecnum=0,*active2dnum,inactivenum;
    long *working2dnum,*selexam;
    long activenum;
    double criterion,eq;
    double *a_old;
    long t0=0,t1=0,t2=0,t3=0,t4=0,t5=0,t6=0; /* timing */
    long transductcycle;
    long transduction;
    double epsilon_crit_org; 
    
    double *selcrit;  /* buffer for sorting */        
    CFLOAT *aicache;  /* buffer to keep one row of hessian */
    double *weights;  /* buffer for weight vector in linear case */
    QP qp;            /* buffer for one quadratic program */
    
    epsilon_crit_org=learn_parm->epsilon_crit; /* save org */
    if(kernel_parm->kernel_type == LINEAR)
    {
        learn_parm->epsilon_crit=2.0;
        kernel_cache=NULL;   /* caching makes no sense for linear kernel */
    } 
    learn_parm->epsilon_shrink=2;
    (*maxdiff)=1;
    
    learn_parm->totwords=totwords;
    
    chosen = (long *)my_malloc(sizeof(long)*totdoc);
    last_suboptimal_at = (long *)my_malloc(sizeof(long)*totdoc);
    key = (long *)my_malloc(sizeof(long)*(totdoc+11)); 
    selcrit = (double *)my_malloc(sizeof(double)*totdoc);
    selexam = (long *)my_malloc(sizeof(long)*totdoc);
    a_old = (double *)my_malloc(sizeof(double)*totdoc);
    aicache = (CFLOAT *)my_malloc(sizeof(CFLOAT)*totdoc);
    working2dnum = (long *)my_malloc(sizeof(long)*(totdoc+11));
    active2dnum = (long *)my_malloc(sizeof(long)*(totdoc+11));
    qp.opt_ce = (double *)my_malloc(sizeof(double)*learn_parm->svm_maxqpsize);
    qp.opt_ce0 = (double *)my_malloc(sizeof(double));
    qp.opt_g = (double *)my_malloc(sizeof(double)*learn_parm->svm_maxqpsize
        *learn_parm->svm_maxqpsize);
    qp.opt_g0 = (double *)my_malloc(sizeof(double)*learn_parm->svm_maxqpsize);
    qp.opt_xinit = (double *)my_malloc(sizeof(double)*learn_parm->svm_maxqpsize);
    qp.opt_low=(double *)my_malloc(sizeof(double)*learn_parm->svm_maxqpsize);
    qp.opt_up=(double *)my_malloc(sizeof(double)*learn_parm->svm_maxqpsize);
    weights=(double *)my_malloc(sizeof(double)*(totwords+1));
    
    choosenum=0;
    inconsistentnum=0;
    transductcycle=0;
    transduction=0;
    if(!retrain) retrain=1;
    iteration=1;
    
    if(kernel_cache)
    {
        kernel_cache->time=iteration;  /* for lru cache */
        kernel_cache_reset_lru(kernel_cache);
    }
    
    for(i=0;i<totdoc;i++)
    {    /* various inits */
        chosen[i]=0;
        a_old[i]=a[i];
        last_suboptimal_at[i]=1;
        if(inconsistent[i]) 
            inconsistentnum++;
        if(unlabeled[i]) 
        {
            transduction=1;
        }
    }
    activenum=compute_index(shrink_state->active,totdoc,active2dnum);
    inactivenum=totdoc-activenum;
    clear_index(working2dnum);
    
    /* repeat this loop until we have convergence */
    for(;retrain&&com_param.Running;iteration++) 
    {
        
        if(kernel_cache)
            kernel_cache->time=iteration;  /* for lru cache */
        i=0;
        for(jj=0;(j=working2dnum[jj])>=0;jj++)
        { /* clear working set */
            if((chosen[j]>=(learn_parm->svm_maxqpsize/
                minl(learn_parm->svm_maxqpsize,
                learn_parm->svm_newvarsinqp))) 
                || (inconsistent[j])
                || (j == heldout)) 
            {
                chosen[j]=0; 
                choosenum--; 
            }
            else 
            {
                chosen[j]++;
                working2dnum[i++]=j;
            }
        }
        working2dnum[i]=-1;
        
        if(retrain == 2)
        {
            choosenum=0;
            for(jj=0;(j=working2dnum[jj])>=0;jj++) 
            { /* fully clear working set */
                chosen[j]=0; 
            }
            clear_index(working2dnum);
            for(i=0;i<totdoc;i++) 
            { /* set inconsistent examples to zero (-i 1) */
                if((inconsistent[i] || (heldout==i)) && (a[i] != 0.0)) 
                {
                    chosen[i]=99999;
                    choosenum++;
                    a[i]=0;
                }
            }
            if(learn_parm->biased_hyperplane)
            {
                eq=0;
                for(i=0;i<totdoc;i++) 
                { /* make sure we fulfill equality constraint */
                    eq+=a[i]*label[i];
                }
                for(i=0;(i<totdoc) && (fabs(eq) > learn_parm->epsilon_a);i++) 
                {
                    if((eq*label[i] > 0) && (a[i] > 0)) 
                    {
                        chosen[i]=88888;
                        choosenum++;
                        if((eq*label[i]) > a[i]) 
                        {
                            eq-=(a[i]*label[i]);
                            a[i]=0;
                        }
                        else 
                        {
                            a[i]-=(eq*label[i]);
                            eq=0;
                        }
                    }
                }
            }
            compute_index(chosen,totdoc,working2dnum);
        }
        else
        {      /* select working set according to steepest gradient */ 
            if((minl(learn_parm->svm_newvarsinqp,learn_parm->svm_maxqpsize)>=4) 
                && (kernel_parm->kernel_type != LINEAR)) 
            {
                /* select part of the working set from cache */
                already_chosen=select_next_qp_subproblem_grad_cache(
                    label,unlabeled,a,lin,totdoc,
                    minl((long)(learn_parm->svm_maxqpsize-choosenum),
                    (long)(learn_parm->svm_newvarsinqp/2)),
                    learn_parm,inconsistent,active2dnum,
                    working2dnum,selcrit,selexam,kernel_cache,
                    key,chosen);
                choosenum+=already_chosen;
            }
            choosenum+=select_next_qp_subproblem_grad(label,unlabeled,a,lin,totdoc,
                minl((long)(learn_parm->svm_maxqpsize-choosenum),
                (long)(learn_parm->svm_newvarsinqp-already_chosen)),
                learn_parm,inconsistent,active2dnum,
                working2dnum,selcrit,selexam,kernel_cache,key,
                chosen);
        }
        
        //      sprintf(temstr," %ld vectors chosen\n",choosenum);  printm(temstr);
        
        
        t1=get_runtime();
        
        if(kernel_cache) 
            cache_multiple_kernel_rows(kernel_cache,docs,working2dnum,choosenum,kernel_parm); 
        
        
        t2=get_runtime();
        if(retrain != 2)
        {
            optimize_svm(docs,label,unlabeled,chosen,active2dnum,model,totdoc,
                working2dnum,choosenum,a,lin,learn_parm,aicache,
                kernel_parm,&qp,&epsilon_crit_org);
        }
        
        
        t3=get_runtime();
        update_linear_component(docs,label,active2dnum,a,a_old,working2dnum,totdoc,
            totwords,kernel_parm,kernel_cache,lin,aicache,weights);
        
        
        
        t4=get_runtime();
        supvecnum=calculate_svm_model(docs,label,unlabeled,lin,a,a_old,learn_parm,
            working2dnum,model);
        
        
        t5=get_runtime();
        
        /* The following computation of the objective function works only */
        /* relative to the active variables */
        
        criterion=compute_objective_function(a,lin,label,active2dnum);
        //  sprintf(temstr,"Objective function (over active variables): %.16f\n",criterion);
        //  printm(temstr);
        
        for(jj=0;(i=working2dnum[jj])>=0;jj++)
        {
            a_old[i]=a[i];
        }
        
        if(retrain == 2) 
        {  /* reset inconsistent unlabeled examples */
            for(i=0;(i<totdoc);i++) 
            {
                if(inconsistent[i] && unlabeled[i]) 
                {
                    inconsistent[i]=0;
                    label[i]=0;
                }
            }
        }
        
        retrain=check_optimality(model,label,unlabeled,a,lin,totdoc,learn_parm,
            maxdiff,epsilon_crit_org,&misclassified,
            inconsistent,active2dnum,last_suboptimal_at,
            iteration,kernel_parm);
        
        
        t6=get_runtime();
        timing_profile->time_select+=t1-t0;
        timing_profile->time_kernel+=t2-t1;
        timing_profile->time_opti+=t3-t2;
        timing_profile->time_update+=t4-t3;
        timing_profile->time_model+=t5-t4;
        timing_profile->time_check+=t6-t5;
        
        noshrink=0;
        if((!retrain) && (inactivenum>0) 
            && ((!learn_parm->skip_final_opt_check) 
            || (kernel_parm->kernel_type == LINEAR))) 
        { 
			if (com_pro.show_other)
			{
	            sprintf(temstr," Checking optimality of inactive variables..."); 
				printm(temstr);
			}
            
            t1=get_runtime();
            reactivate_inactive_examples(label,unlabeled,a,shrink_state,lin,totdoc,
                totwords,iteration,learn_parm,inconsistent,
                docs,kernel_parm,kernel_cache,model,aicache,
                weights,maxdiff);
            /* Update to new active variables. */
            activenum=compute_index(shrink_state->active,totdoc,active2dnum);
            inactivenum=totdoc-activenum;
            /* termination criterion */
            noshrink=1;
            retrain=0;
            if((*maxdiff) > learn_parm->epsilon_crit) 
                retrain=1;
            timing_profile->time_shrink+=get_runtime()-t1;
        }
        
        if((!retrain) && (learn_parm->epsilon_crit>(*maxdiff))) 
            learn_parm->epsilon_crit=(*maxdiff);
        if((!retrain) && (learn_parm->epsilon_crit>epsilon_crit_org)) 
        {
            learn_parm->epsilon_crit/=2.0;
            retrain=1;
            noshrink=1;
        }
        if(learn_parm->epsilon_crit<epsilon_crit_org) 
            learn_parm->epsilon_crit=epsilon_crit_org;
        {
         ///   sprintf(temstr," => (%ld SV (incl. %ld SV at u-bound), max violation=%.5f)\n",
         //       supvecnum,model->at_upper_bound,(*maxdiff)); 
           // printm(temstr);
        }
        
        if((!retrain) && (transduction)) 
        {
            for(i=0;(i<totdoc);i++)
            {
                shrink_state->active[i]=1;
            }
            activenum=compute_index(shrink_state->active,totdoc,active2dnum);
            inactivenum=0;
            
            retrain=incorporate_unlabeled_examples(model,label,inconsistent,
                unlabeled,a,lin,totdoc,
                selcrit,selexam,key,
                transductcycle,kernel_parm,
                learn_parm);
            epsilon_crit_org=learn_parm->epsilon_crit;
            if(kernel_parm->kernel_type == LINEAR)
                learn_parm->epsilon_crit=1; 
            transductcycle++;
        } 
        else if(((iteration % 10) == 0) && (!noshrink)) 
        {
            activenum=shrink_problem(learn_parm,shrink_state,active2dnum,iteration,last_suboptimal_at,
                totdoc,maxl((long)(activenum/10),100),
                a,inconsistent);
            inactivenum=totdoc-activenum;
            if((kernel_cache)
                && (supvecnum>kernel_cache->max_elems)
                && ((kernel_cache->activenum-activenum)>maxl((long)(activenum/10),500))) 
            {
                kernel_cache_shrink(kernel_cache,totdoc,maxl((long)(activenum/10),500),
                    shrink_state->active); 
            }
        }
        
        if((!retrain) && learn_parm->remove_inconsistent) 
        {
            
            sprintf(temstr," Moving training errors to inconsistent examples...");
            printm(temstr);
            
            if(learn_parm->remove_inconsistent == 1) 
            {
                retrain=identify_inconsistent(a,label,unlabeled,totdoc,learn_parm,
                    &inconsistentnum,inconsistent); 
            }
            else if(learn_parm->remove_inconsistent == 2) 
            {
                retrain=identify_misclassified(lin,label,unlabeled,totdoc,
                    model,&inconsistentnum,inconsistent); 
            }
            else if(learn_parm->remove_inconsistent == 3) 
            {

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -