📄 database.java
字号:
import java.awt.*;import java.util.*;/******************************************************** * This class maintains the database of the points under classification * */class Database { final int PtSize = 5; Vector points; int np; Color col; boolean lockflag = false; PointGenerator pgen; /** This method copies a specified number of records from the beginning of a given database * @param source The source database to copy records from * @param count The number of records to copy from the beginning */ public Database(Database source, int count) { points = new Vector(); for (int i=0; i < count; i++) { points.add(source.getPoint(i)); } np = count; } /** This method constructs a new Database object with the given automatic point generator and a Color scheme for display * @param p The reference to an automatic point generator * @param Col The color used when displaying the database */ public Database(PointGenerator p, Color Col) { pgen = p; col = Col; np = 0; points = new Vector(); } /** This method paints the databse on the given Graphics object * @param g The Graphics object to draw to */ public void paint(Graphics g) { g.setColor(col); for(int i = 0; i < np; i++) { g.fillRect( (int) (getPoint(i).xVal()*600) -PtSize/2,(int) (getPoint(i).yVal()*300)-PtSize/2,PtSize,PtSize); } } /** This method adds a point to the end of the database * @param p The point to be added */ public void push(Point p) { points.add(p); np++; } /** This method clears all the points stored in the database */ public void clearPoints() { np = 0; points.removeAllElements(); } /** This method populates the database with given number of randomly generated points * @param n The number of points to generate */ public void randomPoints(int n) { for(int i = 0; i < n; i++) push(pgen.RandomPoint()); } /** Returns the number of points currently in the database * @return The number of points if populated else 0 */ public int nPoints() { return np; } /** Returns a point at a gievn position in the database * @param i The index of the required point * @return The requested point */ public Point getPoint(int i) { return (Point) points.elementAt(i); } /** Returns the number of attributes (i.e. dimensions) for each given point * @return The number of dimensions of a point */ public int nAttribs() { return Point.nAttribs(); } /** Creates the test set for one fold of a cross-validation * on the dataset. * @param numFolds the number of folds in the cross-validation. Must * be greater than 1. * @param numFold 0 for the first fold, 1 for the second, ... * @return the training set as a set of weighted * instances * @exception Exception if dataset can't be generated * successfully */ public Database testCV(int numFolds, int numFold) throws Exception { int numInstForFold, first, offset; Database test; if (numFolds < 2) { throw new Exception("Number of folds must be at least 2!"); } if (numFolds > np) { throw new Exception("Can't have more folds than instances!"); } numInstForFold = np / numFolds; if (numFold < np % numFolds){ numInstForFold++; offset = numFold; } else offset = np % numFolds; test = new Database(this, numInstForFold); first = numFold * (np / numFolds) + offset; copyInstances(first, test, numInstForFold); return test; } /** * Creates the training set for one fold of a cross-validation * on the dataset. * * @param numFolds the number of folds in the cross-validation. Must * be greater than 1. * @param numFold 0 for the first fold, 1 for the second, ... * @return the training set as a set of weighted * instances * @exception Exception if dataset can't be generated * successfully */ public Database trainCV(int numFolds, int numFold) throws Exception { int numInstForFold, first, offset; Database train; if (numFolds < 2) { throw new Exception("Number of folds must be at least 2!"); } if (numFolds > np) { throw new Exception("Can't have more folds than instances!"); } numInstForFold = np / numFolds; if (numFold < np % numFolds) { numInstForFold++; offset = numFold; } else offset = np % numFolds; train = new Database(this, np - numInstForFold); first = numFold * (np / numFolds) + offset; copyInstances(0, train, first); copyInstances(first + numInstForFold, train, np - first - numInstForFold); return train; } /* This method is used to copy number of records to a given database. */ private void copyInstances(int start, Database toDb, int count) { for (int i=0; i < count; i++) toDb.push(this.getPoint(start + i)); } /** This method is used to randomly reaarange the locations of the points in the database * @param random A random number generator */ public final void randomize(Random random) { for (int j = np - 1; j > 1; j--) { Object temp = points.get(j); points.removeElementAt(j); int swapindex = (int) (random.nextDouble()*(double)j); if (swapindex == 0) swapindex = 1; Object temp1= points.elementAt(swapindex - 1); points.removeElementAt(swapindex - 1); points.add(swapindex, temp); points.add(j,temp1); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -