⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 svm_learn.cpp

📁 支持向量机分类器(可分类文本
💻 CPP
📖 第 1 页 / 共 5 页
字号:
        if((!inconsistent[i]) && (!unlabeled[i]) 
            && (a[i]>=(learn_parm->svm_cost[i]-learn_parm->epsilon_a))) 
        { 
            (*inconsistentnum)++;
            inconsistent[i]=1;  /* never choose again */
            retrain=2;          /* start over */
            
            sprintf(temstr,"inconsistent(%ld)..",i);  printm(temstr);
        }
    }
    return(retrain);
}

long identify_misclassified(
                            double *lin,
                            long *label,long *unlabeled,long totdoc,
                            MODEL *model,
                            long *inconsistentnum,long *inconsistent)
{
    long i,retrain;
    double dist;
    
    /* Throw out misclassified examples. This */
    /* corresponds to the -i 2 option. */
    /* ATTENTION: this is just a heuristic for finding a close */
    /*            to minimum number of examples to exclude to */
    /*            make the problem separable with desired margin */
    retrain=0;
    for(i=0;i<totdoc;i++) 
    {
        dist=(lin[i]-model->b)*(double)label[i]; /* 'distance' from hyperplane*/  
        if((!inconsistent[i]) && (!unlabeled[i]) && (dist <= 0)) 
        { 
            (*inconsistentnum)++;
            inconsistent[i]=1;  /* never choose again */
            retrain=2;          /* start over */
            
            sprintf(temstr,"inconsistent(%ld)..",i);  printm(temstr);
        }
    }
    return(retrain);
}

long identify_one_misclassified(
                                double *lin,
                                long *label,long *unlabeled,long totdoc,
                                MODEL *model,
                                long *inconsistentnum,long *inconsistent)
{
    long i,retrain,maxex=-1;
    double dist,maxdist=0;
    
    /* Throw out the 'most misclassified' example. This */
    /* corresponds to the -i 3 option. */
    /* ATTENTION: this is just a heuristic for finding a close */
    /*            to minimum number of examples to exclude to */
    /*            make the problem separable with desired margin */
    retrain=0;
    for(i=0;i<totdoc;i++) 
    {
        if((!inconsistent[i]) && (!unlabeled[i]))
        {
            dist=(lin[i]-model->b)*(double)label[i];/* 'distance' from hyperplane*/  
            if(dist<maxdist)
            {
                maxdist=dist;
                maxex=i;
            }
        }
    }
    if(maxex>=0) 
    {
        (*inconsistentnum)++;
        inconsistent[maxex]=1;  /* never choose again */
        retrain=2;          /* start over */
        
        sprintf(temstr,"inconsistent(%ld)..",i);  printm(temstr);
        
    }
    return(retrain);
}

void update_linear_component(
                             DOC *docs,
                             long *label,
                             long *active2dnum,                  /* keep track of the linear component */
                             double *a,double *a_old,                   /* lin of the gradient etc. by updating */
                             long *working2dnum,long totdoc,long totwords, /* based on the change of the variables */
                             KERNEL_PARM *kernel_parm,           /* in the current working set */
                             KERNEL_CACHE *kernel_cache,
                             double *lin,
                             CFLOAT *aicache,
                             double *weights)
{
    register long i,ii,j,jj;
    register double tec;
    
    if(kernel_parm->kernel_type==0) { /* special linear case */
        clear_vector_n(weights,totwords);
        for(ii=0;(i=working2dnum[ii])>=0;ii++) 
        {
            if(a[i] != a_old[i])
            {
                add_vector_ns(weights,docs[i].words,((a[i]-a_old[i])*(double)label[i]));
            }
        }
        for(jj=0;(j=active2dnum[jj])>=0;jj++)
        {
            lin[j]+=sprod_ns(weights,docs[j].words);
        }
    }
    else 
    {                            /* general case */
        for(jj=0;(i=working2dnum[jj])>=0;jj++) 
        {
            if(a[i] != a_old[i])
            {
                get_kernel_row(kernel_cache,docs,i,totdoc,active2dnum,aicache,
                    kernel_parm);
                for(ii=0;(j=active2dnum[ii])>=0;ii++)
                {
                    tec=aicache[j];
                    lin[j]+=(((a[i]*tec)-(a_old[i]*tec))*(double)label[i]);
                }
            }
        }
    }
}


long incorporate_unlabeled_examples(
                                    MODEL *model,
                                    long *label,
                                    long *inconsistent,long *unlabeled,
                                    double *a,double *lin,
                                    long totdoc,
                                    double *selcrit,
                                    long *select,long *key,long transductcycle,
                                    KERNEL_PARM *kernel_parm,
                                    LEARN_PARM *learn_parm)
{
    long i,j,k,j1,j2,j3,j4,unsupaddnum1=0,unsupaddnum2=0;
    long pos,neg,upos,uneg,orgpos,orgneg,nolabel,newpos,newneg,allunlab;
    double dist,model_length,posratio,negratio;
    long check_every=2;
    double loss;
    static double switchsens=0.0,switchsensorg=0.0;
    double umin,umax,sumalpha;
    long imin=0,imax=0;
    static long switchnum=0;
    
    switchsens/=1.2;
    
    /* assumes that lin[] is up to date -> no inactive vars */
    
    orgpos=0;
    orgneg=0;
    newpos=0;
    newneg=0;
    nolabel=0;
    allunlab=0;
    for(i=0;i<totdoc;i++)
    {
        if(!unlabeled[i]) 
        {
            if(label[i] > 0) 
            {
                orgpos++;
            }
            else 
            {
                orgneg++;
            }
        }
        else 
        {
            allunlab++;
            if(unlabeled[i]) 
            {
                if(label[i] > 0)
                {
                    newpos++;
                }
                else if(label[i] < 0) 
                {
                    newneg++;
                }
            }
        }
        if(label[i]==0)
        {
            nolabel++;
        }
    }
    
    if(learn_parm->transduction_posratio >= 0) 
    {
        posratio=learn_parm->transduction_posratio;
    }
    else 
    {
        posratio=(double)orgpos/(double)(orgpos+orgneg); /* use ratio of pos/neg */
    }                                                  /* in training data */
    negratio=1.0-posratio;
    
    learn_parm->svm_costratio=1.0;                     /* global */
    if(posratio>0)
    {
        learn_parm->svm_costratio_unlab=negratio/posratio;
    }
    else 
    {
        learn_parm->svm_costratio_unlab=1.0;
    }
    
    pos=0;
    neg=0;
    upos=0;
    uneg=0;
    for(i=0;i<totdoc;i++)
    {
        dist=(lin[i]-model->b);  /* 'distance' from hyperplane*/
        if(dist>0)
        {
            pos++;
        }
        else
        {
            neg++;
        }
        if(unlabeled[i]) 
        {
            if(dist>0)
            {
                upos++;
            }
            else
            {
                uneg++;
            }
        }
        if((!unlabeled[i]) && (a[i]>(learn_parm->svm_cost[i]-learn_parm->epsilon_a))) 
        {
            /*      printf("Ubounded %ld (class %ld, unlabeled %ld)\n",i,label[i],unlabeled[i]); */
        }
    }
    
    sprintf(temstr,"POS=%ld, ORGPOS=%ld, ORGNEG=%ld\n",pos,orgpos,orgneg);
    printm(temstr);
    sprintf(temstr,"POS=%ld, NEWPOS=%ld, NEWNEG=%ld\n",pos,newpos,newneg);
    printm(temstr);
    sprintf(temstr,"pos ratio = %f (%f).\n",(double)(upos)/(double)(allunlab),posratio);
    printm(temstr);
    
    
    if(transductcycle == 0) 
    {
        j1=0; 
        j2=0;
        j4=0;
        for(i=0;i<totdoc;i++) 
        {
            dist=(lin[i]-model->b);  /* 'distance' from hyperplane*/
            if((label[i]==0) && (unlabeled[i]))
            {
                selcrit[j4]=dist;
                key[j4]=i;
                j4++;
            }
        }
        unsupaddnum1=0; 
        unsupaddnum2=0; 
        select_top_n(selcrit,j4,select,(long)(allunlab*posratio+0.5));
        for(k=0;(k<(long)(allunlab*posratio+0.5));k++) 
        {
            i=key[select[k]];
            label[i]=1;
            unsupaddnum1++; 
            j1++;
        }
        for(i=0;i<totdoc;i++)
        {
            if((label[i]==0) && (unlabeled[i])) 
            {
                label[i]=-1;
                j2++;
                unsupaddnum2++;
            }
        }
        for(i=0;i<totdoc;i++) 
        {  /* set upper bounds on vars */
            if(unlabeled[i])
            {
                if(label[i] == 1) 
                {
                    learn_parm->svm_cost[i]=learn_parm->svm_c*learn_parm->svm_costratio_unlab*learn_parm->svm_unlabbound;
                }
                else if(label[i] == -1) 
                {
                    learn_parm->svm_cost[i]=learn_parm->svm_c*learn_parm->svm_unlabbound;
                }
            }
        }
        
        sprintf(temstr,"costratio %lf, costratio_unlab %lf, unlabbound %lf\n",
            learn_parm->svm_costratio,learn_parm->svm_costratio_unlab,
            learn_parm->svm_unlabbound); 
        printm(temstr);
        sprintf(temstr,"Classifying unlabeled data as %ld POS / %ld NEG.\n",
            unsupaddnum1,unsupaddnum2); 
        printm(temstr);
        
        sprintf(temstr,"Retraining.");
        printm(temstr);
        sprintf(temstr,"\n");
        printm(temstr);
        
        return((long)3);
    }
    if((transductcycle % check_every) == 0) 
    {
        
        sprintf(temstr,"Retraining.");
        printm(temstr);
        sprintf(temstr,"\n");
        printm(temstr);
        j1=0;
        j2=0;
        unsupaddnum1=0;
        unsupaddnum2=0;
        for(i=0;i<totdoc;i++) 
        {
            if((unlabeled[i] == 2)) 
            {
                unlabeled[i]=1;
                label[i]=1;
                j1++;
                unsupaddnum1++;
            }
            else if((unlabeled[i] == 3))
            {
                unlabeled[i]=1;
                label[i]=-1;
                j2++;
                unsupaddnum2++;
            }
        }
        for(i=0;i<totdoc;i++)
        {  /* set upper bounds on vars */
            if(unlabeled[i]) 
            {
                if(label[i] == 1)
                {
                    learn_parm->svm_cost[i]=learn_parm->svm_c*
                        learn_parm->svm_costratio_unlab*learn_parm->svm_unlabbound;
                }
                else if(label[i] == -1) 
                {
                    learn_parm->svm_cost[i]=learn_parm->svm_c*
                        learn_parm->svm_unlabbound;
                }
            }
        }
        
        
        sprintf(temstr,"costratio %lf, costratio_unlab %lf, unlabbound %lf\n",
            learn_parm->svm_costratio,learn_parm->svm_costratio_unlab,
            learn_parm->svm_unlabbound); 
        printm(temstr);
        sprintf(temstr,"%ld positive -> Added %ld POS / %ld NEG unlabeled examples.\n",
            upos,unsupaddnum1,unsupaddnum2); 
        printm(temstr);
        
        if(learn_parm->svm_unlabbound == 1) 
        {
            learn_parm->epsilon_crit=0.001; /* do the last run right */

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -