⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 semisupdecorate.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
     * Ignore low-confidence examples.     */    protected Instances labelIgnoreLow(Instances instances, Instances used) throws Exception {	Instance curr;	double []probs;	int label; 	double highestProb;		for(int i=0; i<instances.numInstances(); i++){	    curr = instances.instance(i);	    //compute the class membership probs predicted by the current ensemble 	    probs = distributionForInstance(curr);	    label = (int) classifyInstance(curr);	    highestProb = probs[label];	    	    if(highestProb >= m_Threshold){		curr.setClassValue(label);		used.add(curr);	    }		}	return used;    }        /**      * Label low-confidence examples with inverse of ensemble's prediction.     * Ignore high-confidence examples.     */    protected Instances labelIgnoreHigh(Instances instances, Instances used) throws Exception {	Instance curr;	double []probs;	int label; 	double highestProb;		for(int i=0; i<instances.numInstances(); i++){	    curr = instances.instance(i);	    //compute the class membership probs predicted by the current ensemble 	    probs = distributionForInstance(curr);	    label = (int) classifyInstance(curr);	    highestProb = probs[label];	    	    if(highestProb < m_Threshold){		curr.setClassValue(inverseLabel(probs));		used.add(curr);	    }		}	return used;    }           /**      * Label low-confidence examples with inverse of ensemble's prediction.     * Use ensemble's prediction for high-confidence examples.     */    protected Instances labelFlipLow(Instances instances) throws Exception {	Instance curr;	double []probs;	int label; 	double highestProb;		int a=0, b=0;	for(int i=0; i<instances.numInstances(); i++){	    curr = instances.instance(i);	    //compute the class membership probs predicted by the current ensemble 	    probs = distributionForInstance(curr);	    label = (int) classifyInstance(curr);	    highestProb = probs[label];	    	    if(highestProb >= m_Threshold){		curr.setClassValue(label);		a++;	    }else{		curr.setClassValue(inverseLabel(probs));		b++;	    }	}	System.out.println("As is: "+a+"\tFlipped: "+b);	return m_Unlabeled;    }            //Helper method to print arrays    protected void printArray(double []array){	for(int i=0; i<array.length; i++)	    System.out.print(array[i]+" ");	System.out.println();    }        /** Returns class predictions of each ensemble member */    public double []getEnsemblePredictions(Instance instance) throws Exception{	double preds[] = new double [m_Committee.size()];	for(int i=0; i<m_Committee.size(); i++)	    preds[i] = ((Classifier) m_Committee.get(i)).classifyInstance(instance);		return preds;    }        /**      * Returns vote weights of ensemble members.     *     * @return vote weights of ensemble members     */    public double []getEnsembleWts(){	return m_EnsembleWts;    }        /** Returns size of ensemble */    public double getEnsembleSize(){	return m_Committee.size();    }            /**      * Compute and store statistics required for generating artificial data.     *     * @param data training instances     * @exception Exception if statistics could not be calculated successfully     */    protected void computeStats(Instances data) throws Exception{	int numAttributes = data.numAttributes();	m_AttributeStats = new Vector(numAttributes);//use to map attributes to their stats		for(int j=0; j<numAttributes; j++){	    if(data.attribute(j).isNominal()){		//Compute the probability of occurence of each distinct value 		int []nomCounts = (data.attributeStats(j)).nominalCounts;		double []counts = new double[nomCounts.length];		if(counts.length < 2) throw new Exception("Nominal attribute has less than two distinct values!"); 		//Perform Laplace smoothing		for(int i=0; i<counts.length; i++)		    counts[i] = nomCounts[i] + 1;		Utils.normalize(counts);		double []stats = new double[counts.length - 1];		stats[0] = counts[0];		//Calculate cumulative probabilities		for(int i=1; i<stats.length; i++)		    stats[i] = stats[i-1] + counts[i];		m_AttributeStats.add(j,stats);	    }else if(data.attribute(j).isNumeric()){		//Get mean and standard deviation from the training data		double []stats = new double[2];		stats[0] = data.meanOrMode(j);		stats[1] = Math.sqrt(data.variance(j));		m_AttributeStats.add(j,stats);	    }else System.err.println("SemiSupDecorate can only handle numeric and nominal values.");	}    }    /**     * Generate artificial training examples.     * @param artSize size of examples set to create     * @param data training data     * @return the set of unlabeled artificial examples     */    protected Instances generateArtificialData(int artSize, Instances data){	int numAttributes = data.numAttributes();	Instances artData = new Instances(data, artSize);	double []att; 	Instance artInstance;		for(int i=0; i<artSize; i++){	    att = new double[numAttributes];	    for(int j=0; j<numAttributes; j++){		if(data.attribute(j).isNominal()){		    //Select nominal value based on the frequency of occurence in the training data  		    double []stats = (double [])m_AttributeStats.get(j);		    att[j] =  (double) selectIndexProbabilistically(stats);		}		else if(data.attribute(j).isNumeric()){		    //Generate numeric value from the Guassian distribution 		    //defined by the mean and std dev of the attribute		    double []stats = (double [])m_AttributeStats.get(j);		    att[j] = (m_Random.nextGaussian()*stats[1])+stats[0];		}else System.err.println("SemiSupDecorate can only handle numeric and nominal values.");	    }	    artInstance = new Instance(1.0, att);	    artData.add(artInstance);	}	return artData;    }            /**      * Labels the artificially generated data.     *     * @param artData the artificially generated instances     * @exception Exception if instances cannot be labeled successfully      */    protected void labelData(Instances artData) throws Exception {	Instance curr;	double []probs;		for(int i=0; i<artData.numInstances(); i++){	    curr = artData.instance(i);	    //compute the class membership probs predicted by the current ensemble 	    probs = distributionForInstance(curr);	    //select class label inversely proportional to the ensemble predictions	    curr.setClassValue(inverseLabel(probs));	}	    }        /**      * Select class label such that the probability of selection is     * inversely proportional to the ensemble's predictions.     *     * @param probs class membership probabilities of instance     * @return index of class label selected     * @exception Exception if instances cannot be labeled successfully      */    protected int inverseLabel(double []probs) throws Exception{	double []invProbs = new double[probs.length];	//Produce probability distribution inversely proportional to the given	for(int i=0; i<probs.length; i++){	    if(probs[i]==0){		invProbs[i] = Double.MAX_VALUE/probs.length; 		//Account for probability values of 0 - to avoid divide-by-zero errors		//Divide by probs.length to make sure normalizing works properly	    }else{		invProbs[i] = 1.0 / probs[i];	    }	}	Utils.normalize(invProbs);	double []cdf = new double[invProbs.length];	//Compute cumulative probabilities 	cdf[0] = invProbs[0];	for(int i=1; i<invProbs.length; i++){	    cdf[i] = invProbs[i]+cdf[i-1];	}		if(Double.isNaN(cdf[invProbs.length-1]))	    System.err.println("Cumulative class membership probability is NaN!"); 	return selectIndexProbabilistically(cdf);    }        /**      * Given cumulative probabilities select a nominal attribute value index      *     * @param cdf array of cumulative probabilities     * @return index of attribute selected based on the probability distribution      */    protected int selectIndexProbabilistically(double []cdf){	double rnd = m_Random.nextDouble();	int index = 0;	while(index < cdf.length && rnd > cdf[index]){	    index++;	}	return index;    }         /**     * Removes a specified number of instances from the given set of instances.     *     * @param data given instances     * @param numRemove number of instances to delete from the given instances     */    protected void removeInstances(Instances data, int numRemove){	int num = data.numInstances();	for(int i=num - 1; i>num - 1 - numRemove;i--){	    data.delete(i);	}    }        /**     * Add new instances to the given set of instances.     *     * @param data given instances     * @param newData set of instances to add to given instances     */    protected void addInstances(Instances data, Instances newData){	for(int i=0; i<newData.numInstances(); i++)	    data.add(newData.instance(i));    }        /**      * Computes the error in classification on the given data.     *     * @param data the instances to be classified     * @return classification error     * @exception Exception if error can not be computed successfully     */    protected double computeError(Instances data) throws Exception {	double error = 0.0;	int numInstances = data.numInstances();	Instance curr;		for(int i=0; i<numInstances; i++){	    curr = data.instance(i);	    //Check if the instance has been misclassified	    if(curr.classValue() != ((int) classifyInstance(curr))) error++;	}	return (error/numInstances);    }      /**   * Calculates the class membership probabilities for the given test instance.   *   * @param instance the instance to be classified   * @return predicted class probability distribution   * @exception Exception if distribution can't be computed successfully   */  public double[] distributionForInstance(Instance instance) throws Exception {      if (instance.classAttribute().isNumeric()) {	  throw new UnsupportedClassTypeException("SemiSupDecorate can't handle a numeric class!");      }      double [] sums = new double [instance.numClasses()], newProbs;       Classifier curr;            for (int i = 0; i < m_Committee.size(); i++) {	  curr = (Classifier) m_Committee.get(i);	  if (curr instanceof DistributionClassifier) {	      newProbs = ((DistributionClassifier)curr).distributionForInstance(instance);	      for (int j = 0; j < newProbs.length; j++)		  sums[j] += newProbs[j];	  } else {	      sums[(int)curr.classifyInstance(instance)]++;	  }      }      if (Utils.eq(Utils.sum(sums), 0)) {	  return sums;      } else {	  Utils.normalize(sums);	  return sums;      }  }        /**     * Returns description of the SemiSupDecorate classifier.     *     * @return description of the SemiSupDecorate classifier as a string     */    public String toString() {		if (m_Committee == null) {	    return "SemiSupDecorate: No model built yet.";	}	StringBuffer text = new StringBuffer();	text.append("SemiSupDecorate base classifiers: \n\n");	for (int i = 0; i < m_Committee.size(); i++)	    text.append(((Classifier) m_Committee.get(i)).toString() + "\n\n");	text.append("Number of classifier in the ensemble: "+m_Committee.size()+"\n");	return text.toString();    }        /**     * Main method for testing this class.     *     * @param argv the options     */    public static void main(String [] argv) {		try {	    System.out.println(Evaluation.evaluateModel(new SemiSupDecorate(), argv));	} catch (Exception e) {	    System.err.println(e.getMessage());	}    }}    

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -