📄 classifierpanel.java
字号:
/** * The evaluation object that will eventually contain all of the * information gathered during cross-validation. */ private Evaluation evaluation; /** Holds points for a graph used in displaying the results */ private FastVector plotShape; /** Holds points for a graph used in displaying the results */ private FastVector plotSize; /** Holds predictions calculated during cross-validation */ private FastVector predictions; /** * Holds the names of other computers that contributed to the * parallelization */ private StringBuffer otherComputers; /** * Initializes the GUIDistributedClient object with all of the * necessary variables. * * @param classifier the classifier used for cross-validation * @param eval the object where the final results will be stored * @param predictions a fastvector to add the prediction to * @param predInstances a set of plottable instances * @param inst the data the cross-validation will be done on * @param plotShape additional plotting information (shape) * @param plotSize additional plotting information (size) * @param otherComputers holds the names of other computers that * contributed to the parallelization */ public GUIDistributedClient(Classifier classifier, Evaluation eval, FastVector predictions, Instances predInstances, Instances inst, FastVector plotShape, FastVector plotSize, StringBuffer otherComputers) { this.numFolds = Integer.parseInt(m_CVText.getText()); this.foldsCompleted = 0; this.classifier = classifier; this.evaluation = eval; this.predInstances = predInstances; this.inst = inst; this.predictions = predictions; this.plotShape = plotShape; this.plotSize = plotSize; this.otherComputers = otherComputers; port = -1; computers = new LinkedList(); status = new int[numFolds]; lastIndexSent = numFolds - 1; for(int i = 0; i < numFolds; i++) status[i] = Status.NOT_DONE; } /** * Determines which fold should start to be calculated next. * A Round-Robin algorithm is used to try and maximize efficiency. * * @return the index of the fold that should be started next or * -1 if all of the folds have already been calculated */ private int determineIndex() { synchronized(status) { for(int i = 0; i < numFolds; i++) { lastIndexSent = (lastIndexSent + 1) % numFolds; if(status[lastIndexSent] != Status.DONE) { return lastIndexSent; } } } return -1; } /** * Initialize the client by finding out what port to connect on * and what computers to connect to. This information is found * inside of ~/.weka-parallel */ public void initialize() { String tempString; Integer tempInteger; File inFile; try { // Get the location of the config file String configLocation = System.getProperty("user.home"); if(System.getProperty("os.name").charAt(0) == 'W') { configLocation = configLocation.concat("\\.weka-parallel"); } else { configLocation = configLocation.concat("/.weka-parallel"); } inFile = new File(configLocation); // Places a wrapper around the input stream from the file BufferedReader fileInStream = new BufferedReader(new FileReader(inFile)); // Finds the port number from the config file tempString = fileInStream.readLine().trim(); tempString = tempString.substring(5, tempString.length()).trim(); tempInteger = new Integer(tempString); port = tempInteger.intValue(); // Finds the list of server names tempString = fileInStream.readLine().trim(); while(tempString != null) { computers.add(tempString); tempString = fileInStream.readLine().trim(); } } catch (Exception e) { e.printStackTrace(); return; } } /** * Starts the client by starting up one thread for each server * and one thread for the client to do calculations. */ public void start() { ConnectionService [] ss = new ConnectionService[computers.size()]; ClientSideComputations cs = new ClientSideComputations(); // Create a thread for each server for(int i = 0; i < computers.size(); i++) { ss[i] = new ConnectionService((String)computers.get(i)); ss[i].start(); } // Start the client-side thread cs.start(); // Have the main thread wait until one of the other threads // signals that all of the folds have been calculated try { synchronized(evaluation) { evaluation.wait(); } } catch (InterruptedException e) { e.printStackTrace(); } } /** * Performs part of the cross-fold validation on the computer running * Weka. This yields a speed increase since this computer will * be calculating data at the same time the other computers are doing * so. * * @author Dave Musicant (dmusican@mathcs.carleton.edu) * @author Sebastian Celis (celiss@mathcs.carleton.edu) */ private class ClientSideComputations extends Thread { /** * This is the actual thread which performs the necessary * computations. */ public void run() { int index; FastVector hv = new FastVector(); Attribute predictedClass; Attribute classAt = inst.attribute(inst.classIndex()); FastVector tempPredictions = null; FastVector tempPlotShape; FastVector tempPlotSize; Instances tempPredInstances; try { // Set up the necessary variables if (inst.classAttribute().isNominal() && classifier instanceof DistributionClassifier) { tempPredictions = new FastVector(); } if (classAt.isNominal()) { FastVector attVals = new FastVector(); for (int i = 0; i < classAt.numValues(); i++) { attVals.addElement(classAt.value(i)); } predictedClass = new Attribute("predicted"+classAt.name(), attVals); } else { predictedClass = new Attribute("predicted"+classAt.name()); } for (int i = 0; i < inst.numAttributes(); i++) { if (i == inst.classIndex()) { hv.addElement(predictedClass); } hv.addElement(inst.attribute(i).copy()); } // Determine which section we should train on index = determineIndex(); while(index != -1) { // Run the fold described by index // Set up the necessary variables. Evaluation tempEvaluation = new Evaluation(inst, null); Instances train = inst.trainCV(numFolds, index); Instances test = inst.testCV(numFolds, index); classifier.buildClassifier(train); if (inst.classAttribute().isNominal() && classifier instanceof DistributionClassifier) { tempPredictions = new FastVector(); } tempPlotShape = new FastVector(); tempPlotSize = new FastVector(); tempPredInstances = new Instances(inst.relationName() + "_predicted", hv, inst.numInstances()); tempPredInstances.setClassIndex(inst.classIndex()+1); // For each instance in the fold, process it for (int jj=0;jj<test.numInstances();jj++) { processClassifierPrediction(test.instance(jj), classifier, tempEvaluation, tempPredictions, tempPredInstances, tempPlotShape, tempPlotSize); } // Set the status of this fold to DONE and aggregate // all of the data accumulated from it synchronized(status) { if(status[index] != Status.DONE) { status[index] = Status.DONE; synchronized(evaluation) { evaluation.aggregate(tempEvaluation); } if(predictions != null) { synchronized(predictions) { for(int i=0; i < tempPredictions.size(); i++) { predictions.addElement( tempPredictions.elementAt(i)); } } } synchronized(predInstances) { for(int i=0; i < tempPredInstances.numInstances(); i++) { predInstances.add( tempPredInstances.instance(i)); } } synchronized(plotShape) { for(int i=0; i < tempPlotShape.size(); i++) { plotShape.addElement( tempPlotShape.elementAt(i)); } } synchronized(plotSize) { for(int i=0; i < tempPlotSize.size(); i++) { plotSize.addElement( tempPlotSize.elementAt(i)); } } foldsCompleted++; m_Log.statusMessage(foldsCompleted + "/" + numFolds + " folds completed"); } } // Determine which section we should train on index = determineIndex(); } // When there are no folds left to be done, notify the // main thread that it can continue synchronized(evaluation) { evaluation.notifyAll(); } } catch (Exception e) { e.printStackTrace(); return; } } } /** * One instance of this class will connect to a single computer * and send it the information necessary for computing cross-validations. * This class instructs the other computer concerning which folds it * should work on. When the other computer finishes a fold, it sends * the results back to this class, which then compiles that information. * * @author Dave Musicant (dmusican@mathcs.carleton.edu) * @author Sebastian Celis (celiss@mathcs.carleton.edu) */
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -