📄 lpassigner.java
字号:
while( pairItr.hasNext() ){ InstancePair pair = (InstancePair) pairItr.next(); addPairPenalties(pair, idx, objCoeffs); idx++; } } } /** accumulate penalties associated with a given constraint */ protected void addPairPenalties(InstancePair pair, int idx, double[] objCoeffs) throws Exception { int instance1Idx = pair.first; int instance2Idx = pair.second; Instance instance1 = m_instances.instance(instance1Idx); Instance instance2 = m_instances.instance(instance2Idx); int linkType = ((Integer) m_constraintHash.get(pair)).intValue(); double cost = 0; if (linkType == InstancePair.MUST_LINK) { cost = m_clusterer.getMustLinkWeight(); } else if (linkType == InstancePair.CANNOT_LINK) { cost = m_clusterer.getCannotLinkWeight(); } // if a single metric is used, we don't need to calculate separately for each cluster if (!m_useMultipleMetrics) { // MAJOR KLUDGE. TODO: create penalty(InstancePair) method in MPCKMeans; use both internally and here; // avoid iterating through constraints inside individual calculateConstraintPenalties methods double penalty = 0; // add the penalty for different types of metrics if (m_metric instanceof WeightedDotP) { double sim = m_metric.similarity(instance1, instance2); if (linkType == InstancePair.MUST_LINK) { penalty = -cost * (1 - sim); } else if (linkType == InstancePair.CANNOT_LINK) { penalty = -cost * sim; } } else if (m_metric instanceof KL) { double distance = ((KL) m_metric).distanceJS(instance1, instance2); if (linkType == InstancePair.MUST_LINK) { penalty = cost * distance; } else if (linkType == InstancePair.CANNOT_LINK) { penalty = cost * (2.0 - distance); } } else if (m_metric instanceof WeightedEuclidean || m_metric instanceof WeightedMahalanobis) { double distance = m_metric.distance(instance1, instance2); if (linkType == InstancePair.MUST_LINK) { penalty = cost * distance * distance; } else if (linkType == InstancePair.CANNOT_LINK) { penalty = cost * (m_maxCLDistances[0] * m_maxCLDistances[0] - distance * distance); } } else { throw new Exception("Unknown metric: " + m_metric.getClass().getName()); } // y_m = 0.5 sum_j (y_{mj}) if (linkType == InstancePair.MUST_LINK) { penalty = 0.5 * penalty; } else { // penalty = -0.5 * penalty; } int offset = m_numLabelVars; for (int centroidIdx = 0; centroidIdx < m_numClusters; centroidIdx++) { objCoeffs[offset + idx * m_numClusters + centroidIdx] += penalty; } } else { // MULTIPLE METRICS // KLUDGE - TODO - CURRENTLY WRONG!// for (int centroidIdx1 = 0; centroidIdx1 < m_numClusters; centroidIdx1++) {// for (int centroidIdx2 = 0; centroidIdx2 < m_numClusters; centroidIdx2++) {// double penalty = 0;// if (m_metric instanceof WeightedDotP) {// double sim1 = m_metrics[centroidIdx1].similarity(instance1, instance2);// double sim2 = m_metrics[centroidIdx2].similarity(instance1, instance2);// penalty = 0.5 * cost * (1 - sim2) + 0.5 * cost * (1 - sim1);// } else if (m_metric instanceof KL) {// double penalty1 = ((KL) m_metrics[centroidIdx1]).distanceJS(instance1, instance2);// double penalty2 = ((KL) m_metrics[centroidIdx2]).distanceJS(instance1, instance2);// penalty = 0.5 * cost * (penalty1 + penalty2);// } else if (m_metric instanceof WeightedEuclidean || m_metric instanceof WeightedMahalanobis) {// double distance1 = m_metrics[centroidIdx1].distance(instance1, instance2);// double distance2 = m_metrics[centroidIdx2].distance(instance1, instance2);// penalty = 0.5 * cost * (distance1*distance1 + distance2*distance2);// } else {// throw new Exception("Unknown metric: " + m_metric.getClass().getName());// }// objCoeffs[centroidIdx1 * m_numInstances + instance1Idx] += penalty;// objCoeffs[centroidIdx1 * m_numInstances + instance2Idx] += penalty;// objCoeffs[centroidIdx2 * m_numInstances + instance1Idx] += penalty;// objCoeffs[centroidIdx2 * m_numInstances + instance2Idx] += penalty;// }// } } } /** * Dump data matrix into a file */ protected void dumpData(double[] objCoeffs, double[][] A_eq, double[] b_eq, double[][] A, double[] b) { if (m_engineType == ENGINE_TOMLAB) { dumpDataTomLab(objCoeffs,A_eq, b_eq, A,b); } else { try { File dataFile = File.createTempFile(m_dataFilenameBase, ".m", m_tempDirFile); m_dataFilename = dataFile.getPath(); if (!m_clusterer.getVerbose()) { dataFile.deleteOnExit(); } PrintWriter writer = new PrintWriter(new BufferedOutputStream(new FileOutputStream(dataFile))); // dump f writer.print("f = ["); for (int i = 0; i < objCoeffs.length; i++) { writer.print(objCoeffs[i] + "; "); } writer.println("];"); // dump Aeq if (m_engineType != ENGINE_OCTAVE) { writer.print("Aeq = ["); for (int i = 0; i < A_eq.length; i++) { for (int j = 0; j < A_eq[i].length; j++) { writer.print(A_eq[i][j] + ", "); } writer.flush(); writer.println(";"); } writer.println("];"); } else { // for octave, we dump into a separate file... PrintWriter writerAeq = new PrintWriter(new BufferedOutputStream(new FileOutputStream(m_tempDirPath + "Aeq"))); for (int i = 0; i < A_eq.length; i++) { for (int j = 0; j < A_eq[i].length; j++) { writerAeq.print(A_eq[i][j] + " "); } writerAeq.flush(); writerAeq.println(); } writerAeq.close(); } // dump b writer.print("beq = ["); for (int i = 0; i < b_eq.length; i++) { writer.print(b_eq[i] + "; "); } writer.println("];"); // dump A PrintWriter writerA = new PrintWriter(new BufferedOutputStream(new FileOutputStream(m_dataFilename + ".A"))); for (int i = 0; i < A.length; i++) { for (int j = 0; j < A[i].length; j++) { writerA.print(A[i][j] + " "); } writerA.println(); } writerA.flush(); writerA.close(); // dump b writer.print("b = ["); for (int i = 0; i < b.length; i++) { writer.print(b[i] + "; "); } writer.println("];"); writer.close(); } catch (Exception e) { System.err.println("Could not create temporary file \'" + m_dataFilename + "\' for dumping the LP: " + e); } } }/** * Dump data matrix into a file */ protected void dumpDataTomLab(double[] objCoeffs, double[][] A_eq, double[] b_eq, double[][] A, double[] b) { try { File dataFile = File.createTempFile(m_dataFilenameBase, ".m", m_tempDirFile); m_dataFilename = dataFile.getPath(); if (!m_clusterer.getVerbose()) { dataFile.deleteOnExit(); } PrintWriter writer = new PrintWriter(new BufferedOutputStream(new FileOutputStream(dataFile))); // dump f writer.print("f = ["); for (int i = 0; i < objCoeffs.length; i++) { writer.print(objCoeffs[i] + "; "); } writer.println("];"); // dump xl and xu writer.println("xl = zeros(" + m_numVars + ",1);"); writer.println("xu = ones(" + m_numVars + ",1);"); // dump bu writer.print("bu = [ones(" + m_numInstances + ",1); "); for (int i = m_numVars; i < b.length; i++) { writer.print(b[i] + "; "); } writer.println("];"); // dump bl writer.println("bl = [ones(" + m_numInstances + ",1); zeros(" + b.length + "-" + m_numVars + ",1)];"); writer.println("bl(" + (m_numInstances+1) + ":" + (m_numInstances + b.length - m_numVars) + ",:)=-Inf;"); writer.close(); // dump A into a separate file PrintWriter writerA = new PrintWriter(new BufferedOutputStream(new FileOutputStream(m_dataFilename + ".A"))); File aFile = new File(m_dataFilename + ".A"); aFile.deleteOnExit(); // first, dump Aeq for (int i = 0; i < A_eq.length; i++) { for (int j = 0; j < A_eq[i].length; j++) { writerA.print(A_eq[i][j] + " "); } writerA.flush(); writerA.println(); } // next, dump constraints from A for (int i = m_numVars; i < A.length; i++) { for (int j = 0; j < A[i].length; j++) { writerA.print(A[i][j] + " "); } writerA.flush(); writerA.println(); } writerA.close(); } catch (Exception e) { System.err.println("Could not create temporary file \'" + m_dataFilename + "\' for dumping the LP: " + e); } } /** Read the solution from the output file of Octave */protected double[][] getSolution() { double[][] probs = new double[m_numLabelVars][1]; try { BufferedReader r = new BufferedReader(new FileReader(m_outFilename)); String s = null; int i = 0; while ((s = r.readLine()) != null && i < m_numLabelVars) { probs[i++][0] = Double.parseDouble(s); } } catch (Exception e) { System.out.println("Problems reading the solution from the engine: " + e); e.printStackTrace(); } File aFile = new File(m_dataFilename + ".A"); aFile.delete(); File dataFile = new File(m_dataFilename); dataFile.delete(); return probs; } /** Create octave m-file * @param filename file where the script is created */ public void prepareEngine() { try{ PrintWriter writer = new PrintWriter(new BufferedOutputStream(new FileOutputStream(m_progFilename))); writer.println("cd " + m_tempDirPath + ";"); String dataFilename = Utils.removeSubstring(m_dataFilename, m_tempDirPath); dataFilename = Utils.removeSubstring(dataFilename, ".m"); writer.println(dataFilename + ";"); switch (m_engineType) { case ENGINE_MATLAB: writer.println("A = load(\'" + m_dataFilename + ".A" + "\');"); writer.println("x = linprog(f,A,b,Aeq,beq);"); break; case ENGINE_TOMLAB: writer.println("cd /u/ml/software/tomlab;"); writer.println("startup;"); writer.println("A = load(\'" + m_dataFilename + ".A" + "\');"); writer.println("Prob = lpAssign(f,A,bl,bu,xl,xu,[],'test');"); writer.println("Result = tomRun('pdco', Prob,[],1);"); writer.println("x = Result.x_k;"); break; case ENGINE_OCTAVE: // load A and Aeq stored in auxiliary files writer.println("load A;"); writer.println("load Aeq;"); } File outFile = File.createTempFile(m_outFilenameBase, ".out", m_tempDirFile); m_outFilename = outFile.getPath(); if (!m_clusterer.getVerbose()) { outFile.deleteOnExit(); File outFileDump = new File(m_outFilename + ".dump"); outFileDump.deleteOnExit(); } writer.println("x"); writer.println("save " + m_outFilename + " x -ascii;"); writer.close(); } catch (Exception e) { System.err.println("Could not create script file \'" + m_progFilename + "\': " + e); } } /** Run octave in command line with a given argument * @param inFile file to be input * @param outFile file where results are stored */ public int runEngine() { int exitValue = -1; try { String cmd = ""; if (m_engineType == ENGINE_OCTAVE) { cmd = "octave " + m_progFilename + " > " + m_outFilename; } else if (m_engineType == ENGINE_MATLAB || m_engineType == ENGINE_TOMLAB) { cmd = "matlab -nodesktop -nosplash < " + m_progFilename + " > " + m_outFilename + ".dump"; } System.out.println("Starting to run engine " + m_engineType + cmd); Process proc = Runtime.getRuntime().exec(cmd); System.out.println("Waiting for process ..."); // read the error if (proc != null){ BufferedReader procError = new BufferedReader(new InputStreamReader(proc.getErrorStream())); try { String line; while ((line = procError.readLine()) != null){ System.out.println("ERROR: " + line); } } catch (Exception e) { System.err.println("Problems trapping error stream in debug mode: " + e); e.printStackTrace(); } } // read the output if (proc != null){ BufferedReader procOutput = new BufferedReader(new InputStreamReader(proc.getInputStream())); try { String line; while ((line = procOutput.readLine()) != null){ System.out.println("OUTPUT: " + line); } } catch (Exception e) { System.err.println("Problems trapping output in debug mode: " + e); e.printStackTrace(); } } exitValue = proc.waitFor(); System.out.println("End of running engine, exitValue = " + exitValue); } catch (Exception e) { System.err.println("Problems running engine: " + e); e.printStackTrace(); } return exitValue; } protected double[] calculateMaxDistances(Instance maxCLPoints[][]) throws Exception { double [] maxCLDistances = new double[maxCLPoints.length]; for (int i = 0; i < maxCLDistances.length; i++) { if (m_useMultipleMetrics) { maxCLDistances[i] = m_metrics[i].distance(maxCLPoints[i][0], maxCLPoints[i][1]); } else { maxCLDistances[i] = m_metric.distance(maxCLPoints[0][0], maxCLPoints[0][1]); } } return maxCLDistances; } /** Set the engine type * @param type one of the kernel types */ public void setEngineType(SelectedTag engineType) { if (engineType.getTags() == TAGS_ENGINE_TYPE) { m_engineType = engineType.getSelectedTag().getID(); } } /** Get the engine type * @return engine type */ public SelectedTag getEngineType() { return new SelectedTag(m_engineType, TAGS_ENGINE_TYPE); } public void setOptions (String[] options) throws Exception { // TODO } public Enumeration listOptions () { // TODO return null; } public String [] getOptions () { String[] options = new String[1]; int current = 0; switch (m_engineType) { case ENGINE_JMATLINK: options[current++] = "jmatlink"; break; case ENGINE_OCTAVE: options[current++] = "octave"; break; case ENGINE_MATLAB: options[current++] = "matlab"; break; case ENGINE_TOMLAB: options[current++] = "tomlab"; break; default: options[current++] = "unknown"; } return options; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -