📄 dec.java
字号:
/** * Sets number of random instances to add at each iteration. * * @param new_random_size the number of random instances to add at each iteration. */ public void setRandomSize(double new_random_size) { m_RandomSize = new_random_size; } /** * Gets the desired size of the committee. * * @return the bag size, as a percentage. */ public int getDesiredSize() { return m_DesiredSize; } /** * Sets the desired size of the committee. * * @param newdesired_size the bag size, as a percentage. */ public void setDesiredSize(int new_desired_size) { m_DesiredSize = new_desired_size; } /** * Sets the number of bagging iterations */ public void setNumIterations(int numIterations) { m_NumIterations = numIterations; } /** * Gets the number of bagging iterations * * @return the maximum number of bagging iterations */ public int getNumIterations() { return m_NumIterations; } /** * Set the seed for random number generation. * * @param seed the seed */ public void setSeed(int seed) { m_Seed = seed; if(m_Seed==-1){ random = new Random(); }else{ random = new Random(m_Seed); } } /** * Gets the seed for the random number generations * * @return the seed for the random number generation */ public int getSeed() { return m_Seed; } /** * Get the value of threshold. * @return value of threshold. */ public double getThreshold() { return m_Threshold; } /** * Set the value of threshold. * @param v Value to assign to threshold. */ public void setThreshold(double v) { this.m_Threshold = v; } /** * DEC method. * * @param data the training data to be used for generating the * bagged classifier. * @exception Exception if the classifier could not be built successfully */ public void buildClassifier(Instances data) throws Exception { //DEUBG if(m_UseWeights==1) System.out.println("\t>> Using weights..."); else System.out.println("\t>> Not using weights..."); //initialize ensemble wts to be equal m_EnsembleWts = new double [m_DesiredSize]; for(int j=0; j<m_DesiredSize; j++) m_EnsembleWts[j] = 1.0; initMeasures(); if (m_Classifier == null) { throw new Exception("A base classifier has not been specified!"); } if (data.checkForStringAttributes()) { throw new Exception("Can't handle string attributes!"); } //number of random instances to add at each iteration int random_size; if(m_RandomSize<0){ random_size = (int) (Math.abs(m_RandomSize)*data.numInstances()); if(random_size==0) random_size=1;//atleast add one random example } else random_size = (int) m_RandomSize; System.out.println("random size = "+random_size); //maximum number if random instances to generate (when using //confidence thresholds not all randomly generated examples //will be labeled). int max_random_generated = 3*random_size; int num_attributes = data.numAttributes(); committee = new Vector();//initialize new committee double e_comm; //classification error of current committee int i = 1;//current size of committee int num_trials = 0; Classifier classifier = m_Classifier; Instances div_data = new Instances(data); //Local copy of data Instances random_data = null; computeStats(div_data);//Find mean and std devs for numeric data //Create first committee member - base m_Classifier.buildClassifier(div_data); committee.add(classifier); if(m_UseWeights==1) m_EnsembleWts[i-1] = computeEnsembleWt(classifier, data);//compute wt based on classifier accuracy e_comm = computeError(div_data); System.out.println("Mem "+i+" added. E_comm = "+e_comm); computeAccuracy(data); if(e_comm >= 0.5) {//if the base classifier has an error > 0.5 m_EnsembleWts[0] = 1.0;//reset the ensemble wt i=m_DesiredSize; //skip the following while loop } while(i<m_DesiredSize && num_trials<m_NumIterations){ //System.out.println("\tTrial: "+num_trials); //Keeping generating random data until the desired number are actually labaled Instances total_random_data = new Instances(data, random_size); int labeled = 0, generated = 0; //Set confidence threshold for relabeling double threshold = selectThreshold(e_comm); //System.out.println("\tThreshold: "+threshold); //Continue generating random examples until random_size of //them are labeled or the the number of examples generated //exceeds max_number_generated. The latter is a failsafe //to ensure this loop terminates while(labeled < random_size && generated < max_random_generated){ //System.out.println("\tGenerating ramdom instances..."); //Create random instances (diversity data) //random_data = generateRandomData(random_size, num_attributes, data); random_data = generateRandomData(random_size-labeled, num_attributes, data); generated += (random_size-labeled); //System.out.println("\tLabeling random instances..."); //Label the random data random_data = labelData(random_data, threshold); labeled += random_data.numInstances(); addInstances(total_random_data, random_data); } if(labeled!=random_size) System.out.println(labeled+" random examples labeled out of the desired "+random_size); Assert.that(total_random_data.numInstances()==labeled,"Error in random example generation+labeling loop: "+total_random_data.numInstances()+" != "+labeled); //Remove all the diversity data from the previous step (if any) if(div_data.numInstances() > data.numInstances()) { //System.out.println("\tRemoving previous random data..."); //removeInstances(div_data, random_size); removeInstances(div_data, div_data.numInstances()-data.numInstances()); } Assert.that(div_data.numInstances() == data.numInstances()); //System.out.println("\tAdding new random instances..."); //Add new random data addInstances(div_data, total_random_data); //System.out.println("\tBuild new classifier..."); //initialize new classifier Classifier tmp[] = Classifier.makeCopies(m_Classifier,1); classifier = tmp[0]; classifier.buildClassifier(div_data); committee.add(classifier); if(m_UseWeights==1) m_EnsembleWts[i] = computeEnsembleWt(classifier, data);//compute wt based on classifier accuracy //System.out.println("\tCompute current committee error..."); double curr_error = computeError(data); if(m_EnsembleWts[i]>0 && curr_error <= e_comm){ //adding the new member did not increase the error and the new member has an error < 0.5 i++; e_comm = curr_error; System.out.println("Iteration: "+num_trials+"\tMem "+i+" added. E_comm = "+e_comm); }else{ committee.removeElementAt(committee.size()-1);//pop the last member } num_trials++; } System.out.println("Final ensemble size: "+committee.size()); //Set measures computeEnsembleMeasures(data); //DEBUG Assert.that(m_TrainError == (100.0 * e_comm), "Bug in train error computation!"+m_TrainError+"\t"+(100.0 * e_comm));} /** * Find and store mean and std devs for numeric attributes. * * @param data training instances */ protected void computeStats(Instances data){ //Use to maintain the mean and std devs of numeric attributes int num_attributes = data.numAttributes(); attribute_stats = new HashMap(num_attributes); for(int j=0; j<num_attributes; j++){ if(data.attribute(j).isNominal()){ if(m_DataCreationMethod == TRAINING_DIST){ int []nom_counts = (data.attributeStats(j)).nominalCounts; double []counts = new double[nom_counts.length]; //Laplace smoothing for(int i=0; i<counts.length; i++) counts[i] = nom_counts[i] + 1; Utils.normalize(counts); double []stats = new double[counts.length - 1]; stats[0] = counts[0]; //Calculate cummalitive probabilities for(int i=1; i<stats.length; i++) stats[i] = stats[i-1] + counts[i]; attribute_stats.put(new Integer(j),stats); } }else if(data.attribute(j).isNumeric()){ if(m_DataCreationMethod == UNIFORM){ //get range of numeric attribute from the training data Stats s = (data.attributeStats(j)).numericStats; double []stats = new double[2]; stats[0] = s.min; stats[1] = s.max; attribute_stats.put(new Integer(j),stats); }else if(m_DataCreationMethod == TRAINING_DIST){ //get mean and standard deviation from the training data double []stats = new double[2]; stats[0] = data.meanOrMode(j); stats[1] = Math.sqrt(data.variance(j)); attribute_stats.put(new Integer(j),stats); } } } } protected Instances generateRandomData(int random_size, int num_attributes, Instances data){ Instances random_data = new Instances(data, random_size); double []att; Instance random_instance; for(int i=0; i<random_size; i++){ att = new double[num_attributes]; for(int j=0; j<num_attributes; j++){ if(data.attribute(j).isNominal()){ if(m_DataCreationMethod == UNIFORM || m_DataCreationMethod == MIXED){ att[j] = (double) random.nextInt(data.numDistinctValues(j)); }else if(m_DataCreationMethod == TRAINING_DIST){ double []stats = (double [])attribute_stats.get(new Integer(j)); att[j] = selectNominalValue(stats); } }else if(data.attribute(j).isNumeric()){ if(m_DataCreationMethod == UNIFORM){ double []stats = (double [])attribute_stats.get(new Integer(j)); double min = stats[0]; double max = stats[1]; //System.out.println("<Min, Max> = "+min+"\t"+max); att[j] = (random.nextDouble() * (max - min)) + min; }else if(m_DataCreationMethod == TRAINING_DIST || m_DataCreationMethod == MIXED){ double []stats = (double [])attribute_stats.get(new Integer(j)); att[j] = (random.nextGaussian()*stats[1])+stats[0]; //System.out.println(data.attribute(j).name()+"\tMean= "+stats[0]+"\tStd Dev= "+stats[1]); } }else{ System.err.println("Current version of DEC cannot deal with STRING attributes."); } //System.out.println("\t Random att value: "+att[j]); } random_instance = new Instance(1.0, att); random_data.add(random_instance); } Assert.that(random_data.numInstances()==random_size); return random_data; } /** Given cummaltive probabilities select a nominal value index */ protected double selectNominalValue(double []cumm){ double rnd = random.nextDouble(); int index = 0; while(index < cumm.length && rnd > cumm[index]){ index++; } return((double) index); } /** * Set threshold for relabeling based on user specified threhsold * or on error of current committee * * @param error Error of current committee * @return the selected threshold */ protected double selectThreshold(double error){ double threshold; if(m_Threshold == -1){ if(error >= 0.5) threshold = 1.0;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -