⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 chsvm.java

📁 java实现的支持向量机分类
💻 JAVA
字号:

public class CHSVM {
	private double[] alpha;
	private double bias;
	
	public double getAlpha(int u){
		return alpha[u];
	}
	
	public double getBias(){
		return bias;
	}
	
	public boolean IsStable(){
		return UnStable;
	}
	////////////////////////////////////////////////////////////////////////////
	
 	private boolean UnStable;
	private double[] alphaRowSum;
	private  int LowmaxIdx;
	private int UpminIdx;

	private int NumOfPoint;
	private double ERpenalty;
	private	long m_Caller;
	private	double Accuracy;
	
	//private Kernal kernal;
	
	//////////////////////////////////////////////////////////////

	public CHSVM()
	{
		alpha = null;
		alphaRowSum = null;
		LowmaxIdx = 0;
		UpminIdx =0;


		NumOfPoint = 0;
		ERpenalty = 0;
		bias = 0;

		m_Caller = 0;
		////////wait to be user control Top
		Accuracy = 0.001;
		////////wait to be user control Bottom
		UnStable = false;
		
	}
	
	//*****************************************************************************************************************************************************************************************************************************
	public void Initialize(int numofpoint,double erpenalty,long caller,Kernal kernal)
	{
		m_Caller=caller;
		UnStable=false;
		if(NumOfPoint!=numofpoint)
		{
			if(alpha != null)
			{
				alpha = null;
			}
			if(alphaRowSum != null)
			{
				alphaRowSum = null;
			}
			alpha=new double[numofpoint];
			alphaRowSum=new double[numofpoint];
			NumOfPoint=numofpoint;
		}
		ERpenalty=erpenalty;

		int nGetClass;
		for(int i=0;i<numofpoint;++i)
		{
			alpha[i]=0.0;
			nGetClass=kernal.GetClass(i,m_Caller);
			alphaRowSum[i]=-(nGetClass);
			if(nGetClass==1)
			{
				UpminIdx=i;
			}
			else
			{
				LowmaxIdx=i;
			}
		}
		///////////////////////

	}
//	*****************************************************************************************************************************************************************************************************************************
	public boolean trainSVM(Kernal kernal)
	{

		boolean HasFindUp;
		HasFindUp = false;
		boolean HasFindDown;
		HasFindDown = false;

		double resultUpmin;

		double resultLowmax;

		double Klr;
		double Kll;
		double Krr;
		double Amu;
		double Bzi;
		double BchuA;
		double LowMinB;
		double LowMaxB;
		double Rconst;
		int Yl;
		int Yr;
		int YlYr;

		int oldupminIdx;
		int oldlowmaxIdx;
		double oldalphaUpMin = 0;
		double oldalphaLowMax = 0;

		resultUpmin = alphaRowSum[UpminIdx];
		resultLowmax = alphaRowSum[LowmaxIdx];
		if(resultUpmin<resultLowmax-Accuracy)
		{
			//////optimal Top
			Kll = kernal.KernelFunction(UpminIdx,UpminIdx,m_Caller);
			Krr = kernal.KernelFunction(LowmaxIdx,LowmaxIdx,m_Caller);
			Klr = kernal.KernelFunction(UpminIdx,LowmaxIdx,m_Caller);
			Yl  = kernal.GetClass(UpminIdx,m_Caller);
			Yr  = kernal.GetClass(LowmaxIdx,m_Caller);
			YlYr=Yl*Yr;
			Rconst=alpha[UpminIdx] + YlYr * alpha[LowmaxIdx];

			if(YlYr==1)
			{
				LowMinB=Rconst-ERpenalty;
				LowMaxB=Rconst;
				if(LowMinB<0.0)
				{
					LowMinB=0.0;
				}
				if(LowMaxB>ERpenalty)
				{
					LowMaxB=ERpenalty;
				}
			}
			else
			{
				LowMinB=-Rconst;
				LowMaxB=ERpenalty-Rconst;
				if(LowMinB<0.0)
				{
					LowMinB=0.0;
				}
				if(LowMaxB>ERpenalty)
				{
					LowMaxB=ERpenalty;
				}
			}
			resultUpmin+=Yl;
			resultUpmin-=(alpha[UpminIdx]*Yl*Kll);
			resultUpmin-=(alpha[LowmaxIdx]*Yr*Klr);
			resultLowmax+=Yr;
			resultLowmax-=(alpha[LowmaxIdx]*Yr*Krr);
			resultLowmax-=(alpha[UpminIdx]*Yl*Klr);
			Amu=Kll+Krr-2*Klr;
			Bzi=Rconst*YlYr*(Klr-Kll) + Yr*(resultLowmax-resultUpmin) + YlYr-1;
			if(Amu==0)
			{
				///unchanged////////////////////////////会导致死锁!!!
			}
			else
			{
				BchuA=-Bzi/Amu;
				if(BchuA>LowMaxB)
				{
					BchuA=LowMaxB;
				}
				else if(BchuA<LowMinB)
				{
					BchuA=LowMinB;
				}
				oldalphaUpMin=alpha[UpminIdx];
				oldalphaLowMax=alpha[LowmaxIdx];
				alpha[LowmaxIdx]=BchuA;
				alpha[UpminIdx]=Rconst-YlYr*alpha[LowmaxIdx];
			}
			//////optimal Bottom
			oldupminIdx=UpminIdx;
			oldlowmaxIdx=LowmaxIdx;
			//////update rowSum Top
			for(int v=0;v<NumOfPoint;++v)
			{
				alphaRowSum[v]+=
					(
					 (alpha[oldupminIdx ]-oldalphaUpMin )* kernal.GetClass(oldupminIdx ,m_Caller) * kernal.KernelFunction(v,oldupminIdx ,m_Caller)
					+(alpha[oldlowmaxIdx]-oldalphaLowMax)* kernal.GetClass(oldlowmaxIdx,m_Caller) * kernal.KernelFunction(v,oldlowmaxIdx,m_Caller)
					);
				if(alpha[v]==0.0)
				{
					if(kernal.GetClass(v,m_Caller)==1)
					{
						if(HasFindUp)
						{
							if(alphaRowSum[v]<resultUpmin)
							{
								resultUpmin=alphaRowSum[v];
								UpminIdx=v;
							}
						}
						else
						{
							resultUpmin=alphaRowSum[v];
							UpminIdx=v;
							HasFindUp=true;
						}
					}
					else
					{
						if(HasFindDown)
						{
							if(alphaRowSum[v]>resultLowmax)
							{
								resultLowmax=alphaRowSum[v];
								LowmaxIdx=v;
							}
						}
						else
						{
							resultLowmax=alphaRowSum[v];
							LowmaxIdx=v;
							HasFindDown=true;
						}
					}
				}
				else if(alpha[v]==ERpenalty)
				{
					if(kernal.GetClass(v,m_Caller)==-1)
					{
						if(HasFindUp)
						{
							if(alphaRowSum[v]<resultUpmin)
							{
								resultUpmin=alphaRowSum[v];
								UpminIdx=v;
							}
						}
						else
						{
							resultUpmin=alphaRowSum[v];
							UpminIdx=v;
							HasFindUp=true;
						}
					}
					else
					{
						if(HasFindDown)
						{
							if(alphaRowSum[v]>resultLowmax)
							{
								resultLowmax=alphaRowSum[v];
								LowmaxIdx=v;
							}
						}
						else
						{
							resultLowmax=alphaRowSum[v];
							LowmaxIdx=v;
							HasFindDown=true;
						}
					}
				}
				else
				{
					if(HasFindUp)
					{
						if(alphaRowSum[v]<resultUpmin)
						{
							resultUpmin=alphaRowSum[v];
							UpminIdx=v;
						}
					}
					else
					{
						resultUpmin=alphaRowSum[v];
						UpminIdx=v;
						HasFindUp=true;
					}
					////
					if(HasFindDown)
					{
						if(alphaRowSum[v]>resultLowmax)
						{
							resultLowmax=alphaRowSum[v];
							LowmaxIdx=v;
						}
					}
					else
					{
						resultLowmax=alphaRowSum[v];
						LowmaxIdx=v;
						HasFindDown=true;
					}
				}
			}
			//////update rowSum Bottom
			if(HasFindDown&&HasFindUp)
			{
				return true;
			}
			else
			{
				return false;//////This is an Error for illegal input
			}
		}
		else
		{
			return false;
		}

	}

	
	
//	*****************************************************************************************************************************************************************************************************************************
	public void CalBias(Kernal kernal)
	{
		double tempRule;
		int Pntcount;
		Pntcount=0;
		int selg;
		UnStable=true;
		bias=0;
		for(selg=0;selg<NumOfPoint;++selg)
		{
			if(alpha[selg]>Accuracy&&alpha[selg]<=ERpenalty-Accuracy)
			{
				UnStable=false;
				tempRule=0;
				for( int tr=0;tr<NumOfPoint;++tr)
				{
					if(alpha[tr] != 0)
					{
						tempRule+=alpha[tr]*kernal.GetClass(tr,m_Caller)*kernal.KernelFunction(tr,selg,m_Caller);
					}
				}
				bias+=(kernal.GetClass(selg,m_Caller)-tempRule);
				++Pntcount;
			}
		}
		if(Pntcount != 0)
		{
			bias/=((double)Pntcount);
		}

	}
	
//	*****************************************************************************************************************************************************************************************************************************
public double CalError(Kernal kernal)
	{
		int errnum;
		double threshold;
		errnum=0;
		if(ERpenalty>=0.5)
		{
			threshold=0.5-Accuracy;
			for( int i=0;i<NumOfPoint;++i)
			{
				if(alpha[i]>=threshold)
				{
					++errnum;										
				}
			}
		}
		else
		{
			double resultR;
			double relaxE;
			for( int i=0;i<NumOfPoint;++i)
			{
				threshold=ERpenalty-Accuracy;
				if(alpha[i]>=threshold)
				{
					resultR=0;
					for( int v=0;v<NumOfPoint;++v)
					{
						if(alpha[v] != 0)
						{
							resultR+=(kernal.GetClass(v,m_Caller)*kernal.KernelFunction(v,i,m_Caller)*alpha[v]);
						}
					}
					resultR+=bias;
					resultR*=kernal.GetClass(i,m_Caller);
					relaxE=1-resultR;
					if(2*alpha[i]+relaxE>=1-Accuracy)
					{
						++errnum;
					}
				}
			}
		}
		return ((double)errnum/(double)NumOfPoint);
	}

	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -