📄 distributedserver.java
字号:
attVals.addElement(classAt.value(i)); predictedClass = new Attribute("predicted"+classAt.name(), attVals); } else predictedClass = new Attribute("predicted"+classAt.name()); for (int i = 0; i < data.numAttributes(); i++) { if (i == data.classIndex()) hv.addElement(predictedClass); hv.addElement(data.attribute(i).copy()); } predInstances = new Instances(data.relationName() + "_predicted", hv, data.numInstances()); predInstances.setClassIndex(data.classIndex()+1); // Run the fold Instances train = data.trainCV(numFolds, index); Instances test = data.testCV(numFolds, index); classifier.buildClassifier(train); for(int jj = 0; jj < test.numInstances(); jj++) processClassifierPrediction(test.instance(jj), classifier, evaluation, predictions, predInstances, plotShape, plotSize); // Open the ObjectOutputStream oos = new ObjectOutputStream( new BufferedOutputStream( sock.getOutputStream())); // Send the results back to the client oos.writeObject(evaluation); oos.writeObject(predictions); oos.writeObject(predInstances); oos.writeObject(plotShape); oos.writeObject(plotSize); oos.flush(); while(true) { // Find out which fold we should run index = dis.readInt(); // Create a new evaluation object evaluation = new Evaluation(data, null); // Initialize the necessary variables plotShape = new FastVector(); plotSize = new FastVector(); if (data.classAttribute().isNominal() && classifier instanceof DistributionClassifier) { predictions = new FastVector(); } predInstances = new Instances(data.relationName() + "_predicted", hv, data.numInstances()); predInstances.setClassIndex(data.classIndex()+1); // Run the fold train = data.trainCV(numFolds, index); test = data.testCV(numFolds, index); classifier.buildClassifier(train); for(int jj = 0; jj < test.numInstances(); jj++) processClassifierPrediction(test.instance(jj), classifier, evaluation, predictions, predInstances, plotShape, plotSize); // Send the results back to the client oos.writeObject(evaluation); oos.writeObject(predictions); oos.writeObject(predInstances); oos.writeObject(plotShape); oos.writeObject(plotSize); oos.flush(); } } /** * Process a classifier's prediction for an instance and update a * set of plotting instances and additional plotting info. plotInfo * for nominal class datasets holds shape types (actual data points * have automatic shape type assignment; classifer error data points * have box shape type). For numeric class datasets, the actual data * points are stored in plotInstances and plotInfo stores the error * (which is later converted to shape size values) * @param toPredict the actual data point * @param classifier the classifier * @param eval the evaluation object to use for evaluating the * classifer on the instance to predict * @param predictions a fastvector to add the prediction to * @param plotInstances a set of plottable instances * @param plotShape additional plotting information (shape) * @param plotSize additional plotting information (size) */ private void processClassifierPrediction(Instance toPredict, Classifier classifier, Evaluation eval, FastVector predictions, Instances plotInstances, FastVector plotShape, FastVector plotSize) { try { double pred; // classifier is a distribution classifer and class is nominal if (predictions != null) { DistributionClassifier dc = (DistributionClassifier)classifier; double [] dist = dc.distributionForInstance(toPredict); pred = eval.evaluateModelOnce(dist, toPredict); int actual = (int)toPredict.classValue(); predictions.addElement(new NominalPrediction(actual, dist, toPredict.weight())); } else { pred = eval.evaluateModelOnce(classifier, toPredict); } double [] values = new double[plotInstances.numAttributes()]; for (int i = 0; i < plotInstances.numAttributes(); i++) { if (i < toPredict.classIndex()) { values[i] = toPredict.value(i); } else if (i == toPredict.classIndex()) { values[i] = pred; values[i+1] = toPredict.value(i); /* // if the class value of the instances to predict // is missing then set it to the predicted value if (toPredict.isMissing(i)) { values[i+1] = pred; } */ i++; } else { values[i] = toPredict.value(i-1); } } plotInstances.add(new Instance(1.0, values)); if (toPredict.classAttribute().isNominal()) { if (toPredict.isMissing(toPredict.classIndex())) { plotShape.addElement( new Integer(Plot2D.MISSING_SHAPE)); } else if (pred != toPredict.classValue()) { // set to default error point shape plotShape.addElement(new Integer(Plot2D.ERROR_SHAPE)); } else { // otherwise set to constant // (automatically assigned) point shape plotShape.addElement( new Integer(Plot2D.CONST_AUTOMATIC_SHAPE)); } plotSize.addElement( new Integer(Plot2D.DEFAULT_SHAPE_SIZE)); } else { // store the error (to be converted to a point size later) Double errd = null; if (!toPredict.isMissing(toPredict.classIndex())) { errd = new Double(pred - toPredict.classValue()); plotShape.addElement( new Integer(Plot2D.CONST_AUTOMATIC_SHAPE)); } else { // missing shape if actual class not present plotShape.addElement( new Integer(Plot2D.MISSING_SHAPE)); } plotSize.addElement(errd); } } catch (Exception ex) { ex.printStackTrace(); } } } /** * Start's up the server by listening on the specified port for any * connections. */ private void start() { ServerSocket listener; try { // Instantiate the listener listener = new ServerSocket(port); System.out.print(new Date() + ": "); System.out.println("Server started on port " + port); System.out.print(new Date() + ": "); System.out.println("Waiting for connections..."); try { while(true) { // Start listening for connections Socket sock = listener.accept(); System.out.print(new Date() + ": "); System.out.println("Processed request from " + sock.getInetAddress().getHostName()); // Process this connection in a thread of its own // so that we can process multiple connections // concurrently. ConnectionThread ct = new ConnectionThread(sock); ct.start(); } } finally { // Close the listener no matter what exceptions occurred listener.close(); } } catch(BindException e) { System.out.println("ERROR: Can't bind to port " + port); } catch(IllegalArgumentException e) { System.out.println("ERROR: Illegal port number specified."); } catch(IOException e) { e.printStackTrace(); return; } } /** * The main method for this class. Creates a DistributedServer and starts * it. * * @param args args[0] should be the port number to listen on */ public static void main(String[] args) throws IOException { int port = 0; try { port = Integer.parseInt(args[0]); } catch(Exception e) { System.out.println("Usage: java.weka.classifiers.DistributedServer "+ "<port>"); return; } // Create a server and start it up. DistributedServer ds = new DistributedServer(port); ds.start(); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -