📄 densematrix.java
字号:
// Copy the contents for (int i = 0; i < values.length; ++i) { if (values[i].length != numColumns) throw new IllegalArgumentException("Array cannot be jagged"); for (int j = 0; j < values[i].length; ++j) set(i, j, values[i][j]); } } @Override public DenseMatrix copy() { return new DenseMatrix(this); } @Override void copy(Matrix A) { for (MatrixEntry e : A) set(e.row(), e.column(), e.get()); } @Override public Matrix multAdd(double alpha, Matrix B, Matrix C) { if (!(B instanceof DenseMatrix) || !(C instanceof DenseMatrix)) return super.multAdd(alpha, B, C); checkMultAdd(B, C); double[] Bd = ((DenseMatrix) B).getData(), Cd = ((DenseMatrix) C) .getData(); BLAS.getInstance().dgemm(Transpose.NoTranspose.netlib(), Transpose.NoTranspose.netlib(), C.numRows(), C.numColumns(), numColumns, alpha, data, Math.max(1, numRows), Bd, Math.max(1, B.numRows()), 1, Cd, Math.max(1, C.numRows())); return C; } @Override public Matrix transAmultAdd(double alpha, Matrix B, Matrix C) { if (!(B instanceof DenseMatrix) || !(C instanceof DenseMatrix)) return super.transAmultAdd(alpha, B, C); checkTransAmultAdd(B, C); double[] Bd = ((DenseMatrix) B).getData(), Cd = ((DenseMatrix) C) .getData(); BLAS.getInstance().dgemm(Transpose.Transpose.netlib(), Transpose.NoTranspose.netlib(), C.numRows(), C.numColumns(), numRows, alpha, data, Math.max(1, numRows), Bd, Math.max(1, B.numRows()), 1, Cd, Math.max(1, C.numRows())); return C; } @Override public Matrix transBmultAdd(double alpha, Matrix B, Matrix C) { if (!(B instanceof DenseMatrix) || !(C instanceof DenseMatrix)) return super.transBmultAdd(alpha, B, C); checkTransBmultAdd(B, C); double[] Bd = ((DenseMatrix) B).getData(), Cd = ((DenseMatrix) C) .getData(); BLAS.getInstance().dgemm(Transpose.NoTranspose.netlib(), Transpose.Transpose.netlib(), C.numRows(), C.numColumns(), numColumns, alpha, data, Math.max(1, numRows), Bd, Math.max(1, B.numRows()), 1, Cd, Math.max(1, C.numRows())); return C; } @Override public Matrix transABmultAdd(double alpha, Matrix B, Matrix C) { if (!(B instanceof DenseMatrix) || !(C instanceof DenseMatrix)) return super.transABmultAdd(alpha, B, C); checkTransABmultAdd(B, C); double[] Bd = ((DenseMatrix) B).getData(), Cd = ((DenseMatrix) C) .getData(); BLAS.getInstance().dgemm(Transpose.Transpose.netlib(), Transpose.Transpose.netlib(), C.numRows(), C.numColumns(), numRows, alpha, data, Math.max(1, numRows), Bd, Math.max(1, B.numRows()), 1, Cd, Math.max(1, C.numRows())); return C; } @Override public Matrix rank1(double alpha, Vector x, Vector y) { if (!(x instanceof DenseVector) || !(y instanceof DenseVector)) return super.rank1(alpha, x, y); checkRank1(x, y); double[] xd = ((DenseVector) x).getData(), yd = ((DenseVector) y) .getData(); BLAS.getInstance().dger(numRows, numColumns, alpha, xd, 1, yd, 1, data, Math.max(1, numRows)); return this; } @Override public Vector multAdd(double alpha, Vector x, Vector y) { if (!(x instanceof DenseVector) || !(y instanceof DenseVector)) return super.multAdd(alpha, x, y); checkMultAdd(x, y); double[] xd = ((DenseVector) x).getData(), yd = ((DenseVector) y) .getData(); BLAS.getInstance().dgemv(Transpose.NoTranspose.netlib(), numRows, numColumns, alpha, data, Math.max(numRows, 1), xd, 1, 1, yd, 1); return y; } @Override public Vector transMultAdd(double alpha, Vector x, Vector y) { if (!(x instanceof DenseVector) || !(y instanceof DenseVector)) return super.transMultAdd(alpha, x, y); checkTransMultAdd(x, y); double[] xd = ((DenseVector) x).getData(), yd = ((DenseVector) y) .getData(); BLAS.getInstance().dgemv(Transpose.Transpose.netlib(), numRows, numColumns, alpha, data, Math.max(numRows, 1), xd, 1, 1, yd, 1); return y; } @Override public Matrix solve(Matrix B, Matrix X) { // We allow non-square matrices, as we then use a least-squares solver if (numRows != B.numRows()) throw new IllegalArgumentException("numRows != B.numRows() (" + numRows + " != " + B.numRows() + ")"); if (numColumns != X.numRows()) throw new IllegalArgumentException("numColumns != X.numRows() (" + numColumns + " != " + X.numRows() + ")"); if (X.numColumns() != B.numColumns()) throw new IllegalArgumentException( "X.numColumns() != B.numColumns() (" + X.numColumns() + " != " + B.numColumns() + ")"); if (isSquare()) return LUsolve(B, X); else return QRsolve(B, X, Transpose.NoTranspose); } @Override public Vector solve(Vector b, Vector x) { DenseMatrix B = new DenseMatrix(b, false), X = new DenseMatrix(x, false); solve(B, X); return x; } @Override public Matrix transSolve(Matrix B, Matrix X) { // We allow non-square matrices, as we then use a least-squares solver if (numColumns != B.numRows()) throw new IllegalArgumentException("numColumns != B.numRows() (" + numColumns + " != " + B.numRows() + ")"); if (numRows != X.numRows()) throw new IllegalArgumentException("numRows != X.numRows() (" + numRows + " != " + X.numRows() + ")"); if (X.numColumns() != B.numColumns()) throw new IllegalArgumentException( "X.numColumns() != B.numColumns() (" + X.numColumns() + " != " + B.numColumns() + ")"); return QRsolve(B, X, Transpose.Transpose); } @Override public Vector transSolve(Vector b, Vector x) { DenseMatrix B = new DenseMatrix(b, false), X = new DenseMatrix(x, false); transSolve(B, X); return x; } Matrix LUsolve(Matrix B, Matrix X) { if (!(X instanceof DenseMatrix)) throw new UnsupportedOperationException("X must be a DenseMatrix"); double[] Xd = ((DenseMatrix) X).getData(); X.set(B); int[] piv = new int[numRows]; intW info = new intW(0); LAPACK.getInstance().dgesv(numRows, B.numColumns(), data.clone(), Matrices.ld(numRows), piv, Xd, Matrices.ld(numRows), info); if (info.val > 0) throw new MatrixSingularException(); else if (info.val < 0) throw new IllegalArgumentException(); return X; } Matrix QRsolve(Matrix B, Matrix X, Transpose trans) { int nrhs = B.numColumns(); // Allocate temporary solution matrix DenseMatrix Xtmp = new DenseMatrix(Math.max(numRows, numColumns), nrhs); int M = trans == Transpose.NoTranspose ? numRows : numColumns; for (int j = 0; j < nrhs; ++j) for (int i = 0; i < M; ++i) Xtmp.set(i, j, B.get(i, j)); double[] newData = data.clone(); // Query optimal workspace double[] work = new double[1]; intW info = new intW(0); LAPACK.getInstance().dgels(trans.netlib(), numRows, numColumns, nrhs, newData, Matrices.ld(numRows), Xtmp.getData(), Matrices.ld(numRows, numColumns), work, -1, info); // Allocate workspace int lwork = -1; if (info.val != 0) lwork = Math.max(1, Math.min(numRows, numColumns) + Math.max(Math.min(numRows, numColumns), nrhs)); else lwork = Math.max((int) work[0], 1); work = new double[lwork]; // Compute the factorization info.val = 0; LAPACK.getInstance().dgels(trans.netlib(), numRows, numColumns, nrhs, newData, Matrices.ld(numRows), Xtmp.getData(), Matrices.ld(numColumns), work, lwork, info); if (info.val < 0) throw new IllegalArgumentException(); // Extract the solution int N = trans == Transpose.NoTranspose ? numColumns : numRows; for (int j = 0; j < nrhs; ++j) for (int i = 0; i < N; ++i) X.set(i, j, Xtmp.get(i, j)); return X; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -