gridsearch.java

来自「Weka」· Java 代码 · 共 2,258 行 · 第 1/5 页

JAVA
2,258
字号
        /**     * returns the performance measure     *      * @param evaluation	the type of measure to return     * @return 			the performance measure     */    public double getPerformance(int evaluation) {      double	result;            result = Double.NaN;            switch (evaluation) {	case EVALUATION_CC:	  result = m_CC;	  break;	case EVALUATION_RMSE:	  result = m_RMSE;	  break;	case EVALUATION_RRSE:	  result = m_RRSE;	  break;	case EVALUATION_MAE:	  result = m_MAE;	  break;	case EVALUATION_RAE:	  result = m_RAE;	  break;	case EVALUATION_COMBINED:	  result = (1 - StrictMath.abs(m_CC)) + m_RRSE + m_RAE;	  break;	case EVALUATION_ACC:	  result = m_ACC;	  break;	default:	  throw new IllegalArgumentException("Evaluation type '" + evaluation + "' not supported!");      }            return result;    }        /**     * returns the values-pair for this performance     *      * @return the values-pair     */    public PointDouble getValues() {      return m_Values;    }        /**     * returns a string representation of this performance object     *      * @param evaluation	the type of performance to return     * @return 			a string representation     */    public String toString(int evaluation) {      String	result;            result =   "Performance (" + getValues() + "): "       	       + getPerformance(evaluation)       	       + " (" + new SelectedTag(evaluation, TAGS_EVALUATION) + ")";            return result;    }        /**     * returns a Gnuplot string of this performance object     *      * @param evaluation	the type of performance to return     * @return 			the gnuplot string (x, y, z)     */    public String toGnuplot(int evaluation) {      String	result;            result =   getValues().getX() + "\t"       	       + getValues().getY() + "\t"      	       + getPerformance(evaluation);            return result;    }        /**     * returns a string representation of this performance object     *      * @return a string representation     */    public String toString() {      String	result;      int	i;            result = "Performance (" + getValues() + "): ";            for (i = 0; i < TAGS_EVALUATION.length; i++) {	if (i > 0)	  result += ", ";        result +=   getPerformance(TAGS_EVALUATION[i].getID())         	  + " (" + new SelectedTag(TAGS_EVALUATION[i].getID(), TAGS_EVALUATION) + ")";      }            return result;    }  }    /**   * A concrete Comparator for the Performance class.   *    * @see Performance   */  protected class PerformanceComparator    implements Comparator<Performance>, Serializable {        /** for serialization */    private static final long serialVersionUID = 6507592831825393847L;        /** the performance measure to use for comparison      * @see GridSearch#TAGS_EVALUATION */    protected int m_Evaluation;        /**     * initializes the comparator with the given performance measure     *      * @param evaluation	the performance measure to use     * @see GridSearch#TAGS_EVALUATION     */    public PerformanceComparator(int evaluation) {      super();            m_Evaluation = evaluation;    }        /**     * returns the performance measure that's used to compare the objects     *      * @return the performance measure     * @see GridSearch#TAGS_EVALUATION     */    public int getEvaluation() {      return m_Evaluation;    }        /**     * Compares its two arguments for order. Returns a negative integer,      * zero, or a positive integer as the first argument is less than,      * equal to, or greater than the second.     *      * @param o1 	the first performance     * @param o2 	the second performance     * @return 		the order     */    public int compare(Performance o1, Performance o2) {      int	result;      double	p1;      double	p2;            p1 = o1.getPerformance(getEvaluation());      p2 = o2.getPerformance(getEvaluation());            if (Utils.sm(p1, p2))	result = -1;      else if (Utils.gr(p1, p2))	result = 1;      else	result = 0;	      // only correlation coefficient and accuracy obey to this order, for the      // errors (and the combination of all three), the smaller the number the      // better -> hence invert them      if (    (getEvaluation() != EVALUATION_CC)            && (getEvaluation() != EVALUATION_ACC) )	result = -result;	      return result;    }        /**     * Indicates whether some other object is "equal to" this Comparator.     *      * @param obj	the object to compare with     * @return		true if the same evaluation type is used     */    public boolean equals(Object obj) {      if (!(obj instanceof PerformanceComparator))	throw new IllegalArgumentException("Must be PerformanceComparator!");            return (m_Evaluation == ((PerformanceComparator) obj).m_Evaluation);    }  }    /**   * Generates a 2-dim array for the performances from a grid for a certain    * type. x-min/y-min is in the bottom-left corner, i.e., getTable()[0][0]   * returns the performance for the x-min/y-max pair.   * <pre>   * x-min     x-max   * |-------------|   *                - y-max   *                |   *                |   *                - y-min   * </pre>   */  protected class PerformanceTable     implements Serializable {        /** for serialization */    private static final long serialVersionUID = 5486491313460338379L;    /** the corresponding grid */    protected Grid m_Grid;        /** the performances */    protected Vector<Performance> m_Performances;        /** the type of performance the table was generated for */    protected int m_Type;        /** the table with the values */    protected double[][] m_Table;        /** the minimum performance */    protected double m_Min;        /** the maximum performance */    protected double m_Max;        /**     * initializes the table     *      * @param grid		the underlying grid     * @param performances	the performances     * @param type		the type of performance     */    public PerformanceTable(Grid grid, Vector<Performance> performances, int type) {      super();            m_Grid         = grid;      m_Type         = type;      m_Performances = performances;            generate();    }        /**     * generates the table     */    protected void generate() {      Performance 	perf;      int 		i;      PointInt 		location;            m_Table = new double[getGrid().height()][getGrid().width()];      m_Min   = 0;      m_Max   = 0;            for (i = 0; i < getPerformances().size(); i++) {	perf     = (Performance) getPerformances().get(i);	location = getGrid().getLocation(perf.getValues());	m_Table[getGrid().height() - (int) location.getY() - 1][(int) location.getX()] = perf.getPerformance(getType());		// determine min/max	if (i == 0) {	  m_Min = perf.getPerformance(m_Type);	  m_Max = m_Min;	}	else {	  if (perf.getPerformance(m_Type) < m_Min)	    m_Min = perf.getPerformance(m_Type);	  if (perf.getPerformance(m_Type) > m_Max)	    m_Max = perf.getPerformance(m_Type);	}      }    }        /**     * returns the corresponding grid     *      * @return		the underlying grid     */    public Grid getGrid() {      return m_Grid;    }    /**     * returns the underlying performances     *      * @return		the underlying performances     */    public Vector<Performance> getPerformances() {      return m_Performances;    }        /**     * returns the type of performance     *      * @return		the type of performance     */    public int getType() {      return m_Type;    }        /**     * returns the generated table     *      * @return 		the performance table     * @see		#m_Table     * @see		#generate()     */    public double[][] getTable() {      return m_Table;    }        /**     * the minimum performance     *      * @return		the performance     */    public double getMin() {      return m_Min;    }        /**     * the maximum performance     *      * @return		the performance     */    public double getMax() {      return m_Max;    }        /**     * returns the table as string     *      * @return		the table as string     */    public String toString() {      String	result;      int	i;      int	n;            result =   "Table (" 	       + new SelectedTag(getType(), TAGS_EVALUATION).getSelectedTag().getReadable()                + ") - "               + "X: " + getGrid().getLabelX() + ", Y: " + getGrid().getLabelY()               + ":\n";            for (i = 0; i < getTable().length; i++) {	if (i > 0)	  result += "\n";		for (n = 0; n < getTable()[i].length; n++) {	  if (n > 0)	    result += ",";	  result += getTable()[i][n];	}      }            return result;    }        /**     * returns a string containing a gnuplot script+data file     *      * @return		the data in gnuplot format     */    public String toGnuplot() {      StringBuffer	result;      Tag		type;      int		i;            result = new StringBuffer();      type   = new SelectedTag(getType(), TAGS_EVALUATION).getSelectedTag();            result.append("Gnuplot (" + type.getReadable() + "):\n");      result.append("# begin 'gridsearch.data'\n");      result.append("# " + type.getReadable() + "\n");      for (i = 0; i < getPerformances().size(); i++)        result.append(getPerformances().get(i).toGnuplot(type.getID()) + "\n");      result.append("# end 'gridsearch.data'\n\n");            result.append("# begin 'gridsearch.plot'\n");      result.append("# " + type.getReadable() + "\n");      result.append("set data style lines\n");      result.append("set contour base\n");      result.append("set surface\n");      result.append("set title '" + m_Data.relationName() + "'\n");      result.append("set xrange [" + getGrid().getMinX() + ":" + getGrid().getMaxX() + "]\n");      result.append("set xlabel 'x (" + getFilter().getClass().getName() + ": " + getXProperty() + ")'\n");      result.append("set yrange [" + getGrid().getMinY() + ":" + getGrid().getMaxY() + "]\n");      result.append("set ylabel 'y - (" + getClassifier().getClass().getName() + ": " + getYProperty() + ")'\n");      result.append("set zrange [" + (getMin() - (getMax() - getMin())*0.1) + ":" + (getMax() + (getMax() - getMin())*0.1) + "]\n");      result.append("set zlabel 'z - " + type.getReadable() + "'\n");      result.append("set dgrid3d " + getGrid().height() + "," + getGrid().width() + ",1\n");      result.append("show contour\n");      result.append("splot 'gridsearch.data'\n");      result.append("pause -1\n");      result.append("# end 'gridsearch.plot'");      return result.toString();    }  }    /**   * Represents a simple cache for performance objects.   */  protected class PerformanceCache    implements Serializable {    /** for serialization */    private static final long serialVersionUID = 5838863230451530252L;        /** the cache for points in the grid that got calculated */    protected Hashtable m_Cache = new Hashtable();        /**     * returns the ID string for a cache item     *      * @param cv		the number of folds in the cross-validation     * @param values	the point in the grid     * @return		the ID string     */    protected String getID(int cv, PointDouble values) {      return cv + "\t" + values.getX() + "\t" + values.getY();    }        /**     * checks whether the point was already calculated ones     *      * @param cv	the number of folds in the cross-validation     * @param values	the point in the grid     * @return		true if the value is already cached     */    public boolean isCached(int cv, PointDouble values) {      return (get(cv, values) != null);    }        /**     * returns a cached performance object, null if not yet in the cache     *      * @param cv	the number of folds in the cross-validation     * @param values	the point in the grid     * @return		the cached performance item, null if not in cache     */    public Performance get(int cv, PointDouble values) {      return (Performance) m_Cache.get(getID(cv, values));    }        /**     * adds the performance to the cache     *      * @param cv	the number of folds in the cross-validation     * @param p		the performance object to store

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?