📄 em.java
字号:
String[] options = new String[9];
int current = 0;
if (m_verbose) {
options[current++] = "-V";
}
options[current++] = "-I";
options[current++] = "" + m_max_iterations;
options[current++] = "-N";
options[current++] = "" + getNumClusters();
options[current++] = "-S";
options[current++] = "" + m_rseed;
options[current++] = "-M";
options[current++] = ""+getMinStdDev();
while (current < options.length) {
options[current++] = "";
}
return options;
}
/**
* Initialise estimators and storage.
*
* @param inst the instances
**/
private void EM_Init (Instances inst)
throws Exception {
int i, j, k;
// run k means 10 times and choose best solution
SimpleKMeans bestK = null;
double bestSqE = Double.MAX_VALUE;
for (i = 0; i < 10; i++) {
SimpleKMeans sk = new SimpleKMeans();
sk.setSeed(m_rr.nextInt());
sk.setNumClusters(m_num_clusters);
sk.buildClusterer(inst);
if (sk.getSquaredError() < bestSqE) {
bestSqE = sk.getSquaredError();
bestK = sk;
}
}
// initialize with best k-means solution
m_num_clusters = bestK.numberOfClusters();
m_weights = new double[inst.numInstances()][m_num_clusters];
m_model = new DiscreteEstimator[m_num_clusters][m_num_attribs];
m_modelNormal = new double[m_num_clusters][m_num_attribs][3];
m_priors = new double[m_num_clusters];
Instances centers = bestK.getClusterCentroids();
Instances stdD = bestK.getClusterStandardDevs();
int [][][] nominalCounts = bestK.getClusterNominalCounts();
int [] clusterSizes = bestK.getClusterSizes();
for (i = 0; i < m_num_clusters; i++) {
Instance center = centers.instance(i);
for (j = 0; j < m_num_attribs; j++) {
if (inst.attribute(j).isNominal()) {
m_model[i][j] = new DiscreteEstimator(m_theInstances.
attribute(j).numValues()
, true);
for (k = 0; k < inst.attribute(j).numValues(); k++) {
m_model[i][j].addValue(k, nominalCounts[i][j][k]);
}
} else {
double minStdD = (m_minStdDevPerAtt != null)
? m_minStdDevPerAtt[j]
: m_minStdDev;
double mean = (center.isMissing(j))
? inst.meanOrMode(j)
: center.value(j);
m_modelNormal[i][j][0] = mean;
double stdv = (stdD.instance(i).isMissing(j))
? ((m_maxValues[j] - m_minValues[j]) / (2 * m_num_clusters))
: stdD.instance(i).value(j);
if (stdv < minStdD) {
stdv = inst.attributeStats(j).numericStats.stdDev;
if (Double.isInfinite(stdv)) {
stdv = minStdD;
}
if (stdv < minStdD) {
stdv = minStdD;
}
}
if (stdv <= 0) {
stdv = m_minStdDev;
}
m_modelNormal[i][j][1] = stdv;
m_modelNormal[i][j][2] = 1.0;
}
}
}
for (j = 0; j < m_num_clusters; j++) {
// m_priors[j] += 1.0;
m_priors[j] = clusterSizes[j];
}
Utils.normalize(m_priors);
}
/**
* calculate prior probabilites for the clusters
*
* @param inst the instances
* @exception Exception if priors can't be calculated
**/
private void estimate_priors (Instances inst)
throws Exception {
for (int i = 0; i < m_num_clusters; i++) {
m_priors[i] = 0.0;
}
for (int i = 0; i < inst.numInstances(); i++) {
for (int j = 0; j < m_num_clusters; j++) {
m_priors[j] += inst.instance(i).weight() * m_weights[i][j];
}
}
Utils.normalize(m_priors);
}
/** Constant for normal distribution. */
private static double m_normConst = Math.log(Math.sqrt(2*Math.PI));
/**
* Density function of normal distribution.
* @param x input value
* @param mean mean of distribution
* @param stdDev standard deviation of distribution
*/
private double logNormalDens (double x, double mean, double stdDev) {
double diff = x - mean;
// System.err.println("x: "+x+" mean: "+mean+" diff: "+diff+" stdv: "+stdDev);
// System.err.println("diff*diff/(2*stdv*stdv): "+ (diff * diff / (2 * stdDev * stdDev)));
return - (diff * diff / (2 * stdDev * stdDev)) - m_normConst - Math.log(stdDev);
}
/**
* New probability estimators for an iteration
*
* @param num_cl the numbe of clusters
*/
private void new_estimators () {
for (int i = 0; i < m_num_clusters; i++) {
for (int j = 0; j < m_num_attribs; j++) {
if (m_theInstances.attribute(j).isNominal()) {
m_model[i][j] = new DiscreteEstimator(m_theInstances.
attribute(j).numValues()
, true);
}
else {
m_modelNormal[i][j][0] = m_modelNormal[i][j][1] =
m_modelNormal[i][j][2] = 0.0;
}
}
}
}
/**
* The M step of the EM algorithm.
* @param inst the training instances
*/
private void M (Instances inst)
throws Exception {
int i, j, l;
new_estimators();
for (i = 0; i < m_num_clusters; i++) {
for (j = 0; j < m_num_attribs; j++) {
for (l = 0; l < inst.numInstances(); l++) {
Instance in = inst.instance(l);
if (!in.isMissing(j)) {
if (inst.attribute(j).isNominal()) {
m_model[i][j].addValue(in.value(j),
in.weight() * m_weights[l][i]);
}
else {
m_modelNormal[i][j][0] += (in.value(j) * in.weight() *
m_weights[l][i]);
m_modelNormal[i][j][2] += in.weight() * m_weights[l][i];
m_modelNormal[i][j][1] += (in.value(j) *
in.value(j) * in.weight() * m_weights[l][i]);
}
}
}
}
}
// calcualte mean and std deviation for numeric attributes
for (j = 0; j < m_num_attribs; j++) {
if (!inst.attribute(j).isNominal()) {
for (i = 0; i < m_num_clusters; i++) {
if (m_modelNormal[i][j][2] <= 0) {
m_modelNormal[i][j][1] = Double.MAX_VALUE;
// m_modelNormal[i][j][0] = 0;
m_modelNormal[i][j][0] = m_minStdDev;
} else {
// variance
m_modelNormal[i][j][1] = (m_modelNormal[i][j][1] -
(m_modelNormal[i][j][0] *
m_modelNormal[i][j][0] /
m_modelNormal[i][j][2])) /
(m_modelNormal[i][j][2]);
if (m_modelNormal[i][j][1] < 0) {
m_modelNormal[i][j][1] = 0;
}
// std dev
double minStdD = (m_minStdDevPerAtt != null)
? m_minStdDevPerAtt[j]
: m_minStdDev;
m_modelNormal[i][j][1] = Math.sqrt(m_modelNormal[i][j][1]);
if ((m_modelNormal[i][j][1] <= minStdD)) {
m_modelNormal[i][j][1] = inst.attributeStats(j).numericStats.stdDev;
if ((m_modelNormal[i][j][1] <= minStdD)) {
m_modelNormal[i][j][1] = minStdD;
}
}
if ((m_modelNormal[i][j][1] <= 0)) {
m_modelNormal[i][j][1] = m_minStdDev;
}
if (Double.isInfinite(m_modelNormal[i][j][1])) {
m_modelNormal[i][j][1] = m_minStdDev;
}
// mean
m_modelNormal[i][j][0] /= m_modelNormal[i][j][2];
}
}
}
}
}
/**
* The E step of the EM algorithm. Estimate cluster membership
* probabilities.
*
* @param inst the training instances
* @return the average log likelihood
*/
private double E (Instances inst, boolean change_weights)
throws Exception {
double loglk = 0.0, sOW = 0.0;
for (int l = 0; l < inst.numInstances(); l++) {
Instance in = inst.instance(l);
loglk += in.weight() * logDensityForInstance(in);
sOW += in.weight();
if (change_weights) {
m_weights[l] = distributionForInstance(in);
}
}
// reestimate priors
if (change_weights) {
estimate_priors(inst);
}
return loglk / sOW;
}
/**
* Constructor.
*
**/
public EM () {
resetOptions();
}
/**
* Reset to default options
*/
protected void resetOptions () {
m_minStdDev = 1e-6;
m_max_iterations = 100;
m_rseed = 100;
m_num_clusters = -1;
m_initialNumClusters = -1;
m_verbose = false;
}
/**
* Return the normal distributions for the cluster models
*
* @return a <code>double[][][]</code> value
*/
public double [][][] getClusterModelsNumericAtts() {
return m_modelNormal;
}
/**
* Return the priors for the clusters
*
* @return a <code>double[]</code> value
*/
public double [] getClusterPriors() {
return m_priors;
}
/**
* Outputs the generated clusters into a string.
*/
public String toString () {
if (m_priors == null) {
return "No clusterer built yet!";
}
StringBuffer text = new StringBuffer();
text.append("\nEM\n==\n");
if (m_initialNumClusters == -1) {
text.append("\nNumber of clusters selected by cross validation: "
+m_num_clusters+"\n");
} else {
text.append("\nNumber of clusters: " + m_num_clusters + "\n");
}
for (int j = 0; j < m_num_clusters; j++) {
text.append("\nCluster: " + j + " Prior probability: "
+ Utils.doubleToString(m_priors[j], 4) + "\n\n");
for (int i = 0; i < m_num_attribs; i++) {
text.append("Attribute: " + m_theInstances.attribute(i).name() + "\n");
if (m_theInstances.attribute(i).isNominal()) {
if (m_model[j][i] != null) {
text.append(m_model[j][i].toString());
}
}
else {
text.append("Normal Distribution. Mean = "
+ Utils.doubleToString(m_modelNormal[j][i][0], 4)
+ " StdDev = "
+ Utils.doubleToString(m_modelNormal[j][i][1], 4)
+ "\n");
}
}
}
return text.toString();
}
/**
* verbose output for debugging
* @param inst the training instances
*/
private void EM_Report (Instances inst) {
int i, j, l, m;
System.out.println("======================================");
for (j = 0; j < m_num_clusters; j++) {
for (i = 0; i < m_num_attribs; i++) {
System.out.println("Clust: " + j + " att: " + i + "\n");
if (m_theInstances.attribute(i).isNominal()) {
if (m_model[j][i] != null) {
System.out.println(m_model[j][i].toString());
}
}
else {
System.out.println("Normal Distribution. Mean = "
+ Utils.doubleToString(m_modelNormal[j][i][0]
, 8, 4)
+ " StandardDev = "
+ Utils.doubleToString(m_modelNormal[j][i][1]
, 8, 4)
+ " WeightSum = "
+ Utils.doubleToString(m_modelNormal[j][i][2]
, 8, 4));
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -