📄 linearregression.java
字号:
* @exception Exception if any of the attributes are not numeric */ private void calculateAttributeDeviations() throws Exception { m_StdDev = new double [m_TransformedData.numAttributes()]; for (int i = 0; i < m_TransformedData.numAttributes(); i++) { m_StdDev[i] = Math.sqrt(m_TransformedData.variance(i)); } } /** * Removes the attribute with the highest standardised coefficient * greater than 1.5 from the selected attributes. * * @param selectedAttributes an array of flags indicating which * attributes are included in the regression model * @param coefficients an array of coefficients for the regression * model * @return true if an attribute was removed */ private boolean deselectColinearAttribute(boolean [] selectedAttributes, double [] coefficients) { double maxSC = 1.5; int maxAttr = -1, coeff = 0; for (int i = 0; i < selectedAttributes.length; i++) { if (selectedAttributes[i]) { double SC = Math.abs(coefficients[coeff] * m_StdDev[i] / m_StdDev[m_ClassIndex]); if (SC > maxSC) { maxSC = SC; maxAttr = i; } coeff++; } } if (maxAttr >= 0) { selectedAttributes[maxAttr] = false; if (b_Debug) { System.out.println("Deselected colinear attribute:" + (maxAttr + 1) + " with standardised coefficient: " + maxSC); } return true; } return false; } /** * Performs a greedy search for the best regression model using * Akaike's criterion. * * @exception Exception if regression can't be done */ private void findBestModel() throws Exception { int numAttributes = m_TransformedData.numAttributes(); int numInstances = m_TransformedData.numInstances(); boolean [] selectedAttributes = new boolean[numAttributes]; for (int i = 0; i < numAttributes; i++) { if (i != m_ClassIndex) { selectedAttributes[i] = true; } } if (b_Debug) { System.out.println((new Instances(m_TransformedData, 0)).toString()); } // Perform a regression for the full model, and remove colinear attributes double [] coefficients; do { coefficients = doRegression(selectedAttributes); } while (deselectColinearAttribute(selectedAttributes, coefficients)); double fullMSE = calculateMSE(selectedAttributes, coefficients); double akaike = (numInstances - numAttributes) + 2 * numAttributes; if (b_Debug) { System.out.println("Initial Akaike value: " + akaike); } boolean improved; int currentNumAttributes = numAttributes; switch (m_AttributeSelection) { case SELECTION_GREEDY: // Greedy attribute removal do { boolean [] currentSelected = (boolean []) selectedAttributes.clone(); improved = false; currentNumAttributes--; for (int i = 0; i < numAttributes; i++) { if (currentSelected[i]) { // Calculate the akaike rating without this attribute currentSelected[i] = false; double [] currentCoeffs = doRegression(currentSelected); double currentMSE = calculateMSE(currentSelected, currentCoeffs); double currentAkaike = currentMSE / fullMSE * (numInstances - numAttributes) + 2 * currentNumAttributes; if (b_Debug) { System.out.println("(akaike: " + currentAkaike); } // If it is better than the current best if (currentAkaike < akaike) { if (b_Debug) { System.err.println("Removing attribute " + (i + 1) + " improved Akaike: " + currentAkaike); } improved = true; akaike = currentAkaike; System.arraycopy(currentSelected, 0, selectedAttributes, 0, selectedAttributes.length); coefficients = currentCoeffs; } currentSelected[i] = true; } } } while (improved); break; case SELECTION_M5: // Step through the attributes removing the one with the smallest // standardised coefficient until no improvement in Akaike do { improved = false; currentNumAttributes--; // Find attribute with smallest SC double minSC = 0; int minAttr = -1, coeff = 0; for (int i = 0; i < selectedAttributes.length; i++) { if (selectedAttributes[i]) { double SC = Math.abs(coefficients[coeff] * m_StdDev[i] / m_StdDev[m_ClassIndex]); if ((coeff == 0) || (SC < minSC)) { minSC = SC; minAttr = i; } coeff++; } } // See whether removing it improves the Akaike score if (minAttr >= 0) { selectedAttributes[minAttr] = false; double [] currentCoeffs = doRegression(selectedAttributes); double currentMSE = calculateMSE(selectedAttributes, currentCoeffs); double currentAkaike = currentMSE / fullMSE * (numInstances - numAttributes) + 2 * currentNumAttributes; if (b_Debug) { System.out.println("(akaike: " + currentAkaike); } // If it is better than the current best if (currentAkaike < akaike) { if (b_Debug) { System.err.println("Removing attribute " + (minAttr + 1) + " improved Akaike: " + currentAkaike); } improved = true; akaike = currentAkaike; coefficients = currentCoeffs; } else { selectedAttributes[minAttr] = true; } } } while (improved); break; case SELECTION_NONE: break; } m_SelectedAttributes = selectedAttributes; m_Coefficients = coefficients; } /** * Calculate the mean squared error of a regression model on the * training data * * @param selectedAttributes an array of flags indicating which * attributes are included in the regression model * @param coefficients an array of coefficients for the regression * model * @return the mean squared error on the training data * @exception Exception if there is a missing class value in the training * data */ private double calculateMSE(boolean [] selectedAttributes, double [] coefficients) throws Exception { double mse = 0; for (int i = 0; i < m_TransformedData.numInstances(); i++) { double prediction = regressionPrediction(m_TransformedData.instance(i), selectedAttributes, coefficients); double error = prediction - m_TransformedData.instance(i).classValue(); mse += error * error; } return mse; } /** * Calculate the dependent value for a given instance for a * given regression model. * * @param transformedInstance the input instance * @param selectedAttributes an array of flags indicating which * attributes are included in the regression model * @param coefficients an array of coefficients for the regression * model * @return the regression value for the instance. * @exception Exception if the class attribute of the input instance * is not assigned */ private double regressionPrediction(Instance transformedInstance, boolean [] selectedAttributes, double [] coefficients) throws Exception { double result = 0; int column = 0; for (int j = 0; j < transformedInstance.numAttributes(); j++) if ((m_ClassIndex != j) && (selectedAttributes[j])) { result += coefficients[column] * transformedInstance.value(j); column++; } result += coefficients[column]; return result; } /** * Calculate a linear regression using the selected attributes * * @param selectedAttributes an array of booleans where each element * is true if the corresponding attribute should be included in the * regression. * @return an array of coefficients for the linear regression model. * @exception Exception if an error occurred during the regression. */ private double [] doRegression(boolean [] selectedAttributes) throws Exception { if (b_Debug) { System.out.print("doRegression("); for (int i = 0; i < selectedAttributes.length; i++) { System.out.print(" " + selectedAttributes[i]); } System.out.println(" )"); } int numAttributes = 1; for (int i = 0; i < selectedAttributes.length; i++) { if (selectedAttributes[i]) { numAttributes++; } } Matrix independent = new Matrix(m_TransformedData.numInstances(), numAttributes); Matrix dependent = new Matrix(m_TransformedData.numInstances(), 1); for (int i = 0; i < m_TransformedData.numInstances(); i ++) { int column = 0; for (int j = 0; j < m_TransformedData.numAttributes(); j++) { if (j == m_ClassIndex) { dependent.setElement(i, 0, m_TransformedData.instance(i).classValue()); } else { if (selectedAttributes[j]) { independent.setElement(i, column, m_TransformedData.instance(i).value(j)); column++; } } } independent.setElement(i, column, 1.0); } // Grab instance weights double [] weights = new double [m_TransformedData.numInstances()]; for (int i = 0; i < weights.length; i++) { weights[i] = m_TransformedData.instance(i).weight(); } // Compute coefficients return independent.regression(dependent, weights); } /** * Generates a linear regression function predictor. * * @param String the options */ public static void main(String argv[]) { try { System.out.println(Evaluation.evaluateModel(new LinearRegression(), argv)); } catch (Exception e) { System.out.println(e.getMessage()); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -