📄 gridsearch.java
字号:
* @return a string representation */ public String toString() { String result; int i; result = "Performance (" + getValues() + "): "; for (i = 0; i < TAGS_EVALUATION.length; i++) { if (i > 0) result += ", "; result += getPerformance(TAGS_EVALUATION[i].getID()) + " (" + new SelectedTag(TAGS_EVALUATION[i].getID(), TAGS_EVALUATION) + ")"; } return result; } } /** * A concrete Comparator for the Performance class. * * @see Performance */ protected class PerformanceComparator implements Comparator<Performance>, Serializable { /** for serialization */ private static final long serialVersionUID = 6507592831825393847L; /** the performance measure to use for comparison * @see GridSearch#TAGS_EVALUATION */ protected int m_Evaluation; /** * initializes the comparator with the given performance measure * * @param evaluation the performance measure to use * @see GridSearch#TAGS_EVALUATION */ public PerformanceComparator(int evaluation) { super(); m_Evaluation = evaluation; } /** * returns the performance measure that's used to compare the objects * * @return the performance measure * @see GridSearch#TAGS_EVALUATION */ public int getEvaluation() { return m_Evaluation; } /** * Compares its two arguments for order. Returns a negative integer, * zero, or a positive integer as the first argument is less than, * equal to, or greater than the second. * * @param o1 the first performance * @param o2 the second performance * @return the order */ public int compare(Performance o1, Performance o2) { int result; double p1; double p2; p1 = o1.getPerformance(getEvaluation()); p2 = o2.getPerformance(getEvaluation()); if (Utils.sm(p1, p2)) result = -1; else if (Utils.gr(p1, p2)) result = 1; else result = 0; // only correlation coefficient and accuracy obey to this order, for the // errors (and the combination of all three), the smaller the number the // better -> hence invert them if ( (getEvaluation() != EVALUATION_CC) && (getEvaluation() != EVALUATION_ACC) ) result = -result; return result; } /** * Indicates whether some other object is "equal to" this Comparator. * * @param obj the object to compare with * @return true if the same evaluation type is used */ public boolean equals(Object obj) { if (!(obj instanceof PerformanceComparator)) throw new IllegalArgumentException("Must be PerformanceComparator!"); return (m_Evaluation == ((PerformanceComparator) obj).m_Evaluation); } } /** * Generates a 2-dim array for the performances from a grid for a certain * type. x-min/y-min is in the bottom-left corner, i.e., getTable()[0][0] * returns the performance for the x-min/y-max pair. * <pre> * x-min x-max * |-------------| * - y-max * | * | * - y-min * </pre> */ protected class PerformanceTable implements Serializable { /** for serialization */ private static final long serialVersionUID = 5486491313460338379L; /** the corresponding grid */ protected Grid m_Grid; /** the performances */ protected Vector<Performance> m_Performances; /** the type of performance the table was generated for */ protected int m_Type; /** the table with the values */ protected double[][] m_Table; /** the minimum performance */ protected double m_Min; /** the maximum performance */ protected double m_Max; /** * initializes the table * * @param grid the underlying grid * @param performances the performances * @param type the type of performance */ public PerformanceTable(Grid grid, Vector<Performance> performances, int type) { super(); m_Grid = grid; m_Type = type; m_Performances = performances; generate(); } /** * generates the table */ protected void generate() { Performance perf; int i; PointInt location; m_Table = new double[getGrid().height()][getGrid().width()]; m_Min = 0; m_Max = 0; for (i = 0; i < getPerformances().size(); i++) { perf = (Performance) getPerformances().get(i); location = getGrid().getLocation(perf.getValues()); m_Table[getGrid().height() - (int) location.getY() - 1][(int) location.getX()] = perf.getPerformance(getType()); // determine min/max if (i == 0) { m_Min = perf.getPerformance(m_Type); m_Max = m_Min; } else { if (perf.getPerformance(m_Type) < m_Min) m_Min = perf.getPerformance(m_Type); if (perf.getPerformance(m_Type) > m_Max) m_Max = perf.getPerformance(m_Type); } } } /** * returns the corresponding grid * * @return the underlying grid */ public Grid getGrid() { return m_Grid; } /** * returns the underlying performances * * @return the underlying performances */ public Vector<Performance> getPerformances() { return m_Performances; } /** * returns the type of performance * * @return the type of performance */ public int getType() { return m_Type; } /** * returns the generated table * * @return the performance table * @see #m_Table * @see #generate() */ public double[][] getTable() { return m_Table; } /** * the minimum performance * * @return the performance */ public double getMin() { return m_Min; } /** * the maximum performance * * @return the performance */ public double getMax() { return m_Max; } /** * returns the table as string * * @return the table as string */ public String toString() { String result; int i; int n; result = "Table (" + new SelectedTag(getType(), TAGS_EVALUATION).getSelectedTag().getReadable() + ") - " + "X: " + getGrid().getLabelX() + ", Y: " + getGrid().getLabelY() + ":\n"; for (i = 0; i < getTable().length; i++) { if (i > 0) result += "\n"; for (n = 0; n < getTable()[i].length; n++) { if (n > 0) result += ","; result += getTable()[i][n]; } } return result; } /** * returns a string containing a gnuplot script+data file * * @return the data in gnuplot format */ public String toGnuplot() { StringBuffer result; Tag type; int i; result = new StringBuffer(); type = new SelectedTag(getType(), TAGS_EVALUATION).getSelectedTag(); result.append("Gnuplot (" + type.getReadable() + "):\n"); result.append("# begin 'gridsearch.data'\n"); result.append("# " + type.getReadable() + "\n"); for (i = 0; i < getPerformances().size(); i++) result.append(getPerformances().get(i).toGnuplot(type.getID()) + "\n"); result.append("# end 'gridsearch.data'\n\n"); result.append("# begin 'gridsearch.plot'\n"); result.append("# " + type.getReadable() + "\n"); result.append("set data style lines\n"); result.append("set contour base\n"); result.append("set surface\n"); result.append("set title '" + m_Data.relationName() + "'\n"); result.append("set xrange [" + getGrid().getMinX() + ":" + getGrid().getMaxX() + "]\n"); result.append("set xlabel 'x (" + getFilter().getClass().getName() + ": " + getXProperty() + ")'\n"); result.append("set yrange [" + getGrid().getMinY() + ":" + getGrid().getMaxY() + "]\n"); result.append("set ylabel 'y - (" + getClassifier().getClass().getName() + ": " + getYProperty() + ")'\n"); result.append("set zrange [" + (getMin() - (getMax() - getMin())*0.1) + ":" + (getMax() + (getMax() - getMin())*0.1) + "]\n"); result.append("set zlabel 'z - " + type.getReadable() + "'\n"); result.append("set dgrid3d " + getGrid().height() + "," + getGrid().width() + ",1\n"); result.append("show contour\n"); result.append("splot 'gridsearch.data'\n"); result.append("pause -1\n"); result.append("# end 'gridsearch.plot'"); return result.toString(); } } /** * Represents a simple cache for performance objects. */ protected class PerformanceCache implements Serializable { /** for serialization */ private static final long serialVersionUID = 5838863230451530252L; /** the cache for points in the grid that got calculated */ protected Hashtable m_Cache = new Hashtable(); /** * returns the ID string for a cache item * * @param cv the number of folds in the cross-validation * @param values the point in the grid * @return the ID string */ protected String getID(int cv, PointDouble values) { return cv + "\t" + values.getX() + "\t" + values.getY(); } /** * checks whether the point was already calculated ones * * @param cv the number of folds in the cross-validation * @param values the point in the grid * @return true if the value is already cached */ public boolean isCached(int cv, PointDouble values) { return (get(cv, values) != null); } /** * returns a cached performance object, null if not yet in the cache * * @param cv the number of folds in the cross-validation * @param values the point in the grid * @return the cached performance item, null if not in cache */ public Performance get(int cv, PointDouble values) { return (Performance) m_Cache.get(getID(cv, values)); } /** * adds the performance to the cache * * @param cv the number of folds in the cross-validation * @param p the performance object to store */ public void add(int cv, Performance p) { m_Cache.put(getID(cv, p.getValues()), p); } /** * returns a string representation of the cache * * @return the string representation of the cache */ public String toString() { return m_Cache.toString(); } } /** for serialization */ private static final long serialVersionUID = -3034773968581595348L; /** evaluation via: Correlation coefficient */ public static final int EVALUATION_CC = 0; /** evaluation via: Root mean squared error */ public static final int EVALUATION_RMSE = 1; /** evaluation via: Root relative squared error */ public static final int EVALUATION_RRSE = 2; /** evaluation via: Mean absolute error */ public static final int EVALUATION_MAE = 3; /** evaluation via: Relative absolute error */ public static final int EVALUATION_RAE = 4; /** evaluation via: Combined = (1-CC) + RRSE + RAE */ public static final int EVALUATION_COMBINED = 5; /** evaluation via: Accuracy */ public static final int EVALUATION_ACC = 6; /** evaluation */ public static final Tag[] TAGS_EVALUATION = { new Tag(EVALUATION_CC, "CC", "Correlation coefficient"), new Tag(EVALUATION_RMSE, "RMSE", "Root mean squared error"), new Tag(EVALUATION_RRSE, "RRSE", "Root relative squared error"), new Tag(EVALUATION_MAE, "MAE", "Mean absolute error"), new Tag(EVALUATION_RAE, "RAE", "Root absolute error"), new Tag(EVALUATION_COMBINED, "COMB", "Combined = (1-abs(CC)) + RRSE + RAE"), new Tag(EVALUATION_ACC, "ACC", "Accuracy") }; /** row-wise grid traversal */ public static final int TRAVERSAL_BY_ROW = 0; /** column-wise grid traversal */ public static final int TRAVERSAL_BY_COLUMN = 1; /** traversal */ public static final Tag[] TAGS_TRAVERSAL = { new Tag(TRAVERSAL_BY_ROW, "row-wise", "row-wise"), new Tag(TRAVERSAL_BY_COLUMN, "column-wise", "column-wise") }; /** the prefix to indicate that the option is for the classifier */ public final static String PREFIX_CLASSIFIER = "classifier."; /** the prefix to indicate that the option is for the filter */ public final static String PREFIX_FILTER = "filter."; /** the Filter */ protected Filter m_Filter; /** the Filter with the best setup */ protected Filter m_BestFilter; /** the Classifier with the best setup */ protected Classifier m_BestClassifier; /** the best values */ protected PointDouble m_Values = null; /** the type of evaluation */ protected int m_Evaluation = EVALUATION_CC; /** the Y option to work on (without leading dash, preceding 'classifier.' * means to set the option for the classifier 'filter.' for the filter) */ protected String m_Y_Property = PREFIX_CLASSIFIER + "ridge"; /** the minimum of Y */ protected double m_Y_Min = -10; /** the maximum of Y */ protected double m_Y_Max = +5; /** the step size of Y */ protected double m_Y_Step = 1; /** the base for Y */ protected double m_Y_Base = 10; /** * The expression for the Y property. Available parameters for the
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -