⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 robustmath.java

📁 这是一个CRF(条件随机域)算法的实现,希望能对从事算法的有些帮助.
💻 JAVA
字号:
package iitb.CRF;import java.io.Serializable;import java.util.*;import cern.colt.function.*;import cern.colt.matrix.DoubleMatrix1D;import cern.colt.matrix.DoubleMatrix2D;public class RobustMath {    public static double LOG0 = -1*Double.MAX_VALUE;    public static double LOG2 = 0.69314718055;    static final double MINUS_LOG_EPSILON = 30; //-1*Math.log(Double.MIN_VALUE);        static class LogExpCache {        static int CUT_OFF = 6;        static int NUM_FINE = 10000;        static int NUM_COARSE = 1000;        static boolean useCache = true;        static double vals[] = new double[CUT_OFF*NUM_FINE+((int)MINUS_LOG_EPSILON-CUT_OFF)*NUM_COARSE+1];        static {            for(int i = vals.length-1; i >= 0; vals[i--]=-1);        }        static double lookupAdd(double val) {            if (!useCache)                return Math.log(Math.exp(-1*val) + 1.0);            int index = 0;            //assert ((val < MINUS_LOG_EPSILON) && (val > 0));            if (val < CUT_OFF) {                index = (int)Math.rint(val*NUM_FINE);            } else {                index = NUM_FINE*CUT_OFF + (int)Math.rint((val-CUT_OFF)*NUM_COARSE);            }            if (vals[index] < 0)                 vals[index] = Math.log(Math.exp(-1*val) + 1.0);            return vals[index];        }    };    public static double logSumExp(double v1, double v2) {        if (Math.abs(v1-v2) < Double.MIN_VALUE)            return v1 + LOG2;        double vmin = Math.min(v1,v2);        double vmax = Math.max(v1,v2);        if ( vmax > vmin + MINUS_LOG_EPSILON ) {            return vmax;        } else {            return vmax + LogExpCache.lookupAdd(vmax-vmin);            /*            double retval = vmax + Math.log(Math.exp(vmin-vmax) + 1.0);            //System.out.println((vmax-vmin) + " " + (retval-vmax));            return retval;            */        }    }    static class LogSumExp implements DoubleDoubleFunction {        public double apply(double v1, double v2) {            return logSumExp(v1,v2);        }    };    public static LogSumExp logSumExpFunc = new LogSumExp();    static void addNoDups(TreeSet vec, double v) {        Double val = new Double(v);        if (!vec.add(val)) {            vec.remove(val);            addNoDups(vec, val.doubleValue()+LOG2);        }    }    public static double logSumExp(TreeSet logProbVector) {        while ( logProbVector.size() > 1 ) {            double lp0 = ((Double)logProbVector.first()).doubleValue();            logProbVector.remove(logProbVector.first());            double lp1 = ((Double)logProbVector.first()).doubleValue();            logProbVector.remove(logProbVector.first());            addNoDups(logProbVector,logSumExp(lp0,lp1));        }        if (logProbVector.size() > 0)            return ((Double)logProbVector.first()).doubleValue();        return RobustMath.LOG0;    }        // matrix stuff for the older version..    static double logSumExp(DoubleMatrix1D logProb) {        TreeSet logProbVector = new TreeSet();        for ( int lpx = 0; lpx < logProb.size(); lpx++ )            if (logProb.getQuick(lpx) != RobustMath.LOG0)                addNoDups(logProbVector,logProb.getQuick(lpx));        return logSumExp(logProbVector);    }    static void logSumExp(DoubleMatrix1D v1, DoubleMatrix1D v2) {        for (int i = 0; i < v1.size(); i++) {            v1.set(i,logSumExp(v1.get(i), v2.get(i)));        }    }    static class LogMult implements IntIntDoubleFunction {        DoubleMatrix2D M;        DoubleMatrix1D z;        double lalpha;        boolean transposeA;        DoubleMatrix1D y;        public double apply(int i, int j, double val) {            int r = i;            int c = j;            if (transposeA) {                r = j;                c = i;            }            z.set(r, RobustMath.logSumExp(z.get(r), M.get(i,j)+y.get(c)+lalpha));            return val;        }    };    static LogMult logMult = new LogMult();    public static DoubleMatrix1D logMult(DoubleMatrix2D M, DoubleMatrix1D y, DoubleMatrix1D z, double alpha, double beta, boolean transposeA) {        // z = alpha * A * y + beta*z        double lalpha = 0;        if (alpha != 1)            lalpha = Math.log(alpha);        if (beta != 0) {            if (beta != 1) {                double lbeta = Math.log(beta);                for (int i = 0; i < z.size(); z.set(i,z.get(i)+lbeta),i++);            }        } else {            z.assign(RobustMath.LOG0);        }        // in log domain this becomes:         logMult.M = M;        logMult.z = z;        logMult.lalpha = lalpha;        logMult.transposeA = transposeA;        logMult.y = y;        M.forEachNonZero(logMult);        return z;    }    static DoubleMatrix1D logMult(DoubleMatrix2D M, DoubleMatrix1D y, DoubleMatrix1D z, double alpha, double beta, boolean transposeA, EdgeGenerator edgeGen) {        // z = alpha * A * y + beta*z        // in log domain this becomes:                 double lalpha = 0;        if (alpha != 1)            lalpha = Math.log(alpha);        if (beta != 0) {            if (beta != 1) {                for (int i = 0; i < z.size(); z.set(i,z.get(i)+Math.log(beta)),i++);            }        } else {            z.assign(LOG0);        }        for (int j = 0; j < M.columns(); j++) {            for (int i = edgeGen.first(j); i < M.rows(); i = edgeGen.next(j,i)) {                int r = i;                int c = j;                if (transposeA) {                    r = j;                    c = i;                }                z.setQuick(r, logSumExp(z.getQuick(r), M.getQuick(i,j)+y.get(c)+lalpha));            }        }        return z;    }    static DoubleMatrix1D Mult(DoubleMatrix2D M, DoubleMatrix1D y, DoubleMatrix1D z, double alpha, double beta, boolean transposeA, EdgeGenerator edgeGen) {        // z = alpha * A * y + beta*z        for (int i = 0; i < z.size(); z.set(i,z.get(i)*beta),i++);        for (int j = 0; j < M.columns(); j++) {            for (int i = edgeGen.first(j); i < M.rows(); i = edgeGen.next(j,i)) {                int r = i;                int c = j;                if (transposeA) {                    r = j;                    c = i;                }                z.set(r, z.getQuick(r) + M.getQuick(i,j)*y.getQuick(c)*alpha);            }        }        return z;    }            public static void main(String args[]) {//      double vals[] = new double[]{10.172079, 7.452882, 2.429751, 7.452882, 10.818797, 8.573773, 19.215824};        /*double vals[] = new double[]{2.883626, 1.670196, 0.553112, 1.670196, -0.935964, 1.864568, 2.064754};        TreeSet vec = new TreeSet();        double trueSum = 0;        for (int i = 0; i < vals.length; i++) {            addNoDups(vec,vals[i]);            trueSum += Math.exp(vals[i]);        }        double sum = logSumExp(vec);	*/        System.out.println(logSumExp(Double.parseDouble(args[0]), Double.parseDouble(args[1])));    }    /**     * @param d     * @return     */    public static double exp(double d) {        if (Double.isInfinite(d) || ((d < 0) && (Math.abs(d) > MINUS_LOG_EPSILON)))            return 0;        //if ((d > 0) && (d < Double.MIN_VALUE))        //    return 1;        //System.out.println(d + " " + Math.exp(d));        return Math.exp(d);    }    /**     * @param val     * @return     */    public static double log(float val) {        return (Math.abs(val-1) < Double.MIN_VALUE)?0:Math.log(val);    }};

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -