gridsearch.java
来自「Weka」· Java 代码 · 共 2,258 行 · 第 1/5 页
JAVA
2,258 行
/** * returns the performance measure * * @param evaluation the type of measure to return * @return the performance measure */ public double getPerformance(int evaluation) { double result; result = Double.NaN; switch (evaluation) { case EVALUATION_CC: result = m_CC; break; case EVALUATION_RMSE: result = m_RMSE; break; case EVALUATION_RRSE: result = m_RRSE; break; case EVALUATION_MAE: result = m_MAE; break; case EVALUATION_RAE: result = m_RAE; break; case EVALUATION_COMBINED: result = (1 - StrictMath.abs(m_CC)) + m_RRSE + m_RAE; break; case EVALUATION_ACC: result = m_ACC; break; default: throw new IllegalArgumentException("Evaluation type '" + evaluation + "' not supported!"); } return result; } /** * returns the values-pair for this performance * * @return the values-pair */ public PointDouble getValues() { return m_Values; } /** * returns a string representation of this performance object * * @param evaluation the type of performance to return * @return a string representation */ public String toString(int evaluation) { String result; result = "Performance (" + getValues() + "): " + getPerformance(evaluation) + " (" + new SelectedTag(evaluation, TAGS_EVALUATION) + ")"; return result; } /** * returns a Gnuplot string of this performance object * * @param evaluation the type of performance to return * @return the gnuplot string (x, y, z) */ public String toGnuplot(int evaluation) { String result; result = getValues().getX() + "\t" + getValues().getY() + "\t" + getPerformance(evaluation); return result; } /** * returns a string representation of this performance object * * @return a string representation */ public String toString() { String result; int i; result = "Performance (" + getValues() + "): "; for (i = 0; i < TAGS_EVALUATION.length; i++) { if (i > 0) result += ", "; result += getPerformance(TAGS_EVALUATION[i].getID()) + " (" + new SelectedTag(TAGS_EVALUATION[i].getID(), TAGS_EVALUATION) + ")"; } return result; } } /** * A concrete Comparator for the Performance class. * * @see Performance */ protected class PerformanceComparator implements Comparator<Performance>, Serializable { /** for serialization */ private static final long serialVersionUID = 6507592831825393847L; /** the performance measure to use for comparison * @see GridSearch#TAGS_EVALUATION */ protected int m_Evaluation; /** * initializes the comparator with the given performance measure * * @param evaluation the performance measure to use * @see GridSearch#TAGS_EVALUATION */ public PerformanceComparator(int evaluation) { super(); m_Evaluation = evaluation; } /** * returns the performance measure that's used to compare the objects * * @return the performance measure * @see GridSearch#TAGS_EVALUATION */ public int getEvaluation() { return m_Evaluation; } /** * Compares its two arguments for order. Returns a negative integer, * zero, or a positive integer as the first argument is less than, * equal to, or greater than the second. * * @param o1 the first performance * @param o2 the second performance * @return the order */ public int compare(Performance o1, Performance o2) { int result; double p1; double p2; p1 = o1.getPerformance(getEvaluation()); p2 = o2.getPerformance(getEvaluation()); if (Utils.sm(p1, p2)) result = -1; else if (Utils.gr(p1, p2)) result = 1; else result = 0; // only correlation coefficient and accuracy obey to this order, for the // errors (and the combination of all three), the smaller the number the // better -> hence invert them if ( (getEvaluation() != EVALUATION_CC) && (getEvaluation() != EVALUATION_ACC) ) result = -result; return result; } /** * Indicates whether some other object is "equal to" this Comparator. * * @param obj the object to compare with * @return true if the same evaluation type is used */ public boolean equals(Object obj) { if (!(obj instanceof PerformanceComparator)) throw new IllegalArgumentException("Must be PerformanceComparator!"); return (m_Evaluation == ((PerformanceComparator) obj).m_Evaluation); } } /** * Generates a 2-dim array for the performances from a grid for a certain * type. x-min/y-min is in the bottom-left corner, i.e., getTable()[0][0] * returns the performance for the x-min/y-max pair. * <pre> * x-min x-max * |-------------| * - y-max * | * | * - y-min * </pre> */ protected class PerformanceTable implements Serializable { /** for serialization */ private static final long serialVersionUID = 5486491313460338379L; /** the corresponding grid */ protected Grid m_Grid; /** the performances */ protected Vector<Performance> m_Performances; /** the type of performance the table was generated for */ protected int m_Type; /** the table with the values */ protected double[][] m_Table; /** the minimum performance */ protected double m_Min; /** the maximum performance */ protected double m_Max; /** * initializes the table * * @param grid the underlying grid * @param performances the performances * @param type the type of performance */ public PerformanceTable(Grid grid, Vector<Performance> performances, int type) { super(); m_Grid = grid; m_Type = type; m_Performances = performances; generate(); } /** * generates the table */ protected void generate() { Performance perf; int i; PointInt location; m_Table = new double[getGrid().height()][getGrid().width()]; m_Min = 0; m_Max = 0; for (i = 0; i < getPerformances().size(); i++) { perf = (Performance) getPerformances().get(i); location = getGrid().getLocation(perf.getValues()); m_Table[getGrid().height() - (int) location.getY() - 1][(int) location.getX()] = perf.getPerformance(getType()); // determine min/max if (i == 0) { m_Min = perf.getPerformance(m_Type); m_Max = m_Min; } else { if (perf.getPerformance(m_Type) < m_Min) m_Min = perf.getPerformance(m_Type); if (perf.getPerformance(m_Type) > m_Max) m_Max = perf.getPerformance(m_Type); } } } /** * returns the corresponding grid * * @return the underlying grid */ public Grid getGrid() { return m_Grid; } /** * returns the underlying performances * * @return the underlying performances */ public Vector<Performance> getPerformances() { return m_Performances; } /** * returns the type of performance * * @return the type of performance */ public int getType() { return m_Type; } /** * returns the generated table * * @return the performance table * @see #m_Table * @see #generate() */ public double[][] getTable() { return m_Table; } /** * the minimum performance * * @return the performance */ public double getMin() { return m_Min; } /** * the maximum performance * * @return the performance */ public double getMax() { return m_Max; } /** * returns the table as string * * @return the table as string */ public String toString() { String result; int i; int n; result = "Table (" + new SelectedTag(getType(), TAGS_EVALUATION).getSelectedTag().getReadable() + ") - " + "X: " + getGrid().getLabelX() + ", Y: " + getGrid().getLabelY() + ":\n"; for (i = 0; i < getTable().length; i++) { if (i > 0) result += "\n"; for (n = 0; n < getTable()[i].length; n++) { if (n > 0) result += ","; result += getTable()[i][n]; } } return result; } /** * returns a string containing a gnuplot script+data file * * @return the data in gnuplot format */ public String toGnuplot() { StringBuffer result; Tag type; int i; result = new StringBuffer(); type = new SelectedTag(getType(), TAGS_EVALUATION).getSelectedTag(); result.append("Gnuplot (" + type.getReadable() + "):\n"); result.append("# begin 'gridsearch.data'\n"); result.append("# " + type.getReadable() + "\n"); for (i = 0; i < getPerformances().size(); i++) result.append(getPerformances().get(i).toGnuplot(type.getID()) + "\n"); result.append("# end 'gridsearch.data'\n\n"); result.append("# begin 'gridsearch.plot'\n"); result.append("# " + type.getReadable() + "\n"); result.append("set data style lines\n"); result.append("set contour base\n"); result.append("set surface\n"); result.append("set title '" + m_Data.relationName() + "'\n"); result.append("set xrange [" + getGrid().getMinX() + ":" + getGrid().getMaxX() + "]\n"); result.append("set xlabel 'x (" + getFilter().getClass().getName() + ": " + getXProperty() + ")'\n"); result.append("set yrange [" + getGrid().getMinY() + ":" + getGrid().getMaxY() + "]\n"); result.append("set ylabel 'y - (" + getClassifier().getClass().getName() + ": " + getYProperty() + ")'\n"); result.append("set zrange [" + (getMin() - (getMax() - getMin())*0.1) + ":" + (getMax() + (getMax() - getMin())*0.1) + "]\n"); result.append("set zlabel 'z - " + type.getReadable() + "'\n"); result.append("set dgrid3d " + getGrid().height() + "," + getGrid().width() + ",1\n"); result.append("show contour\n"); result.append("splot 'gridsearch.data'\n"); result.append("pause -1\n"); result.append("# end 'gridsearch.plot'"); return result.toString(); } } /** * Represents a simple cache for performance objects. */ protected class PerformanceCache implements Serializable { /** for serialization */ private static final long serialVersionUID = 5838863230451530252L; /** the cache for points in the grid that got calculated */ protected Hashtable m_Cache = new Hashtable(); /** * returns the ID string for a cache item * * @param cv the number of folds in the cross-validation * @param values the point in the grid * @return the ID string */ protected String getID(int cv, PointDouble values) { return cv + "\t" + values.getX() + "\t" + values.getY(); } /** * checks whether the point was already calculated ones * * @param cv the number of folds in the cross-validation * @param values the point in the grid * @return true if the value is already cached */ public boolean isCached(int cv, PointDouble values) { return (get(cv, values) != null); } /** * returns a cached performance object, null if not yet in the cache * * @param cv the number of folds in the cross-validation * @param values the point in the grid * @return the cached performance item, null if not in cache */ public Performance get(int cv, PointDouble values) { return (Performance) m_Cache.get(getID(cv, values)); } /** * adds the performance to the cache * * @param cv the number of folds in the cross-validation * @param p the performance object to store
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?