⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 adacost.java

📁 AdaCost算法源程序 java编写 可添加到weka系统中
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAllClasses();
        capabilities.disableAllClassDependencies();
        if(super.getCapabilities().handles(weka.core.Capabilities.Capability.NOMINAL_CLASS))
            capabilities.enable(weka.core.Capabilities.Capability.NOMINAL_CLASS);
        if(super.getCapabilities().handles(weka.core.Capabilities.Capability.BINARY_CLASS))
            capabilities.enable(weka.core.Capabilities.Capability.BINARY_CLASS);
        return capabilities;
    }
    
     public void buildClassifier(Instances instances)
        throws Exception
    {
        getCapabilities().testWithFail(instances);
        instances = new Instances(instances);
        instances.deleteWithMissingClass();
        if(m_Classifier == null)
            throw new Exception("No base classifier has been set!");
        if(m_MatrixSource == 1)
        {
            String s = (new StringBuilder()).append(instances.relationName()).append(CostMatrix.FILE_EXTENSION).toString();
            File file = new File(getOnDemandDirectory(), s);
            if(!file.exists())
                throw new Exception((new StringBuilder()).append("On-demand cost file doesn't exist: ").append(file).toString());
            setCostMatrix(new CostMatrix(new BufferedReader(new FileReader(file))));
        } else
        if(m_CostMatrix == null)
        {
            m_CostMatrix = new CostMatrix(instances.numClasses());
            m_CostMatrix.readOldFormat(new BufferedReader(new FileReader(m_CostFile)));
        }
        if(!m_MinimizeExpectedCost)
        {
            Random random = null;
            if(!(m_Classifier instanceof WeightedInstancesHandler))
                random = new Random(m_Seed);
            instances = m_CostMatrix.applyCostMatrix(instances, random);
        }
        
        startbuildClassifier(instances);
           
    }
    
    public void startbuildClassifier(Instances instances)
        throws Exception
    {
    	  super.buildClassifier(instances);
        getCapabilities().testWithFail(instances);
        instances = new Instances(instances);
        instances.deleteWithMissingClass();
        m_NumClasses = instances.numClasses();
        if(!m_UseResampling && (m_Classifier instanceof WeightedInstancesHandler))
            buildClassifierWithWeights(instances);
        else
            buildClassifierUsingResampling(instances);

    }

    protected void buildClassifierUsingResampling(Instances instances)
        throws Exception
    {
        int i = instances.numInstances();
        Random random = new Random(m_Seed);
        boolean flag = false;
        m_Betas = new double[m_Classifiers.length];
        m_NumIterationsPerformed = 0;
        Instances instances3 = new Instances(instances, 0, i);
        double d2 = instances3.sumOfWeights();
        for(int k = 0; k < instances3.numInstances(); k++)
            instances3.instance(k).setWeight(instances3.instance(k).weight() / d2);

        for(m_NumIterationsPerformed = 0; m_NumIterationsPerformed < m_Classifiers.length; m_NumIterationsPerformed++)
        {
            if(m_Debug)
                System.err.println((new StringBuilder()).append("Training classifier ").append(m_NumIterationsPerformed + 1).toString());
            Instances instances1;
            if(m_WeightThreshold < 100)
                instances1 = selectWeightQuantile(instances3, (double)m_WeightThreshold / 100D);
            else
                instances1 = new Instances(instances3);
            int j = 0;
            double ad[] = new double[instances1.numInstances()];
            for(int l = 0; l < ad.length; l++)
                ad[l] = instances1.instance(l).weight();

            double d;
            do
            {
                Instances instances2 = instances1.resampleWithWeights(random, ad);
                m_Classifiers[m_NumIterationsPerformed].buildClassifier(instances2);
                Evaluation evaluation = new Evaluation(instances);
                evaluation.evaluateModel(m_Classifiers[m_NumIterationsPerformed], instances3);
                d = evaluation.errorRate();
                j++;
            } while(Utils.eq(d, 0.0D) && j < MAX_NUM_RESAMPLING_ITERATIONS);
            if(Utils.grOrEq(d, 0.5D) || Utils.eq(d, 0.0D))
            {
                if(m_NumIterationsPerformed == 0)
                    m_NumIterationsPerformed = 1;
                break;
            }
            m_Betas[m_NumIterationsPerformed] = Math.log((1.0D - d) / d);
            double d1 = (1.0D - d) / d;
            if(m_Debug)
                System.err.println((new StringBuilder()).append("\terror rate = ").append(d).append("  beta = ").append(m_Betas[m_NumIterationsPerformed]).toString());
            setWeights(instances3, d1);
        }

    }
    
    
     protected void setWeights(Instances instances, double d)
        throws Exception
    {
        double d1 = instances.sumOfWeights();
        Enumeration enumeration = instances.enumerateInstances();
        do
        {
            if(!enumeration.hasMoreElements())
                break;
            Instance instance = (Instance)enumeration.nextElement();
            if(!Utils.eq(m_Classifiers[m_NumIterationsPerformed].classifyInstance(instance), instance.classValue()))
                instance.setWeight(instance.weight() * d);
        } while(true);
        double d2 = instances.sumOfWeights();
        Instance instance1;
        for(Enumeration enumeration1 = instances.enumerateInstances(); enumeration1.hasMoreElements(); instance1.setWeight((instance1.weight() * d1) / d2))
            instance1 = (Instance)enumeration1.nextElement();

    }

    protected void buildClassifierWithWeights(Instances instances)
        throws Exception
    {
        int i = instances.numInstances();
        Random random = new Random(m_Seed);
        m_Betas = new double[m_Classifiers.length];
        m_NumIterationsPerformed = 0;
        Instances instances2 = new Instances(instances, 0, i);
        for(m_NumIterationsPerformed = 0; m_NumIterationsPerformed < m_Classifiers.length; m_NumIterationsPerformed++)
        {
            if(m_Debug)
                System.err.println((new StringBuilder()).append("Training classifier ").append(m_NumIterationsPerformed + 1).toString());
            Instances instances1;
            if(m_WeightThreshold < 100)
                instances1 = selectWeightQuantile(instances2, (double)m_WeightThreshold / 100D);
            else
                instances1 = new Instances(instances2, 0, i);
            if(m_Classifiers[m_NumIterationsPerformed] instanceof Randomizable)
                ((Randomizable)m_Classifiers[m_NumIterationsPerformed]).setSeed(random.nextInt());
            m_Classifiers[m_NumIterationsPerformed].buildClassifier(instances1);
            Evaluation evaluation = new Evaluation(instances);
            evaluation.evaluateModel(m_Classifiers[m_NumIterationsPerformed], instances2);
            double d = evaluation.errorRate();
            if(Utils.grOrEq(d, 0.5D) || Utils.eq(d, 0.0D))
            {
                if(m_NumIterationsPerformed == 0)
                    m_NumIterationsPerformed = 1;
                break;
            }
            m_Betas[m_NumIterationsPerformed] = Math.log((1.0D - d) / d);
            double d1 = (1.0D - d) / d;
            if(m_Debug)
                System.err.println((new StringBuilder()).append("\terror rate = ").append(d).append("  beta = ").append(m_Betas[m_NumIterationsPerformed]).toString());
            setWeights(instances2, d1);
        }

    }


public double[] distributionForInstance(Instance instance)
        throws Exception
    {
        if(!m_MinimizeExpectedCost)
            return AgaindistributionForInstance(instance);
        double ad[] = m_Classifier.distributionForInstance(instance);
        double ad1[] = m_CostMatrix.expectedCosts(ad);
        int i = Utils.minIndex(ad1);
        for(int j = 0; j < ad.length; j++)
            if(j == i)
                ad[j] = 1.0D;
            else
                ad[j] = 0.0D;

        return ad;
      }
        
  public double[] AgaindistributionForInstance(Instance instance)
        throws Exception
       {    
         if(m_NumIterationsPerformed == 0)
            throw new Exception("No model built");
         double ad[] = new double[instance.numClasses()];
         if(m_NumIterationsPerformed == 1)
            return m_Classifiers[0].distributionForInstance(instance);
         for(int i = 0; i < m_NumIterationsPerformed; i++)
            ad[(int)m_Classifiers[i].classifyInstance(instance)] += m_Betas[i];

        return Utils.logs2probs(ad);
      }
      
      public int graphType()
    {
        if(m_Classifier instanceof Drawable)
            return ((Drawable)m_Classifier).graphType();
        else
            return 0;
    }

    public String graph()
        throws Exception
    {
        if(m_Classifier instanceof Drawable)
            return ((Drawable)m_Classifier).graph();
        else
            throw new Exception((new StringBuilder()).append("Classifier: ").append(getClassifierSpec()).append(" cannot be graphed").toString());
    }

 public String toSource(String s)
        throws Exception
    {
        if(m_NumIterationsPerformed == 0)
            throw new Exception("No model built yet");
        if(!(m_Classifiers[0] instanceof Sourcable))
            throw new Exception((new StringBuilder()).append("Base learner ").append(m_Classifier.getClass().getName()).append(" is not Sourcable").toString());
        StringBuffer stringbuffer = new StringBuffer("class ");
        stringbuffer.append(s).append(" {\n\n");
        stringbuffer.append("  public static double classify(Object [] i) {\n");
        if(m_NumIterationsPerformed == 1)
        {
            stringbuffer.append((new StringBuilder()).append("    return ").append(s).append("_0.classify(i);\n").toString());
        } else
        {
            stringbuffer.append((new StringBuilder()).append("    double [] sums = new double [").append(m_NumClasses).append("];\n").toString());
            for(int i = 0; i < m_NumIterationsPerformed; i++)
                stringbuffer.append((new StringBuilder()).append("    sums[(int) ").append(s).append('_').append(i).append(".classify(i)] += ").append(m_Betas[i]).append(";\n").toString());

            stringbuffer.append((new StringBuilder()).append("    double maxV = sums[0];\n    int maxI = 0;\n    for (int j = 1; j < ").append(m_NumClasses).append("; j++) {\n").append("      if (sums[j] > maxV) { maxV = sums[j]; maxI = j; }\n").append("    }\n    return (double) maxI;\n").toString());
        }
        stringbuffer.append("  }\n}\n");
        for(int j = 0; j < m_Classifiers.length; j++)
            stringbuffer.append(((Sourcable)m_Classifiers[j]).toSource((new StringBuilder()).append(s).append('_').append(j).toString()));

        return stringbuffer.toString();
    }
    
    
    public String toString()
    {
        StringBuffer stringbuffer = new StringBuffer();
        if(m_NumIterationsPerformed == 0)
            stringbuffer.append("AdaCost: No model built yet.\n");
        else
        if(m_NumIterationsPerformed == 1)
        {
            stringbuffer.append("AdaCost: No boosting possible, one classifier used!\n");
            stringbuffer.append((new StringBuilder()).append(m_Classifiers[0].toString()).append("\n").toString());
        } else
        {
            stringbuffer.append("AdaCost: Base classifiers and their weights: \n\n");
            for(int i = 0; i < m_NumIterationsPerformed; i++)
            {
                stringbuffer.append((new StringBuilder()).append(m_Classifiers[i].toString()).append("\n\n").toString());
                stringbuffer.append((new StringBuilder()).append("Weight: ").append(Utils.roundDouble(m_Betas[i], 2)).append("\n\n").toString());
            }

            stringbuffer.append((new StringBuilder()).append("Number of performed Iterations: ").append(m_NumIterationsPerformed).append("\n").toString());
        }
        
        
        stringbuffer.append(new StringBuilder()).append("AdaCost using ");
        if(m_MinimizeExpectedCost)
            stringbuffer.append(new StringBuilder()).append("minimized expected misclasification cost\n").toString();
        else
            stringbuffer.append(new StringBuilder()).append("reweighted training instances\n").toString();
        stringbuffer.append(new StringBuilder()).append("\n").append(getClassifierSpec()).append("\n\nClassifier Model\n").append(m_Classifier.toString()).append("\n\nCost Matrix\n").append(m_CostMatrix.toString()).toString();

        return stringbuffer.toString();
    }
    
    

    private static int MAX_NUM_RESAMPLING_ITERATIONS = 10;
    protected double m_Betas[];
    protected int m_NumIterationsPerformed;
    protected int m_WeightThreshold;
    protected boolean m_UseResampling;
    protected int m_NumClasses;
    

    public static final int MATRIX_ON_DEMAND = 1;
    public static final int MATRIX_SUPPLIED = 2;
    public static final Tag TAGS_MATRIX_SOURCE[] = {
        new Tag(1, "Load cost matrix on demand"), new Tag(2, "Use explicit cost matrix")
    };
    protected int m_MatrixSource;
    protected File m_OnDemandDirectory;
    protected String m_CostFile;
    protected CostMatrix m_CostMatrix;
    protected boolean m_MinimizeExpectedCost;
    
    
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -