⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 ensembleclassifier.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
字号:
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    EnsembleClassifier.java *    Copyright (C) 2003 Prem Melville * */package weka.classifiers;import weka.core.*;import com.jmage.*;import java.util.*;/**  * Abstract class for Ensemble Classifiers * * @author Prem Melville * @version $Revision: 1.3 $ */public abstract class EnsembleClassifier extends DistributionClassifier implements AdditionalMeasureProducer{        /** the error on the training data */    protected double m_TrainError=0;    /** the average error of the ensemble on the training data */    protected double m_TrainEnsembleError=0;    /** the ensemble diversity computed in the training data */    protected double m_TrainEnsembleDiversity=0;    /** Sum of ensemble weights */    protected double m_SumEnsembleWts=0;    /** Vote weights of ensemble members */    protected double []m_EnsembleWts;    /** Returns class predictions of each ensemble member */    public abstract double []getEnsemblePredictions(Instance instance) 	throws Exception;        /**      * Returns vote weights of ensemble members.     *     * @return vote weights of ensemble members     */    public abstract double []getEnsembleWts();        /** Returns size of ensemble */    public abstract double getEnsembleSize();  /**   * Returns an enumeration of the additional measure names   * @return an enumeration of the measure names   */  public Enumeration enumerateMeasures() {    Vector newVector = new Vector(3);    newVector.addElement("measureTrainError");    newVector.addElement("measureTrainEnsembleError");    newVector.addElement("measureTrainEnsembleDiversity");    return newVector.elements();  }  /**   * Returns the value of the named measure   * @param measureName the name of the measure to query for its value   * @return the value of the named measure   * @exception IllegalArgumentException if the named measure is not supported   */  public double getMeasure(String additionalMeasureName) {    if (additionalMeasureName.compareTo("measureTrainError") == 0) {	return measureTrainError();    } else if (additionalMeasureName.compareTo("measureTrainEnsembleError") == 0) {	return measureTrainEnsembleError();    } else if (additionalMeasureName.compareTo("measureTrainEnsembleDiversity") == 0) {	return measureTrainEnsembleDiversity();    } else {      throw new IllegalArgumentException(additionalMeasureName 			  + " not supported (DEC)");    }  }        /**     * @return the error on the training data     **/    public double measureTrainError(){	return m_TrainError;    }        /**      * @return the average error of the ensemble on the training data     */    public double measureTrainEnsembleError(){	return m_TrainEnsembleError;    }        /**     * @return the ensemble diversity computed in the training data     */    public double measureTrainEnsembleDiversity(){	return m_TrainEnsembleDiversity;    }        /** Initialize measures */    protected void initMeasures(){	m_SumEnsembleWts=0;	m_TrainError=0;	m_TrainEnsembleError=0;	m_TrainEnsembleDiversity=0;    }        /**     * Compute ensemble measures.     * @param data training instances     */    protected void computeEnsembleMeasures(Instances data) throws Exception{	for(int j=0; j<getEnsembleSize(); j++)	    m_SumEnsembleWts += m_EnsembleWts[j];		//DEBUG	//System.out.println("Ensemble size = "+getEnsembleSize());	if(m_SumEnsembleWts == 0.0){	    System.out.println("Ensemble wts sum to 0!");	    for(int j=0; j<m_EnsembleWts.length; j++)		System.out.print("\t"+m_EnsembleWts[j]);	    System.out.println();	}		double totalInstanceWt=0;	Instance curr;	for (int i = 0; i < data.numInstances(); i++) {	    curr = data.instance(i); 	    totalInstanceWt += curr.weight();	    if(curr.weight() != 1.0) System.out.println(">>> Instance Weight = "+curr.weight());	    updateEnsembleStats(classifyInstance(curr), curr, getEnsemblePredictions(curr));	}	//DEBUG	Assert.that(totalInstanceWt==data.numInstances(),"Instance wts don't total to num of instances!");		m_TrainError = 100.0 * (m_TrainError/totalInstanceWt);	m_TrainEnsembleError = 100.0 * m_TrainEnsembleError/totalInstanceWt;	m_TrainEnsembleDiversity = 100.0 * m_TrainEnsembleDiversity/totalInstanceWt;    }        /**     * Update statistics for ensemble classifiers.     *     * @param pred ensemble prediction     * @param instance training instance     * @param ensemblePreds predictions of ensemble members     */    protected void updateEnsembleStats(double pred, Instance instance, double []ensemblePreds){	//System.out.print("Updating Ensemble Stats...");	double sumEnsembleError = 0, sumEnsembleDiversity = 0;	double actualClass = instance.classValue();		for(int i=0; i<getEnsembleSize(); i++){	    if(actualClass != ensemblePreds[i])		sumEnsembleError += m_EnsembleWts[i];	    	    //if member's prediction differs from the ensemble prediction, diversity increases 	    if(pred != ensemblePreds[i])		sumEnsembleDiversity += m_EnsembleWts[i];	}		if(pred != actualClass) m_TrainError += instance.weight();	m_TrainEnsembleError += ((sumEnsembleError/m_SumEnsembleWts)*instance.weight());	m_TrainEnsembleDiversity += ((sumEnsembleDiversity/m_SumEnsembleWts)*instance.weight());    }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -