📄 matrix.java~42~
字号:
package neural.matrix;public class Matrix { double[][] m; public Matrix(double[][] m) { this.m = m; } public Matrix(Matrix matrix) { m = alloc(matrix.getH(), matrix.getW()); for (int i = matrix.getH() - 1; i >= 0 ; i--) { for (int j = matrix.getW() - 1; j >= 0 ; j--) { m[i][j] = matrix.m[i][j]; } } } public Matrix(int h, int w) { this.m = alloc(h, w); } public Matrix(int h, int w, double initialValue) { this.m = alloc(h, w, initialValue); } public int getH() { return m.length; } public int getW() { return m[0].length; } public double get(int r, int c) { return m[r][c]; } public void set(int r, int c, double value) { m[r][c] = value; } public void setColVec(int c, Matrix colVec) throws SizeMismatchException { if (this.getH() != colVec.getH()) { throw new SizeMismatchException(); } for (int i = this.getH() - 1; i >= 0; i--) { this.m[i][c] = colVec.m[i][0]; } } public void setRowVec(int r, Matrix rowVec) throws SizeMismatchException { if (this.getW() != rowVec.getW()) { throw new SizeMismatchException(); } for (int i = this.getW() - 1; i >= 0; i--) { this.m[r][i] = rowVec.m[r][i]; } } public static Matrix createNumberMatrix(double value) { return new Matrix(1, 1, value); } public static Matrix createColumnVector(int length) { return new Matrix(length, 1); } public static Matrix createColumnVector(double[] v) { Matrix t = new Matrix(v.length, 1); for (int i = 0; i < v.length; i++) { t.m[i][0] = v[i]; } return t; } public static Matrix createRowVector(int length) { return new Matrix(1, length); } public static Matrix createRowVector(double[] v) { Matrix t = new Matrix(1, v.length); for (int i = 0; i < v.length; i++) { t.m[0][i] = v[i]; } return t; } public static Matrix createSquareMatrix(int size) { return new Matrix(size, size); } public static Matrix createSquareMatrix(int size, double initialValue) { return new Matrix(size, size, initialValue); } public static Matrix createEMatrix(int size) { Matrix t = new Matrix(size, size); for (int i = 0; i < size; i++) { t.m[i][i] = 1.0; } return t; } public int getNumberOfZero() { int t = 0; for (int i = this.getH() - 1; i >= 0; i--) { for (int j = this.getW() - 1; j >= 0; j--) { if (m[i][j] == 0) { t ++; } } } return t; } public static double distanceBetweenColVec(Matrix v1, Matrix v2) throws SizeMismatchException { // Matrix v = Matrix.sub(v1, v2); return v1.getH() - v.getNumberOfZero(); } public static Matrix createEMatrix(int size, double zeroRefill) { Matrix t = new Matrix(size, size, zeroRefill); for (int i = 0; i < size; i++) { t.m[i][i] = 1.0; } return t; } public static Matrix createEMatrix(double[] oneRefill) { Matrix t = new Matrix(oneRefill.length, oneRefill.length); for (int i = 0; i < oneRefill.length; i++) { t.m[i][i] = oneRefill[i]; } return t; } public static Matrix createEMatrix(double[] oneRefill, double zeroRefill) { Matrix t = new Matrix(oneRefill.length, oneRefill.length, zeroRefill); for (int i = 0; i < oneRefill.length; i++) { t.m[i][i] = oneRefill[i]; } return t; } public static Matrix add(Matrix m1, Matrix m2) throws SizeMismatchException { if (m1.getW() != m2.getW() || m1.getH() != m2.getH()) { throw new SizeMismatchException(); } Matrix t = new Matrix(m1.getH(), m1.getW()); for (int i = m1.getH() - 1; i >= 0; i--) { for (int j = m1.getW() - 1; j >= 0; j--) { t.m[i][j] = m1.m[i][j] + m2.m[i][j]; } } return t; } public static Matrix sub(Matrix m1, Matrix m2) throws SizeMismatchException { if (m1.getW() != m2.getW() || m1.getH() != m2.getH()) { throw new SizeMismatchException(); } Matrix t = new Matrix(m1.getH(), m1.getW()); for (int i = m1.getH() - 1; i >= 0; i--) { for (int j = m1.getW() - 1; j >= 0; j--) { t.m[i][j] = m1.m[i][j] - m2.m[i][j]; } } return t; } public static Matrix getReformedMatrix(Matrix m, int h, int w) throws SizeMismatchException { double size = h * w; if (size != m.getH() * m.getW()) { throw new SizeMismatchException(); } int mw = m.getW(); Matrix t = new Matrix(h, w); for (int i = 0; i < size; i++) { t.m[i / w][i % w] = m.m[i / mw][i % mw]; } return t; } public static Matrix multiply(Matrix m1, Matrix m2) throws SizeMismatchException { if (m1.getW() != m2.getH()) { throw new SizeMismatchException(); } Matrix t = new Matrix(m1.getH(), m2.getW()); for (int i = m1.getH() - 1; i >= 0; i--) { for (int j = m2.getW() - 1; j >= 0; j--) { double temp = 0; for (int k = m1.getW() - 1; k >= 0; k--) { temp += m1.m[i][k] * m2.m[k][j]; } t.m[i][j] = temp; } } return t; } public static Matrix multiply(Matrix m, double k) { Matrix t = new Matrix(m.getH(), m.getW()); for (int i = m.getH() - 1; i >= 0; i--) { for (int j = m.getW() - 1; j >= 0; j--) { t.m[i][j] = k * m.m[i][j]; } } return t; } public static Matrix transpose(Matrix m) { Matrix t = new Matrix(m.getW(), m.getH()); for (int i = t.getH() - 1; i >= 0; i--) { for (int j = t.getW() - 1; j >= 0; j--) { t.m[i][j] = m.m[j][i]; } } return t; } private static double[][] alloc(int h, int w) { double[][] t = new double[h][]; for (int i = 0; i < h; i++) { t[i] = new double[w]; } return t; } private static double[][] alloc(int h, int w, double initialValue) { double[][] t = new double[h][]; for (int i = 0; i < h; i++) { t[i] = new double[w]; for (int j = 0; j < w; j++) { t[i][j] = initialValue; } } return t; } public static double innerProduct(Matrix m1, int r, Matrix m2, int c) throws SizeMismatchException { if (m1.getW() != m2.getH()) { throw new SizeMismatchException(); } int t = 0; for (int i = m1.getW() - 1; i >= 0; i--) { t += m1.m[r][i] * m2.m[i][c]; } return t; } public static double innerProduct(Matrix m1, Matrix m2) throws SizeMismatchException { if (m1.getW() != m2.getH()) { throw new SizeMismatchException(); } int t = 0; for (int i = m1.getW() - 1; i >= 0; i--) { t += m1.m[0][i] * m2.m[i][0]; } return t; } public static double innerProduct(Matrix colVec) { int t = 0; for (int i = colVec.getH() - 1; i >= 0; i--) { t += colVec.m[i][0] * colVec.m[i][0]; } return t; } public boolean isEqual(Matrix m) throws SizeMismatchException { if (this.getW() != m.getW() || this.getH() != m.getH()) { throw new SizeMismatchException(); } for (int i = this.getH() - 1; i >= 0; i--) { for (int j = this.getW() - 1; j >= 0; j--) { if (this.m[i][j] != m.m[i][j]) { return false; } } } return true; } public String toString() { StringBuffer sb = new StringBuffer(); sb.append("<" + this.getH() + ", " + this.getW() + ">\n"); int h = this.getH(); int w = this.getW(); for (int i = 0; i < h; i++) { for (int j = 0; j < w; j++) { sb.append(this.get(i, j)); if (j == w - 1) { sb.append("\n"); } else { sb.append(", "); } } } return sb.toString(); } public String toString3() { StringBuffer sb = new StringBuffer(); int h = this.getH(); int w = this.getW(); for (int i = 0; i < h; i++) { sb.append("["); for (int j = 0; j < w; j++) { sb.append((int)this.get(i, j)); if (j == w - 1) { sb.append(""); } else { sb.append(", "); } } sb.append("]\n"); } return sb.toString(); } public String toString2() { StringBuffer sb = new StringBuffer(); //sb.append("<" + this.getH() + ", " + this.getW() + ">\n"); int h = this.getH(); int w = this.getW(); for (int i = 0; i < h; i++) { for (int j = 0; j < w; j++) { double v = this.get(i, j); sb.append(" "); if (v > 0) { sb.append('*'); } else { sb.append('.'); } if (j == w - 1) { sb.append("\n"); } else { sb.append(""); } } } return sb.toString(); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -