⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 logisticregressiontest.java

📁 一个自然语言处理的Java开源工具包。LingPipe目前已有很丰富的功能
💻 JAVA
字号:
package com.aliasi.test.unit.stats;import com.aliasi.stats.AnnealingSchedule;import com.aliasi.stats.LogisticRegression;import com.aliasi.stats.RegressionPrior;import com.aliasi.matrix.DenseVector;import com.aliasi.matrix.SparseFloatVector;import com.aliasi.matrix.Vector;import com.aliasi.test.unit.BaseTestCase;import com.aliasi.util.AbstractExternalizable;import java.io.IOException;public class LogisticRegressionTest extends BaseTestCase {    // WALLET example from:    // Paul David Allison.  Logistic Regression Using the SAS System: Theory and Application.    // p. 117.  http://books.google.com/books?id=AcHB61vd-1UC    // 0 = KEEP_BOTH, 1=KEEP_MONEY, 2=RETURN_BOTH    static final int[] WALLET_OUTCOME_VECTOR        = new int[] {        1,        1,        2,        2,        0,        2,        2,        2,        2,        2,        1,        2,        2,        2,        2,        2,        2,        1,        0,        1,        1,        2,        2,        2,        2,        1,        1,        0,        2,        2,        2,        2,        0,        2,        2,        2,        2,        2,        2,        2,        2,        2,        2,        2,        2,        2,        2,        2,        2,        2,        1,        2,        2,        2,        2,        2,        2,        1,        2,        2,        2,        2,        2,        0,        2,        2,        0,        2,        1,        0,        0,        2,        2,        1,        1,        1,        2,        2,        2,        2,        2,        2,        2,        2,        1,        2,        2,        1,        2,        2,        2,        2,        2,        2,        2,        0,        0,        1,        0,        1,        0,        1,        0,        2,        2,        1,        2,        0,        2,        1,        2,        2,        1,        2,        2,        0,        1,        1,        0,        0,        2,        2,        2,        2,        2,        2,        2,        2,        2,        1,        1,        2,        1,        2,        1,        2,        2,        0,        2,        2,        2,        2,        1,        2,        1,        2,        1,        2,        2,        2,        2,        1,        2,        2,        1,        2,        2,        1,        2,        1,        2,        0,        2,        1,        0,        1,        2,        1,        2,        1,        1,        0,        1,        1,        0,        1,        1,        2,        2,        1,        0,        1,        2,        1,        2,        0,        1,        2,        1,        2,        2,        2,        2,        2,        1, };    // INTERCEPT, MALE, BUSINESS, PUNISH, EXPLAIN    static final double[][] WALLET_DATA_MATRIX        = new double[][] {        { 1, 0, 0, 2, 0 },        { 1, 0, 0, 2, 1 },        { 1, 0, 0, 1, 1 },        { 1, 0, 0, 2, 0 },        { 1, 1, 0, 1, 1 },        { 1, 0, 0, 1, 1 },        { 1, 0, 0, 1, 1 },        { 1, 1, 0, 1, 1 },        { 1, 1, 0, 1, 1 },        { 1, 0, 0, 2, 1 },        { 1, 0, 1, 1, 1 },        { 1, 1, 1, 1, 1 },        { 1, 1, 0, 1, 1 },        { 1, 1, 0, 1, 1 },        { 1, 0, 0, 1, 1 },        { 1, 0, 0, 1, 0 },        { 1, 0, 0, 2, 1 },        { 1, 0, 0, 3, 0 },        { 1, 1, 1, 3, 0 },        { 1, 0, 0, 1, 1 },        { 1, 1, 0, 2, 0 },        { 1, 0, 0, 1, 1 },        { 1, 1, 0, 1, 1 },        { 1, 1, 1, 1, 0 },        { 1, 1, 0, 1, 1 },        { 1, 0, 0, 1, 0 },        { 1, 1, 0, 3, 0 },        { 1, 1, 0, 2, 0 },        { 1, 1, 0, 2, 0 },        { 1, 1, 0, 1, 1 },        { 1, 0, 0, 1, 1 },        { 1, 1, 0, 2, 1 },        { 1, 1, 0, 1, 1 },        { 1, 0, 0, 2, 0 },        { 1, 1, 0, 1, 0 },        { 1, 1, 1, 2, 1 },        { 1, 0, 0, 2, 0 },        { 1, 0, 0, 1, 1 },        { 1, 1, 0, 1, 1 },        { 1, 0, 0, 1, 1 },        { 1, 0, 0, 1, 1 },        { 1, 0, 0, 1, 1 },        { 1, 0, 0, 2, 0 },        { 1, 1, 1, 1, 1 },        { 1, 1, 0, 1, 1 },        { 1, 0, 0, 1, 1 },        { 1, 0, 0, 1, 1 },        { 1, 0, 0, 1, 0 },        { 1, 0, 1, 3, 0 },        { 1, 1, 0, 2, 0 },        { 1, 0, 0, 2, 1 },        { 1, 0, 0, 1, 1 },        { 1, 1, 0, 1, 1 },        { 1, 1, 0, 1, 1 },        { 1, 1, 0, 1, 0 },        { 1, 1, 1, 1, 0 },        { 1, 1, 0, 1, 0 },        { 1, 1, 1, 3, 1 },        { 1, 0, 0, 1, 1 },        { 1, 0, 0, 3, 1 },        { 1, 0, 0, 1, 1 },        { 1, 0, 0, 1, 0 },        { 1, 0, 0, 1, 1 },        { 1, 0, 1, 1, 1 },        { 1, 0, 0, 1, 0 },        { 1, 1, 0, 1, 1 },        { 1, 1, 0, 1, 0 },        { 1, 1, 0, 3, 1 },        { 1, 1, 0, 3, 1 },        { 1, 1, 1, 2, 1 },        { 1, 1, 0, 2, 1 },        { 1, 0, 0, 1, 1 },        { 1, 0, 0, 3, 0 },        { 1, 1, 0, 1, 1 },        { 1, 0, 0, 1, 1 },        { 1, 1, 0, 1, 1 },        { 1, 0, 0, 1, 0 },        { 1, 1, 1, 1, 1 },        { 1, 1, 0, 1, 1 },        { 1, 0, 0, 1, 1 },        { 1, 0, 1, 1, 1 },        { 1, 0, 0, 1, 1 },        { 1, 0, 0, 1, 1 },        { 1, 0, 0, 2, 1 },        { 1, 1, 1, 1, 0 },        { 1, 1, 0, 1, 1 },        { 1, 0, 0, 1, 1 },        { 1, 1, 1, 1, 1 },        { 1, 0, 0, 1, 1 },        { 1, 1, 1, 1, 1 },        { 1, 0, 0, 1, 1 },        { 1, 0, 0, 1, 1 },        { 1, 0, 0, 1, 1 },        { 1, 0, 1, 1, 1 },        { 1, 0, 0, 2, 1 },        { 1, 1, 1, 1, 1 },        { 1, 1, 0, 2, 0 },        { 1, 1, 0, 1, 1 },        { 1, 1, 1, 1, 1 },        { 1, 1, 0, 1, 1 },        { 1, 1, 1, 3, 0 },        { 1, 1, 1, 1, 1 },        { 1, 1, 1, 3, 1 },        { 1, 0, 0, 3, 1 },        { 1, 0, 0, 1, 1 },        { 1, 0, 0, 1, 1 },        { 1, 0, 0, 1, 1 },        { 1, 0, 1, 3, 0 },        { 1, 0, 0, 1, 1 },        { 1, 0, 0, 1, 0 },        { 1, 0, 0, 1, 1 },        { 1, 0, 0, 1, 1 },        { 1, 1, 0, 1, 1 },        { 1, 1, 0, 1, 1 },        { 1, 0, 0, 1, 1 },        { 1, 1, 0, 3, 0 },        { 1, 1, 0, 1, 1 },        { 1, 0, 1, 1, 1 },        { 1, 0, 0, 3, 0 },        { 1, 0, 1, 2, 0 },        { 1, 0, 0, 1, 1 },        { 1, 0, 0, 1, 1 },        { 1, 0, 0, 1, 1 },        { 1, 1, 1, 1, 1 },        { 1, 0, 0, 1, 0 },        { 1, 0, 0, 1, 1 },        { 1, 1, 0, 1, 1 },        { 1, 1, 1, 1, 1 },        { 1, 1, 0, 1, 1 },        { 1, 1, 1, 1, 1 },        { 1, 0, 0, 1, 0 },        { 1, 0, 0, 1, 1 },        { 1, 1, 1, 2, 0 },        { 1, 1, 0, 1, 0 },        { 1, 1, 0, 1, 0 },        { 1, 0, 0, 2, 1 },        { 1, 1, 1, 1, 1 },        { 1, 0, 0, 3, 0 },        { 1, 0, 0, 1, 1 },        { 1, 1, 1, 1, 1 },        { 1, 0, 0, 1, 0 },        { 1, 0, 1, 1, 0 },        { 1, 0, 0, 1, 1 },        { 1, 1, 0, 1, 1 },        { 1, 1, 0, 1, 0 },        { 1, 0, 0, 1, 0 },        { 1, 1, 0, 1, 1 },        { 1, 1, 0, 1, 1 },        { 1, 1, 0, 2, 0 },        { 1, 1, 1, 2, 1 },        { 1, 1, 0, 1, 1 },        { 1, 1, 0, 1, 0 },        { 1, 0, 0, 1, 1 },        { 1, 1, 0, 1, 1 },        { 1, 1, 1, 1, 0 },        { 1, 0, 0, 1, 1 },        { 1, 1, 0, 1, 1 },        { 1, 1, 0, 1, 0 },        { 1, 0, 1, 2, 1 },        { 1, 1, 1, 2, 1 },        { 1, 0, 0, 1, 1 },        { 1, 1, 0, 1, 1 },        { 1, 1, 1, 3, 1 },        { 1, 1, 0, 1, 1 },        { 1, 0, 0, 1, 0 },        { 1, 1, 0, 3, 0 },        { 1, 0, 0, 1, 1 },        { 1, 1, 1, 1, 1 },        { 1, 0, 0, 1, 1 },        { 1, 0, 0, 2, 0 },        { 1, 1, 0, 1, 1 },        { 1, 1, 1, 2, 0 },        { 1, 1, 0, 2, 0 },        { 1, 0, 0, 1, 1 },        { 1, 0, 1, 3, 0 },        { 1, 1, 0, 1, 1 },        { 1, 1, 0, 1, 1 },        { 1, 1, 0, 1, 1 },        { 1, 0, 0, 1, 1 },        { 1, 1, 0, 1, 1 },        { 1, 1, 0, 2, 0 },        { 1, 0, 1, 2, 1 },        { 1, 0, 0, 2, 0 },        { 1, 1, 1, 1, 1 },        { 1, 0, 1, 2, 1 },        { 1, 0, 0, 3, 0 },        { 1, 1, 1, 1, 0 },        { 1, 0, 0, 3, 1 },        { 1, 1, 0, 2, 1 },        { 1, 0, 0, 1, 1 },        { 1, 1, 0, 3, 1 },        { 1, 0, 0, 1, 1 },        { 1, 1, 1, 1, 1 },        { 1, 1, 0, 1, 1 },        { 1, 1, 0, 1, 1 }    };    static final double[][] WALLET_EXPECTED_FEATURES        = new double[][] {        // INTERCEPT, MALE, BUSINESS, PUNISH, EXPLAIN        { -3.4712, 1.2673, 1.1804, 1.0817, -1.6006 }, // CAT 0=KEEP_BOTH        { -1.2917, 1.1699, 0.4179, 0.1957, -0.8040 }, // CAT 1=KEEP_MONEY        { 0.0, 0.0, 0.0, 0.0, 0.0 }                   // CAT 2=RETURN_BOTH    };    public void testClass() {        Vector[] weightVectors = new Vector[2];        weightVectors[0] = new DenseVector(new double[] { 1, 2, 3 });        weightVectors[1] = new DenseVector(new double[] { -2, 1, -1 });        LogisticRegression regression = new LogisticRegression(weightVectors);        Vector testCase = new DenseVector(new double[] { 1, -1, 2 });        double prod1 = (1*1) + (-1 * 2) + (2 * 3);        double prod2 = (1 * -2) + (-1 * 1) + (2 * -1);        double prod3 = 0;        double prop1 = Math.exp(prod1);        double prop2 = Math.exp(prod2);        double prop3 = Math.exp(prod3);        assertEquals(1.0,prop3,0.0001);        double p1 = prop1 / (prop1 + prop2 + prop3);        double p2 = prop2 / (prop1 + prop2 + prop3);        double p3 = prop3 / (prop1 + prop2 + prop3);        double[] expected = new double[] { p1, p2, p3};        double[] estimated = regression.classify(testCase);        assertEquals(expected.length, estimated.length);        for (int i = 0; i < expected.length; ++i)            assertEquals(expected[i],estimated[i],0.0000001);    }    static Vector[] sparseCopy(Vector[] matrix) {        Vector[] result = new Vector[matrix.length];        for (int i = 0; i < matrix.length; ++i)            result[i] = sparseCopy(matrix[i]);        return result;    }    static Vector sparseCopy(Vector v) {        int[] dims = new int[v.numDimensions()];        float[] vals = new float[v.numDimensions()];        for (int i = 0; i < dims.length; ++i) {            dims[i] = i;            vals[i] = (float) v.value(i);        }        return new SparseFloatVector(dims,vals,v.numDimensions());    }    public void testEstimation() throws IOException, ClassNotFoundException {        Vector[] data_matrix = new Vector[WALLET_DATA_MATRIX.length];        for (int i = 0; i < data_matrix.length; ++i)            data_matrix[i] = new DenseVector(WALLET_DATA_MATRIX[i]);        Vector[] sparse_data_matrix = sparseCopy(data_matrix);        // assertCorrectRegression(data_matrix);        assertCorrectRegression(sparse_data_matrix);    }    void assertCorrectRegression(Vector[] data_matrix) throws IOException, ClassNotFoundException {        LogisticRegression regression            = LogisticRegression.estimate(data_matrix,                                          WALLET_OUTCOME_VECTOR,                                          RegressionPrior.noninformative(),                                          AnnealingSchedule.inverse(0.05,100),                                          0.00000001, // min improve                                          10, // min epochs                                          500000, // max epochs                                          null);  // no print feedback        Vector[] vs = regression.weightVectors();        for (int i = 0; i < vs.length; ++i) {            // System.out.println("residuals " + i);            for (int j = 0; j < vs[i].numDimensions(); ++j) {                assertEquals(WALLET_EXPECTED_FEATURES[i][j],vs[i].value(j),0.1);                // System.out.printf("%8.5f ", vs[i].value(j) - WALLET_EXPECTED_FEATURES[i][j]);            }            // System.out.println();        }        LogisticRegression regression2            = (LogisticRegression) AbstractExternalizable.compile(regression);        assertEquals(regression.numOutcomes(),                     regression2.numOutcomes());        assertEquals(regression.numInputDimensions(),                     regression.numInputDimensions());        Vector[] vs1 = regression.weightVectors();        Vector[] vs2 = regression2.weightVectors();        assertEquals(vs1.length,vs2.length);        assertEquals(vs1.length,vs2.length);        for (int i = 0; i < vs1.length; ++i)            assertEquals(vs1[i],vs2[i]);    }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -