⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 dec.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
  /**   * Sets number of random instances to add at each iteration.   *   * @param new_random_size the number of random instances to add at each iteration.   */  public void setRandomSize(double new_random_size) {            m_RandomSize = new_random_size;  }    /**   * Gets the desired size of the committee.   *   * @return the bag size, as a percentage.   */  public int getDesiredSize() {    return m_DesiredSize;  }    /**   * Sets the desired size of the committee.   *   * @param newdesired_size the bag size, as a percentage.   */  public void setDesiredSize(int new_desired_size) {    m_DesiredSize = new_desired_size;  }      /**   * Sets the number of bagging iterations   */  public void setNumIterations(int numIterations) {    m_NumIterations = numIterations;  }  /**   * Gets the number of bagging iterations   *   * @return the maximum number of bagging iterations   */  public int getNumIterations() {        return m_NumIterations;  }    /**     * Set the seed for random number generation.     *     * @param seed the seed      */    public void setSeed(int seed) {	m_Seed = seed;	if(m_Seed==-1){	    random = new Random();	}else{	    random = new Random(m_Seed);	}    }      /**   * Gets the seed for the random number generations   *   * @return the seed for the random number generation   */  public int getSeed() {        return m_Seed;  }        /**     * Get the value of threshold.     * @return value of threshold.     */    public double getThreshold() {	return m_Threshold;    }        /**     * Set the value of threshold.     * @param v  Value to assign to threshold.     */    public void setThreshold(double  v) {	this.m_Threshold = v;    }        /**     * DEC method.     *     * @param data the training data to be used for generating the     * bagged classifier.     * @exception Exception if the classifier could not be built successfully     */    public void buildClassifier(Instances data) throws Exception {		//DEUBG	if(m_UseWeights==1)	    System.out.println("\t>> Using weights...");	else	    System.out.println("\t>> Not using weights...");		//initialize ensemble wts to be equal 	m_EnsembleWts = new double [m_DesiredSize];	for(int j=0; j<m_DesiredSize; j++)	    m_EnsembleWts[j] = 1.0;		initMeasures();			if (m_Classifier == null) {	    throw new Exception("A base classifier has not been specified!");	}	if (data.checkForStringAttributes()) {	    throw new Exception("Can't handle string attributes!");	}		//number of random instances to add at each iteration	int random_size;	if(m_RandomSize<0){	    random_size = (int) (Math.abs(m_RandomSize)*data.numInstances());	    if(random_size==0) random_size=1;//atleast add one random example	} 	else	    random_size = (int) m_RandomSize;	System.out.println("random size = "+random_size);		//maximum number if random instances to generate (when using	//confidence thresholds not all randomly generated examples	//will be labeled).	int max_random_generated = 3*random_size;	int num_attributes = data.numAttributes();	committee = new Vector();//initialize new committee	double e_comm; //classification error of current committee	int i = 1;//current size of committee	int num_trials = 0;	Classifier classifier = m_Classifier;	Instances div_data = new Instances(data); //Local copy of data	Instances random_data = null;	computeStats(div_data);//Find mean and std devs for numeric data	//Create first committee member - base	m_Classifier.buildClassifier(div_data);	committee.add(classifier);		if(m_UseWeights==1)	    m_EnsembleWts[i-1] = computeEnsembleWt(classifier, data);//compute wt based on classifier accuracy 		e_comm = computeError(div_data);	System.out.println("Mem "+i+" added. E_comm = "+e_comm);	computeAccuracy(data);		if(e_comm >= 0.5) {//if the base classifier has an error > 0.5	    m_EnsembleWts[0] = 1.0;//reset the ensemble wt	    i=m_DesiredSize; //skip the following while loop	}		while(i<m_DesiredSize && num_trials<m_NumIterations){	    //System.out.println("\tTrial: "+num_trials);	    	    //Keeping generating random data until the desired number are actually labaled	    Instances total_random_data = new Instances(data, random_size);	    int labeled = 0, generated = 0;	    	    //Set confidence threshold for relabeling 	    double threshold = selectThreshold(e_comm);	    //System.out.println("\tThreshold: "+threshold);	    	    //Continue generating random examples until random_size of	    //them are labeled or the the number of examples generated	    //exceeds max_number_generated. The latter is a failsafe	    //to ensure this loop terminates	    while(labeled < random_size && generated < max_random_generated){		//System.out.println("\tGenerating ramdom instances...");		//Create random instances (diversity data)		//random_data = generateRandomData(random_size, num_attributes, data);				random_data = generateRandomData(random_size-labeled, num_attributes, data);		generated += (random_size-labeled);				//System.out.println("\tLabeling random instances...");		//Label the random data		random_data = labelData(random_data, threshold);				labeled += random_data.numInstances();				addInstances(total_random_data, random_data);	    }	    if(labeled!=random_size)		System.out.println(labeled+" random examples labeled out of the desired "+random_size);	    	    Assert.that(total_random_data.numInstances()==labeled,"Error in random example generation+labeling loop: "+total_random_data.numInstances()+" != "+labeled);	    	    //Remove all the diversity data from the previous step (if any)	    if(div_data.numInstances() > data.numInstances()) {		//System.out.println("\tRemoving previous random data...");		//removeInstances(div_data, random_size);		removeInstances(div_data, div_data.numInstances()-data.numInstances());	    }	    	    Assert.that(div_data.numInstances() == data.numInstances());	    	    //System.out.println("\tAdding new random instances...");	    //Add new random data	    addInstances(div_data, total_random_data);	    	    //System.out.println("\tBuild new classifier...");	    //initialize new classifier	    Classifier tmp[] = Classifier.makeCopies(m_Classifier,1);	    classifier = tmp[0]; 	    classifier.buildClassifier(div_data);	    	    committee.add(classifier);	    if(m_UseWeights==1)		m_EnsembleWts[i] = computeEnsembleWt(classifier, data);//compute wt based on classifier accuracy 	    	    //System.out.println("\tCompute current committee error...");	    double curr_error = computeError(data);	    if(m_EnsembleWts[i]>0 && curr_error <= e_comm){		//adding the new member did not increase the error and the new member has an error < 0.5		i++;		e_comm = curr_error;		System.out.println("Iteration: "+num_trials+"\tMem "+i+" added. E_comm = "+e_comm);	    }else{		committee.removeElementAt(committee.size()-1);//pop the last member	    }	    num_trials++;	}	System.out.println("Final ensemble size: "+committee.size());		//Set measures	computeEnsembleMeasures(data);	//DEBUG	Assert.that(m_TrainError == (100.0 * e_comm), "Bug in train error computation!"+m_TrainError+"\t"+(100.0 * e_comm));}            /**      * Find and store mean and std devs for numeric attributes.     *     * @param data training instances     */    protected void computeStats(Instances data){	//Use to maintain the mean and std devs of numeric attributes	int num_attributes = data.numAttributes();	attribute_stats = new HashMap(num_attributes);		for(int j=0; j<num_attributes; j++){	    if(data.attribute(j).isNominal()){		if(m_DataCreationMethod == TRAINING_DIST){		    int []nom_counts = (data.attributeStats(j)).nominalCounts;		    double []counts = new double[nom_counts.length];		    //Laplace smoothing		    for(int i=0; i<counts.length; i++)			counts[i] = nom_counts[i] + 1;		    Utils.normalize(counts);		    double []stats = new double[counts.length - 1];		    stats[0] = counts[0];		    //Calculate cummalitive probabilities		    for(int i=1; i<stats.length; i++)			stats[i] = stats[i-1] + counts[i];		    		    attribute_stats.put(new Integer(j),stats);		}	    }else if(data.attribute(j).isNumeric()){		if(m_DataCreationMethod == UNIFORM){		    //get range of numeric attribute from the training data		    Stats s = (data.attributeStats(j)).numericStats;		    double []stats = new double[2];		    stats[0] = s.min; 		    stats[1] = s.max;		    attribute_stats.put(new Integer(j),stats);		}else if(m_DataCreationMethod == TRAINING_DIST){		    //get mean and standard deviation from the training data		    double []stats = new double[2];		    stats[0] = data.meanOrMode(j);		    stats[1] = Math.sqrt(data.variance(j));		    attribute_stats.put(new Integer(j),stats);		}	    }	}    }        protected Instances generateRandomData(int random_size, int num_attributes, Instances data){	Instances random_data = new Instances(data, random_size);	double []att; 	Instance random_instance;		for(int i=0; i<random_size; i++){	    att = new double[num_attributes];	    	    for(int j=0; j<num_attributes; j++){		if(data.attribute(j).isNominal()){		    if(m_DataCreationMethod == UNIFORM || m_DataCreationMethod == MIXED){			att[j] = (double) random.nextInt(data.numDistinctValues(j));		    }else if(m_DataCreationMethod == TRAINING_DIST){			double []stats = (double [])attribute_stats.get(new Integer(j));			att[j] = selectNominalValue(stats);		    }		}else if(data.attribute(j).isNumeric()){		    if(m_DataCreationMethod == UNIFORM){			double []stats = (double [])attribute_stats.get(new Integer(j));			double min = stats[0]; double max = stats[1];			//System.out.println("<Min, Max> = "+min+"\t"+max);			att[j] = (random.nextDouble() * (max - min)) + min;		    }else if(m_DataCreationMethod == TRAINING_DIST || m_DataCreationMethod == MIXED){			double []stats = (double [])attribute_stats.get(new Integer(j));			att[j] = (random.nextGaussian()*stats[1])+stats[0];			//System.out.println(data.attribute(j).name()+"\tMean= "+stats[0]+"\tStd Dev= "+stats[1]);		    }					}else{		    System.err.println("Current version of DEC cannot deal with STRING attributes.");		}		//System.out.println("\t Random att value: "+att[j]);	    }	    	    random_instance = new Instance(1.0, att);	    random_data.add(random_instance);	}		Assert.that(random_data.numInstances()==random_size);	return random_data;    }        /** Given cummaltive probabilities select a nominal value index */    protected double selectNominalValue(double []cumm){	double rnd = random.nextDouble();	int index = 0;	while(index < cumm.length && rnd > cumm[index]){	    index++;	}	return((double) index);    }        /**     * Set threshold for relabeling based on user specified threhsold     * or on error of current committee     *     * @param error Error of current committee     * @return the selected threshold     */    protected double selectThreshold(double error){	double threshold;		if(m_Threshold == -1){	    if(error >= 0.5) 		threshold = 1.0;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -