📄 clusterevaluation.java
字号:
* Returns a "confusion" style matrix of classes to clusters assignments * @param counts the counts of classes for each cluster * @param clusterTotals total number of examples in each cluster * @param inst the training instances (with class) * @return the "confusion" style matrix as string * @throws Exception if matrix can't be generated */ private String toMatrixString(int[][] counts, int[] clusterTotals, Instances inst) throws Exception { StringBuffer ms = new StringBuffer(); int maxval = 0; for (int i = 0; i < m_numClusters; i++) { for (int j = 0; j < counts[i].length; j++) { if (counts[i][j] > maxval) { maxval = counts[i][j]; } } } int Cwidth = 1 + Math.max((int)(Math.log(maxval) / Math.log(10)), (int)(Math.log(m_numClusters) / Math.log(10))); ms.append("\n"); for (int i = 0; i < m_numClusters; i++) { if (clusterTotals[i] > 0) { ms.append(" ").append(Utils.doubleToString((double)i, Cwidth, 0)); } } ms.append(" <-- assigned to cluster\n"); for (int i = 0; i< counts[0].length; i++) { for (int j = 0; j < m_numClusters; j++) { if (clusterTotals[j] > 0) { ms.append(" ").append(Utils.doubleToString((double)counts[j][i], Cwidth, 0)); } } ms.append(" | ").append(inst.classAttribute().value(i)).append("\n"); } return ms.toString(); } /** * Finds the minimum error mapping of classes to clusters. Recursively * considers all possible class to cluster assignments. * * @param numClusters the number of clusters * @param lev the cluster being processed * @param counts the counts of classes in clusters * @param clusterTotals the total number of examples in each cluster * @param current the current path through the class to cluster assignment * tree * @param best the best assignment path seen * @param error accumulates the error for a particular path */ public static void mapClasses(int numClusters, int lev, int[][] counts, int[] clusterTotals, double[] current, double[] best, int error) { // leaf if (lev == numClusters) { if (error < best[numClusters]) { best[numClusters] = error; for (int i = 0; i < numClusters; i++) { best[i] = current[i]; } } } else { // empty cluster -- ignore if (clusterTotals[lev] == 0) { current[lev] = -1; // cluster ignored mapClasses(numClusters, lev+1, counts, clusterTotals, current, best, error); } else { // first try no class assignment to this cluster current[lev] = -1; // cluster assigned no class (ie all errors) mapClasses(numClusters, lev+1, counts, clusterTotals, current, best, error+clusterTotals[lev]); // now loop through the classes in this cluster for (int i = 0; i < counts[0].length; i++) { if (counts[lev][i] > 0) { boolean ok = true; // check to see if this class has already been assigned for (int j = 0; j < lev; j++) { if ((int)current[j] == i) { ok = false; break; } } if (ok) { current[lev] = i; mapClasses(numClusters, lev+1, counts, clusterTotals, current, best, (error + (clusterTotals[lev] - counts[lev][i]))); } } } } } } /** * Evaluates a clusterer with the options given in an array of * strings. It takes the string indicated by "-t" as training file, the * string indicated by "-T" as test file. * If the test file is missing, a stratified ten-fold * cross-validation is performed (distribution clusterers only). * Using "-x" you can change the number of * folds to be used, and using "-s" the random seed. * If the "-p" option is present it outputs the classification for * each test instance. If you provide the name of an object file using * "-l", a clusterer will be loaded from the given file. If you provide the * name of an object file using "-d", the clusterer built from the * training data will be saved to the given file. * * @param clusterer machine learning clusterer * @param options the array of string containing the options * @throws Exception if model could not be evaluated successfully * @return a string describing the results */ public static String evaluateClusterer(Clusterer clusterer, String[] options) throws Exception { int seed = 1, folds = 10; boolean doXval = false; Instances train = null; Random random; String trainFileName, testFileName, seedString, foldsString; String objectInputFileName, objectOutputFileName, attributeRangeString; String graphFileName; String[] savedOptions = null; boolean printClusterAssignments = false; Range attributesToOutput = null; StringBuffer text = new StringBuffer(); int theClass = -1; // class based evaluation of clustering boolean updateable = (clusterer instanceof UpdateableClusterer); DataSource source = null; Instance inst; try { if (Utils.getFlag('h', options)) { throw new Exception("Help requested."); } // Get basic options (options the same for all clusterers //printClusterAssignments = Utils.getFlag('p', options); objectInputFileName = Utils.getOption('l', options); objectOutputFileName = Utils.getOption('d', options); trainFileName = Utils.getOption('t', options); testFileName = Utils.getOption('T', options); graphFileName = Utils.getOption('g', options); // Check -p option try { attributeRangeString = Utils.getOption('p', options); } catch (Exception e) { throw new Exception(e.getMessage() + "\nNOTE: the -p option has changed. " + "It now expects a parameter specifying a range of attributes " + "to list with the predictions. Use '-p 0' for none."); } if (attributeRangeString.length() != 0) { printClusterAssignments = true; if (!attributeRangeString.equals("0")) attributesToOutput = new Range(attributeRangeString); } if (trainFileName.length() == 0) { if (objectInputFileName.length() == 0) { throw new Exception("No training file and no object " + "input file given."); } if (testFileName.length() == 0) { throw new Exception("No training file and no test file given."); } } else { if ((objectInputFileName.length() != 0) && (printClusterAssignments == false)) { throw new Exception("Can't use both train and model file " + "unless -p specified."); } } seedString = Utils.getOption('s', options); if (seedString.length() != 0) { seed = Integer.parseInt(seedString); } foldsString = Utils.getOption('x', options); if (foldsString.length() != 0) { folds = Integer.parseInt(foldsString); doXval = true; } } catch (Exception e) { throw new Exception('\n' + e.getMessage() + makeOptionString(clusterer)); } try { if (trainFileName.length() != 0) { source = new DataSource(trainFileName); train = source.getStructure(); String classString = Utils.getOption('c',options); if (classString.length() != 0) { if (classString.compareTo("last") == 0) theClass = train.numAttributes(); else if (classString.compareTo("first") == 0) theClass = 1; else theClass = Integer.parseInt(classString); if (theClass != -1) { if (doXval || testFileName.length() != 0) throw new Exception("Can only do class based evaluation on the " +"training data"); if (objectInputFileName.length() != 0) throw new Exception("Can't load a clusterer and do class based " +"evaluation"); if (objectOutputFileName.length() != 0) throw new Exception( "Can't do class based evaluation and save clusterer"); } } else { // if the dataset defines a class attribute, use it if (train.classIndex() != -1) { theClass = train.classIndex() + 1; System.err.println( "Note: using class attribute from dataset, i.e., attribute #" + theClass); } } if (theClass != -1) { if (theClass < 1 || theClass > train.numAttributes()) throw new Exception("Class is out of range!"); if (!train.attribute(theClass - 1).isNominal()) throw new Exception("Class must be nominal!"); train.setClassIndex(theClass - 1); } } } catch (Exception e) { throw new Exception("ClusterEvaluation: " + e.getMessage() + '.'); } // Save options if (options != null) { savedOptions = new String[options.length]; System.arraycopy(options, 0, savedOptions, 0, options.length); } if (objectInputFileName.length() != 0) Utils.checkForRemainingOptions(options); // Set options for clusterer if (clusterer instanceof OptionHandler) ((OptionHandler)clusterer).setOptions(options); Utils.checkForRemainingOptions(options); if (objectInputFileName.length() != 0) { // Load the clusterer from file clusterer = (Clusterer) SerializationHelper.read(objectInputFileName); } else { // Build the clusterer if no object file provided if (theClass == -1) { if (updateable) { clusterer.buildClusterer(source.getStructure()); while (source.hasMoreElements(train)) { inst = source.nextElement(train); ((UpdateableClusterer) clusterer).updateClusterer(inst); } ((UpdateableClusterer) clusterer).updateFinished(); } else { clusterer.buildClusterer(source.getDataSet()); } } else { Remove removeClass = new Remove(); removeClass.setAttributeIndices("" + theClass); removeClass.setInvertSelection(false); removeClass.setInputFormat(train); if (updateable) { Instances clusterTrain = Filter.useFilter(train, removeClass); clusterer.buildClusterer(clusterTrain); while (source.hasMoreElements(train)) { inst = source.nextElement(train); removeClass.input(inst); removeClass.batchFinished(); Instance clusterTrainInst = removeClass.output(); ((UpdateableClusterer) clusterer).updateClusterer(clusterTrainInst); } ((UpdateableClusterer) clusterer).updateFinished(); } else { Instances clusterTrain = Filter.useFilter(source.getDataSet(), removeClass); clusterer.buildClusterer(clusterTrain); } ClusterEvaluation ce = new ClusterEvaluation(); ce.setClusterer(clusterer); ce.evaluateClusterer(train, trainFileName); return "\n\n=== Clustering stats for training data ===\n\n" + ce.clusterResultsToString(); } } /* Output cluster predictions only (for the test data if specified, otherwise for the training data */ if (printClusterAssignments) { return printClusterings(clusterer, trainFileName, testFileName, attributesToOutput); } text.append(clusterer.toString()); text.append("\n\n=== Clustering stats for training data ===\n\n" + printClusterStats(clusterer, trainFileName)); if (testFileName.length() != 0) text.append("\n\n=== Clustering stats for testing data ===\n\n" + printClusterStats(clusterer, testFileName)); if ((clusterer instanceof DensityBasedClusterer) && (doXval == true) && (testFileName.length() == 0) && (objectInputFileName.length() == 0)) { // cross validate the log likelihood on the training data random = new Random(seed); random.setSeed(seed); train = source.getDataSet(); train.randomize(random); text.append( crossValidateModel( clusterer.getClass().getName(), train, folds, savedOptions, random)); } // Save the clusterer if an object output file is provided if (objectOutputFileName.length() != 0) SerializationHelper.write(objectOutputFileName, clusterer); // If classifier is drawable output string describing graph if ((clusterer instanceof Drawable) && (graphFileName.length() != 0)) { BufferedWriter writer = new BufferedWriter(new FileWriter(graphFileName)); writer.write(((Drawable) clusterer).graph()); writer.newLine(); writer.flush(); writer.close(); } return text.toString(); } /** * Perform a cross-validation for DensityBasedClusterer on a set of instances. * * @param clusterer the clusterer to use * @param data the training data * @param numFolds number of folds of cross validation to perform * @param random random number seed for cross-validation * @return the cross-validated log-likelihood * @throws Exception if an error occurs */ public static double crossValidateModel(DensityBasedClusterer clusterer, Instances data, int numFolds, Random random) throws Exception { Instances train, test; double foldAv = 0;; data = new Instances(data); data.randomize(random); // double sumOW = 0; for (int i = 0; i < numFolds; i++) { // Build and test clusterer train = data.trainCV(numFolds, i, random);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -