seededkmeans.java
来自「wekaUT是 university texas austin 开发的基于wek」· Java 代码 · 共 2,032 行 · 第 1/5 页
JAVA
2,032 行
String metricName = metricSpec[0]; metricSpec[0] = ""; if (m_Verbose) { System.out.println("Metric name: " + metricName + "\nMetric parameters: " + concatStringArray(metricSpec)); } setMetric((LearnableMetric) LearnableMetric.forName(metricName, metricSpec)); } } /** A little helper to create a single String from an array of Strings * @param strings an array of strings * @returns a single concatenated string, separated by commas */ public static String concatStringArray(String[] strings) { String result = new String(); for (int i = 0; i < strings.length; i++) { result = result + "\"" + strings[i] + "\" "; } return result; } /** * return a string describing this clusterer * * @return a description of the clusterer as a string */ public String toString() { StringBuffer temp = new StringBuffer(); temp.append("\nkMeans\n======\n"); temp.append("\nNumber of iterations: " + m_Iterations+"\n"); temp.append("\nCluster centroids:\n"); for (int i = 0; i < m_NumClusters; i++) { temp.append("\nCluster "+i+"\n\t"); /* temp.append(m_ClusterCentroids.instance(i)); for (int j = 0; j < m_ClusterCentroids.numAttributes(); j++) { if (m_ClusterCentroids.attribute(j).isNominal()) { temp.append(" "+m_ClusterCentroids.attribute(j). value((int)m_ClusterCentroids.instance(i).value(j))); } else { temp.append(" "+m_ClusterCentroids.instance(i).value(j)); } } */ } temp.append("\n"); return temp.toString(); } /** * set the verbosity level of the clusterer * @param verbose messages on(true) or off (false) */ public void setVerbose (boolean verbose) { m_Verbose = verbose; } /** * get the verbosity level of the clusterer * @return messages on(true) or off (false) */ public boolean getVerbose () { return m_Verbose; } /** * Train the clusterer using specified parameters * * @param instances Instances to be used for training */ public void trainClusterer (Instances instances) throws Exception { if (m_metric instanceof LearnableMetric) { if (((LearnableMetric)m_metric).getTrainable()) { ((LearnableMetric)m_metric).learnMetric(instances); } else { throw new Exception ("Metric is not trainable"); } } else { throw new Exception ("Metric is not trainable"); } } /** Normalizes Instance or SparseInstance * * @author Sugato Basu * @param inst Instance to be normalized */ public void normalize(Instance inst) throws Exception { if (inst instanceof SparseInstance) { normalizeSparseInstance(inst); } else { ((LearnableMetric) m_metric).normalizeInstanceWeighted(inst); } } /** Normalizes the values of a normal Instance * * @author Sugato Basu * @param inst Instance to be normalized */ public void normalizeInstance(Instance inst) throws Exception{ double norm = 0; double values [] = inst.toDoubleArray(); if (inst instanceof SparseInstance) { throw new Exception("Use normalizeSparseInstance function"); } for (int i=0; i<values.length; i++) { if (i != inst.classIndex()) { // don't normalize the class index norm += values[i] * values[i]; } } norm = Math.sqrt(norm); for (int i=0; i<values.length; i++) { if (i != inst.classIndex()) { // don't normalize the class index if (norm == 0) { values[i]= 0; } else { values[i] /= norm; } } } inst.setValueArray(values); } /** Normalizes the values of a SparseInstance * * @author Sugato Basu * @param inst SparseInstance to be normalized */ public void normalizeSparseInstance(Instance inst) throws Exception{ double norm=0; int length = inst.numValues(); if (!(inst instanceof SparseInstance)) { throw new Exception("Use normalizeInstance function"); } for (int i=0; i<length; i++) { if (inst.index(i) != inst.classIndex()) { // don't normalize the class index norm += inst.valueSparse(i) * inst.valueSparse(i); } } norm = Math.sqrt(norm); for (int i=0; i<length; i++) { // don't normalize the class index if (inst.index(i) != inst.classIndex()) { inst.setValueSparse(i, inst.valueSparse(i)/norm); } } } /** Fast version of meanOrMode - streamlined from Instances.meanOrMode for efficiency * Does not check for missing attributes, assumes numeric attributes, assumes Sparse instances */ protected double[] meanOrMode(Instances insts) { int numAttributes = insts.numAttributes(); double [] value = new double[numAttributes]; double weight = 0; for (int i=0; i<numAttributes; i++) { value[i] = 0; } for (int j=0; j<insts.numInstances(); j++) { SparseInstance inst = (SparseInstance) (insts.instance(j)); weight += inst.weight(); for (int i=0; i<inst.numValues(); i++) { int indexOfIndex = inst.index(i); value[indexOfIndex] += inst.weight() * inst.valueSparse(i); } } if (Utils.eq(weight, 0)) { for (int k=0; k<numAttributes; k++) { value[k] = 0; } } else { for (int k=0; k<numAttributes; k++) { value[k] = value[k] / weight; } } return value; } /** * Main method for testing this class. * */ public static void main (String[] args) { try { String dataSet = new String("news"); //String dataSet = new String("iris"); if (dataSet.equals("iris")) { //////// Iris data String datafile = "/u/ml/software/weka-latest/data/iris.arff"; // set up the data FileReader reader = new FileReader (datafile); Instances data = new Instances (reader); // Make the last attribute be the class int theClass = data.numAttributes(); data.setClassIndex(theClass-1); // starts with 0 // Remove the class labels before clustering Instances clusterData = new Instances(data); clusterData.deleteClassAttribute(); // #clusters = #classes int num_clusters = data.numClasses(); // cluster with seeding Instances seeds = new Instances(data,0,5); seeds.add(data.instance(50)); seeds.add(data.instance(51)); seeds.add(data.instance(52)); seeds.add(data.instance(53)); seeds.add(data.instance(54)); seeds.add(data.instance(100)); seeds.add(data.instance(101)); seeds.add(data.instance(102)); seeds.add(data.instance(103)); seeds.add(data.instance(104)); data.delete(104); data.delete(103); data.delete(102); data.delete(101); data.delete(100); data.delete(54); data.delete(53); data.delete(52); data.delete(51); data.delete(50); data.delete(4); data.delete(3); data.delete(2); data.delete(1); data.delete(0); System.out.println("\nClustering the iris data with seeding, using seeded KMeans...\n"); WeightedEuclidean euclidean = new WeightedEuclidean(); SeededKMeans kmeans = new SeededKMeans (euclidean); kmeans.resetClusterer(); kmeans.setVerbose(false); kmeans.setSeedingMethod(new SelectedTag(SEEDING_SEEDED, TAGS_SEEDING)); kmeans.setAlgorithm(new SelectedTag(ALGORITHM_SIMPLE, TAGS_ALGORITHM)); euclidean.setExternal(false); euclidean.setTrainable(false); // phase 1 test kmeans.setSeedable(false); kmeans.buildClusterer(null, clusterData, theClass, data, 150); // phase 2 test //kmeans.setSeedable(true); //kmeans.buildClusterer(seeds, clusterData, theClass, data, 150); kmeans.getIndexClusters(); kmeans.printIndexClusters(); // kmeans.setVerbose(true); kmeans.bestInstancesForActiveLearning(50); } else if (dataSet.equals("news")) { //////// Text data - 3000 documents String datafile = "/u/ml/data/CCSfiles/arffFromCCS/cmu-newsgroup-clean-1000_fromCCS.arff"; System.out.println("\nClustering complete newsgroup data with seeding, using constrained KMeans...\n"); // set up the data FileReader reader = new FileReader (datafile); Instances data = new Instances (reader); System.out.println("Initial data has size: " + data.numInstances()); // Make the last attribute be the class int theClass = data.numAttributes(); data.setClassIndex(theClass-1); // starts with 0 int num_clusters = data.numClasses(); // cluster with seeding Instances seeds = new Instances(data, 0); /* seeds.add(data.instance(994)); seeds.add(data.instance(1431)); seeds.add(data.instance(1612)); seeds.add(data.instance(1747)); seeds.add(data.instance(2205)); seeds.add(data.instance(2736)); data.delete(2736); data.delete(2205); data.delete(1747); data.delete(1612); data.delete(1431); data.delete(994); seeds.add(data.instance(1000)); seeds.add(data.instance(1001)); seeds.add(data.instance(1002)); seeds.add(data.instance(1003)); seeds.add(data.instance(1004)); seeds.add(data.instance(2000)); seeds.add(data.instance(2001)); seeds.add(data.instance(2002)); seeds.add(data.instance(2003)); seeds.add(data.instance(2004)); // System.out.println("Labeled data has size: " + seeds.numInstances() + ", number of attributes: " + data.numAttributes()); data.delete(2004); data.delete(2003); data.delete(2002); data.delete(2001); data.delete(2000); data.delete(1004); data.delete(1003); data.delete(1002); data.delete(1001); data.delete(1000); data.delete(4); data.delete(3); data.delete(2); data.delete(1); data.delete(0); */ System.out.println("Unlabeled data has size: " + data.numInstances()); // Remove the class labels before clustering Instances clusterData = new Instances(data); clusterData.deleteClassAttribute(); WeightedDotP dotp = new WeightedDotP(); dotp.setExternal(false); dotp.setTrainable(false); dotp.setLengthNormalized(false); SeededKMeans kmeans = new SeededKMeans(dotp); kmeans.setVerbose(false); kmeans.setSeedingMethod(new SelectedTag(SEEDING_SEEDED, TAGS_SEEDING)); kmeans.setAlgorithm(new SelectedTag(ALGORITHM_SPHERICAL, TAGS_ALGORITHM)); kmeans.setNumClusters(3); // phase 1 test kmeans.setSeedable(false); kmeans.buildClusterer(null, clusterData, theClass, data, data.numInstances()); // phase 2 test //kmeans.setSeedable(true); //kmeans.buildClusterer(seeds, clusterData, theClass, data, 3000); kmeans.getIndexClusters(); kmeans.printIndexClusters(); // kmeans.setVerbose(true); //kmeans.bestInstancesForActiveLearning(50); // // cluster with seeding for small newsgroup// seeds = new Instances(data, 0, 3); // seeds.add(data.instance(100)); // seeds.add(data.instance(101));// seeds.add(data.instance(102));// seeds.add(data.instance(200));// seeds.add(data.instance(201));// seeds.add(data.instance(202));// seeds.add(data.instance(300));// seeds.add(data.instance(301));// seeds.add(data.instance(302));// seeds.add(data.instance(400));// seeds.add(data.instance(401));// seeds.add(data.instance(402));// seeds.add(data.instance(500));// seeds.add(data.instance(501));// seeds.add(data.instance(502));// seeds.add(data.instance(600));// seeds.add(data.instance(601));// seeds.add(data.instance(602));// seeds.add(data.instance(700));// seeds.add(data.instance(701));// seeds.add(data.instance(702));// seeds.add(data.instance(800));// seeds.add(data.instance(801));// seeds.add(data.instance(802));// seeds.add(data.instance(900));// seeds.add(data.instance(901)); // seeds.add(data.instance(902));// seeds.add(data.instance(1000)); // seeds.add(da
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?