⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 trainer.java

📁 这是一个CRF(条件随机域)算法的实现,希望能对从事算法的有些帮助.
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
                }                            }            if (params.debugLvl > 2) {                for (int f = 0; f < lambda.length; f++)                    System.out.print(lambda[f] + " ");                System.out.println(" :x");                for (int f = 0; f < lambda.length; f++)                    System.out.println(featureGenerator.featureName(f) + " " + grad[f] + " ");                System.out.println(" :g");            }                        if (params.debugLvl > 0)                Util.printDbg("Iter " + icall + " log likelihood "+logli + " norm(grad logli) " + norm(grad) + " norm(x) "+ norm(lambda));            if (icall == 0) {                System.out.println("Number of training records" + numRecord);            }        } catch (Exception e) {            System.out.println("Alpha-i " + alpha_Y.toString());            System.out.println("Ri " + Ri_Y.toString());            System.out.println("Mi " + Mi_YY.toString());                        e.printStackTrace();            System.exit(0);        }        return logli;    }    static void computeLogMi(FeatureGenerator featureGen, double lambda[],             DoubleMatrix2D Mi_YY,            DoubleMatrix1D Ri_Y, boolean takeExp) {        computeLogMi(featureGen,lambda,Mi_YY,Ri_Y,takeExp,false,false);    }    static boolean computeLogMi(FeatureGenerator featureGen, double lambda[],             DoubleMatrix2D Mi_YY,            DoubleMatrix1D Ri_Y, boolean takeExp,boolean reuseM, boolean initMDone) {                if (reuseM && initMDone) {            Mi_YY = null;        } else            initMDone = false;        if (Mi_YY != null) Mi_YY.assign(0);        Ri_Y.assign(0);        while (featureGen.hasNext()) {             Feature feature = featureGen.next();            int f = feature.index();            int yp = feature.y();            int yprev = feature.yprev();            float val = feature.value();            //	    System.out.println(feature.toString());                        if (yprev < 0) {                // this is a single state feature.                double oldVal = Ri_Y.getQuick(yp);                Ri_Y.setQuick(yp,oldVal+lambda[f]*val);            } else if (Mi_YY != null) {                Mi_YY.setQuick(yprev,yp,Mi_YY.getQuick(yprev,yp)+lambda[f]*val);                initMDone = true;            }        }        if (takeExp) {            for(int r = Ri_Y.size()-1; r >= 0; r--) {                Ri_Y.setQuick(r,expE(Ri_Y.getQuick(r)));                if (Mi_YY != null)                    for(int c = Mi_YY.columns()-1; c >= 0; c--) {                        Mi_YY.setQuick(r,c,expE(Mi_YY.getQuick(r,c)));                    }            }        }        return initMDone;    }    static void computeLogMi(FeatureGenerator featureGen, double lambda[],             DataSequence dataSeq, int i,             DoubleMatrix2D Mi_YY,            DoubleMatrix1D Ri_Y, boolean takeExp) {        computeLogMi(featureGen, lambda, dataSeq, i, Mi_YY, Ri_Y, takeExp,false,false);    }    static boolean computeLogMi(FeatureGenerator featureGen, double lambda[],             DataSequence dataSeq, int i,             DoubleMatrix2D Mi_YY,            DoubleMatrix1D Ri_Y, boolean takeExp, boolean reuseM, boolean initMDone) {        featureGen.startScanFeaturesAt(dataSeq, i);        return computeLogMi(featureGen, lambda, Mi_YY, Ri_Y, takeExp,reuseM, initMDone);    }            protected double computeFunctionGradientLL(double lambda[], double grad[]) {        double logli = 0;        try {            for (int f = 0; f < lambda.length; f++) {                grad[f] = -1*lambda[f]*params.invSigmaSquare;                logli -= ((lambda[f]*lambda[f])*params.invSigmaSquare)/2;            }            diter.startScan();            if (featureGenCache != null) featureGenCache.startDataScan();            for (int numRecord = 0; diter.hasNext(); numRecord++) {                DataSequence dataSeq = (DataSequence)diter.next();                if (featureGenCache != null) featureGenCache.nextDataIndex();                if (params.debugLvl > 1) {                    Util.printDbg("Read next seq: " + numRecord + " logli " + logli);                }                alpha_Y.assign(0);                for (int f = 0; f < lambda.length; f++)                    ExpF[f] = RobustMath.LOG0;                                if ((beta_Y == null) || (beta_Y.length < dataSeq.length())) {                    beta_Y = new DenseDoubleMatrix1D[2*dataSeq.length()];                    for (int i = 0; i < beta_Y.length; i++)                        beta_Y[i] = new DenseDoubleMatrix1D(numY);                }                // compute beta values in a backward scan.                // also scale beta-values to 1 to avoid numerical problems.                beta_Y[dataSeq.length()-1].assign(0);                for (int i = dataSeq.length()-1; i > 0; i--) {                    if (params.debugLvl > 2) {                        /*  Util.printDbg("Features fired");                         featureGenerator.startScanFeaturesAt(dataSeq, i);                             while (featureGenerator.hasNext()) {                          Feature feature = featureGenerator.next();                         Util.printDbg(feature.toString());                         }                         */                    }                                        // compute the Mi matrix                    initMDone = computeLogMi(featureGenerator,lambda,dataSeq,i,Mi_YY,Ri_Y,false,reuseM,initMDone);                    tmp_Y.assign(beta_Y[i]);                    tmp_Y.assign(Ri_Y,sumFunc);                    RobustMath.logMult(Mi_YY, tmp_Y, beta_Y[i-1],1,0,false,edgeGen);                }                                                double thisSeqLogli = 0;                for (int i = 0; i < dataSeq.length(); i++) {                    // compute the Mi matrix                    initMDone = computeLogMi(featureGenerator,lambda,dataSeq,i,Mi_YY,Ri_Y,false,reuseM,initMDone);                    // find features that fire at this position..                    featureGenerator.startScanFeaturesAt(dataSeq, i);                                        if (i > 0) {                        tmp_Y.assign(alpha_Y);                        RobustMath.logMult(Mi_YY, tmp_Y, newAlpha_Y,1,0,true,edgeGen);                        newAlpha_Y.assign(Ri_Y,sumFunc);                     } else {                        newAlpha_Y.assign(Ri_Y);                    }                                        while (featureGenerator.hasNext()) {                         Feature feature = featureGenerator.next();                        int f = feature.index();                                                int yp = feature.y();                        int yprev = feature.yprev();                        float val = feature.value();                                               if ((dataSeq.y(i) == yp) && (((i-1 >= 0) && (yprev == dataSeq.y(i-1))) || (yprev < 0))) {                            grad[f] += val;                            thisSeqLogli += val*lambda[f];                            if (params.debugLvl > 2) {                                System.out.println("Feature fired " + f + " " + feature);                            }                         }                                                                       if (yprev < 0) {                            ExpF[f] = RobustMath.logSumExp(ExpF[f], newAlpha_Y.get(yp) + RobustMath.log(val) + beta_Y[i].get(yp));                        } else {                            ExpF[f] = RobustMath.logSumExp(ExpF[f], alpha_Y.get(yprev)+Ri_Y.get(yp)+Mi_YY.get(yprev,yp)+RobustMath.log(val)+beta_Y[i].get(yp));                        }                    }                    alpha_Y.assign(newAlpha_Y);                                        if (params.debugLvl > 2) {                        System.out.println("Alpha-i " + alpha_Y.toString());                        System.out.println("Ri " + Ri_Y.toString());                        System.out.println("Mi " + Mi_YY.toString());                        System.out.println("Beta-i " + beta_Y[i].toString());                    }                }                double lZx = RobustMath.logSumExp(alpha_Y);                thisSeqLogli -= lZx;                logli += thisSeqLogli;                // update grad.                for (int f = 0; f < grad.length; f++) {                    grad[f] -= RobustMath.exp(ExpF[f]-lZx);                                    }                if (params.debugLvl > 1) {                    System.out.println("Sequence "  + thisSeqLogli + " logli " + logli + " log(Zx) " + lZx + " Zx " + Math.exp(lZx));                }                            }            if (params.debugLvl > 2) {                for (int f = 0; f < lambda.length; f++)                    System.out.print(lambda[f] + " ");                System.out.println(" :x");                for (int f = 0; f < lambda.length; f++)                    System.out.print(grad[f] + " ");                System.out.println(" :g");            }                        if (params.debugLvl > 0)                Util.printDbg("Iteration " + icall + " log-likelihood "+logli + " norm(grad logli) " + norm(grad) + " norm(x) "+ norm(lambda));                    } catch (Exception e) {            System.out.println("Alpha-i " + alpha_Y.toString());            System.out.println("Ri " + Ri_Y.toString());            System.out.println("Mi " + Mi_YY.toString());                        e.printStackTrace();            System.exit(0);        }        return logli;    }            static double log(double val) {        try {            return logE(val);        } catch (Exception e) {            System.out.println(e.getMessage());            e.printStackTrace();        }        return -1*Double.MAX_VALUE;    }        static double logE(double val) throws Exception {        double pr = Math.log(val);        if (Double.isNaN(pr) || Double.isInfinite(pr)) {            throw new Exception("Overflow error when taking log of " + val);        }        return pr;    }     static double expE(double val)  {        double pr = RobustMath.exp(val);        if (Double.isNaN(pr) || Double.isInfinite(pr)) {            try {                throw new Exception("Overflow error when taking exp of " + val + "\n Try running the CRF with the following option \"trainer ll\" to perform computations in the log-space.");            } catch (Exception e) {                System.out.println(e.getMessage());                e.printStackTrace();                return Double.MAX_VALUE;            }        }        return pr;    }    static double expLE(double val) {        double pr = RobustMath.exp(val);        if (Double.isNaN(pr) || Double.isInfinite(pr)) {            try {                throw new Exception("Overflow error when taking exp of " + val                         + " you might need to redesign feature values so as to not reach such high values");            } catch (Exception e) {                System.out.println(e.getMessage());                e.printStackTrace();                return Double.MAX_VALUE;            }        }        return pr;    }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -