📄 matrix.java
字号:
for (i = 0; i < s.getRowDimension(); i++)
for (n = 0; n < s.getColumnDimension(); n++)
s.set(i, n, StrictMath.sqrt(d.get(i, n)));
// to calculate:
// result = V*S/V
//
// with X = B/A
// and B/A = (A'\B')'
// and V=A and V*S=B
// we get
// result = (V'\(V*S)')'
//
// A*X = B
// X = A\B
// which is
// X = A.solve(B)
//
// with A=V' and B=(V*S)'
// we get
// X = V'.solve((V*S)')
// or
// result = X'
//
// which is in full length
// result = (V'.solve((V*S)'))'
a = v.inverse();
b = v.times(s).inverse();
v = null;
d = null;
evd = null;
s = null;
result = a.solve(b).inverse();
return result;
}
/**
* Performs a (ridged) linear regression.
*
* @param y the dependent variable vector
* @param ridge the ridge parameter
* @return the coefficients
* @throws IllegalArgumentException if not successful
* @author FracPete, taken from old weka.core.Matrix class
*/
public LinearRegression regression(Matrix y, double ridge) {
return new LinearRegression(this, y, ridge);
}
/**
* Performs a weighted (ridged) linear regression.
*
* @param y the dependent variable vector
* @param w the array of data point weights
* @param ridge the ridge parameter
* @return the coefficients
* @throws IllegalArgumentException if the wrong number of weights were
* provided.
* @author FracPete, taken from old weka.core.Matrix class
*/
public final LinearRegression regression(Matrix y, double[] w, double ridge) {
return new LinearRegression(this, y, w, ridge);
}
/**
* Matrix determinant
* @return determinant
*/
public double det() {
return new LUDecomposition(this).det();
}
/**
* Matrix rank
* @return effective numerical rank, obtained from SVD.
*/
public int rank() {
return new SingularValueDecomposition(this).rank();
}
/**
* Matrix condition (2 norm)
* @return ratio of largest to smallest singular value.
*/
public double cond() {
return new SingularValueDecomposition(this).cond();
}
/**
* Matrix trace.
* @return sum of the diagonal elements.
*/
public double trace() {
double t = 0;
for (int i = 0; i < Math.min(m,n); i++) {
t += A[i][i];
}
return t;
}
/**
* Generate matrix with random elements
* @param m Number of rows.
* @param n Number of colums.
* @return An m-by-n matrix with uniformly distributed random elements.
*/
public static Matrix random(int m, int n) {
Matrix A = new Matrix(m,n);
double[][] X = A.getArray();
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
X[i][j] = Math.random();
}
}
return A;
}
/**
* Generate identity matrix
* @param m Number of rows.
* @param n Number of colums.
* @return An m-by-n matrix with ones on the diagonal and zeros elsewhere.
*/
public static Matrix identity(int m, int n) {
Matrix A = new Matrix(m,n);
double[][] X = A.getArray();
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
X[i][j] = (i == j ? 1.0 : 0.0);
}
}
return A;
}
/**
* Print the matrix to stdout. Line the elements up in columns
* with a Fortran-like 'Fw.d' style format.
* @param w Column width.
* @param d Number of digits after the decimal.
*/
public void print(int w, int d) {
print(new PrintWriter(System.out,true),w,d);
}
/**
* Print the matrix to the output stream. Line the elements up in
* columns with a Fortran-like 'Fw.d' style format.
* @param output Output stream.
* @param w Column width.
* @param d Number of digits after the decimal.
*/
public void print(PrintWriter output, int w, int d) {
DecimalFormat format = new DecimalFormat();
format.setDecimalFormatSymbols(new DecimalFormatSymbols(Locale.US));
format.setMinimumIntegerDigits(1);
format.setMaximumFractionDigits(d);
format.setMinimumFractionDigits(d);
format.setGroupingUsed(false);
print(output,format,w+2);
}
/**
* Print the matrix to stdout. Line the elements up in columns.
* Use the format object, and right justify within columns of width
* characters.
* Note that is the matrix is to be read back in, you probably will want
* to use a NumberFormat that is set to US Locale.
* @param format A Formatting object for individual elements.
* @param width Field width for each column.
* @see java.text.DecimalFormat#setDecimalFormatSymbols
*/
public void print(NumberFormat format, int width) {
print(new PrintWriter(System.out,true),format,width);
}
// DecimalFormat is a little disappointing coming from Fortran or C's printf.
// Since it doesn't pad on the left, the elements will come out different
// widths. Consequently, we'll pass the desired column width in as an
// argument and do the extra padding ourselves.
/**
* Print the matrix to the output stream. Line the elements up in columns.
* Use the format object, and right justify within columns of width
* characters.
* Note that is the matrix is to be read back in, you probably will want
* to use a NumberFormat that is set to US Locale.
* @param output the output stream.
* @param format A formatting object to format the matrix elements
* @param width Column width.
* @see java.text.DecimalFormat#setDecimalFormatSymbols
*/
public void print(PrintWriter output, NumberFormat format, int width) {
output.println(); // start on new line.
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
String s = format.format(A[i][j]); // format the number
int padding = Math.max(1,width-s.length()); // At _least_ 1 space
for (int k = 0; k < padding; k++)
output.print(' ');
output.print(s);
}
output.println();
}
output.println(); // end with blank line.
}
/**
* Read a matrix from a stream. The format is the same the print method,
* so printed matrices can be read back in (provided they were printed using
* US Locale). Elements are separated by
* whitespace, all the elements for each row appear on a single line,
* the last row is followed by a blank line.
* <p/>
* Note: This format differs from the one that can be read via the
* Matrix(Reader) constructor! For that format, the write(Writer) method
* is used (from the original weka.core.Matrix class).
*
* @param input the input stream.
* @see Matrix(Reader)
* @see #write(Writer)
*/
public static Matrix read(BufferedReader input) throws java.io.IOException {
StreamTokenizer tokenizer= new StreamTokenizer(input);
// Although StreamTokenizer will parse numbers, it doesn't recognize
// scientific notation (E or D); however, Double.valueOf does.
// The strategy here is to disable StreamTokenizer's number parsing.
// We'll only get whitespace delimited words, EOL's and EOF's.
// These words should all be numbers, for Double.valueOf to parse.
tokenizer.resetSyntax();
tokenizer.wordChars(0,255);
tokenizer.whitespaceChars(0, ' ');
tokenizer.eolIsSignificant(true);
java.util.Vector v = new java.util.Vector();
// Ignore initial empty lines
while (tokenizer.nextToken() == StreamTokenizer.TT_EOL);
if (tokenizer.ttype == StreamTokenizer.TT_EOF)
throw new java.io.IOException("Unexpected EOF on matrix read.");
do {
v.addElement(Double.valueOf(tokenizer.sval)); // Read & store 1st row.
} while (tokenizer.nextToken() == StreamTokenizer.TT_WORD);
int n = v.size(); // Now we've got the number of columns!
double row[] = new double[n];
for (int j=0; j<n; j++) // extract the elements of the 1st row.
row[j]=((Double)v.elementAt(j)).doubleValue();
v.removeAllElements();
v.addElement(row); // Start storing rows instead of columns.
while (tokenizer.nextToken() == StreamTokenizer.TT_WORD) {
// While non-empty lines
v.addElement(row = new double[n]);
int j = 0;
do {
if (j >= n) throw new java.io.IOException
("Row " + v.size() + " is too long.");
row[j++] = Double.valueOf(tokenizer.sval).doubleValue();
} while (tokenizer.nextToken() == StreamTokenizer.TT_WORD);
if (j < n) throw new java.io.IOException
("Row " + v.size() + " is too short.");
}
int m = v.size(); // Now we've got the number of rows.
double[][] A = new double[m][];
v.copyInto(A); // copy the rows out of the vector
return new Matrix(A);
}
/**
* Check if size(A) == size(B)
*/
private void checkMatrixDimensions(Matrix B) {
if (B.m != m || B.n != n) {
throw new IllegalArgumentException("Matrix dimensions must agree.");
}
}
/**
* Writes out a matrix. The format can be read via the Matrix(Reader)
* constructor.
*
* @param w the output Writer
* @throws Exception if an error occurs
* @see Matrix(Reader)
* @author FracPete, taken from old weka.core.Matrix class
*/
public void write(Writer w) throws Exception {
w.write("% Rows\tColumns\n");
w.write("" + getRowDimension() + "\t" + getColumnDimension() + "\n");
w.write("% Matrix elements\n");
for(int i = 0; i < getRowDimension(); i++) {
for(int j = 0; j < getColumnDimension(); j++)
w.write("" + get(i, j) + "\t");
w.write("\n");
}
w.flush();
}
/**
* Converts a matrix to a string
*
* @return the converted string
* @author FracPete, taken from old weka.core.Matrix class
*/
public String toString() {
// Determine the width required for the maximum element,
// and check for fractional display requirement.
double maxval = 0;
boolean fractional = false;
for (int i = 0; i < getRowDimension(); i++) {
for (int j = 0; j < getColumnDimension(); j++) {
double current = get(i, j);
if (current < 0)
current *= -11;
if (current > maxval)
maxval = current;
double fract = Math.abs(current - Math.rint(current));
if (!fractional
&& ((Math.log(fract) / Math.log(10)) >= -2)) {
fractional = true;
}
}
}
int width = (int)(Math.log(maxval) / Math.log(10)
+ (fractional ? 4 : 1));
StringBuffer text = new StringBuffer();
for (int i = 0; i < getRowDimension(); i++) {
for (int j = 0; j < getColumnDimension(); j++)
text.append(" ").append(Utils.doubleToString(get(i, j),
width, (fractional ? 2 : 0)));
text.append("\n");
}
return text.toString();
}
/**
* converts the Matrix into a single line Matlab string: matrix is enclosed
* by parentheses, rows are separated by semicolon and single cells by
* blanks, e.g., [1 2; 3 4].
* @return the matrix in Matlab single line format
*/
public String toMatlab() {
StringBuffer result;
int i;
int n;
result = new StringBuffer();
result.append("[");
for (i = 0; i < getRowDimension(); i++) {
if (i > 0)
result.append("; ");
for (n = 0; n < getColumnDimension(); n++) {
if (n > 0)
result.append(" ");
result.append(Double.toString(get(i, n)));
}
}
result.append("]");
return result.toString();
}
/**
* creates a matrix from the given Matlab string.
* @param matlab the matrix in matlab format
* @return the matrix represented by the given string
* @see #toMatlab()
*/
public static Matrix parseMatlab(String matlab) throws Exception {
StringTokenizer tokRow;
StringTokenizer tokCol;
int rows;
int cols;
Matrix result;
String cells;
// get content
cells = matlab.substring(
matlab.indexOf("[") + 1, matlab.indexOf("]")).trim();
// determine dimenions
tokRow = new StringTokenizer(cells, ";");
rows = tokRow.countTokens();
tokCol = new StringTokenizer(tokRow.nextToken(), " ");
cols = tokCol.countTokens();
// fill matrix
result = new Matrix(rows, cols);
tokRow = new StringTokenizer(cells, ";");
rows = 0;
while (tokRow.hasMoreTokens()) {
tokCol = new StringTokenizer(tokRow.nextToken(), " ");
cols = 0;
while (tokCol.hasMoreTokens()) {
result.set(rows, cols, Double.parseDouble(tokCol.nextToken()));
cols++;
}
rows++;
}
return result;
}
/**
* Main method for testing this class.
*/
public static void main(String[] args) {
Matrix I;
Matrix A;
Matrix B;
try {
// Identity
System.out.println("\nIdentity\n");
I = Matrix.identity(3, 5);
System.out.println("I(3,5)\n" + I);
// basic operations - square
System.out.println("\nbasic operations - square\n");
A = Matrix.random(3, 3);
B = Matrix.random(3, 3);
System.out.println("A\n" + A);
System.out.println("B\n" + B);
System.out.println("A'\n" + A.inverse());
System.out.println("A^T\n" + A.transpose());
System.out.println("A+B\n" + A.plus(B));
System.out.println("A*B\n" + A.times(B));
System.out.println("X from A*X=B\n" + A.solve(B));
// basic operations - non square
System.out.println("\nbasic operations - non square\n");
A = Matrix.random(2, 3);
B = Matrix.random(3, 4);
System.out.println("A\n" + A);
System.out.println("B\n" + B);
System.out.println("A*B\n" + A.times(B));
// sqrt
System.out.println("\nsqrt (1)\n");
A = new Matrix(new double[][]{{5,-4,1,0,0},{-4,6,-4,1,0},{1,-4,6,-4,1},{0,1,-4,6,-4},{0,0,1,-4,5}});
System.out.println("A\n" + A);
System.out.println("sqrt(A)\n" + A.sqrt());
// sqrt
System.out.println("\nsqrt (2)\n");
A = new Matrix(new double[][]{{7, 10},{15, 22}});
System.out.println("A\n" + A);
System.out.println("sqrt(A)\n" + A.sqrt());
System.out.println("det(A)\n" + A.det() + "\n");
// eigenvalue decomp.
System.out.println("\nEigenvalue Decomposition\n");
EigenvalueDecomposition evd = A.eig();
System.out.println("[V,D] = eig(A)");
System.out.println("- V\n" + evd.getV());
System.out.println("- D\n" + evd.getD());
// LU decomp.
System.out.println("\nLU Decomposition\n");
LUDecomposition lud = A.lu();
System.out.println("[L,U,P] = lu(A)");
System.out.println("- L\n" + lud.getL());
System.out.println("- U\n" + lud.getU());
System.out.println("- P\n" + Utils.arrayToString(lud.getPivot()) + "\n");
// regression
System.out.println("\nRegression\n");
B = new Matrix(new double[][]{{3},{2}});
double ridge = 0.5;
double[] weights = new double[]{0.3, 0.7};
LinearRegression lr = A.regression(B, ridge);
System.out.println("A\n" + A);
System.out.println("B\n" + B);
System.out.println("ridge = " + ridge + "\n");
System.out.println("weights = " + Utils.arrayToString(weights) + "\n");
System.out.println("A.regression(B, ridge)\n"
+ A.regression(B, ridge) + "\n");
System.out.println("A.regression(B, weights, ridge)\n"
+ A.regression(B, weights, ridge) + "\n");
// writer/reader
System.out.println("\nWriter/Reader\n");
StringWriter writer = new StringWriter();
A.write(writer);
System.out.println("A.write(Writer)\n" + writer);
A = new Matrix(new StringReader(writer.toString()));
System.out.println("A = new Matrix.read(Reader)\n" + A);
// Matlab
System.out.println("\nMatlab-Format\n");
String matlab = "[ 1 2;3 4 ]";
System.out.println("Matlab: " + matlab);
System.out.println("from Matlab:\n" + Matrix.parseMatlab(matlab));
System.out.println("to Matlab:\n" + Matrix.parseMatlab(matlab).toMatlab());
matlab = "[1 2 3 4;3 4 5 6;7 8 9 10]";
System.out.println("Matlab: " + matlab);
System.out.println("from Matlab:\n" + Matrix.parseMatlab(matlab));
System.out.println("to Matlab:\n" + Matrix.parseMatlab(matlab).toMatlab() + "\n");
}
catch (Exception e) {
e.printStackTrace();
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -