regression.java

来自「JfreeChart 常用图表例子」· Java 代码 · 共 229 行

JAVA
229
字号
/* =========================================================== * JFreeChart : a free chart library for the Java(tm) platform * =========================================================== * * (C) Copyright 2000-2005, by Object Refinery Limited and Contributors. * * Project Info:  http://www.jfree.org/jfreechart/index.html * * This library is free software; you can redistribute it and/or modify it  * under the terms of the GNU Lesser General Public License as published by  * the Free Software Foundation; either version 2.1 of the License, or  * (at your option) any later version. * * This library is distributed in the hope that it will be useful, but  * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY  * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public  * License for more details. * * You should have received a copy of the GNU Lesser General Public License  * along with this library; if not, write to the Free Software Foundation,  * Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307, USA. * * [Java is a trademark or registered trademark of Sun Microsystems, Inc.  * in the United States and other countries.] * * --------------- * Regression.java * --------------- * (C) Copyright 2002-2005, by Object Refinery Limited. * * Original Author:  David Gilbert (for Object Refinery Limited); * Contributor(s):   -; * * $Id: Regression.java,v 1.3 2005/03/24 12:49:37 mungady Exp $ * * Changes * ------- * 30-Sep-2002 : Version 1 (DG); * 18-Aug-2003 : Added 'abstract' (DG); * 15-Jul-2004 : Switched getX() with getXValue() and getY() with  *               getYValue() (DG); * */package org.jfree.data.statistics;import org.jfree.data.xy.XYDataset;/** * A utility class for fitting regression curves to data. */public abstract class Regression {    /**     * Returns the parameters 'a' and 'b' for an equation y = a + bx, fitted to     * the data using ordinary least squares regression.  The result is      * returned as a double[], where result[0] --> a, and result[1] --> b.     *     * @param data  the data.     *     * @return The parameters.     */    public static double[] getOLSRegression(double[][] data) {        int n = data.length;        if (n < 2) {            throw new IllegalArgumentException("Not enough data.");        }        double sumX = 0;        double sumY = 0;        double sumXX = 0;        double sumXY = 0;        for (int i = 0; i < n; i++) {            double x = data[i][0];            double y = data[i][1];            sumX += x;            sumY += y;            double xx = x * x;            sumXX += xx;            double xy = x * y;            sumXY += xy;        }        double sxx = sumXX - (sumX * sumX) / n;        double sxy = sumXY - (sumX * sumY) / n;        double xbar = sumX / n;        double ybar = sumY / n;        double[] result = new double[2];        result[1] = sxy / sxx;        result[0] = ybar - result[1] * xbar;        return result;    }    /**     * Returns the parameters 'a' and 'b' for an equation y = a + bx, fitted to      * the data using ordinary least squares regression. The result is returned      * as a double[], where result[0] --> a, and result[1] --> b.     *     * @param data  the data.     * @param series  the series (zero-based index).     *     * @return The parameters.     */    public static double[] getOLSRegression(XYDataset data, int series) {        int n = data.getItemCount(series);        if (n < 2) {            throw new IllegalArgumentException("Not enough data.");        }        double sumX = 0;        double sumY = 0;        double sumXX = 0;        double sumXY = 0;        for (int i = 0; i < n; i++) {            double x = data.getXValue(series, i);            double y = data.getYValue(series, i);            sumX += x;            sumY += y;            double xx = x * x;            sumXX += xx;            double xy = x * y;            sumXY += xy;        }        double sxx = sumXX - (sumX * sumX) / n;        double sxy = sumXY - (sumX * sumY) / n;        double xbar = sumX / n;        double ybar = sumY / n;        double[] result = new double[2];        result[1] = sxy / sxx;        result[0] = ybar - result[1] * xbar;        return result;    }    /**     * Returns the parameters 'a' and 'b' for an equation y = ax^b, fitted to      * the data using a power regression equation.  The result is returned as      * an array, where double[0] --> a, and double[1] --> b.     *     * @param data  the data.     *     * @return The parameters.     */    public static double[] getPowerRegression(double[][] data) {        int n = data.length;        if (n < 2) {            throw new IllegalArgumentException("Not enough data.");        }        double sumX = 0;        double sumY = 0;        double sumXX = 0;        double sumXY = 0;        for (int i = 0; i < n; i++) {            double x = Math.log(data[i][0]);            double y = Math.log(data[i][1]);            sumX += x;            sumY += y;            double xx = x * x;            sumXX += xx;            double xy = x * y;            sumXY += xy;        }        double sxx = sumXX - (sumX * sumX) / n;        double sxy = sumXY - (sumX * sumY) / n;        double xbar = sumX / n;        double ybar = sumY / n;        double[] result = new double[2];        result[1] = sxy / sxx;        result[0] = Math.pow(Math.exp(1.0), ybar - result[1] * xbar);        return result;    }    /**     * Returns the parameters 'a' and 'b' for an equation y = ax^b, fitted to      * the data using a power regression equation.  The result is returned as      * an array, where double[0] --> a, and double[1] --> b.     *     * @param data  the data.     * @param series  the series to fit the regression line against.     *     * @return The parameters.     */    public static double[] getPowerRegression(XYDataset data, int series) {        int n = data.getItemCount(series);        if (n < 2) {            throw new IllegalArgumentException("Not enough data.");        }        double sumX = 0;        double sumY = 0;        double sumXX = 0;        double sumXY = 0;        for (int i = 0; i < n; i++) {            double x = Math.log(data.getXValue(series, i));            double y = Math.log(data.getYValue(series, i));            sumX += x;            sumY += y;            double xx = x * x;            sumXX += xx;            double xy = x * y;            sumXY += xy;        }        double sxx = sumXX - (sumX * sumX) / n;        double sxy = sumXY - (sumX * sumY) / n;        double xbar = sumX / n;        double ybar = sumY / n;        double[] result = new double[2];        result[1] = sxy / sxx;        result[0] = Math.pow(Math.exp(1.0), ybar - result[1] * xbar);        return result;    }}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?