📄 logisticregression.java
字号:
* Setting this to a low value will lead to slow, but accurate * coefficient estimates. * * <p>Finally, the search parameters include an instance of * {@link AnnealingSchedule} which impelements the <code>learningRate(epoch)</code> * method. See that method for concrete implementations, including * a standard inverse epoch annealing and exponential decay annealing. * * <h4>Serialization and Compilation</h4> * * For convenience, this class implements both the {@link Serializable} * and {@link Compilable} interfaces. Serializing or compiling * a logistic regression model has the same effect. The model * read back in from its serialized state will be an instance of * this class, {@link LogisticRegression}. * * <h4>References</h4> * * Logistic regression is discussed in most machine learning and * statistics textbooks. These three machine learning textbooks all * introduce some form of stochastic gradient descent and logistic * regression (often not together, and often under different names as * listed in the AKA section above): * * <ul> * * <li>MacKay, David. 2003. <a href="http://www.inference.phy.cam.ac.uk/mackay/itprnn/book.html"><i>Information Theory, Inference, and Learning Algorithms</i></a> (includes free download links). * Cambridge University Press. * * <li>Hastie, Trevor, Robert Tibshirani, and Jerome Friedman. 2001. * <i><a href="http://www-stat.stanford.edu/~tibs/ElemStatLearn/">Elements of Statistical Learning</a></i>. * Springer.</li> * * <li>Bishop, Christopher M. 2006. <a href="http://research.microsoft.com/~cmbishop/PRML/">Pattern Recognition and Machine Learning</a>. * Springer.</li> * * </ul> * * An introduction to traditional statistical modeling with logistic * regression may be found in: * * <ul> * <li>Gelman, Andrew and Jennnifer Hill. 2007. <a href="http://www.stat.columbia.edu/~gelman/arm/">Data Analysis Using Regression and Multilevel/Hierarchical Models</a>. Cambridge University Press. * </ul> * * A discussion of text classification using regression that evaluates * with respect to support vector machines (SVMs) and considers * informative Laplace and Gaussian priors varying by dimension (which * this class supports), see: * * <ul> * <li>Genkin, Alexander, David D. Lewis, and David Madigan. 2004. * <a href="http://www.stat.columbia.edu/~gelman/stuff_for_blog/madigan.pdf">Large-Scale Bayesian Logistic Regression for Text Categorization</a>. * Rutgers University Technical Report. * (<a href="http://stat.rutgers.edu/~madigan/PAPERS/techno-06-09-18.pdf">alternate download</a>). * </li> * </ul> * * @author Bob Carpenter * @version 3.6 * @since LingPipe3.5 */public class LogisticRegression implements Compilable, Serializable { private final Vector[] mWeightVectors; /** * Construct a multinomial logistic regression model with * the specified weight vectors. With <code>k-1</code> * weight vectors, the result is a multinomial classifier * with <code>k</code> outcomes. * *<p>See the class definition above for more information on *logistic regression. * * @param weightVectors Weight vectors definining this regression * model. * @throws IllegalArgumentException If the array of weight vectors * does not have at least one element or if there are two weight * vectors with different numbers of dimensions. */ public LogisticRegression(Vector[] weightVectors) { if (weightVectors.length < 1) { String msg = "Require at least one weight vector."; throw new IllegalArgumentException(msg); } int numDimensions = weightVectors[0].numDimensions(); for (int k = 1; k < weightVectors.length; ++k) { if (numDimensions != weightVectors[k].numDimensions()) { String msg = "All weight vectors must be same dimensionality." + " Found weightVectors[0].numDimensions()=" + numDimensions + " weightVectors[" + k + "]=" + weightVectors[k].numDimensions(); throw new IllegalArgumentException(msg); } } mWeightVectors = weightVectors; } /** * Construct a binomial logistic regression model with the * specified parameter vector. See the class definition above * for more information on logistic regression. * * @param weightVector The weights of features defining this * model. */ public LogisticRegression(Vector weightVector) { mWeightVectors = new Vector[] { weightVector }; } /** * Returns the dimensionality of inputs for this logistic * regression model. * * @return The number of dimensions for this model. */ public int numInputDimensions() { return mWeightVectors[0].numDimensions(); } /** * Returns the number of outcomes for this logistic regression * model. * * @return The number of outcomes for this model. */ public int numOutcomes() { return mWeightVectors.length + 1; } /** * Returns an array of views of the weight vectors used for this * regression model. The returned weight vectors are immutable * views of the underlying vectors used by this model, so will * change if the vectors making up this model change. * * @return An array of views of the weight vectors for this model. * */ public Vector[] weightVectors() { Vector[] immutables = new Vector[mWeightVectors.length]; for (int i = 0; i < immutables.length; ++i) immutables[i] = Matrices.unmodifiableVector(mWeightVectors[i]); return immutables; } /** * Returns an array of conditional probabilities indexed by * outcomes for the specified input vector. The resulting array * has a value for index <code>i</code> that is equal to the * probability of the outcome <code>i</code> for the specified * input. The sum of the returned values will be 1.0 (modulo * arithmetic precision). * * <p>See the class definition above for more information on * how the conditional probabilities are computed. * * @param x The input vector. * @return The array of conditional probabilities of * outcomes. * @throws IllegalArgumentException If the specified vector is not * the same dimensionality as this logistic regression instance. */ public double[] classify(Vector x) { if (numInputDimensions() != x.numDimensions()) { String msg = "Vector and classifer must be of same dimensionality." + " Regression model this.numInputDimensions()=" + numInputDimensions() + " Vector x.numDimensions()=" + x.numDimensions(); throw new IllegalArgumentException(msg); } double[] ysHat = new double[numOutcomes()]; int numOutcomesMinus1 = numOutcomes() - 1; double sum = 1.0; // for cat k-1 which has no vector for (int k = 0; k < numOutcomesMinus1; ++k) { ysHat[k] = Math.exp(x.dotProduct(mWeightVectors[k])); sum += ysHat[k]; } for (int k = 0; k < numOutcomesMinus1; ++k) ysHat[k] /= sum; ysHat[numOutcomesMinus1] = 1.0 / sum; return ysHat; } private Object writeReplace() { return new Externalizer(this); } /** * Compiles this model to the specified object output. The * compiled model, when read back in, will remain an instance of * this class, {@link LogisticRegression}. * * <p>Compilation does the same thing as serialization. * * @param out Object output to which this model is compiled. * @throws IOException If there is an underlying I/O error during * serialization. */ public void compileTo(ObjectOutput out) throws IOException { out.writeObject(new Externalizer(this)); } /** * Estimate a logistic regression model from the specified input * data using the specified Gaussian prior, initial learning rate * and annealing rate, the minimum improvement per epoch and the * minimum and maximum number of estimation epochs. * * <p>See the class documentation above for more information on * logistic regression and the stochastic gradient descent algorithm * used to implement this method. * * @param xs Input vectors indexed by training case. * @param cs Output categories indexed by training case. * @param prior The prior to be used for regression. * @param annealingSchedule Class to compute learning rate for each epoch. * @param minImprovement The minimum relative improvement in * log likelihood for the corpus to continue to another epoch. * @param minEpochs Minimum number of epochs. * @param maxEpochs Maximum number of epochs. * @param progressWriter Writer to which progress reports are written, * or null if no progress reports are needed. * @throws IllegalArgumentException If the set of input vectors * does not contain at least one instance, if the number of output * categories isn't the same as the input categories, if two input * vectors have different dimensions, or if the prior has a * different number of dimensions than the instances. */ public static LogisticRegression estimate(Vector[] xs, int[] cs, RegressionPrior prior, AnnealingSchedule annealingSchedule, double minImprovement, int minEpochs, int maxEpochs, PrintWriter progressWriter) { if (xs.length < 1) { String msg = "Require at least one training instance."; throw new IllegalArgumentException(msg); } if (xs.length != cs.length) { String msg = "Require same number of training instances as outcomes." + " Found xs.length=" + xs.length + " cs.length=" + cs.length; throw new IllegalArgumentException(msg); } int numTrainingInstances = xs.length; int numOutcomesMinus1 = max(cs); int numOutcomes = numOutcomesMinus1 + 1; int numDimensions = xs[0].numDimensions(); prior.verifyNumberOfDimensions(numDimensions); for (int i = 1; i < xs.length; ++i) { if (xs[i].numDimensions() != numDimensions) { String msg = "Number of dimensions must match for all input vectors." + " Found xs[0].numDimensions()=" + numDimensions + " xs[" + i + "].numDimensions()=" + xs[i].numDimensions(); throw new IllegalArgumentException(msg); } } DenseVector[] weightVectors = new DenseVector[numOutcomesMinus1]; for (int k = 0; k < numOutcomesMinus1; ++k) weightVectors[k] = new DenseVector(numDimensions); // values all 0.0 boolean hasSparseInputs = isSparse(xs); boolean hasPrior = (prior != null) && (!(prior instanceof RegressionPrior.NoninformativeRegressionPrior)); if (progressWriter != null) { progressWriter.println("Logistic Regression Progress Report"); progressWriter.println("Number of dimensions=" + numDimensions); progressWriter.println("Number of Outcomes=" + numOutcomes);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -