⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 regmethod.java

📁 一个数据挖掘软件ALPHAMINERR的整个过程的JAVA版源代码
💻 JAVA
字号:
/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/**
 * Title: XELOPES Data Mining Library
 * Description: The XELOPES library is an open platform-independent and data-source-independent library for Embedded Data Mining.
 * Copyright: Copyright (c) 2002 Prudential Systems Software GmbH
 * Company: ZSoft (www.zsoft.ru), Prudsys (www.prudsys.com)
 * @author Michael Thess
 * @author Victor Borichev
 * @author Valentine Stepanenko (valentine.stepanenko@zsoft.ru)
 * @version 1.0
 */

package com.prudsys.pdm.Models.Regression.SVM.Algorithms.RegularizationNetworks;

import java.util.Vector;

import com.prudsys.pdm.Core.MiningAttribute;
import com.prudsys.pdm.Core.MiningDataSpecification;
import com.prudsys.pdm.Core.MiningException;
import com.prudsys.pdm.Input.MiningInputStream;
import com.prudsys.pdm.Input.MiningVector;
import com.prudsys.pdm.Input.Records.Arff.MiningArffStream;
import com.prudsys.pdm.Models.Supervised.Classifier;

/**
 * Training of regularization networks.
 *
 */
public class RegMethod implements Classifier {

  /**
   * Regularization problem.
   */
  protected RegProblem m_prob;

  /**
   * Regularization network parameters.
   */
  protected RegParameters m_param;

  /**
   * Classifier data.
   */
  protected RegModel m_model;

  /**
   * Name of classifier.
   */
  protected String m_model_file_name = "data/test.model";

  private int classIndex = -1;

  /**
   * Constructor initilizes parameter values by default.
   */
  public RegMethod() {

    // Init parameters:
    m_param = new RegParameters();
  }

  /**
   * reads data from DataSource and converts it into RegProblem class.
   * The data is written into the global variable prob.
   *
   * @param data data source
   * @param max_lines maximum lines to be read, -1 for unbounded
   * @return number of lines read
   * @Exception reading error
   */
  private int readData(MiningInputStream data, int max_lines, MiningAttribute target) throws MiningException {

    Vector vy     = new Vector();
    Vector vx     = new Vector();
    int max_index = 0;
    int n_missing = 0;
    int n_line    = 0;
    data.reset();
    MiningDataSpecification metaData = data.getMetaData();
    int pos = metaData.getAttributeIndex(target);
    classIndex = pos;
    while (data.next()) {

        MiningVector vec = data.read();

        // Add class label:

        double val = vec.getValue(pos);
//        if (Math.abs(val) < 0.0001)
//          val = -1;
        if (Double.isNaN(val)) {
          val = 0;
          n_missing = n_missing + 1;
        };
        vy.addElement( new Double(val) );

        // Add coordinate values:
        int numbValues = 0;
        double[] values = vec.getValues();
        for(int i=0; i<values.length;i++)
          if(values[i]!=0.0 || i == pos) numbValues++;
//        int shift = 1;
//        if (data.getClassIndex() == AbsDataSource.CLASS_INDEX_LAST_COL)
//          shift = 0;
        int m       = numbValues - 1;
        RegNode[] x = new RegNode[m];
        m = 0;
        for(int j = 0; j < values.length; j++) {
          if(values[j] == 0.0 || j == pos) continue;
          x[m]       = new RegNode();
          x[m].index = j;//vec.getIndex(j);
//          val        = vec.getValueSparse(j+shift);
          val = values[j];
          if (Double.isNaN(val)) {
            val       = 0.0;
            n_missing = n_missing + 1;
          };
          x[m++].value = val;
        };

        if (m > 0)
          max_index = Math.max(max_index, x[m-1].index);
        vx.addElement(x);

        // Count lines for break:
        n_line = n_line + 1;
        if (n_line == max_lines)
          break;
    };

    // Fill regularization problem data:
    m_prob   = new RegProblem();
    m_prob.l = vy.size();
    m_prob.x = new RegNode[m_prob.l][];
    for(int i = 0; i < m_prob.l; i++)
      m_prob.x[i] = (RegNode[]) vx.elementAt(i);
    m_prob.y = new double[m_prob.l];
    for(int i = 0; i < m_prob.l; i++)
      m_prob.y[i] = ((Double)vy.elementAt(i)).doubleValue();

    m_prob.max_c     = max_index+1;
    m_prob.n_missing = n_missing;

    return m_prob.l;
  }

  /**
   * reads a basic vector and converts it into the RegNode object.
   * If the vector is not sparse, the sparse vector will be created.
   * If a classifying feature is included, it will be ignored.
   *
   * @param bvec vector to be read
   * @return RegNode object of bvec
   * @Exception reading error
   */
  private RegNode[] readDataVector(MiningVector vec, int classIndex) throws MiningException {

    // Init:
    int max_index = 0;
    int n_missing = 0;

    // Get vector in sparse format:
/*    SparseVector vec = null;
    if (bvec instanceof SparseVector)
       vec = (SparseVector) bvec;
    else
       vec = new SparseVector(bvec);*/

    // Add coordinate values:
/*    int classIndex = AbsDataSource.NO_CLASS_INDEX;
    try {
      classIndex = bvec.getClassIndex();
    }
    catch (Exception ex) {
      classIndex = AbsDataSource.NO_CLASS_INDEX;
    };*/
    int numbValues = 0;
    double[] values = vec.getValues();
    for(int i=0; i<values.length;i++)
      if(values[i]!=0.0 || (classIndex != -1 && i == classIndex)) numbValues++;

    int m = numbValues;
    if(classIndex != -1) m--;
/*    int shift = -1;
    if (classIndex == AbsDataSource.CLASS_INDEX_FIRST_COL) {
       m     = m - 1;
       shift = 1;
    }
    else if (classIndex == AbsDataSource.CLASS_INDEX_LAST_COL) {
       m     = m - 1;
       shift = 0;
    }
    else if (classIndex == AbsDataSource.NO_CLASS_INDEX)
       shift = 0;
    if (shift == -1)
      throw new Exception("Class index > 0 forbidden!");
*/
    RegNode[] x = new RegNode[m];
    m = 0;
    for(int j = 0; j < values.length; j++) {
      if(values[j] == 0.0 || (classIndex != -1 && j == classIndex)) continue;
      x[m] = new RegNode();
      x[m].index = j;//vec.getIndex(j);
      double val = values[j];
      if (Double.isNaN(val)) {
        val = 0.0;
        n_missing = n_missing + 1;
      };
      x[m++].value = val;
    };

    // Determine maximum index:
    if (m > 0)
      max_index = Math.max(max_index, x[m-1].index);

    return x;
  }

  /**
   * Generates the classifier.
   *
   * @param data set of data serving as training data
   * @exception MiningException if the classifier has not been generated successfully
   */
  public void buildClassifier(MiningInputStream data, MiningAttribute target) throws MiningException {

    // read data:
    readData(data, -1, target);

    // Modify parameters:
    if (m_param.gamma == 0)
      m_param.gamma = 1.0/(m_prob.max_c-1);

    // Building classfier:
//    m_model = RegNetwork.svm_train(m_prob, m_param);
    m_model = LIBSVMAdapter.svm_train(m_prob, m_param);

    // Save classifier:
    try {
//      RegNetwork.svm_save_model(m_model_file_name, m_model);
      LIBSVMAdapter.svm_save_model(m_model_file_name, m_model);
    } catch(Exception ex) {
      throw new MiningException("svm_save_model failed");
    }
  }

  /**
   * Classifies a given vector.
   *
   * If the vector is still connected to its datasource
   * the information of its class index will be used:
   * If a class feature is contained in the vector it will be ignored.
   * Otherwise, it is supposed that the vector does
   * not contain any class feature.
   *
   * @param vector the vector to be classified
   * @return index of the predicted value
   * @exception MiningException if vector could not be classified
   * successfully
   */
  public double apply(MiningVector vector) throws MiningException {

    RegNode[] rNode = readDataVector(vector, classIndex);

//    return RegNetwork.svm_predict(m_model, rNode);
    return LIBSVMAdapter.svm_predict(m_model, rNode);
  }

  //<<17/03/2005, Frank J. Xu
  //Do not add any changes for the current version except implementing the interface.
  /**
   * Classifies a given vector.
   *
   * If the vector is still connected to its datasource
   * the information of its class index will be used:
   * If a class feature is contained in the vector it will be ignored.
   * Otherwise, it is supposed that the vector does
   * not contain any class feature.
   *
   * @param vector the vector to be classified
   * @return index of the predicted value
   * @exception MiningException if vector could not be classified
   * successfully
   */
  public double apply(MiningVector vector, Object a_wekaInstances) throws MiningException {

    RegNode[] rNode = readDataVector(vector, classIndex);

//    return RegNetwork.svm_predict(m_model, rNode);
    return LIBSVMAdapter.svm_predict(m_model, rNode);
  }
  //17/03/2005, Frank J. Xu>>
  /**
   * Returns a description of this classifier.
   *
   * @return a description of this classifier as a string.
   */
  public String toString() {

    return ("Regularization network classifier.");
  }

  /**
   * Returns SVM model.
   *
   * @return SVM model
   */
  public RegModel getModel()
  {
    return m_model;
  }

  /**
   * Sets new SVM model.
   *
   * @param model new SVM model
   */
  public void setModel(RegModel model)
  {
    this.m_model = model;
  }

  /**
   * Returns SVM parameters,
   *
   * @return SVM parameters
   */
  public RegParameters getParam()
  {
    return m_param;
  }

  /**
   * Sets SVM parameters.
   *
   * @param param new SVM parameters
   */
  public void setParam(RegParameters param)
  {
    this.m_param = param;
  }

  /**
   * Returns SVM problem.
   *
   * @return SVM problem
   */
  public RegProblem getProb()
  {
    return m_prob;
  }

  /**
   * Sets new SVM problem.
   *
   * @param prob new SVM problem
   */
  public void setProb(RegProblem prob)
  {
    this.m_prob = prob;
  }


  /**
   * Test routine.
   */
  public static void main(String[] args) {

    try {
      // Test file:
//    File inFile = new File("RegNetworks\\heart_scale");  // sparse
//      File inFile = new File("RegNetworks\\test.txt");     // dense

      // Create data source for sparse test file:
//    DataSourceFileSparse dsa = new DataSourceFileSparse(inFile, 0,
//                               AbsDataSource.CLASS_INDEX_FIRST_COL); // sparse
//      DataSourceFile dsa = new DataSourceFile(inFile);                 // dense

      MiningArffStream aid = new MiningArffStream("data/arff/regress.arff");
      MiningAttribute target = aid.getMetaData().getMiningAttribute("three");
      // Create instance of this class:
      RegMethod reg = new RegMethod();

      // Set and show parameters:
 /*     String[] param      = new String[6];
      param[0]            = "-C";
      param[1]            = "10";
      param[2]            = "-Y";
      param[3]            = "3";
      param[4]            = "-P";
      param[5]            = "0.01";
      reg.setParameters(param);*/
/*      Enumeration em = reg.getParameterList();
      while (em.hasMoreElements())
        System.out.println( ((Parameter) em.nextElement()).toString() );
      param = reg.getParameters();
      for (int i = 0; i < param.length; i++)
        System.out.print(param[i] + " ");
      System.out.println();
*/
      // Set parameters:
//    dsa.setClassIndex(AbsDataSource.CLASS_INDEX_FIRST_COL);  // sparse
//      dsa.setClassIndex(AbsDataSource.CLASS_INDEX_LAST_COL);   // dense
      reg.m_param.reg_type = 3;
      reg.m_param.p = 0.01;
      reg.m_param.C = 10.;
      // Build classifier:
      reg.buildClassifier(aid,target);

      // Evaluate classifier:
      int i     = 0;
      int wrong = 0;
      aid.reset();
      int classIndex = aid.getMetaData().getAttributeIndex(target);
      while (aid.next()) {
        MiningVector vec = aid.read();
        double val = reg.apply(vec);
        if ( val*reg.m_prob.y[i] < 0 )
          wrong = wrong + 1;
        i = i + 1;
        System.out.println("val=" + val);
      }
      System.out.println("rate = " + (100.0 - ((double) wrong / i)*100.0) );
    }
    catch (MiningException ex) {
      ex.printStackTrace();
    };
  }
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -