📄 svmfusvmbase.cpp
字号:
// This is a part of the SvmFu library, a library for training// Support Vector Machines.//// Copyright (C) 2000 rif and MIT//// Contact: rif@mit.edu// This program is free software; you can redistribute it and/or// modify it under the terms of the GNU General Public License as// published by the Free Software Foundation; either version 2 of// the License, or (at your option) any later version.// This program is distributed in the hope that it will be useful,// but WITHOUT ANY WARRANTY; without even the implied warranty of// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the// GNU General Public License for more details.// You should have received a copy of the GNU General Public// License along with this program; if not, write to the Free// Software Foundation, Inc., 59 Temple Place, Suite 330, Boston,// MA 02111-1307 USA#include "SvmFuSvmBase.h"#define CallTemplate(rettype, funcname) \template<class DataPt, class KernVal> \rettype SvmBase<DataPt, KernVal>::funcname#define InlineTemplate(rettype, funcname) \template<class DataPt, class KernVal> \rettype SvmBase<DataPt, KernVal>::funcname template<class DataPt, class KernVal> SvmBase<DataPt, KernVal>::SvmBase(int svmSize, const IntVec y, DataPt *trnSetPtr, const KernVal (*kernProdFuncPtr)(const DataPt &pt1, const DataPt &pt2), double C, double eps) : kernProdFuncPtr_(kernProdFuncPtr), svmSize_(svmSize), y_(y), trnSetPtr_(trnSetPtr), eps_(eps) { initSvmBase(C);}CallTemplate(void, initSvmBase)(double C) { numSupVecs_ = 0; numUnbndSupVecs_ = 0; numEverSV_ = 0; b_ = 0; alphas_ = new double[svmSize_]; cVec_ = new double[svmSize_]; unbndSet_ = new int[svmSize_]; unbndPos_ = new int[svmSize_]; SV_Set_ = new int[svmSize_]; SV_Pos_ = new int[svmSize_]; everSV_P_ = new bool[svmSize_]; for (int i = 0; i < svmSize_; i++) { alphas_[i] = 0.0; cVec_[i] = C; unbndSet_[i] = -1; unbndPos_[i] = -1; SV_Set_[i] = -1; SV_Pos_[i] = -1; everSV_P_[i] = false; }} template<class DataPt, class KernVal> SvmBase<DataPt, KernVal>::~SvmBase() { delete[] alphas_; delete[] cVec_; delete[] unbndSet_; delete[] unbndPos_; delete[] SV_Set_; delete[] SV_Pos_; delete[] everSV_P_;} InlineTemplate(int, classifyDataPt)(const DataPt &pt) const { return (outputAtDataPt(pt) > 0 ? 1 : -1);} CallTemplate(double, outputAtDataPt)(const DataPt &pt) const { double result = 0.0; int nSV = getNumSupVecs(); IntVec SV_IDs = getSupVecIDsPtr(); for (int i = 0; i < nSV; i++) { int ex = SV_IDs[i]; result += getY(ex)*getAlpha(ex)* (double)kernProdFuncPtr_(getTrainingExample(ex), pt); } result += b_; return result;}InlineTemplate(int, classifyTrainingExample)(int ex) const { return (outputAtTrainingExample(ex) > 0 ? 1 : -1);} CallTemplate(double, outputAtTrainingExample)(int ex) const { return outputAtDataPt(getTrainingExample(ex));} CallTemplate(double, computeTrainingSetPerf)(bool printInfo) const { int correct = 0; int size = getSize(); for (int i = 0; i < size; i++) { if (classifyTrainingExample(i) == getY(i)) { correct++; } } double perf = correct/(double)size; if (printInfo) { cout << "Training Set Performance: " << perf << " (" << correct << "/" << size << ")" << endl; } return perf;}CallTemplate(double, dualObjFunc)() const { double obj = 0.0; int nSV = getNumSupVecs(); IntVec svIDs = getSupVecIDsPtr(); for (int i = 0; i < nSV; i++) { double aI = getAlpha(svIDs[i]); obj += aI; for (int j = 0; j < nSV; j++) { obj -= .5*aI*getAlpha(svIDs[j])*getY(svIDs[i])*getY(svIDs[j])* kernProd(svIDs[i],svIDs[j]); } } return obj;}InlineTemplate(KernVal, kernProd)(int ex1, int ex2) const { return kernProdFuncPtr_(getTrainingExample(ex1), getTrainingExample(ex2));}InlineTemplate(bool, supVecP)(int ex) const { return SV_Pos_[ex] != -1; }InlineTemplate(bool, unbndSupVecP)(int ex) const { return unbndPos_[ex] != -1;}InlineTemplate(bool, everSV_P)(int ex) const { return everSV_P_[ex]; }InlineTemplate(int, getNumSupVecs)() const { return numSupVecs_; }InlineTemplate(int, getNumUnbndSupVecs)() const { return numUnbndSupVecs_; }InlineTemplate(int, getNumEverSV)() const { return numEverSV_; } InlineTemplate(IntVec, getSupVecIDs)() const { int nSV = getNumSupVecs(); IntVec IDs = new int[nSV]; for (int i = 0; i < nSV; i++) { IDs[i] = SV_Set_[i]; } return IDs;}InlineTemplate(IntVec, getSupVecIDsPtr)() const { return SV_Set_;} InlineTemplate(IntVec, getUnbndSupVecIDs)() const { int numUnbnd = getNumUnbndSupVecs(); IntVec IDs = new int[numUnbnd]; for (int i = 0; i < numUnbnd; i++) { IDs[i] = unbndSet_[i]; }; return IDs;}InlineTemplate(IntVec, getUnbndSupVecIDsPtr)() const { return unbndSet_;}CallTemplate(void, setAlpha)(int ex, double newAlpha) { double eps = getEpsilon(); double cEx = getC(ex); if (newAlpha >= cEx+eps) { cerr << "Error: C(" << ex << ") is " << getC(ex) << ", alpha(" << ex << ") may not be set to " << newAlpha << ". Exiting." << endl; exit(1); } double oA = getAlpha(ex), nA = newAlpha; double cExEps = cEx - eps; if (nA >= eps) { // It's an SV now. if (oA < eps) { // It wasn't an SV before addToSV_Set(ex); if (nA < cExEps) { // Is it now a USV? addToUnbndSet(ex); } } else { // It's an SV now and before, check BSV <-> USV if ((oA >= cExEps) && (nA < cExEps)) { addToUnbndSet(ex); } if ((nA >= cExEps) && (oA < cExEps)) { removeFromUnbndSet(ex); } } } else { // It's NOT an SV now if (oA >= eps) { // Was it an SV before? removeFromSV_Set(ex); if (oA < cExEps) { removeFromUnbndSet(ex); } // Was it a USV? } } alphas_[ex] = newAlpha;}CallTemplate(void, setAllAlphas)(const DoubleVec alphas) { int svSize = getSize(); for (int i = 0; i < svSize; i++) { setAlpha(i, alphas[i]); }}InlineTemplate(double, getAlpha)(int ex) const { return alphas_[ex]; }CallTemplate(DoubleVec, getAllAlphas)(void) const { DoubleVec alphas = new double[svmSize_]; for (int i = 0; i < svmSize_; i++) { alphas[i] = alphas_[i]; } return alphas;}CallTemplate(DoubleVec, getSupVecAlphas)(void) const { int nSV = getNumSupVecs(); const IntVec svIDs = getSupVecIDsPtr(); DoubleVec alphas = new double[nSV]; for (int i = 0; i < nSV; i++) { alphas[i] = getAlpha(svIDs[i]); } return alphas;}CallTemplate(DoubleVec, getUnbndSupVecAlphas)(void) const { int nSV = getNumUnbndSupVecs(); const IntVec svIDs = getUnbndSupVecIDsPtr(); DoubleVec alphas = new double[nSV]; for (int i = 0; i < nSV; i++) { alphas[i] = getAlpha(svIDs[i]); } return alphas;} CallTemplate(void, setB)(double b) { b_ = b;}InlineTemplate(double, getB)() const { return b_; }InlineTemplate(double, getEpsilon)() const { return eps_; }CallTemplate(void, setAllC)(double c) { for (int i = 0; i < svmSize_; i++) { setC(i, c); }}CallTemplate(void, setPosNegC)(double posC, double negC) { for (int i = 0; i < svmSize_; i++) { setC(i, y_[i] == 1 ? posC : negC); }}CallTemplate(void, setCVec)(const DoubleVec cVec) { for (int i = 0; i < svmSize_; i++) { setC(i, cVec[i]); }} CallTemplate(void, setC)(int ex, double newC) { double aI = getAlpha(ex); double oldC = getC(ex); double eps = getEpsilon(); if (newC < aI-eps_) { cerr << "Error: alpha(" << ex << ") = " << aI << ": cannot set C(" << ex << ") to " << newC << " Exiting." << endl; exit(-1); } if (aI > eps) { // aI is currently an SV if ((aI < oldC-eps) && (aI >= newC-eps)) { // USV->BSV removeFromUnbndSet(ex); } if ((aI >= oldC-eps) && (aI < newC-eps)) { addToUnbndSet(ex); } } cVec_[ex] = newC;}CallTemplate(DoubleVec, getCVec)() const { int svSize = getSize(); DoubleVec cVec = new double[svSize]; for (int i = 0; i < svSize; i++) { cVec[i] = getC(i); } return cVec;} InlineTemplate(double, getC)(int ex) const { return cVec_[ex]; } InlineTemplate(int, getSize)() const { return svmSize_; }// InlineTemplate(int, getY)(int ex) const { return y_[ex]; }InlineTemplate(void, addToUnbndSet)(int ex) { if (unbndPos_[ex] != -1) { cerr << "ERROR: Trying to add point " << ex << " to the unbounded set; it's already there! Exiting." << endl; exit(1); } unbndPos_[ex] = numUnbndSupVecs_; unbndSet_[numUnbndSupVecs_++] = ex;}InlineTemplate(void, removeFromUnbndSet)(int ex) { if (unbndPos_[ex] == -1) { cerr << "ERROR: Trying to remove point " << ex << " from the unbounded set; it's not there! Exiting." << endl; exit(1); } unbndSet_[unbndPos_[ex]] = unbndSet_[--numUnbndSupVecs_]; unbndPos_[unbndSet_[numUnbndSupVecs_]] = unbndPos_[ex]; unbndPos_[ex] = -1;}InlineTemplate(void, addToSV_Set)(int ex) { if (SV_Pos_[ex] != -1) { cerr << "ERROR: Trying to add point " << ex << " to the SV set; it's already there! Exiting." << endl; exit(1); } SV_Pos_[ex] = numSupVecs_; SV_Set_[numSupVecs_++] = ex;}InlineTemplate(void, removeFromSV_Set)(int ex) { if (SV_Pos_[ex] == -1) { cerr << "ERROR: Trying to remove point " << ex << " from the SV set; it's not there! Exiting." << endl; exit(1); } SV_Set_[SV_Pos_[ex]] = SV_Set_[--numSupVecs_]; SV_Pos_[SV_Set_[numSupVecs_]] = SV_Pos_[ex]; SV_Pos_[ex] = -1;}InlineTemplate(const DataPt, getTrainingExample)(int ex) const { return trnSetPtr_[ex];} #include "SvmFuSvmDataPoint.h"#define IterateTypes(datatype, kerntype) \ template class SvmBase<DataPoint<datatype>, kerntype>;#include "SvmFuSvmTypes.h"
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -