📄 chsvm.java
字号:
public class CHSVM {
private double[] alpha;
private double bias;
public double getAlpha(int u){
return alpha[u];
}
public double getBias(){
return bias;
}
public boolean IsStable(){
return UnStable;
}
////////////////////////////////////////////////////////////////////////////
private boolean UnStable;
private double[] alphaRowSum;
private int LowmaxIdx;
private int UpminIdx;
private int NumOfPoint;
private double ERpenalty;
private long m_Caller;
private double Accuracy;
//private Kernal kernal;
//////////////////////////////////////////////////////////////
public CHSVM()
{
alpha = null;
alphaRowSum = null;
LowmaxIdx = 0;
UpminIdx =0;
NumOfPoint = 0;
ERpenalty = 0;
bias = 0;
m_Caller = 0;
////////wait to be user control Top
Accuracy = 0.001;
////////wait to be user control Bottom
UnStable = false;
}
//*****************************************************************************************************************************************************************************************************************************
public void Initialize(int numofpoint,double erpenalty,long caller,Kernal kernal)
{
m_Caller=caller;
UnStable=false;
if(NumOfPoint!=numofpoint)
{
if(alpha != null)
{
alpha = null;
}
if(alphaRowSum != null)
{
alphaRowSum = null;
}
alpha=new double[numofpoint];
alphaRowSum=new double[numofpoint];
NumOfPoint=numofpoint;
}
ERpenalty=erpenalty;
int nGetClass;
for(int i=0;i<numofpoint;++i)
{
alpha[i]=0.0;
nGetClass=kernal.GetClass(i,m_Caller);
alphaRowSum[i]=-(nGetClass);
if(nGetClass==1)
{
UpminIdx=i;
}
else
{
LowmaxIdx=i;
}
}
///////////////////////
}
// *****************************************************************************************************************************************************************************************************************************
public boolean trainSVM(Kernal kernal)
{
boolean HasFindUp;
HasFindUp = false;
boolean HasFindDown;
HasFindDown = false;
double resultUpmin;
double resultLowmax;
double Klr;
double Kll;
double Krr;
double Amu;
double Bzi;
double BchuA;
double LowMinB;
double LowMaxB;
double Rconst;
int Yl;
int Yr;
int YlYr;
int oldupminIdx;
int oldlowmaxIdx;
double oldalphaUpMin = 0;
double oldalphaLowMax = 0;
resultUpmin = alphaRowSum[UpminIdx];
resultLowmax = alphaRowSum[LowmaxIdx];
if(resultUpmin<resultLowmax-Accuracy)
{
//////optimal Top
Kll = kernal.KernelFunction(UpminIdx,UpminIdx,m_Caller);
Krr = kernal.KernelFunction(LowmaxIdx,LowmaxIdx,m_Caller);
Klr = kernal.KernelFunction(UpminIdx,LowmaxIdx,m_Caller);
Yl = kernal.GetClass(UpminIdx,m_Caller);
Yr = kernal.GetClass(LowmaxIdx,m_Caller);
YlYr=Yl*Yr;
Rconst=alpha[UpminIdx] + YlYr * alpha[LowmaxIdx];
if(YlYr==1)
{
LowMinB=Rconst-ERpenalty;
LowMaxB=Rconst;
if(LowMinB<0.0)
{
LowMinB=0.0;
}
if(LowMaxB>ERpenalty)
{
LowMaxB=ERpenalty;
}
}
else
{
LowMinB=-Rconst;
LowMaxB=ERpenalty-Rconst;
if(LowMinB<0.0)
{
LowMinB=0.0;
}
if(LowMaxB>ERpenalty)
{
LowMaxB=ERpenalty;
}
}
resultUpmin+=Yl;
resultUpmin-=(alpha[UpminIdx]*Yl*Kll);
resultUpmin-=(alpha[LowmaxIdx]*Yr*Klr);
resultLowmax+=Yr;
resultLowmax-=(alpha[LowmaxIdx]*Yr*Krr);
resultLowmax-=(alpha[UpminIdx]*Yl*Klr);
Amu=Kll+Krr-2*Klr;
Bzi=Rconst*YlYr*(Klr-Kll) + Yr*(resultLowmax-resultUpmin) + YlYr-1;
if(Amu==0)
{
///unchanged////////////////////////////会导致死锁!!!
}
else
{
BchuA=-Bzi/Amu;
if(BchuA>LowMaxB)
{
BchuA=LowMaxB;
}
else if(BchuA<LowMinB)
{
BchuA=LowMinB;
}
oldalphaUpMin=alpha[UpminIdx];
oldalphaLowMax=alpha[LowmaxIdx];
alpha[LowmaxIdx]=BchuA;
alpha[UpminIdx]=Rconst-YlYr*alpha[LowmaxIdx];
}
//////optimal Bottom
oldupminIdx=UpminIdx;
oldlowmaxIdx=LowmaxIdx;
//////update rowSum Top
for(int v=0;v<NumOfPoint;++v)
{
alphaRowSum[v]+=
(
(alpha[oldupminIdx ]-oldalphaUpMin )* kernal.GetClass(oldupminIdx ,m_Caller) * kernal.KernelFunction(v,oldupminIdx ,m_Caller)
+(alpha[oldlowmaxIdx]-oldalphaLowMax)* kernal.GetClass(oldlowmaxIdx,m_Caller) * kernal.KernelFunction(v,oldlowmaxIdx,m_Caller)
);
if(alpha[v]==0.0)
{
if(kernal.GetClass(v,m_Caller)==1)
{
if(HasFindUp)
{
if(alphaRowSum[v]<resultUpmin)
{
resultUpmin=alphaRowSum[v];
UpminIdx=v;
}
}
else
{
resultUpmin=alphaRowSum[v];
UpminIdx=v;
HasFindUp=true;
}
}
else
{
if(HasFindDown)
{
if(alphaRowSum[v]>resultLowmax)
{
resultLowmax=alphaRowSum[v];
LowmaxIdx=v;
}
}
else
{
resultLowmax=alphaRowSum[v];
LowmaxIdx=v;
HasFindDown=true;
}
}
}
else if(alpha[v]==ERpenalty)
{
if(kernal.GetClass(v,m_Caller)==-1)
{
if(HasFindUp)
{
if(alphaRowSum[v]<resultUpmin)
{
resultUpmin=alphaRowSum[v];
UpminIdx=v;
}
}
else
{
resultUpmin=alphaRowSum[v];
UpminIdx=v;
HasFindUp=true;
}
}
else
{
if(HasFindDown)
{
if(alphaRowSum[v]>resultLowmax)
{
resultLowmax=alphaRowSum[v];
LowmaxIdx=v;
}
}
else
{
resultLowmax=alphaRowSum[v];
LowmaxIdx=v;
HasFindDown=true;
}
}
}
else
{
if(HasFindUp)
{
if(alphaRowSum[v]<resultUpmin)
{
resultUpmin=alphaRowSum[v];
UpminIdx=v;
}
}
else
{
resultUpmin=alphaRowSum[v];
UpminIdx=v;
HasFindUp=true;
}
////
if(HasFindDown)
{
if(alphaRowSum[v]>resultLowmax)
{
resultLowmax=alphaRowSum[v];
LowmaxIdx=v;
}
}
else
{
resultLowmax=alphaRowSum[v];
LowmaxIdx=v;
HasFindDown=true;
}
}
}
//////update rowSum Bottom
if(HasFindDown&&HasFindUp)
{
return true;
}
else
{
return false;//////This is an Error for illegal input
}
}
else
{
return false;
}
}
// *****************************************************************************************************************************************************************************************************************************
public void CalBias(Kernal kernal)
{
double tempRule;
int Pntcount;
Pntcount=0;
int selg;
UnStable=true;
bias=0;
for(selg=0;selg<NumOfPoint;++selg)
{
if(alpha[selg]>Accuracy&&alpha[selg]<=ERpenalty-Accuracy)
{
UnStable=false;
tempRule=0;
for( int tr=0;tr<NumOfPoint;++tr)
{
if(alpha[tr] != 0)
{
tempRule+=alpha[tr]*kernal.GetClass(tr,m_Caller)*kernal.KernelFunction(tr,selg,m_Caller);
}
}
bias+=(kernal.GetClass(selg,m_Caller)-tempRule);
++Pntcount;
}
}
if(Pntcount != 0)
{
bias/=((double)Pntcount);
}
}
// *****************************************************************************************************************************************************************************************************************************
public double CalError(Kernal kernal)
{
int errnum;
double threshold;
errnum=0;
if(ERpenalty>=0.5)
{
threshold=0.5-Accuracy;
for( int i=0;i<NumOfPoint;++i)
{
if(alpha[i]>=threshold)
{
++errnum;
}
}
}
else
{
double resultR;
double relaxE;
for( int i=0;i<NumOfPoint;++i)
{
threshold=ERpenalty-Accuracy;
if(alpha[i]>=threshold)
{
resultR=0;
for( int v=0;v<NumOfPoint;++v)
{
if(alpha[v] != 0)
{
resultR+=(kernal.GetClass(v,m_Caller)*kernal.KernelFunction(v,i,m_Caller)*alpha[v]);
}
}
resultR+=bias;
resultR*=kernal.GetClass(i,m_Caller);
relaxE=1-resultR;
if(2*alpha[i]+relaxE>=1-Accuracy)
{
++errnum;
}
}
}
}
return ((double)errnum/(double)NumOfPoint);
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -