📄 dec.java
字号:
else threshold = (error/(1.0 - error))*0.5 + 0.5; }else threshold=m_Threshold; return threshold; } /** * Labels the randomly generated data. * * @param random_data the randomly generated instances * @param threhsold confidence threshold for relabeling data * @return labeled data * @exception Exception if instances cannot be labeled successfully */ protected Instances labelData(Instances random_data, double threshold) throws Exception { Instances labeled = new Instances(random_data,1); Instance curr; double []probs; int ctr = 0; for(int i=0; i<random_data.numInstances(); i++){ curr = random_data.instance(i); probs = distributionForInstance(curr); if(probs[Utils.maxIndex(probs)] <= threshold || committee.size()==1){ ctr++; if(labeling_method == LOW_PROB){ curr.setClassValue(lowProbLabel(probs)); }else if(labeling_method == HIGH_PROB){ curr.setClassValue(highProbLabel(probs)); }else if(labeling_method == LEAST_LIKELY){ curr.setClassValue(Utils.minIndex(probs));//Assign the least likely label }else if(labeling_method == MOST_LIKELY){ curr.setClassValue(Utils.maxIndex(probs));//Assign the most likely label }else{ System.err.println("Unknown labeling method!"); } labeled.add(curr); } } return labeled; } /** * Probabilisticly select class label - (high probability). * * @param probs posterior probability of each class * @return highly likely class label probabilistically selected */ protected int highProbLabel(double []probs){ double []cumm = new double[probs.length]; //System.out.println("enter hi prob"); //System.out.println("prob length = "+probs.length); //Compute cumulative probabilities cumm[0] = probs[0]; for(int i=1; i<probs.length; i++){ cumm[i] = probs[i]+cumm[i-1]; } if(Double.isNaN(cumm[probs.length-1])) System.err.println("Calculated cummaltive probability is NaN"); //System.out.println("cumm = "+cumm[probs.length-1]); //Assert.that(Math.abs(cumm[probs.length-1] - 1)<0.00001,"Cummalative probability sums to "+cumm[probs.length-1]+" instead of 1."); //last value should be very close to one float rnd = random.nextFloat(); int index = 0; while(rnd > cumm[index]){ index++; } //System.out.println("exit hi prob"); return index; } /** * Probabilisticly select class label - (low probability). * * @param probs posterior probability of each class * @return low probability class label probabilistically selected * @exception Exception if instances cannot be labeled successfully */ protected int lowProbLabel(double []probs) throws Exception{ double []inv_probs = new double[probs.length]; //System.out.println("enter low prob"); //System.out.println("prob length = "+probs.length); for(int i=0; i<probs.length; i++){ if(probs[i]==0){ inv_probs[i] = Double.MAX_VALUE/probs.length; //Hack to fix probability values of 0 //Divide by probs.length to make sure normalizing works properly }else{ inv_probs[i] = 1.0 / probs[i]; } } Utils.normalize(inv_probs); //System.out.println("call hi prob"); return highProbLabel(inv_probs); } /** * * @param div_data given instances * @param random_size number of instances to delete from the end of given instances */ protected void removeInstances(Instances div_data, int random_size){ int num = div_data.numInstances(); for(int i=num - 1; i>num - 1 - random_size;i--){ div_data.delete(i); } Assert.that(div_data.numInstances() == num - random_size); } /** * * @param div_data given instances * @param random_data set of instances to add to given instances */ protected void addInstances(Instances div_data, Instances random_data){ for(int i=0; i<random_data.numInstances(); i++){ div_data.add(random_data.instance(i)); } } /** * Computes the error in classification on the given data. * * @param data the instances to be classified * @return classification error * @exception Exception if error can not be computed successfully */ protected double computeError(Instances data) throws Exception { double error = 0.0; int num_instances = data.numInstances(); Instance curr; for(int i=0; i<num_instances; i++){ curr = data.instance(i); if(curr.classValue() != ((int) classifyInstance(curr))){//misclassified error++; } } return (error/num_instances); } /** * Compute ensemble weight. * * @param classifier current classifier * @param data instances to compute accuracy on * @return computed vote weight for given classifier * @exception Exception if weight cannot be computed successfully */ protected double computeEnsembleWt(Classifier classifier, Instances data) throws Exception{ double wt = 0.0; //Compute error of classifier on data double error = 0.0; int num_instances = data.numInstances(); Instance curr; double invBeta; for(int i=0; i<num_instances; i++){ curr = data.instance(i); if(curr.classValue() != ((int) classifier.classifyInstance(curr))){//misclassified error++; } } error = error/num_instances; if(error == 0.0)//prevent divide by zero error invBeta = Double.MAX_VALUE; else invBeta = ((1-error)/error); wt = Math.log(invBeta); return wt; } /** * Computes classification accuracy on the given data. * * @param data the instances to be classified * @return classification accuracy * @exception Exception if error can not be computed successfully */ protected double computeAccuracy(Instances data) throws Exception { double acc = 0.0; int num_instances = data.numInstances(); Instance curr; for(int i=0; i<num_instances; i++){ curr = data.instance(i); if(curr.classValue() == ((int) classifyInstance(curr))){//correctly classified acc++; } } //System.out.println("# correctly classified: "+acc); //System.out.println("total #: "+num_instances); return (acc/num_instances); } /** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @exception Exception if distribution can't be computed successfully */ public double[] distributionForInstance(Instance instance) throws Exception { if(m_UseWeights==1){ return distributionForInstanceUsingWeights(instance); }else{ double [] sums = new double [instance.numClasses()], newProbs; Classifier curr; for (int i = 0; i < committee.size(); i++) { curr = (Classifier) committee.get(i); if (instance.classAttribute().isNumeric() == true) { sums[0] += curr.classifyInstance(instance); } else if (curr instanceof DistributionClassifier) { newProbs = ((DistributionClassifier)curr).distributionForInstance(instance); for (int j = 0; j < newProbs.length; j++) sums[j] += newProbs[j]; } else { sums[(int)curr.classifyInstance(instance)]++; } } if (instance.classAttribute().isNumeric() == true) { sums[0] /= (double)(committee.size()); return sums; } else if (Utils.eq(Utils.sum(sums), 0)) { return sums; } else { Utils.normalize(sums); return sums; } } } /** * Calculates the class membership probabilities for the given test instance. * Incorporates vote weights. * @param instance the instance to be classified * @return predicted class probability distribution * @exception Exception if distribution can't be computed successfully */ public double[] distributionForInstanceUsingWeights(Instance instance) throws Exception { int commSize = committee.size(); if (commSize == 0) { throw new Exception("No model built"); } double [] sums = new double [instance.numClasses()]; Classifier curr; if (commSize == 1) { curr = (Classifier) committee.get(0); if (curr instanceof DistributionClassifier) { return ((DistributionClassifier)curr).distributionForInstance(instance); } else { sums[(int)curr.classifyInstance(instance)] ++; } } else {//commSize > 1 for (int i = 0; i < commSize; i++) { curr = (Classifier) committee.get(i); sums[(int)curr.classifyInstance(instance)] += m_EnsembleWts[i]; } } if (Utils.eq(Utils.sum(sums), 0)) { return sums; } else { Utils.normalize(sums); return sums; } } /** Returns class predictions of each ensemble member */ public double []getEnsemblePredictions(Instance instance) throws Exception{ double preds[] = new double [committee.size()]; for(int i=0; i<committee.size(); i++) preds[i] = ((Classifier) committee.get(i)).classifyInstance(instance); return preds; } /** * Returns vote weights of ensemble members. * * @return vote weights of ensemble members */ public double []getEnsembleWts(){ return m_EnsembleWts; } /** Returns size of ensemble */ public double getEnsembleSize(){ return committee.size(); } /** * Returns description of the bagged classifier. * * @return description of the bagged classifier as a string */ public String toString() { if (committee == null) { return "DEC: No model built yet."; } StringBuffer text = new StringBuffer(); text.append("All the base classifiers: \n\n"); for (int i = 0; i < committee.size(); i++) text.append(((Classifier) committee.get(i)).toString() + "\n\n"); return text.toString(); } /** * Main method for testing this class. * * @param argv the options */ public static void main(String [] argv) { try { System.out.println(Evaluation. evaluateModel(new DEC(), argv)); } catch (Exception e) { System.err.println(e.getMessage()); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -