📄 regmethod.java
字号:
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/**
* Title: XELOPES Data Mining Library
* Description: The XELOPES library is an open platform-independent and data-source-independent library for Embedded Data Mining.
* Copyright: Copyright (c) 2002 Prudential Systems Software GmbH
* Company: ZSoft (www.zsoft.ru), Prudsys (www.prudsys.com)
* @author Michael Thess
* @author Victor Borichev
* @author Valentine Stepanenko (valentine.stepanenko@zsoft.ru)
* @version 1.0
*/
package com.prudsys.pdm.Models.Regression.SVM.Algorithms.RegularizationNetworks;
import java.util.Vector;
import com.prudsys.pdm.Core.MiningAttribute;
import com.prudsys.pdm.Core.MiningDataSpecification;
import com.prudsys.pdm.Core.MiningException;
import com.prudsys.pdm.Input.MiningInputStream;
import com.prudsys.pdm.Input.MiningVector;
import com.prudsys.pdm.Input.Records.Arff.MiningArffStream;
import com.prudsys.pdm.Models.Supervised.Classifier;
/**
* Training of regularization networks.
*
*/
public class RegMethod implements Classifier {
/**
* Regularization problem.
*/
protected RegProblem m_prob;
/**
* Regularization network parameters.
*/
protected RegParameters m_param;
/**
* Classifier data.
*/
protected RegModel m_model;
/**
* Name of classifier.
*/
protected String m_model_file_name = "data/test.model";
private int classIndex = -1;
/**
* Constructor initilizes parameter values by default.
*/
public RegMethod() {
// Init parameters:
m_param = new RegParameters();
}
/**
* reads data from DataSource and converts it into RegProblem class.
* The data is written into the global variable prob.
*
* @param data data source
* @param max_lines maximum lines to be read, -1 for unbounded
* @return number of lines read
* @Exception reading error
*/
private int readData(MiningInputStream data, int max_lines, MiningAttribute target) throws MiningException {
Vector vy = new Vector();
Vector vx = new Vector();
int max_index = 0;
int n_missing = 0;
int n_line = 0;
data.reset();
MiningDataSpecification metaData = data.getMetaData();
int pos = metaData.getAttributeIndex(target);
classIndex = pos;
while (data.next()) {
MiningVector vec = data.read();
// Add class label:
double val = vec.getValue(pos);
// if (Math.abs(val) < 0.0001)
// val = -1;
if (Double.isNaN(val)) {
val = 0;
n_missing = n_missing + 1;
};
vy.addElement( new Double(val) );
// Add coordinate values:
int numbValues = 0;
double[] values = vec.getValues();
for(int i=0; i<values.length;i++)
if(values[i]!=0.0 || i == pos) numbValues++;
// int shift = 1;
// if (data.getClassIndex() == AbsDataSource.CLASS_INDEX_LAST_COL)
// shift = 0;
int m = numbValues - 1;
RegNode[] x = new RegNode[m];
m = 0;
for(int j = 0; j < values.length; j++) {
if(values[j] == 0.0 || j == pos) continue;
x[m] = new RegNode();
x[m].index = j;//vec.getIndex(j);
// val = vec.getValueSparse(j+shift);
val = values[j];
if (Double.isNaN(val)) {
val = 0.0;
n_missing = n_missing + 1;
};
x[m++].value = val;
};
if (m > 0)
max_index = Math.max(max_index, x[m-1].index);
vx.addElement(x);
// Count lines for break:
n_line = n_line + 1;
if (n_line == max_lines)
break;
};
// Fill regularization problem data:
m_prob = new RegProblem();
m_prob.l = vy.size();
m_prob.x = new RegNode[m_prob.l][];
for(int i = 0; i < m_prob.l; i++)
m_prob.x[i] = (RegNode[]) vx.elementAt(i);
m_prob.y = new double[m_prob.l];
for(int i = 0; i < m_prob.l; i++)
m_prob.y[i] = ((Double)vy.elementAt(i)).doubleValue();
m_prob.max_c = max_index+1;
m_prob.n_missing = n_missing;
return m_prob.l;
}
/**
* reads a basic vector and converts it into the RegNode object.
* If the vector is not sparse, the sparse vector will be created.
* If a classifying feature is included, it will be ignored.
*
* @param bvec vector to be read
* @return RegNode object of bvec
* @Exception reading error
*/
private RegNode[] readDataVector(MiningVector vec, int classIndex) throws MiningException {
// Init:
int max_index = 0;
int n_missing = 0;
// Get vector in sparse format:
/* SparseVector vec = null;
if (bvec instanceof SparseVector)
vec = (SparseVector) bvec;
else
vec = new SparseVector(bvec);*/
// Add coordinate values:
/* int classIndex = AbsDataSource.NO_CLASS_INDEX;
try {
classIndex = bvec.getClassIndex();
}
catch (Exception ex) {
classIndex = AbsDataSource.NO_CLASS_INDEX;
};*/
int numbValues = 0;
double[] values = vec.getValues();
for(int i=0; i<values.length;i++)
if(values[i]!=0.0 || (classIndex != -1 && i == classIndex)) numbValues++;
int m = numbValues;
if(classIndex != -1) m--;
/* int shift = -1;
if (classIndex == AbsDataSource.CLASS_INDEX_FIRST_COL) {
m = m - 1;
shift = 1;
}
else if (classIndex == AbsDataSource.CLASS_INDEX_LAST_COL) {
m = m - 1;
shift = 0;
}
else if (classIndex == AbsDataSource.NO_CLASS_INDEX)
shift = 0;
if (shift == -1)
throw new Exception("Class index > 0 forbidden!");
*/
RegNode[] x = new RegNode[m];
m = 0;
for(int j = 0; j < values.length; j++) {
if(values[j] == 0.0 || (classIndex != -1 && j == classIndex)) continue;
x[m] = new RegNode();
x[m].index = j;//vec.getIndex(j);
double val = values[j];
if (Double.isNaN(val)) {
val = 0.0;
n_missing = n_missing + 1;
};
x[m++].value = val;
};
// Determine maximum index:
if (m > 0)
max_index = Math.max(max_index, x[m-1].index);
return x;
}
/**
* Generates the classifier.
*
* @param data set of data serving as training data
* @exception MiningException if the classifier has not been generated successfully
*/
public void buildClassifier(MiningInputStream data, MiningAttribute target) throws MiningException {
// read data:
readData(data, -1, target);
// Modify parameters:
if (m_param.gamma == 0)
m_param.gamma = 1.0/(m_prob.max_c-1);
// Building classfier:
// m_model = RegNetwork.svm_train(m_prob, m_param);
m_model = LIBSVMAdapter.svm_train(m_prob, m_param);
// Save classifier:
try {
// RegNetwork.svm_save_model(m_model_file_name, m_model);
LIBSVMAdapter.svm_save_model(m_model_file_name, m_model);
} catch(Exception ex) {
throw new MiningException("svm_save_model failed");
}
}
/**
* Classifies a given vector.
*
* If the vector is still connected to its datasource
* the information of its class index will be used:
* If a class feature is contained in the vector it will be ignored.
* Otherwise, it is supposed that the vector does
* not contain any class feature.
*
* @param vector the vector to be classified
* @return index of the predicted value
* @exception MiningException if vector could not be classified
* successfully
*/
public double apply(MiningVector vector) throws MiningException {
RegNode[] rNode = readDataVector(vector, classIndex);
// return RegNetwork.svm_predict(m_model, rNode);
return LIBSVMAdapter.svm_predict(m_model, rNode);
}
//<<17/03/2005, Frank J. Xu
//Do not add any changes for the current version except implementing the interface.
/**
* Classifies a given vector.
*
* If the vector is still connected to its datasource
* the information of its class index will be used:
* If a class feature is contained in the vector it will be ignored.
* Otherwise, it is supposed that the vector does
* not contain any class feature.
*
* @param vector the vector to be classified
* @return index of the predicted value
* @exception MiningException if vector could not be classified
* successfully
*/
public double apply(MiningVector vector, Object a_wekaInstances) throws MiningException {
RegNode[] rNode = readDataVector(vector, classIndex);
// return RegNetwork.svm_predict(m_model, rNode);
return LIBSVMAdapter.svm_predict(m_model, rNode);
}
//17/03/2005, Frank J. Xu>>
/**
* Returns a description of this classifier.
*
* @return a description of this classifier as a string.
*/
public String toString() {
return ("Regularization network classifier.");
}
/**
* Returns SVM model.
*
* @return SVM model
*/
public RegModel getModel()
{
return m_model;
}
/**
* Sets new SVM model.
*
* @param model new SVM model
*/
public void setModel(RegModel model)
{
this.m_model = model;
}
/**
* Returns SVM parameters,
*
* @return SVM parameters
*/
public RegParameters getParam()
{
return m_param;
}
/**
* Sets SVM parameters.
*
* @param param new SVM parameters
*/
public void setParam(RegParameters param)
{
this.m_param = param;
}
/**
* Returns SVM problem.
*
* @return SVM problem
*/
public RegProblem getProb()
{
return m_prob;
}
/**
* Sets new SVM problem.
*
* @param prob new SVM problem
*/
public void setProb(RegProblem prob)
{
this.m_prob = prob;
}
/**
* Test routine.
*/
public static void main(String[] args) {
try {
// Test file:
// File inFile = new File("RegNetworks\\heart_scale"); // sparse
// File inFile = new File("RegNetworks\\test.txt"); // dense
// Create data source for sparse test file:
// DataSourceFileSparse dsa = new DataSourceFileSparse(inFile, 0,
// AbsDataSource.CLASS_INDEX_FIRST_COL); // sparse
// DataSourceFile dsa = new DataSourceFile(inFile); // dense
MiningArffStream aid = new MiningArffStream("data/arff/regress.arff");
MiningAttribute target = aid.getMetaData().getMiningAttribute("three");
// Create instance of this class:
RegMethod reg = new RegMethod();
// Set and show parameters:
/* String[] param = new String[6];
param[0] = "-C";
param[1] = "10";
param[2] = "-Y";
param[3] = "3";
param[4] = "-P";
param[5] = "0.01";
reg.setParameters(param);*/
/* Enumeration em = reg.getParameterList();
while (em.hasMoreElements())
System.out.println( ((Parameter) em.nextElement()).toString() );
param = reg.getParameters();
for (int i = 0; i < param.length; i++)
System.out.print(param[i] + " ");
System.out.println();
*/
// Set parameters:
// dsa.setClassIndex(AbsDataSource.CLASS_INDEX_FIRST_COL); // sparse
// dsa.setClassIndex(AbsDataSource.CLASS_INDEX_LAST_COL); // dense
reg.m_param.reg_type = 3;
reg.m_param.p = 0.01;
reg.m_param.C = 10.;
// Build classifier:
reg.buildClassifier(aid,target);
// Evaluate classifier:
int i = 0;
int wrong = 0;
aid.reset();
int classIndex = aid.getMetaData().getAttributeIndex(target);
while (aid.next()) {
MiningVector vec = aid.read();
double val = reg.apply(vec);
if ( val*reg.m_prob.y[i] < 0 )
wrong = wrong + 1;
i = i + 1;
System.out.println("val=" + val);
}
System.out.println("rate = " + (100.0 - ((double) wrong / i)*100.0) );
}
catch (MiningException ex) {
ex.printStackTrace();
};
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -