📄 matrix4.java
字号:
package rmn;import java.util.*;import java.lang.reflect.*;public class Matrix4 implements Matrix { public double[][][][] m_matrix; public Matrix4() { } public Matrix4(int nCard1, int nCard2, int nCard3, int nCard4) { m_matrix = new double[nCard1][nCard2][nCard3][nCard4]; } public Matrix4(double[][][][] matrix) { m_matrix = matrix; } public Matrix4(Matrix4 matrix4) { m_matrix = new double[matrix4.m_matrix.length][][][]; for (int i = 0; i < m_matrix.length; i++) { m_matrix[i] = new double[matrix4.m_matrix[i].length][][]; for (int j = 0; j < m_matrix[i].length; j++) { m_matrix[i][j] = new double[matrix4.m_matrix[i][j].length][]; for (int k = 0; k < m_matrix[i][j].length; k++) { m_matrix[i][j][k] = new double[matrix4.m_matrix[i][j][k].length]; System.arraycopy(matrix4.m_matrix[i][j][k], 0, m_matrix[i][j][k], 0, m_matrix[i][j][k].length); } } } } public Matrix getCopy() { return new Matrix4(this); } public Matrix newMatrix() { Matrix4 m = new Matrix4(); m.m_matrix = new double[m_matrix.length][][][]; for (int i = 0; i < m_matrix.length; i++) { m.m_matrix[i] = new double[m_matrix[i].length][][]; for (int j = 0; j < m_matrix[i].length; j++) { m.m_matrix[i][j] = new double[m_matrix[i][j].length][]; for (int k = 0; k < m_matrix[i][j].length; k++) m.m_matrix[i][j][k] = new double[m_matrix[i][j][k].length]; } } return m; } public void fill(double val) { for (int i = 0; i < m_matrix.length; i++) for (int j = 0; j < m_matrix[i].length; j++) for (int k = 0; k < m_matrix[i][j].length; k++) Arrays.fill(m_matrix[i][j][k], val); } public int size() { return 4; } public int[] getDimensions() { int[] dims = new int[size()]; dims[0] = m_matrix.length; dims[1] = m_matrix[0].length; dims[2] = m_matrix[0][0].length; dims[3] = m_matrix[0][0][0].length; return dims; } public void inc(int[] pos) { assert pos.length == size() : pos.length; m_matrix[pos[0]][pos[1]][pos[2]][pos[3]]++; } public void add_sub(Matrix matrix1, Matrix matrix2, double rate) { Matrix4 m1 = (Matrix4) matrix1; Matrix4 m2 = (Matrix4) matrix2; int[] dims = getDimensions(); for (int i = 0; i < dims[0]; i++) for (int j = 0; j < dims[1]; j++) for (int k = 0; k < dims[2]; k++) for (int l = 0; l < dims[3]; l++) { double grad = (m1.m_matrix[i][j][k][l] - m2.m_matrix[i][j][k][l]) * rate; m_matrix[i][j][k][l] = m_matrix[i][j][k][l] * Math.exp(grad); } } public void add_log(Matrix matrix, int delta) { Matrix4 m = (Matrix4) matrix; int[] dims = getDimensions(); for (int i = 0; i < dims[0]; i++) for (int j = 0; j < dims[1]; j++) for (int k = 0; k < dims[2]; k++) for (int l = 0; l < dims[3]; l++) { m_matrix[i][j][k][l] = m_matrix[i][j][k][l] + delta * Math.log(m.m_matrix[i][j][k][l]); } } public Matrix exp_avg(int n) { Matrix4 m = (Matrix4) newMatrix(); int[] dims = getDimensions(); for (int i = 0; i < dims[0]; i++) for (int j = 0; j < dims[1]; j++) for (int k = 0; k < dims[2]; k++) for (int l = 0; l < dims[3]; l++) { m.m_matrix[i][j][k][l] = Math.exp(m_matrix[i][j][k][l] / n); } return m; } public void dotProduct(Matrix matrix) { Matrix4 matrix4 = (Matrix4) matrix; int[] dims = getDimensions(); int[] dimsm = matrix4.getDimensions(); // dimensions should match for (int i = 0; i < dims.length; i++) assert dimsm[i] == dims[i] : dimsm[i]; for (int i = 0; i < dims[0]; i++) for (int j = 0; j < dims[1]; j++) for (int k = 0; k < dims[2]; k++) for (int l = 0; l < dims[3]; l++) m_matrix[i][j][k][l] = m_matrix[i][j][k][l] * matrix4.m_matrix[i][j][k][l]; } public void dotProduct(double[] vector, int dim) { assert dim < size() : dim; int[] dims = getDimensions(); assert vector.length == dims[dim] : vector.length; int idx[] = {0, 0, 0, 0}; for (idx[0] = 0; idx[0] < dims[0]; idx[0]++) for (idx[1] = 0; idx[1] < dims[1]; idx[1]++) for (idx[2] = 0; idx[2] < dims[2]; idx[2]++) for (idx[3] = 0; idx[3] < dims[3]; idx[3]++) m_matrix[idx[0]][idx[1]][idx[2]][idx[3]] = m_matrix[idx[0]][idx[1]][idx[2]][idx[3]] * vector[idx[dim]]; } public void dotQuotient(double[] vector, int dim) { assert dim < size() : dim; int[] dims = getDimensions(); assert vector.length == dims[dim] : vector.length; int idx[] = {0, 0, 0, 0}; for (idx[0] = 0; idx[0] < dims[0]; idx[0]++) for (idx[1] = 0; idx[1] < dims[1]; idx[1]++) for (idx[2] = 0; idx[2] < dims[2]; idx[2]++) for (idx[3] = 0; idx[3] < dims[3]; idx[3]++) m_matrix[idx[0]][idx[1]][idx[2]][idx[3]] = m_matrix[idx[0]][idx[1]][idx[2]][idx[3]] / vector[idx[dim]]; } public double[] marginalize(int dim, boolean bMaximize) { assert dim < size() : dim; // Method sumOrMax = MathUtils.getSumOrMax(bMaximize); int[] dims = getDimensions(); double[] margin = new double[dims[dim]]; // assume positive potentials Arrays.fill(margin, 0); try { int idx[] = {0, 0, 0, 0}; for (idx[0] = 0; idx[0] < dims[0]; idx[0]++) for (idx[1] = 0; idx[1] < dims[1]; idx[1]++) for (idx[2] = 0; idx[2] < dims[2]; idx[2]++) for (idx[3] = 0; idx[3] < dims[3]; idx[3]++) { /* Object[] params = {new Double(margin[idx[dim]]), new Double(m_matrix[idx[0]][idx[1]][idx[2]][idx[3]])}; margin[idx[dim]] = ((Double) sumOrMax.invoke(null, params)).doubleValue(); */ if (bMaximize) margin[idx[dim]] = Math.max(margin[idx[dim]], m_matrix[idx[0]][idx[1]][idx[2]][idx[3]]); else margin[idx[dim]] += m_matrix[idx[0]][idx[1]][idx[2]][idx[3]]; } } catch (Exception e) { System.err.println(e); System.exit(1); } return margin; } public String toString() { String strRes = new String(); for (int i = 0; i < m_matrix.length; i++) for (int j = 0; j < m_matrix[i].length; j++) for (int k = 0; k < m_matrix[i][j].length; k++) for (int l = 0; l < m_matrix[i][j][k].length; l++) strRes += String.valueOf(i) + " " + String.valueOf(j) + " " + String.valueOf(k) + " " + String.valueOf(l) + " ---> " + String.valueOf(m_matrix[i][j][k][l]) + "\n"; return strRes; } }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -