linearsystem.java
来自「一个一元曲线多项式数值演示例子」· Java 代码 · 共 386 行
JAVA
386 行
package numbercruncher.matrix;
import numbercruncher.mathutils.*;
/**
* Solve a system of linear equations using LU decomposition.
*/
public class LinearSystem
extends SquareMatrix {
private static final float TOLERANCE = Epsilon.floatValue();
/** max iters for improvement = twice # of significant digits */
private static final int MAX_ITER;
static {
int i = 0;
float t = TOLERANCE;
while (t < 1) {
++i; t *= 10;
}
MAX_ITER = 2 * i;
}
/** decomposed matrix A = LU */
protected SquareMatrix LU;
/** row index permutation vector */
protected int permutation[];
/** row exchange count */
protected int exchangeCount;
/**
* Constructor.
* @param n the number of rows = the number of columns
*/
public LinearSystem(int n) {
super(n);
reset();
}
/**
* Constructor.
* @param values the array of values
*/
public LinearSystem(float values[][]) {
super(values);
}
/**
* Set the values of the matrix.
* @param values the 2-d array of values
*/
protected void set(float values[][]) {
super.set(values);
reset();
}
/**
* Set the value of element [r,c] in the matrix.
* @param r the row index, 0..nRows
* @param c the column index, 0..nRows
* @param value the value
* @throws matrix.MatrixException for invalid index
*/
public void set(int r, int c, float value) throws MatrixException {
super.set(r, c, value);
reset();
}
/**
* Set a row of this matrix from a row vector.
* @param rv the row vector
* @param r the row index
* @throws matrix.MatrixException for an invalid index or
* an invalid vector size
*/
public void setRow(RowVector rv, int r) throws MatrixException {
super.setRow(rv, r);
reset();
}
/**
* Set a column of this matrix from a column vector.
* @param cv the column vector
* @param c the column index
* @throws matrix.MatrixException for an invalid index or
* an invalid vector size
*/
public void setColumn(ColumnVector cv, int c) throws MatrixException {
super.setColumn(cv, c);
reset();
}
/**
* Reset. Invalidate LU and the permutation vector.
*/
protected void reset() {
LU = null;
permutation = null;
exchangeCount = 0;
}
/**
* Solve Ax = b for x using the Gaussian elimination algorithm.
* @param b the right-hand-side column vector
* @param improve true to improve the solution
* @return the solution column vector
* @throws matrix.MatrixException if an error occurred
*/
public ColumnVector solve(ColumnVector b, boolean improve) throws
MatrixException {
// Validate b's size.
if (b.nRows != nRows) {
throw new MatrixException(
MatrixException.INVALID_DIMENSIONS);
}
decompose();
// Solve Ly = b for y by forward substitution.
// Solve Ux = y for x by back substitution.
ColumnVector y = forwardSubstitution(b);
ColumnVector x = backSubstitution(y);
// Improve and return x.
if (improve) {
improve(b, x);
}
return x;
}
/**
* Print the decomposed matrix LU.
* @param width the column width
* @throws matrix.MatrixException if an error occurred
*/
public void printDecomposed(int width) throws MatrixException {
decompose();
AlignRight ar = new AlignRight();
for (int r = 0; r < nRows; ++r) {
int pr = permutation[r]; // permuted row index
ar.print("Row ", 0);
ar.print(r + 1, 2);
ar.print(":", 0);
for (int c = 0; c < nCols; ++c) {
ar.print(LU.values[pr][c], width);
}
ar.println();
}
}
/**
* Compute the upper triangular matrix U and lower triangular
* matrix L such that A = L*U. Store L and U together in
* matrix LU. Compute the permutation vector permutation of
* the row indices.
* @throws matrix.MatrixException for a zero row or
* a singular matrix
*/
protected void decompose() throws MatrixException {
// Return if the decomposition is valid.
if (LU != null) {
return;
}
// Create a new LU matrix and permutation vector.
// LU is initially just a copy of the values of this system.
LU = new SquareMatrix(this.copyValues2D());
permutation = new int[nRows];
float scales[] = new float[nRows];
// Loop to initialize the permutation vector and scales.
for (int r = 0; r < nRows; ++r) {
permutation[r] = r; // initially no row exchanges
// Find the largest row element.
float largestRowElmt = 0;
for (int c = 0; c < nRows; ++c) {
float elmt = Math.abs(LU.at(r, c));
if (largestRowElmt < elmt) {
largestRowElmt = elmt;
}
}
// Set the scaling factor for row equilibration.
if (largestRowElmt != 0) {
scales[r] = 1 / largestRowElmt;
}
else {
throw new MatrixException(MatrixException.ZERO_ROW);
}
}
// Do forward elimination with scaled partial row pivoting.
forwardElimination(scales);
// Check bottom right element of the permuted matrix.
if (LU.at(permutation[nRows - 1], nRows - 1) == 0) {
throw new MatrixException(MatrixException.SINGULAR);
}
}
/**
* Do forward elimination with scaled partial row pivoting.
* @parm scales the scaling vector
* @throws matrix.MatrixException for a singular matrix
*/
private void forwardElimination(float scales[]) throws MatrixException {
// Loop once per pivot row 0..nRows-1.
for (int rPivot = 0; rPivot < nRows - 1; ++rPivot) {
float largestScaledElmt = 0;
int rLargest = 0;
// Starting from the pivot row rPivot, look down
// column rPivot to find the largest scaled element.
for (int r = rPivot; r < nRows; ++r) {
// Use the permuted row index.
int pr = permutation[r];
float absElmt = Math.abs(LU.at(pr, rPivot));
float scaledElmt = absElmt * scales[pr];
if (largestScaledElmt < scaledElmt) {
// The largest scaled element and
// its row index.
largestScaledElmt = scaledElmt;
rLargest = r;
}
}
// Is the matrix singular?
if (largestScaledElmt == 0) {
throw new MatrixException(MatrixException.SINGULAR);
}
// Exchange rows if necessary to choose the best
// pivot element by making its row the pivot row.
if (rLargest != rPivot) {
int temp = permutation[rPivot];
permutation[rPivot] = permutation[rLargest];
permutation[rLargest] = temp;
++exchangeCount;
}
// Use the permuted pivot row index.
int prPivot = permutation[rPivot];
float pivotElmt = LU.at(prPivot, rPivot);
// Do the elimination below the pivot row.
for (int r = rPivot + 1; r < nRows; ++r) {
// Use the permuted row index.
int pr = permutation[r];
float multiple = LU.at(pr, rPivot) / pivotElmt;
// Set the multiple into matrix L.
LU.set(pr, rPivot, multiple);
// Eliminate an unknown from matrix U.
if (multiple != 0) {
for (int c = rPivot + 1; c < nCols; ++c) {
float elmt = LU.at(pr, c);
// Subtract the multiple of the pivot row.
elmt -= multiple * LU.at(prPivot, c);
LU.set(pr, c, elmt);
}
}
}
}
}
/**
* Solve Ly = b for y by forward substitution.
* @param b the column vector b
* @return the column vector y
* @throws matrix.MatrixException if an error occurred
*/
private ColumnVector forwardSubstitution(ColumnVector b) throws
MatrixException {
ColumnVector y = new ColumnVector(nRows);
// Do forward substitution.
for (int r = 0; r < nRows; ++r) {
int pr = permutation[r]; // permuted row index
float dot = 0;
for (int c = 0; c < r; ++c) {
dot += LU.at(pr, c) * y.at(c);
}
y.set(r, b.at(pr) - dot);
}
return y;
}
/**
* Solve Ux = y for x by back substitution.
* @param y the column vector y
* @return the solution column vector x
* @throws matrix.MatrixException if an error occurred
*/
private ColumnVector backSubstitution(ColumnVector y) throws MatrixException {
ColumnVector x = new ColumnVector(nRows);
// Do back substitution.
for (int r = nRows - 1; r >= 0; --r) {
int pr = permutation[r]; // permuted row index
float dot = 0;
for (int c = r + 1; c < nRows; ++c) {
dot += LU.at(pr, c) * x.at(c);
}
x.set(r, (y.at(r) - dot) / LU.at(pr, r));
}
return x;
}
/**
* Iteratively improve the solution x to machine accuracy.
* @param b the right-hand side column vector
* @param x the improved solution column vector
* @throws matrix.MatrixException if failed to converge
*/
private void improve(ColumnVector b, ColumnVector x) throws MatrixException {
// Find the largest x element.
float largestX = 0;
for (int r = 0; r < nRows; ++r) {
float absX = Math.abs(x.values[r][0]);
if (largestX < absX) {
largestX = absX;
}
}
// Is x already as good as possible?
if (largestX == 0) {
return;
}
ColumnVector residuals = new ColumnVector(nRows);
// Iterate to improve x.
for (int iter = 0; iter < MAX_ITER; ++iter) {
// Compute residuals = b - Ax.
// Must use double precision!
for (int r = 0; r < nRows; ++r) {
double dot = 0;
float row[] = values[r];
for (int c = 0; c < nRows; ++c) {
double elmt = at(r, c);
dot += elmt * x.at(c); // dbl.prec. *
}
double value = b.at(r) - dot; // dbl.prec. -
residuals.set(r, (float) value);
}
// Solve Az = residuals for z.
ColumnVector z = solve(residuals, false);
// Set x = x + z.
// Find largest the largest difference.
float largestDiff = 0;
for (int r = 0; r < nRows; ++r) {
float oldX = x.at(r);
x.set(r, oldX + z.at(r));
float diff = Math.abs(x.at(r) - oldX);
if (largestDiff < diff) {
largestDiff = diff;
}
}
// Is any further improvement possible?
if (largestDiff < largestX * TOLERANCE) {
return;
}
}
// Failed to converge because A is nearly singular.
throw new MatrixException(MatrixException.NO_CONVERGENCE);
}
}
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?