📄 svmmethod.java
字号:
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/**
* Title: XELOPES Data Mining Library
* Description: The XELOPES library is an open platform-independent and data-source-independent library for Embedded Data Mining.
* Copyright: Copyright (c) 2002 Prudential Systems Software GmbH
* Company: ZSoft (www.zsoft.ru), Prudsys (www.prudsys.com)
* @author Michael Thess
* @author Victor Borichev
* @author Valentine Stepanenko (valentine.stepanenko@zsoft.ru)
* @version 1.0
*/
package com.prudsys.pdm.Models.Regression.SVM.Algorithms.SparseSVM;
import java.util.Vector;
import com.prudsys.pdm.Core.MetaDataOperations;
import com.prudsys.pdm.Core.MiningAttribute;
import com.prudsys.pdm.Core.MiningDataSpecification;
import com.prudsys.pdm.Core.MiningException;
import com.prudsys.pdm.Input.MiningInputStream;
import com.prudsys.pdm.Input.MiningSparseVector;
import com.prudsys.pdm.Input.MiningVector;
import com.prudsys.pdm.Models.Regression.SVM.SupportVectorClassifier;
import com.prudsys.pdm.Models.Regression.SVM.SupportVectorSettings;
/**
* Sparse SVM.
*
*/
public class SVMMethod extends SupportVectorClassifier {
/** SVM problem. */
protected SVMProblem m_prob;
/** SVM parameters. */
protected SVMParameters m_param;
/** SVM model. */
protected SVMModel m_model;
/** Name of model file. */
protected String m_model_file_name = "data/test.model";
/**
* Constructor initilizes parameter values by default.
*/
public SVMMethod() {
// Init parameters:
m_param = new SVMParameters();
}
/**
* Tests given value to be equal to zero
* in some small interval.
*
* @param dNumb number to be tested
* @return true if zero, otherwise false
*/
private boolean isZero(double dNumb) {
if (Math.abs(dNumb) < 0.000001)
return true;
else
return false;
}
/**
* reads data from input stream and converts it into SVMProblem class.
* The data is written into the global variable m_prob.
*
* @param data input stream
* @param max_lines maximum lines to be read, -1 for unbounded
* @param target target attribute
* @return number of lines read
* @exception MiningException reading error
*/
private int readData(MiningInputStream data, int max_lines, MiningAttribute target) throws MiningException {
Vector vy = new Vector();
Vector vx = new Vector();
int max_index = 0;
int n_line = 0;
MiningDataSpecification metaData = data.getMetaData();
classIndex = metaData.getAttributeIndex(target);
data.reset();
while (data.next()) {
// Get mining vector:
MiningSparseVector vec = new MiningSparseVector( data.read() );
vec.setMetaData( data.getMetaData() );
// Add class label:
double val = vec.getValue(classIndex);
vy.addElement( new Double(val) );
// Add coordinate values:
int m = vec.getNumValuesSparse();
int ind = vec.locateIndex(classIndex);
if ( ind > -1 && vec.getIndex( ind ) == classIndex )
m = m - 1; // non-zero target attribute value => remove
SVMNode[] x = new SVMNode[m];
m = 0;
for (int j = 0; j < vec.getNumValuesSparse(); j++) {
if (vec.getIndex(j) == classIndex) continue;
x[m] = new SVMNode();
x[m].index = vec.getIndex(j);
x[m++].value = vec.getValueSparse(j);
};
if (m > 0)
max_index = Math.max(max_index, x[m-1].index);
vx.addElement(x);
// Count lines for break:
n_line = n_line + 1;
if (n_line == max_lines)
break;
};
// Fill SVM problem data:
m_prob = new SVMProblem();
m_prob.l = vy.size();
m_prob.x = new SVMNode[m_prob.l][];
for(int i = 0; i < m_prob.l; i++)
m_prob.x[i] = (SVMNode[]) vx.elementAt(i);
m_prob.y = new double[m_prob.l];
for(int i = 0; i < m_prob.l; i++)
m_prob.y[i] = ((Double)vy.elementAt(i)).doubleValue();
m_prob.max_c = max_index+1;
return m_prob.l;
}
/**
* Reads a mining vector for model application and converts it
* into an SVMNode object. If the vector is not sparse, the sparse
* vector will be created.
* If a classifying feature is included, it will be ignored.
*
* @param vec vector to be read
* @return SVMNode object of bvec
* @exception MiningException reading error
*/
private SVMNode[] readDataVector(MiningVector vec) throws MiningException {
double[] values = vec.getValues();
int numbValues = 0;
for (int i = 0; i < values.length; i++)
if (values[i] != 0.0 || i == classIndex) numbValues++;
// Add coordinate values:
int m = numbValues - 1;
SVMNode[] x = new SVMNode[m];
m = 0;
for (int j = 0; j < values.length; j++) {
if (values[j] == 0.0 || j == classIndex) continue;
x[m] = new SVMNode();
x[m].index = j;
x[m++].value = values[j];
};
return x;
}
/**
* Generates the classifier.
*
* @param data set of data serving as training data
* @param target target attribute
* @exception MiningException if the classifier has not been generated successfully
*/
public void buildClassifier(MiningInputStream data, MiningAttribute target) throws MiningException {
// read data:
readData(data, -1, target);
// Modify parameters:
if (m_param.gamma == 0)
m_param.gamma = 1.0/(m_prob.max_c-1);
// Building classfier:
m_model = LIBSVMAdapter.svm_train(m_prob, m_param);
// Save classifier:
try {
LIBSVMAdapter.svm_save_model(m_model_file_name, m_model);
}
catch(Exception ex) {
throw new MiningException("svm_save_model failed");
}
// Copy to SparseVectorClassifier structure:
SVMNode[][] sv = m_model.SV;
MiningSparseVector[] msv = new MiningSparseVector[sv.length];
for (int i = 0; i < sv.length; i++)
{
double[] values = new double[sv[i].length];
int[] indices = new int[sv[i].length];
for (int j = 0; j < sv[i].length; j++)
{
values[j] = sv[i][j].value;
indices[j] = sv[i][j].index;
};
msv[i] = new MiningSparseVector(0.f, values, indices);
msv[i].setMetaData( data.getMetaData() );
};
this.supportVectors = msv;
this.coefficients = m_model.sv_coef[0];
this.absoluteCoefficient = m_model.rho[0];
}
/**
* Classifies a given vector.
*
* If a class attribute is contained in the vector it will be ignored.
* Otherwise, it is supposed that the vector does not contain any
* class attribute.
*
* @param vector the vector to be classified
* @return index of the predicted value
* @exception MiningException if vector could not be classified
* successfully
*/
public double apply(MiningVector vector) throws MiningException {
// Bring mining vector to same meta data like model:
MetaDataOperations metaOp = metaData.getMetaDataOp();
metaOp.setUsageType( MetaDataOperations.USE_ATT_NAMES_AND_TYPES );
vector = metaOp.transform(vector);
// Change vector to sparse format:
SVMNode[] rNode = readDataVector(vector);
// Apply SVM:
return LIBSVMAdapter.svm_predict(m_model, rNode);
}
/**
* Constructs internel SVM classifier model from data read.
*
* @throws MiningException could not construct internal SVM model
*/
protected void constructModel() throws MiningException {
SVMModel model = new SVMModel();
SVMParameters param = new SVMParameters();
model.param = param;
model.label = null;
model.nSV = null;
param.svm_type = svmType;
param.kernel_type = kernelType;
param.degree = degree;
param.coef0 = coef0;
param.gamma = gamma;
model.nr_class = 2;
model.l = 0;
if (supportVectors != null)
model.l = supportVectors.length;
model.rho = new double[1];
model.rho[0] = absoluteCoefficient;
/************** Still not ready: ***************************/
if (svmType == SupportVectorSettings.SVM_C_SVC) {
int n = model.nr_class;
model.label = new int[n];
for (int i = 0; i < n; i++)
model.label[i] = i;
model.nSV = new int[n];
for (int i = 0; i < n; i++)
model.nSV[i] = i;
};
/************** ...still not ready. ************************/
int m = model.nr_class - 1;
int l = model.l;
model.sv_coef = new double[m][l];
model.SV = new SVMNode[l][];
for (int i = 0 ; i < l; i++)
{
for (int k = 0; k < m; k++)
model.sv_coef[k][i] = coefficients[i];
int n = supportVectors[i].getMetaData().getAttributesNumber();
model.SV[i] = new SVMNode[n];
for(int j = 0; j < n; j++)
{
model.SV[i][j] = new SVMNode();
model.SV[i][j].index = j;
model.SV[i][j].value = supportVectors[i].getValue(j);
};
};
this.m_model = model;
}
/**
* Returns a description of this classifier.
*
* @return a description of this classifier as a string.
*/
public String toString() {
return ("Sparse SVM classifier.");
}
/**
* Returns SVM model.
*
* @return SVM model
*/
public SVMModel getModel()
{
return m_model;
}
/**
* Sets new SVM model.
*
* @param model new SVM model
*/
public void setModel(SVMModel model)
{
this.m_model = model;
}
/**
* Returns SVM parameters,
*
* @return SVM parameters
*/
public SVMParameters getParam()
{
return m_param;
}
/**
* Sets SVM parameters.
*
* @param param new SVM parameters
*/
public void setParam(SVMParameters param)
{
this.m_param = param;
}
/**
* Returns SVM problem.
*
* @return SVM problem
*/
public SVMProblem getProb()
{
return m_prob;
}
/**
* Sets new SVM problem.
*
* @param prob new SVM problem
*/
public void setProb(SVMProblem prob)
{
this.m_prob = prob;
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -