📄 comprowmatrix.java
字号:
C.zero(); // optimised a little bit to avoid zeros in rows, but not to // exploit sparsity of matrix B for (int i = 0; i < numRows; ++i) { for (int j = 0; j < C.numColumns(); ++j) { double dot = 0; for (int k = rowPointer[i]; k < rowPointer[i + 1]; ++k) { dot += data[k] * B.get(columnIndex[k], j); } if (dot != 0) { C.set(i, j, dot); } } } return C; } @Override public Vector mult(Vector x, Vector y) { // check dimensions checkMultAdd(x, y); // can't assume this, unfortunately y.zero(); if (x instanceof DenseVector) { // DenseVector optimisations double[] xd = ((DenseVector) x).getData(); for (int i = 0; i < numRows; ++i) { double dot = 0; for (int j = rowPointer[i]; j < rowPointer[i + 1]; j++) { dot += data[j] * xd[columnIndex[j]]; } if (dot != 0) { y.set(i, dot); } } return y; } // use sparsity of matrix (not vector), as get(,) is slow // TODO: additional optimisations for mult(ISparseVector, Vector) // note that this would require Sparse BLAS, e.g. BLAS_DUSDOT(,,,,) // @see http://www.netlib.org/blas/blast-forum/chapter3.pdf for (int i = 0; i < numRows; ++i) { double dot = 0; for (int j = rowPointer[i]; j < rowPointer[i + 1]; j++) { dot += data[j] * x.get(columnIndex[j]); } y.set(i, dot); } return y; } @Override public Vector multAdd(double alpha, Vector x, Vector y) { if (!(x instanceof DenseVector) || !(y instanceof DenseVector)) return super.multAdd(alpha, x, y); checkMultAdd(x, y); double[] xd = ((DenseVector) x).getData(); double[] yd = ((DenseVector) y).getData(); for (int i = 0; i < numRows; ++i) { double dot = 0; for (int j = rowPointer[i]; j < rowPointer[i + 1]; ++j) dot += data[j] * xd[columnIndex[j]]; yd[i] += alpha * dot; } return y; } @Override public Vector transMult(Vector x, Vector y) { if (!(x instanceof DenseVector) || !(y instanceof DenseVector)) return super.transMult(x, y); checkTransMultAdd(x, y); double[] xd = ((DenseVector) x).getData(); double[] yd = ((DenseVector) y).getData(); y.zero(); for (int i = 0; i < numRows; ++i) for (int j = rowPointer[i]; j < rowPointer[i + 1]; ++j) yd[columnIndex[j]] += data[j] * xd[i]; return y; } @Override public Vector transMultAdd(double alpha, Vector x, Vector y) { if (!(x instanceof DenseVector) || !(y instanceof DenseVector)) return super.transMultAdd(alpha, x, y); checkTransMultAdd(x, y); double[] xd = ((DenseVector) x).getData(); double[] yd = ((DenseVector) y).getData(); // y = 1/alpha * y y.scale(1. / alpha); // y = A'x + y for (int i = 0; i < numRows; ++i) for (int j = rowPointer[i]; j < rowPointer[i + 1]; ++j) yd[columnIndex[j]] += data[j] * xd[i]; // y = alpha*y = alpha*A'x + y return y.scale(alpha); } @Override public void set(int row, int column, double value) { check(row, column); int index = getIndex(row, column); data[index] = value; } @Override public void add(int row, int column, double value) { check(row, column); int index = getIndex(row, column); data[index] += value; } @Override public double get(int row, int column) { check(row, column); int index = no.uib.cipr.matrix.sparse.Arrays.binarySearch(columnIndex, column, rowPointer[row], rowPointer[row + 1]); if (index >= 0) return data[index]; else return 0; } /** * Finds the insertion index */ private int getIndex(int row, int column) { int i = no.uib.cipr.matrix.sparse.Arrays.binarySearch(columnIndex, column, rowPointer[row], rowPointer[row + 1]); if (i != -1 && columnIndex[i] == column) return i; else throw new IndexOutOfBoundsException("Entry (" + (row + 1) + ", " + (column + 1) + ") is not in the matrix structure"); } @Override public CompRowMatrix copy() { return new CompRowMatrix(this); } @Override public Iterator<MatrixEntry> iterator() { return new CompRowMatrixIterator(); } @Override public CompRowMatrix zero() { Arrays.fill(data, 0); return this; } @Override public Matrix set(Matrix B) { if (!(B instanceof CompRowMatrix)) return super.set(B); checkSize(B); CompRowMatrix Bc = (CompRowMatrix) B; // Reallocate matrix structure, if necessary if (Bc.columnIndex.length != columnIndex.length || Bc.rowPointer.length != rowPointer.length) { data = new double[Bc.data.length]; columnIndex = new int[Bc.columnIndex.length]; rowPointer = new int[Bc.rowPointer.length]; } System.arraycopy(Bc.data, 0, data, 0, data.length); System.arraycopy(Bc.columnIndex, 0, columnIndex, 0, columnIndex.length); System.arraycopy(Bc.rowPointer, 0, rowPointer, 0, rowPointer.length); return this; } /** * Iterator over a compressed row matrix */ private class CompRowMatrixIterator implements Iterator<MatrixEntry> { private int row, cursor; private CompRowMatrixEntry entry = new CompRowMatrixEntry(); public CompRowMatrixIterator() { // Find first non-empty row nextNonEmptyRow(); } /** * Locates the first non-empty row, starting at the current. After the * new row has been found, the cursor is also updated */ private void nextNonEmptyRow() { while (row < numRows() && rowPointer[row] == rowPointer[row + 1]) row++; cursor = rowPointer[row]; } public boolean hasNext() { return cursor < data.length; } public MatrixEntry next() { entry.update(row, cursor); // Next position is in the same row if (cursor < rowPointer[row + 1] - 1) cursor++; // Next position is at the following (non-empty) row else { row++; nextNonEmptyRow(); } return entry; } public void remove() { entry.set(0); } } /** * Entry of a compressed row matrix */ private class CompRowMatrixEntry implements MatrixEntry { private int row, cursor; /** * Updates the entry */ public void update(int row, int cursor) { this.row = row; this.cursor = cursor; } public int row() { return row; } public int column() { return columnIndex[cursor]; } public double get() { return data[cursor]; } public void set(double value) { data[cursor] = value; } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -