📄 clusterevaluation.java
字号:
private String toMatrixString(int [][] counts, int [] clusterTotals,
Instances inst)
throws Exception {
StringBuffer ms = new StringBuffer();
int maxval = 0;
for (int i = 0; i < m_numClusters; i++) {
for (int j = 0; j < counts[i].length; j++) {
if (counts[i][j] > maxval) {
maxval = counts[i][j];
}
}
}
int Cwidth = 1 + Math.max((int)(Math.log(maxval) / Math.log(10)),
(int)(Math.log(m_numClusters) / Math.log(10)));
ms.append("\n");
for (int i = 0; i < m_numClusters; i++) {
if (clusterTotals[i] > 0) {
ms.append(" ").append(Utils.doubleToString((double)i, Cwidth, 0));
}
}
ms.append(" <-- assigned to cluster\n");
for (int i = 0; i< counts[0].length; i++) {
for (int j = 0; j < m_numClusters; j++) {
if (clusterTotals[j] > 0) {
ms.append(" ").append(Utils.doubleToString((double)counts[j][i],
Cwidth, 0));
}
}
ms.append(" | ").append(inst.classAttribute().value(i)).append("\n");
}
return ms.toString();
}
/**
* Finds the minimum error mapping of classes to clusters. Recursively
* considers all possible class to cluster assignments.
* @param lev the cluster being processed
* @param counts the counts of classes in clusters
* @param clusterTotals the total number of examples in each cluster
* @param current the current path through the class to cluster assignment
* tree
* @param best the best assignment path seen
* @param error accumulates the error for a particular path
*/
private void mapClasses(int lev, int [][] counts, int [] clusterTotals,
double [] current, double [] best, int error) {
// leaf
if (lev == m_numClusters) {
if (error < best[m_numClusters]) {
best[m_numClusters] = error;
for (int i = 0; i < m_numClusters; i++) {
best[i] = current[i];
}
}
} else {
// empty cluster -- ignore
if (clusterTotals[lev] == 0) {
current[lev] = -1; // cluster ignored
mapClasses(lev+1, counts, clusterTotals, current, best,
error);
} else {
// first try no class assignment to this cluster
current[lev] = -1; // cluster assigned no class (ie all errors)
mapClasses(lev+1, counts, clusterTotals, current, best,
error+clusterTotals[lev]);
// now loop through the classes in this cluster
for (int i = 0; i < counts[0].length; i++) {
if (counts[lev][i] > 0) {
boolean ok = true;
// check to see if this class has already been assigned
for (int j = 0; j < lev; j++) {
if ((int)current[j] == i) {
ok = false;
break;
}
}
if (ok) {
current[lev] = i;
mapClasses(lev+1, counts, clusterTotals, current, best,
(error + (clusterTotals[lev] - counts[lev][i])));
}
}
}
}
}
}
/**
* Evaluates a clusterer with the options given in an array of
* strings. It takes the string indicated by "-t" as training file, the
* string indicated by "-T" as test file.
* If the test file is missing, a stratified ten-fold
* cross-validation is performed (distribution clusterers only).
* Using "-x" you can change the number of
* folds to be used, and using "-s" the random seed.
* If the "-p" option is present it outputs the classification for
* each test instance. If you provide the name of an object file using
* "-l", a clusterer will be loaded from the given file. If you provide the
* name of an object file using "-d", the clusterer built from the
* training data will be saved to the given file.
*
* @param clusterer machine learning clusterer
* @param options the array of string containing the options
* @exception Exception if model could not be evaluated successfully
* @return a string describing the results
*/
public static String evaluateClusterer (Clusterer clusterer,
String[] options)
throws Exception {
int seed = 1, folds = 10;
boolean doXval = false;
Instances train = null;
Instances test = null;
Random random;
String trainFileName, testFileName, seedString, foldsString, objectInputFileName, objectOutputFileName, attributeRangeString;
String[] savedOptions = null;
boolean printClusterAssignments = false;
Range attributesToOutput = null;
ObjectInputStream objectInputStream = null;
ObjectOutputStream objectOutputStream = null;
StringBuffer text = new StringBuffer();
int theClass = -1; // class based evaluation of clustering
try {
if (Utils.getFlag('h', options)) {
throw new Exception("Help requested.");
}
// Get basic options (options the same for all clusterers
//printClusterAssignments = Utils.getFlag('p', options);
objectInputFileName = Utils.getOption('l', options);
objectOutputFileName = Utils.getOption('d', options);
trainFileName = Utils.getOption('t', options);
testFileName = Utils.getOption('T', options);
// Check -p option
try {
attributeRangeString = Utils.getOption('p', options);
}
catch (Exception e) {
throw new Exception(e.getMessage() + "\nNOTE: the -p option has changed. " +
"It now expects a parameter specifying a range of attributes " +
"to list with the predictions. Use '-p 0' for none.");
}
if (attributeRangeString.length() != 0) {
printClusterAssignments = true;
if (!attributeRangeString.equals("0"))
attributesToOutput = new Range(attributeRangeString);
}
if (trainFileName.length() == 0) {
if (objectInputFileName.length() == 0) {
throw new Exception("No training file and no object "
+ "input file given.");
}
if (testFileName.length() == 0) {
throw new Exception("No training file and no test file given.");
}
}
else {
if ((objectInputFileName.length() != 0)
&& (printClusterAssignments == false)) {
throw new Exception("Can't use both train and model file "
+ "unless -p specified.");
}
}
seedString = Utils.getOption('s', options);
if (seedString.length() != 0) {
seed = Integer.parseInt(seedString);
}
foldsString = Utils.getOption('x', options);
if (foldsString.length() != 0) {
folds = Integer.parseInt(foldsString);
doXval = true;
}
}
catch (Exception e) {
throw new Exception('\n' + e.getMessage()
+ makeOptionString(clusterer));
}
try {
if (trainFileName.length() != 0) {
train = new Instances(new BufferedReader(new FileReader(trainFileName)));
String classString = Utils.getOption('c',options);
if (classString.length() != 0) {
if (classString.compareTo("last") == 0) {
theClass = train.numAttributes();
} else if (classString.compareTo("first") == 0) {
theClass = 1;
} else {
theClass = Integer.parseInt(classString);
}
if (doXval || testFileName.length() != 0) {
throw new Exception("Can only do class based evaluation on the "
+"training data");
}
if (objectInputFileName.length() != 0) {
throw new Exception("Can't load a clusterer and do class based "
+"evaluation");
}
}
if (theClass != -1) {
if (theClass < 1
|| theClass > train.numAttributes()) {
throw new Exception("Class is out of range!");
}
if (!train.attribute(theClass-1).isNominal()) {
throw new Exception("Class must be nominal!");
}
train.setClassIndex(theClass-1);
}
}
if (objectInputFileName.length() != 0) {
objectInputStream = new ObjectInputStream(new FileInputStream(objectInputFileName));
}
if (objectOutputFileName.length() != 0) {
objectOutputStream = new
ObjectOutputStream(new FileOutputStream(objectOutputFileName));
}
}
catch (Exception e) {
throw new Exception("ClusterEvaluation: " + e.getMessage() + '.');
}
// Save options
if (options != null) {
savedOptions = new String[options.length];
System.arraycopy(options, 0, savedOptions, 0, options.length);
}
if (objectInputFileName.length() != 0) {
Utils.checkForRemainingOptions(options);
}
// Set options for clusterer
if (clusterer instanceof OptionHandler) {
((OptionHandler)clusterer).setOptions(options);
}
Utils.checkForRemainingOptions(options);
if (objectInputFileName.length() != 0) {
// Load the clusterer from file
clusterer = (Clusterer)objectInputStream.readObject();
objectInputStream.close();
}
else {
// Build the clusterer if no object file provided
if (theClass == -1) {
clusterer.buildClusterer(train);
} else {
Remove removeClass = new Remove();
removeClass.setAttributeIndices(""+theClass);
removeClass.setInvertSelection(false);
removeClass.setInputFormat(train);
Instances clusterTrain = Filter.useFilter(train, removeClass);
clusterer.buildClusterer(clusterTrain);
ClusterEvaluation ce = new ClusterEvaluation();
ce.setClusterer(clusterer);
ce.evaluateClusterer(train);
return "\n\n=== Clustering stats for training data ===\n\n" +
ce.clusterResultsToString();
}
}
/* Output cluster predictions only (for the test data if specified,
otherwise for the training data */
if (printClusterAssignments) {
return printClusterings(clusterer, train, testFileName, attributesToOutput);
}
text.append(clusterer.toString());
text.append("\n\n=== Clustering stats for training data ===\n\n"
+ printClusterStats(clusterer, trainFileName));
if (testFileName.length() != 0) {
text.append("\n\n=== Clustering stats for testing data ===\n\n"
+ printClusterStats(clusterer, testFileName));
}
if ((clusterer instanceof DensityBasedClusterer) &&
(doXval == true) &&
(testFileName.length() == 0) &&
(objectInputFileName.length() == 0)) {
// cross validate the log likelihood on the training data
random = new Random(seed);
random.setSeed(seed);
train.randomize(random);
text.append(crossValidateModel(clusterer.getClass().getName()
, train, folds, savedOptions, random));
}
// Save the clusterer if an object output file is provided
if (objectOutputFileName.length() != 0) {
objectOutputStream.writeObject(clusterer);
objectOutputStream.flush();
objectOutputStream.close();
}
return text.toString();
}
/**
* Perform a cross-validation for DensityBasedClusterer on a set of instances.
*
* @param clusterer the clusterer to use
* @param data the training data
* @param numFolds number of folds of cross validation to perform
* @param random random number seed for cross-validation
* @return the cross-validated log-likelihood
* @exception Exception if an error occurs
*/
public static double crossValidateModel(DensityBasedClusterer clusterer,
Instances data,
int numFolds,
Random random) throws Exception {
Instances train, test;
double foldAv = 0;;
double[] tempDist;
data = new Instances(data);
data.randomize(random);
// double sumOW = 0;
for (int i = 0; i < numFolds; i++) {
// Build and test clusterer
train = data.trainCV(numFolds, i, random);
clusterer.buildClusterer(train);
test = data.testCV(numFolds, i);
for (int j = 0; j < test.numInstances(); j++) {
try {
foldAv += ((DensityBasedClusterer)clusterer).
logDensityForInstance(test.instance(j));
// sumOW += test.instance(j).weight();
// double temp = Utils.sum(tempDist);
} catch (Exception ex) {
// unclustered instances
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -