⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 pfsals_svmtrain.c

📁 崭新矢量机SVM源码
💻 C
字号:
#include "mex.h"
#include "math.h"
#include <time.h>


// compute the absolute maximum and the corresponding index
double absmax(double *gradvec, int n, int *gradindex)
{
	double maxvalue;
	int i;
	maxvalue  = fabs(gradvec[0]);
	*gradindex = 0;
	for (i=1;i<n;i++)
	{
		if (fabs(gradvec[i]) > maxvalue)
		{
			maxvalue  = fabs(gradvec[i]);
			*gradindex = i;
		}
	}
	return maxvalue;
};

// dot multiplication
double dot(int p, int q, int dim,double *psamples)
{
	double sum = 0;
	int i, count1, count2;
	count1 = p*dim;
	count2 = q*dim;

	for (i=0; i< dim; i++)
	{
		sum += *(psamples + count1 + i) * (*(psamples + count2 + i));
	}
	return sum;
};

// Gaussian Kernel evaluation
double kernel(double kernelparam, int p,int q, int dim, double *psamples, double *square)
{
	double sum;
	sum = dot(p, q, dim, psamples);
	return exp(-kernelparam*(square[p] + square[q] - 2*sum));
};

// rank 1 update
void Rank1Update(double *R,double *beta, double gamma, int cnum, int n)
{
    int i,j,count=0;
    for(i=0;i<n;i++)
    {
        for(j=0;j<n;j++)
        {
            *(R+i*cnum+j) = *(R+i*cnum+j) + *(beta+i) * *(beta+j)*gamma; 
         }
      }
	for(i=0;i<n;i++)
	{
		*(R+i*cnum + n) = -gamma* *(beta+i);
	}
	for(i=0;i<n;i++)
	{
		*(R + n*cnum+i) = -gamma* *(beta+i);
	}
	*(R + n*cnum + n) = gamma;
};

// main function
void mexFunction (int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
	double *psamples,*plabels, *square, *talpha, *alpha, *sv, *gradvec, *invK, *beta,*nsv,*bias, *keval;
	double param, lamda, tol, gamma, gradmax, diag, delta;
	int *svi, *pindex, *Q, *gradindex, n, d, cnum, pssize,i, j, p, q,tgradindex,lenQ;
	clock_t start;
    
	psamples	= mxGetPr(prhs[0]);
	plabels		= mxGetPr(prhs[1]);
	param		= mxGetScalar(prhs[3]);
	lamda		= mxGetScalar(prhs[4]);
	tol			= mxGetScalar(prhs[5]);
	d			= mxGetM(prhs[0]);
	n			= mxGetN(prhs[0]);
	lamda		= (1.0/2.0)/lamda;
    
	pssize		= 149;
	cnum		= 10000;
    square		= mxCalloc(n, sizeof(double));
	gradvec		= mxCalloc(pssize,sizeof(double));
	Q			= mxCalloc(n,sizeof(int));
	pindex		= mxCalloc(n,sizeof(int));
	keval		= mxCalloc(n,sizeof(double));
	talpha		= mxCalloc(n+1,sizeof(double));
	invK		= mxCalloc(cnum*cnum,sizeof(double));
	beta		= mxCalloc(n+1,sizeof(double));
    gradindex   = mxCalloc(1, sizeof(int));
	svi			= mxCalloc(n,sizeof(int));

	for (i=0; i<n; i++)
		square[i]	= dot(i, i, d,psamples);
// select the next index
	lenQ = n;
	for (i=0;i<lenQ;i++)
		Q[i] = i;
    for (i=0;i<pssize;i++)
	{
		pindex[i] = ceil(rand()*1.0/(1+RAND_MAX)*lenQ);
		gradvec[i]	= -plabels[Q[pindex[i]]];
	}
	gradmax = absmax(gradvec,pssize,gradindex);
	tgradindex = pindex[*gradindex];
	*gradindex = Q[tgradindex];

	for (i=0;i<n;i++)
	{
		if (gradmax < tol || i >= cnum)
		{
			break;
		}
		if (i == 0)
		{
			diag = kernel(param,*gradindex,*gradindex,d,psamples,square) + lamda;
			invK[0] = -1*diag;
			invK[1] = 1;
			invK[cnum] = 1;
			invK[cnum+1] = 0;
			talpha[0] = invK[1]*plabels[*gradindex];
			talpha[1] = invK[cnum+1]*plabels[*gradindex];
			svi[i]    = *gradindex;
		}
		else
		{
			diag = kernel(param,*gradindex,*gradindex,d,psamples,square) + lamda;
			for(j=0;j<i;j++)
				keval[j] = kernel(param,svi[j],*gradindex,d,psamples,square);
			start = clock();
            for (p=0;p<=i;p++)
			{
				beta[p] = invK[p*cnum];
				for(q=1;q<=i;q++)
				{
					beta[p] = beta[p]+invK[p*cnum+q]*keval[q-1];
				}
			}
			gamma = beta[0];
			for (q=1;q<=i;q++)
			{
				gamma = gamma + beta[q]*keval[q-1];
			}
			gamma = 1/(diag - gamma);
            start = clock();
			Rank1Update(invK,beta,gamma,cnum,i+1);
			delta = 0;
			for (q=1;q<=i;q++)
			{
				delta = delta + beta[q]*plabels[svi[q-1]];
			}
			delta = gamma*(delta - plabels[*gradindex]);
			svi[i] = *gradindex;
			for (q=0;q<=i;q++)
			{
				talpha[q] = talpha[q] + delta*beta[q];
			}
			talpha[i+1] = talpha[i+1] - delta;
		}

// select the next index
		lenQ = lenQ-1;
		for (p=tgradindex; p<lenQ; p++)
			Q[p] = Q[p+1];
		for (p=0;p<pssize;p++)
		{
			pindex[p] = ceil(rand()*1.0/(1+RAND_MAX)*lenQ);
		}
        
        start = clock();
		for (p=0;p<pssize;p++)
		{
			gradvec[p] = talpha[0] - plabels[Q[pindex[p]]];
			for (q=1;q<=i+1;q++)
			{
				gradvec[p] = gradvec[p] + talpha[q]*kernel(param,Q[pindex[p]],svi[q-1],d,psamples,square);
			}
		}
        
		gradmax = absmax(gradvec,pssize,gradindex);
		tgradindex = pindex[*gradindex];
		*gradindex = Q[tgradindex];
	}
    
    plhs[0]		= mxCreateDoubleMatrix(i,1,mxREAL);
	plhs[1]		= mxCreateDoubleMatrix(i,1,mxREAL);
	plhs[2]		= mxCreateDoubleMatrix(1,1,mxREAL);
	plhs[3]		= mxCreateDoubleMatrix(1,1,mxREAL);
	alpha		= mxGetPr(plhs[0]);
	sv			= mxGetPr(plhs[1]);
	nsv			= mxGetPr(plhs[2]);
	bias		= mxGetPr(plhs[3]);
	for (q=1;q<=i;q++)
	{
		alpha[q-1] = talpha[q];
		sv[q-1]	 = svi[q-1]+1;
	}	
	nsv[0]		= i;
	bias[0]		= talpha[0];
	return;
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -