⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 prologregressiontreetrainer.java

📁 该原代码是实现机器学习中条件随机场模型的Java代码
💻 JAVA
字号:
/**
 * 
 */
package lcrf.regression;

import java.io.Serializable;
import java.io.StringReader;
import java.util.List;
import java.util.Vector;

import lcrf.logic.Atom;
import lcrf.logic.Variable;
import lcrf.stuff.FileWriter;

import org.apache.log4j.Logger;

import xprolog.Engine;
import xprolog.ParseException;

/**
 * @author bgutmann
 */
public class PrologRegressionTreeTrainer implements RegressionModelTrainer<List<Atom>>, Serializable {
    private static final long serialVersionUID = 3257008743811134257L;

    protected static final String fixedBGPart = "listsucceds([],Test,E):-succeds(Test,E)."
            + "listsucceds([H|Tail],Test,E):-succeds(H,E),listsucceds(Tail,Test,E).";
            

    // this is the default value for a tree, that es returned when one
    // traines a tree from zero examples
    private static final double defaultValue = 0.0;

    // if the depth of the current node is below this constant
    // we try to split the examples further
    private int maxTreeDepth;

    // if we have at least this examples we try to split them
    // 2 is the lower bound for this constant
    private int minLeafSize;

    // temp variables for training
    private boolean[] bestmapping;

    private boolean[] mapping;

    private Engine prologEngine;

    private String backgroundKnowledge;

    /**
     * 
     * @param maxTreeDepth
     * @param minLeafSize
     * @param windowSize
     * @param backgroundKnowledge
     * @param outputPositions  e.g. <pre>[1,2,3,4,5,6]</pre>
     */
    public PrologRegressionTreeTrainer(int maxTreeDepth, int minLeafSize, String backgroundKnowledge,String outputPositions) {
        assert minLeafSize > 1;
        assert maxTreeDepth >= 0;
        assert backgroundKnowledge != null;

        this.maxTreeDepth = maxTreeDepth;
        this.minLeafSize = minLeafSize;

        this.backgroundKnowledge = PrologRegressionTreeTrainer.fixedBGPart.substring(0).concat(
                backgroundKnowledge.substring(0)).concat("outputPositions(").concat(outputPositions).concat(").");

        try {
            prologEngine = new Engine(new StringReader(this.backgroundKnowledge));
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private Iterable<String> getTestIterator(List<RegressionExample<List<Atom>>> examples,
            List<Variable> boundedVars, int actualDepth) {

        return null;
    }

    public RegressionModel<List<Atom>> trainFromExamples(List<RegressionExample<List<Atom>>> examples) {
        if (examples == null || examples.size() == 0) {
            return new PrologRegressionTree(defaultValue, null, null, null, null);
        }


        double examplesSum = 0.0;
        for (int i = 0; i < examples.size(); i++) {
            examplesSum += examples.get(i).value;

            //tansform the content into prolog notation
            String term = "[";
            boolean notfirst = false;
            for (Atom atom : examples.get(i).content) {
                if (notfirst) {
                    term = term.concat(",");
                }
                notfirst = true;
                term = term.concat(atom.toString());
            }
            term = term.concat("]))");
            
            examples.get(i).auxObject = term;

        }

        double average = examplesSum / examples.size();

        // prapare global variables
        this.bestmapping = new boolean[examples.size()];
        this.mapping = new boolean[examples.size()];

        PrologRegressionTree t = trainIntern(examples, new Vector<Variable>(30), 0, average, "");

        // clear global variables
        this.bestmapping = null;
        this.mapping = null;

        FileWriter.writeToFile("models/prt", ".dot", t.toDotString());

        return t;
    }

    private PrologRegressionTree trainIntern(List<RegressionExample<List<Atom>>> examples,
            Vector<Variable> boundedVars, int depth, double average, String testsSoFar) {

        if (depth >= maxTreeDepth || examples.size() < minLeafSize) {
            return new PrologRegressionTree(average, null, null, null, null);
        }

        double bestsumsquarederror = Double.MAX_VALUE;
        String bestTest = null;

        int bestTrueExamplesCount = 0;
        int bestFalseExamplesCount = 0;
        double bestTrueExamplesSum = 0;
        double bestFalseExamplesSum = 0;

        
        Logger.getLogger(getClass()).debug("Start finding possible Tests");
        Vector<String> possibleTests = new Vector<String>();
        try {
            if (prologEngine.setQuery("test(X)")) {
                do {
                    possibleTests.add(prologEngine.answer().term.toString());
                } while (prologEngine.more());
            }
        } catch (ParseException e) {
            // this should never happen :)
            throw new RuntimeException(e);
        }
        Logger.getLogger(getClass()).debug(Integer.toString(possibleTests.size())+" tests found.");
        
        for (String test : possibleTests) {
            int trueExamplesCount = 0;
            int falseExamplesCount = 0;
            double trueExamplesSum = 0;
            double falseExamplesSum = 0;
            
            String queryPrefix = "";
            
            if (testsSoFar.equals("")) {
                queryPrefix = "succeds("+test+",";
            } else{
               queryPrefix = "listsucceds([" + testsSoFar + "]," + test + ",";
            }

            // split examples into true/false-parts
            Logger.getLogger(getClass()).debug("Test : "+test.toString());
            for (int i = 0; i < examples.size(); i++) {
                String query = queryPrefix + examples.get(i).auxObject.toString() + ")";

                try {
                    if (prologEngine.setQuery(query)) {
                        mapping[i] = true;
                        trueExamplesCount++;
                        trueExamplesSum += examples.get(i).value;
                    } else {
                        mapping[i] = false;
                        falseExamplesCount++;
                        falseExamplesSum += examples.get(i).value;
                    }
                } catch (ParseException e) {
                    // this should never happen :)
                    throw new RuntimeException(e);
                }            
            }

            if (trueExamplesCount == 0 || falseExamplesCount == 0) {
                // using this splitatom doesnt make sense
                continue;
            }

            double trueExamplesAverage = trueExamplesSum / trueExamplesCount;
            double falseExamplesAverage = falseExamplesSum / falseExamplesCount;
            double sumsquarederror = 0.0;

            for (int i = 0; i < examples.size(); i++) {
                double val = examples.get(i).value
                        - (mapping[i] ? trueExamplesAverage : falseExamplesAverage);
                sumsquarederror += val * val;
            }

            if (sumsquarederror < bestsumsquarederror) {
                boolean[] tmp = mapping;
                mapping = bestmapping;
                bestmapping = tmp;

                bestsumsquarederror = sumsquarederror;
                bestTest = test;

                bestTrueExamplesCount = trueExamplesCount;
                bestFalseExamplesCount = falseExamplesCount;
                bestTrueExamplesSum = trueExamplesSum;
                bestFalseExamplesSum = falseExamplesSum;
            }
        }

        if (bestTrueExamplesCount == 0 || bestFalseExamplesCount == 0) {
            return new PrologRegressionTree(average, null, null, null, null);
        }

        // split training data for both subtrees
        List<RegressionExample<List<Atom>>> trueExamples = new Vector<RegressionExample<List<Atom>>>(
                bestTrueExamplesCount);
        List<RegressionExample<List<Atom>>> falseExamples = new Vector<RegressionExample<List<Atom>>>(
                bestTrueExamplesCount);

        for (int i = 0; i < examples.size(); i++) {
            RegressionExample<List<Atom>> ex = examples.get(i);
            if (bestmapping[i] == true) {
                trueExamples.add(ex);
            } else {
                falseExamples.add(ex);
            }
        }

        // free memory
        examples.clear();
        examples = null;

        String testsSoFarFalseTree = (testsSoFar.equals("")) ? "not(" + bestTest + ")" : testsSoFar
                + ", not(" + bestTest + ")";
        String testsSoFarTrueTree = (testsSoFar.equals("")) ? bestTest : testsSoFar + ", " + bestTest;

        int sizeold = falseExamples.size();

        PrologRegressionTree falseSubTree = trainIntern(falseExamples, boundedVars, depth + 1,
                bestFalseExamplesSum / bestFalseExamplesCount, testsSoFarFalseTree);
        falseSubTree.dirty_ExampleSize = sizeold;

        // free memory
        falseExamples.clear();
        falseExamples = null;

        sizeold = trueExamples.size();
        PrologRegressionTree trueSubTree = trainIntern(trueExamples, boundedVars, depth + 1,
                bestTrueExamplesSum / bestTrueExamplesCount, testsSoFarTrueTree);
        trueSubTree.dirty_ExampleSize = sizeold;

        // free memory
        trueExamples.clear();
        trueExamples = null;

        return new PrologRegressionTree(average, bestTest, trueSubTree, falseSubTree, prologEngine);
    }

    public String toString() {
        return "LogicalRegressionTreeTrainer(" + maxTreeDepth + "," + minLeafSize + ")";
    }

}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -