⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 gridsearch.java

📁 代码是一个分类器的实现,其中使用了部分weka的源代码。可以将项目导入eclipse运行
💻 JAVA
📖 第 1 页 / 共 5 页
字号:
     * @return a string representation     */    public String toString() {      String	result;      int	i;            result = "Performance (" + getValues() + "): ";            for (i = 0; i < TAGS_EVALUATION.length; i++) {	if (i > 0)	  result += ", ";        result +=   getPerformance(TAGS_EVALUATION[i].getID())         	  + " (" + new SelectedTag(TAGS_EVALUATION[i].getID(), TAGS_EVALUATION) + ")";      }            return result;    }  }    /**   * A concrete Comparator for the Performance class.   *    * @see Performance   */  protected class PerformanceComparator    implements Comparator<Performance>, Serializable {        /** for serialization */    private static final long serialVersionUID = 6507592831825393847L;        /** the performance measure to use for comparison      * @see GridSearch#TAGS_EVALUATION */    protected int m_Evaluation;        /**     * initializes the comparator with the given performance measure     *      * @param evaluation	the performance measure to use     * @see GridSearch#TAGS_EVALUATION     */    public PerformanceComparator(int evaluation) {      super();            m_Evaluation = evaluation;    }        /**     * returns the performance measure that's used to compare the objects     *      * @return the performance measure     * @see GridSearch#TAGS_EVALUATION     */    public int getEvaluation() {      return m_Evaluation;    }        /**     * Compares its two arguments for order. Returns a negative integer,      * zero, or a positive integer as the first argument is less than,      * equal to, or greater than the second.     *      * @param o1 	the first performance     * @param o2 	the second performance     * @return 		the order     */    public int compare(Performance o1, Performance o2) {      int	result;      double	p1;      double	p2;            p1 = o1.getPerformance(getEvaluation());      p2 = o2.getPerformance(getEvaluation());            if (Utils.sm(p1, p2))	result = -1;      else if (Utils.gr(p1, p2))	result = 1;      else	result = 0;	      // only correlation coefficient and accuracy obey to this order, for the      // errors (and the combination of all three), the smaller the number the      // better -> hence invert them      if (    (getEvaluation() != EVALUATION_CC)            && (getEvaluation() != EVALUATION_ACC) )	result = -result;	      return result;    }        /**     * Indicates whether some other object is "equal to" this Comparator.     *      * @param obj	the object to compare with     * @return		true if the same evaluation type is used     */    public boolean equals(Object obj) {      if (!(obj instanceof PerformanceComparator))	throw new IllegalArgumentException("Must be PerformanceComparator!");            return (m_Evaluation == ((PerformanceComparator) obj).m_Evaluation);    }  }    /**   * Generates a 2-dim array for the performances from a grid for a certain    * type. x-min/y-min is in the bottom-left corner, i.e., getTable()[0][0]   * returns the performance for the x-min/y-max pair.   * <pre>   * x-min     x-max   * |-------------|   *                - y-max   *                |   *                |   *                - y-min   * </pre>   */  protected class PerformanceTable     implements Serializable {        /** for serialization */    private static final long serialVersionUID = 5486491313460338379L;    /** the corresponding grid */    protected Grid m_Grid;        /** the performances */    protected Vector<Performance> m_Performances;        /** the type of performance the table was generated for */    protected int m_Type;        /** the table with the values */    protected double[][] m_Table;        /** the minimum performance */    protected double m_Min;        /** the maximum performance */    protected double m_Max;        /**     * initializes the table     *      * @param grid		the underlying grid     * @param performances	the performances     * @param type		the type of performance     */    public PerformanceTable(Grid grid, Vector<Performance> performances, int type) {      super();            m_Grid         = grid;      m_Type         = type;      m_Performances = performances;            generate();    }        /**     * generates the table     */    protected void generate() {      Performance 	perf;      int 		i;      PointInt 		location;            m_Table = new double[getGrid().height()][getGrid().width()];      m_Min   = 0;      m_Max   = 0;            for (i = 0; i < getPerformances().size(); i++) {	perf     = (Performance) getPerformances().get(i);	location = getGrid().getLocation(perf.getValues());	m_Table[getGrid().height() - (int) location.getY() - 1][(int) location.getX()] = perf.getPerformance(getType());		// determine min/max	if (i == 0) {	  m_Min = perf.getPerformance(m_Type);	  m_Max = m_Min;	}	else {	  if (perf.getPerformance(m_Type) < m_Min)	    m_Min = perf.getPerformance(m_Type);	  if (perf.getPerformance(m_Type) > m_Max)	    m_Max = perf.getPerformance(m_Type);	}      }    }        /**     * returns the corresponding grid     *      * @return		the underlying grid     */    public Grid getGrid() {      return m_Grid;    }    /**     * returns the underlying performances     *      * @return		the underlying performances     */    public Vector<Performance> getPerformances() {      return m_Performances;    }        /**     * returns the type of performance     *      * @return		the type of performance     */    public int getType() {      return m_Type;    }        /**     * returns the generated table     *      * @return 		the performance table     * @see		#m_Table     * @see		#generate()     */    public double[][] getTable() {      return m_Table;    }        /**     * the minimum performance     *      * @return		the performance     */    public double getMin() {      return m_Min;    }        /**     * the maximum performance     *      * @return		the performance     */    public double getMax() {      return m_Max;    }        /**     * returns the table as string     *      * @return		the table as string     */    public String toString() {      String	result;      int	i;      int	n;            result =   "Table (" 	       + new SelectedTag(getType(), TAGS_EVALUATION).getSelectedTag().getReadable()                + ") - "               + "X: " + getGrid().getLabelX() + ", Y: " + getGrid().getLabelY()               + ":\n";            for (i = 0; i < getTable().length; i++) {	if (i > 0)	  result += "\n";		for (n = 0; n < getTable()[i].length; n++) {	  if (n > 0)	    result += ",";	  result += getTable()[i][n];	}      }            return result;    }        /**     * returns a string containing a gnuplot script+data file     *      * @return		the data in gnuplot format     */    public String toGnuplot() {      StringBuffer	result;      Tag		type;      int		i;            result = new StringBuffer();      type   = new SelectedTag(getType(), TAGS_EVALUATION).getSelectedTag();            result.append("Gnuplot (" + type.getReadable() + "):\n");      result.append("# begin 'gridsearch.data'\n");      result.append("# " + type.getReadable() + "\n");      for (i = 0; i < getPerformances().size(); i++)        result.append(getPerformances().get(i).toGnuplot(type.getID()) + "\n");      result.append("# end 'gridsearch.data'\n\n");            result.append("# begin 'gridsearch.plot'\n");      result.append("# " + type.getReadable() + "\n");      result.append("set data style lines\n");      result.append("set contour base\n");      result.append("set surface\n");      result.append("set title '" + m_Data.relationName() + "'\n");      result.append("set xrange [" + getGrid().getMinX() + ":" + getGrid().getMaxX() + "]\n");      result.append("set xlabel 'x (" + getFilter().getClass().getName() + ": " + getXProperty() + ")'\n");      result.append("set yrange [" + getGrid().getMinY() + ":" + getGrid().getMaxY() + "]\n");      result.append("set ylabel 'y - (" + getClassifier().getClass().getName() + ": " + getYProperty() + ")'\n");      result.append("set zrange [" + (getMin() - (getMax() - getMin())*0.1) + ":" + (getMax() + (getMax() - getMin())*0.1) + "]\n");      result.append("set zlabel 'z - " + type.getReadable() + "'\n");      result.append("set dgrid3d " + getGrid().height() + "," + getGrid().width() + ",1\n");      result.append("show contour\n");      result.append("splot 'gridsearch.data'\n");      result.append("pause -1\n");      result.append("# end 'gridsearch.plot'");      return result.toString();    }  }    /**   * Represents a simple cache for performance objects.   */  protected class PerformanceCache    implements Serializable {    /** for serialization */    private static final long serialVersionUID = 5838863230451530252L;        /** the cache for points in the grid that got calculated */    protected Hashtable m_Cache = new Hashtable();        /**     * returns the ID string for a cache item     *      * @param cv		the number of folds in the cross-validation     * @param values	the point in the grid     * @return		the ID string     */    protected String getID(int cv, PointDouble values) {      return cv + "\t" + values.getX() + "\t" + values.getY();    }        /**     * checks whether the point was already calculated ones     *      * @param cv	the number of folds in the cross-validation     * @param values	the point in the grid     * @return		true if the value is already cached     */    public boolean isCached(int cv, PointDouble values) {      return (get(cv, values) != null);    }        /**     * returns a cached performance object, null if not yet in the cache     *      * @param cv	the number of folds in the cross-validation     * @param values	the point in the grid     * @return		the cached performance item, null if not in cache     */    public Performance get(int cv, PointDouble values) {      return (Performance) m_Cache.get(getID(cv, values));    }        /**     * adds the performance to the cache     *      * @param cv	the number of folds in the cross-validation     * @param p		the performance object to store     */    public void add(int cv, Performance p) {      m_Cache.put(getID(cv, p.getValues()), p);    }        /**     * returns a string representation of the cache     *      * @return		the string representation of the cache     */    public String toString() {      return m_Cache.toString();    }  }    /** for serialization */  private static final long serialVersionUID = -3034773968581595348L;  /** evaluation via: Correlation coefficient */  public static final int EVALUATION_CC = 0;  /** evaluation via: Root mean squared error */  public static final int EVALUATION_RMSE = 1;  /** evaluation via: Root relative squared error */  public static final int EVALUATION_RRSE = 2;  /** evaluation via: Mean absolute error */  public static final int EVALUATION_MAE = 3;  /** evaluation via: Relative absolute error */  public static final int EVALUATION_RAE = 4;  /** evaluation via: Combined = (1-CC) + RRSE + RAE */  public static final int EVALUATION_COMBINED = 5;  /** evaluation via: Accuracy */  public static final int EVALUATION_ACC = 6;  /** evaluation */  public static final Tag[] TAGS_EVALUATION = {    new Tag(EVALUATION_CC, "CC", "Correlation coefficient"),    new Tag(EVALUATION_RMSE, "RMSE", "Root mean squared error"),    new Tag(EVALUATION_RRSE, "RRSE", "Root relative squared error"),    new Tag(EVALUATION_MAE, "MAE", "Mean absolute error"),    new Tag(EVALUATION_RAE, "RAE", "Root absolute error"),    new Tag(EVALUATION_COMBINED, "COMB", "Combined = (1-abs(CC)) + RRSE + RAE"),    new Tag(EVALUATION_ACC, "ACC", "Accuracy")  };    /** row-wise grid traversal */  public static final int TRAVERSAL_BY_ROW = 0;  /** column-wise grid traversal */  public static final int TRAVERSAL_BY_COLUMN = 1;  /** traversal */  public static final Tag[] TAGS_TRAVERSAL = {    new Tag(TRAVERSAL_BY_ROW, "row-wise", "row-wise"),    new Tag(TRAVERSAL_BY_COLUMN, "column-wise", "column-wise")  };  /** the prefix to indicate that the option is for the classifier */  public final static String PREFIX_CLASSIFIER = "classifier.";  /** the prefix to indicate that the option is for the filter */  public final static String PREFIX_FILTER = "filter.";    /** the Filter */  protected Filter m_Filter;    /** the Filter with the best setup */  protected Filter m_BestFilter;    /** the Classifier with the best setup */  protected Classifier m_BestClassifier;  /** the best values */  protected PointDouble m_Values = null;    /** the type of evaluation */  protected int m_Evaluation = EVALUATION_CC;  /** the Y option to work on (without leading dash, preceding 'classifier.'    * means to set the option for the classifier 'filter.' for the filter) */  protected String m_Y_Property = PREFIX_CLASSIFIER + "ridge";    /** the minimum of Y */  protected double m_Y_Min = -10;    /** the maximum of Y */  protected double m_Y_Max = +5;    /** the step size of Y */  protected double m_Y_Step = 1;    /** the base for Y */  protected double m_Y_Base = 10;    /**    * The expression for the Y property. Available parameters for the

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -