📄 em.java
字号:
}
for (l = 0; l < inst.numInstances(); l++) {
m = Utils.maxIndex(m_weights[l]);
System.out.print("Inst " + Utils.doubleToString((double)l, 5, 0)
+ " Class " + m + "\t");
for (j = 0; j < m_num_clusters; j++) {
System.out.print(Utils.doubleToString(m_weights[l][j], 7, 5) + " ");
}
System.out.println();
}
}
/**
* estimate the number of clusters by cross validation on the training
* data.
*
*/
private void CVClusters ()
throws Exception {
double CVLogLikely = -Double.MAX_VALUE;
double templl, tll;
boolean CVincreased = true;
m_num_clusters = 1;
int num_clusters = m_num_clusters;
int i;
Random cvr;
Instances trainCopy;
int numFolds = (m_theInstances.numInstances() < 10)
? m_theInstances.numInstances()
: 10;
boolean ok = true;
int seed = m_rseed;
int restartCount = 0;
CLUSTER_SEARCH: while (CVincreased) {
// theInstances.stratify(10);
CVincreased = false;
cvr = new Random(m_rseed);
trainCopy = new Instances(m_theInstances);
trainCopy.randomize(cvr);
templl = 0.0;
for (i = 0; i < numFolds; i++) {
Instances cvTrain = trainCopy.trainCV(numFolds, i, cvr);
if (num_clusters > cvTrain.numInstances()) {
break CLUSTER_SEARCH;
}
Instances cvTest = trainCopy.testCV(numFolds, i);
m_rr = new Random(seed);
for (int z=0; z<10; z++) m_rr.nextDouble();
m_num_clusters = num_clusters;
EM_Init(cvTrain);
try {
iterate(cvTrain, false);
} catch (Exception ex) {
// catch any problems - i.e. empty clusters occuring
ex.printStackTrace();
// System.err.println("Restarting after CV training failure ("+num_clusters+" clusters");
seed++;
restartCount++;
ok = false;
if (restartCount > 5) {
break CLUSTER_SEARCH;
}
break;
}
try {
tll = E(cvTest, false);
} catch (Exception ex) {
// catch any problems - i.e. empty clusters occuring
// ex.printStackTrace();
ex.printStackTrace();
// System.err.println("Restarting after CV testing failure ("+num_clusters+" clusters");
// throw new Exception(ex);
seed++;
restartCount++;
ok = false;
if (restartCount > 5) {
break CLUSTER_SEARCH;
}
break;
}
if (m_verbose) {
System.out.println("# clust: " + num_clusters + " Fold: " + i
+ " Loglikely: " + tll);
}
templl += tll;
}
if (ok) {
restartCount = 0;
seed = m_rseed;
templl /= (double)numFolds;
if (m_verbose) {
System.out.println("==================================="
+ "==============\n# clust: "
+ num_clusters
+ " Mean Loglikely: "
+ templl
+ "\n================================"
+ "=================");
}
if (templl > CVLogLikely) {
CVLogLikely = templl;
CVincreased = true;
num_clusters++;
}
}
}
if (m_verbose) {
System.out.println("Number of clusters: " + (num_clusters - 1));
}
m_num_clusters = num_clusters - 1;
}
/**
* Returns the number of clusters.
*
* @return the number of clusters generated for a training dataset.
* @exception Exception if number of clusters could not be returned
* successfully
*/
public int numberOfClusters ()
throws Exception {
if (m_num_clusters == -1) {
throw new Exception("Haven't generated any clusters!");
}
return m_num_clusters;
}
/**
* Updates the minimum and maximum values for all the attributes
* based on a new instance.
*
* @param instance the new instance
*/
private void updateMinMax(Instance instance) {
for (int j = 0; j < m_theInstances.numAttributes(); j++) {
if (!instance.isMissing(j)) {
if (Double.isNaN(m_minValues[j])) {
m_minValues[j] = instance.value(j);
m_maxValues[j] = instance.value(j);
} else {
if (instance.value(j) < m_minValues[j]) {
m_minValues[j] = instance.value(j);
} else {
if (instance.value(j) > m_maxValues[j]) {
m_maxValues[j] = instance.value(j);
}
}
}
}
}
}
/**
* Generates a clusterer. Has to initialize all fields of the clusterer
* that are not being set via options.
*
* @param data set of instances serving as training data
* @exception Exception if the clusterer has not been
* generated successfully
*/
public void buildClusterer (Instances data)
throws Exception {
if (data.checkForStringAttributes()) {
throw new Exception("Can't handle string attributes!");
}
m_replaceMissing = new ReplaceMissingValues();
Instances instances = new Instances(data);
instances.setClassIndex(-1);
m_replaceMissing.setInputFormat(instances);
data = weka.filters.Filter.useFilter(instances, m_replaceMissing);
instances = null;
m_theInstances = data;
// calculate min and max values for attributes
m_minValues = new double [m_theInstances.numAttributes()];
m_maxValues = new double [m_theInstances.numAttributes()];
for (int i = 0; i < m_theInstances.numAttributes(); i++) {
m_minValues[i] = m_maxValues[i] = Double.NaN;
}
for (int i = 0; i < m_theInstances.numInstances(); i++) {
updateMinMax(m_theInstances.instance(i));
}
doEM();
// save memory
m_theInstances = new Instances(m_theInstances,0);
}
/**
* Returns the cluster priors.
*/
public double[] clusterPriors() {
double[] n = new double[m_priors.length];
System.arraycopy(m_priors, 0, n, 0, n.length);
return n;
}
/**
* Computes the log of the conditional density (per cluster) for a given instance.
*
* @param instance the instance to compute the density for
* @return the density.
* @return an array containing the estimated densities
* @exception Exception if the density could not be computed
* successfully
*/
public double[] logDensityPerClusterForInstance(Instance inst) throws Exception {
int i, j;
double logprob;
double[] wghts = new double[m_num_clusters];
m_replaceMissing.input(inst);
inst = m_replaceMissing.output();
for (i = 0; i < m_num_clusters; i++) {
// System.err.println("Cluster : "+i);
logprob = 0.0;
for (j = 0; j < m_num_attribs; j++) {
if (!inst.isMissing(j)) {
if (inst.attribute(j).isNominal()) {
logprob += Math.log(m_model[i][j].getProbability(inst.value(j)));
}
else { // numeric attribute
logprob += logNormalDens(inst.value(j),
m_modelNormal[i][j][0],
m_modelNormal[i][j][1]);
/* System.err.println(logNormalDens(inst.value(j),
m_modelNormal[i][j][0],
m_modelNormal[i][j][1]) + " "); */
}
}
}
// System.err.println("");
wghts[i] = logprob;
}
return wghts;
}
/**
* Perform the EM algorithm
*/
private void doEM ()
throws Exception {
if (m_verbose) {
System.out.println("Seed: " + m_rseed);
}
m_rr = new Random(m_rseed);
// throw away numbers to avoid problem of similar initial numbers
// from a similar seed
for (int i=0; i<10; i++) m_rr.nextDouble();
m_num_instances = m_theInstances.numInstances();
m_num_attribs = m_theInstances.numAttributes();
if (m_verbose) {
System.out.println("Number of instances: "
+ m_num_instances
+ "\nNumber of atts: "
+ m_num_attribs
+ "\n");
}
// setDefaultStdDevs(theInstances);
// cross validate to determine number of clusters?
if (m_initialNumClusters == -1) {
if (m_theInstances.numInstances() > 9) {
CVClusters();
m_rr = new Random(m_rseed);
for (int i=0; i<10; i++) m_rr.nextDouble();
} else {
m_num_clusters = 1;
}
}
// fit full training set
EM_Init(m_theInstances);
m_loglikely = iterate(m_theInstances, m_verbose);
}
/**
* iterates the E and M steps until the log likelihood of the data
* converges.
*
* @param inst the training instances.
* @param num_cl the number of clusters.
* @param report be verbose.
* @return the log likelihood of the data
*/
private double iterate (Instances inst, boolean report)
throws Exception {
int i;
double llkold = 0.0;
double llk = 0.0;
if (report) {
EM_Report(inst);
}
boolean ok = false;
int seed = m_rseed;
int restartCount = 0;
while (!ok) {
try {
for (i = 0; i < m_max_iterations; i++) {
llkold = llk;
llk = E(inst, true);
if (report) {
System.out.println("Loglikely: " + llk);
}
if (i > 0) {
if ((llk - llkold) < 1e-6) {
break;
}
}
M(inst);
}
ok = true;
} catch (Exception ex) {
// System.err.println("Restarting after training failure");
ex.printStackTrace();
seed++;
restartCount++;
m_rr = new Random(seed);
for (int z = 0; z < 10; z++) {
m_rr.nextDouble(); m_rr.nextInt();
}
if (restartCount > 5) {
// System.err.println("Reducing the number of clusters");
m_num_clusters--;
restartCount = 0;
}
EM_Init(m_theInstances);
}
}
if (report) {
EM_Report(inst);
}
return llk;
}
// ============
// Test method.
// ============
/**
* Main method for testing this class.
*
* @param argv should contain the following arguments: <p>
* -t training file [-T test file] [-N number of clusters] [-S random seed]
*/
public static void main (String[] argv) {
try {
System.out.println(ClusterEvaluation.
evaluateClusterer(new EM(), argv));
}
catch (Exception e) {
System.out.println(e.getMessage());
e.printStackTrace();
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -