📄 matlabnmf.java
字号:
"r=param(1);\n"+ "maxiter=param(2);\n"+ "obj=param(3);\n"+ "[n m]=size(V);\n"+ "W=rand(n,r);\n"+ "H=rand(r,m);\n"+ "eps=1e-9;\n"+ "if obj == 1\n"+ "W=W./(ones(n,1)*sum(W));\n"+ "end\n"+ "for iter=1:maxiter\n"+ "switch obj\n"+ "case 1\n"+ "H=H.*(W'*((V+eps)./(W*H+eps)));\n"+ "W=W.*(((V+eps)./(W*H+eps))*H');\n"+ "W=W./(ones(n,1)*sum(W));\n"+ "case 2\n"+ "H=H.*((W'*V+eps)./(W'*W*H+eps));\n"+ "W=W.*((V*H'+eps)./(W*H*H'+eps));\n"+ "case 3\n"+ "H=H.*((W'*((V+eps)./(W*H+eps)))./((sum(W))'*ones(1,m)));\n"+ "W=W.*((((V+eps)./(W*H+eps))*H')./(ones(n,1)*(sum(H,2))'));\n"+ "end\n"+ "end\n"+ "save '" + m_basisFilename + "' n r W -ASCII -DOUBLE\n"+ "save '" + m_encodingFilename + "' r m H -ASCII -DOUBLE\n"); nmf.close(); } catch (Exception e) { System.err.println("Could not create temporary files for dumping the scripts: " + e); } } /** * Dump covariance matrix into a file */ private void dumpInstances(String tempFile) { try { PrintWriter writer = new PrintWriter(new BufferedOutputStream(new FileOutputStream(tempFile))); for (int j = 0; j < m_numAttribs; j++) { for (int k = 0; k < m_numInstances; k++) { Instance instance = m_trainInstances.instance(k); writer.print(instance.value(j) + " "); } writer.println(); } writer.close(); } catch (Exception e) { System.err.println("Could not create a temporary file for dumping the covariance matrix: " + e); } } /** * Dump a column vector of size n into a file */ private void dumpInstance(String tempFile, Instance instance) { try { PrintWriter writer = new PrintWriter(new BufferedOutputStream(new FileOutputStream(tempFile))); for (int j = 0; j < m_numAttribs; j++) { writer.println(instance.value(j)); } writer.close(); } catch (Exception e) { System.err.println("Could not create a temporary file for dumping the column vector: " + e); } } /** * Dump an integer array into a file */ private void dumpVector(String tempFile, int[] a, int n) { try { PrintWriter writer = new PrintWriter(new BufferedOutputStream(new FileOutputStream(tempFile))); for (int j = 0; j < n; j++) { writer.println(a[j]); } writer.close(); } catch (Exception e) { System.err.println("Could not create a temporary file for dumping the column vector: " + e); } } /** Run matlab in command line with a given argument * @param inFile file to be input to Matlab */ public static void runMatlab(String inFile) { // call matlab to do the dirty work try { int exitValue; do { Process proc = Runtime.getRuntime().exec("matlab -tty < " + inFile); exitValue = proc.waitFor(); if (exitValue != 0) { System.err.println("Abnormal termination, trying again later!"); Thread.sleep(300000); } } while (exitValue != 0); } catch (Exception e) { System.err.println("Problems running matlab: " + e); } } /** * Return a summary of the analysis * @return a summary of the analysis. */ private String NMFSummary() { StringBuffer result = new StringBuffer(); Instances output = null; int numVectors=0; try { output = setOutputFormat(); numVectors = (output.classIndex() < 0) ? output.numAttributes() : output.numAttributes()-1; } catch (Exception ex) { } System.out.println("Sanity check: numVectors=" + numVectors); result.append("Basis vectors:\n"); for (int j = 1;j <= numVectors;j++) { result.append(" V"+j+"\t"); } result.append("\n"); for (int j = 0; j < m_numAttribs; j++) { for (int i = 0; i < numVectors; i++) result.append(Utils.doubleToString(m_basis[j][i],7,4)+"\t"); result.append(m_trainInstances.attribute(j).name()+"\n"); } result.append("\nAttribute ranking filter:\n"); result.append(m_eval.toString()); return result.toString(); } /** * Returns a description of this attribute transformer * @return a String describing this attribute transformer */ public String toString() { if (m_basis == null) { return "Basis hasn't been formed yet!"; } else { return "\tNMF Attribute Transformer\n\n" + NMFSummary(); } } /** * Return a matrix as a String * @param matrix that is decribed as a string * @return a String describing a matrix */ private String matrixToString(double [][] matrix) { StringBuffer result = new StringBuffer(); int last = matrix.length - 1; for (int i = 0; i <= last; i++) { for (int j = 0; j <= last; j++) { result.append(Utils.doubleToString(matrix[i][j],6,2)+" "); if (j == last) { result.append('\n'); } } } return result.toString(); } /** * Transform an instance in original (unnormalized) format. Convert back * to the original space if requested. * @param instance an instance in the original (unnormalized) format * @return a transformed instance * @exception Exception if instance cant be transformed */ public Instance convertInstance(Instance instance) throws Exception { if (m_basis == null) { throw new Exception("convertInstance: Basis not formed yet"); } Instance tempInst = (Instance)instance.copy(); if (!instance.equalHeaders(m_trainCopy.instance(0))) { throw new Exception("Can't convert instance: headers don't match: " +"MatlabNMF"); } m_replaceMissingFilter.input(tempInst); m_replaceMissingFilter.batchFinished(); tempInst = m_replaceMissingFilter.output(); if (m_normalize) { m_normalizeFilter.input(tempInst); m_normalizeFilter.batchFinished(); tempInst = m_normalizeFilter.output(); } if (m_attributeFilter != null) { m_attributeFilter.input(tempInst); m_attributeFilter.batchFinished(); tempInst = m_attributeFilter.output(); } // Wanted to do it on Matlab but it's too expensive double[] v = new double[m_numAttribs]; double eps = 1e-9; for (int i = 0; i < m_numAttribs; ++i) v[i] = tempInst.value(i); double[][] W = m_basis; double[] h = null; if (m_hasClass) h = new double[m_rank + 1]; else h = new double[m_rank]; for (int i = 0; i < m_rank; ++i) h[i] = Math.random(); double[] t1 = new double[m_numAttribs]; double[] t2 = new double[m_rank]; double[] t3 = new double[m_rank]; for (int i = 0; i < m_iter; ++i) { switch (m_obj) { case 1: // t1 = W*h+eps for (int j = 0; j < m_numAttribs; ++j) { t1[j] = eps; for (int k = 0; k < m_rank; ++k) t1[j] += W[j][k] * h[k]; } // t1 = (v+eps)./t1 for (int j = 0; j < m_numAttribs; ++j) t1[j] = (v[j]+eps) / t1[j]; // t2 = W'*t1 for (int j = 0; j < m_rank; ++j) { t2[j] = 0.0; for (int k = 0; k < m_numAttribs; ++k) t2[j] += W[k][j] * t1[k]; } break; case 2: // t2 = W'*v for (int j = 0; j < m_rank; ++j) { t2[j] = 0.0; for (int k = 0; k < m_numAttribs; ++k) t2[j] += W[k][j] * v[k]; } // t1 = W*h for (int j = 0; j < m_numAttribs; ++j) { t1[j] = 0.0; for (int k = 0; k < m_rank; ++k) t1[j] += W[j][k] * h[k]; } // t3 = W'*t1 for (int j = 0; j < m_rank; ++j) { t3[j] = 0.0; for (int k = 0; k < m_numAttribs; ++k) t3[j] += W[k][j] * t1[k]; } // t2 = (t2+eps)./(t3+eps) for (int j = 0; j < m_rank; ++j) t2[j] = (t2[j]+eps) / (t3[j]+eps); break; case 3: // t1 = W*h for (int j = 0; j < m_numAttribs; ++j) { t1[j] = 0.0; for (int k = 0; k < m_rank; ++k) t1[j] += W[j][k] * h[k]; } // t1 = (v+eps)./(t1+eps) for (int j = 0; j < m_numAttribs; ++j) t1[j] = (v[j]+eps) / (t1[j]+eps); // t2 = W'*t1 for (int j = 0; j < m_rank; ++j) { t2[j] = 0.0; for (int k = 0; k < m_numAttribs; ++k) t2[j] += W[k][j] * t1[k]; } // t3 = (sum(W))' for (int j = 0; j < m_rank; ++j) { t3[j] = 0.0; for (int k = 0; k < m_numAttribs; ++k) t3[j] += W[k][j]; } // t2 = t2./t3 for (int j = 0; j < m_rank; ++j) t2[j] /= t3[j]; break; } // h = h.*t2 for (int j = 0; j < m_rank; ++j) h[j] *= t2[j]; } if (m_hasClass) { h[m_rank] = instance.value(instance.classIndex()); } System.err.print("."); if (instance instanceof SparseInstance) { return new SparseInstance(instance.weight(), h); } else { return new Instance(instance.weight(), h); } } /** * Set the format for the transformed data * @return a set of empty Instances (header only) in the new format * @exception Exception if the output format can't be set */ private Instances setOutputFormat() throws Exception { if (m_basis == null) { return null; } FastVector attributes = new FastVector(); // add attribute names for (int i = 1; i <= m_basis[0].length; ++i) { attributes.addElement(new Attribute("enc-" + Integer.toString(i))); } if (m_hasClass) { attributes.addElement(m_trainCopy.classAttribute().copy()); } Instances outputFormat = new Instances(m_trainInstances.relationName()+"->NMF", attributes, 0); // set the class to be the last attribute if necessary if (m_hasClass) { outputFormat.setClassIndex(outputFormat.numAttributes()-1); } return outputFormat; } /** * Main method for testing this class * @param argv should contain the command line arguments to the * evaluator/transformer (see AttributeSelection) */ public static void main(String [] argv) { try { System.out.println(AttributeSelection. SelectAttributes(new MatlabNMF(), argv)); } catch (Exception e) { e.printStackTrace(); System.out.println(e.getMessage()); } } }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -