seededkmeans.java
来自「wekaUT是 university texas austin 开发的基于wek」· Java 代码 · 共 2,032 行 · 第 1/5 页
JAVA
2,032 行
// if (m_Verbose) { System.out.println("Active learning point number: " + numPointsSelected + " is: " + nextPoint + ", with class label: " + classLabel); // } allClustersFound = setAllClustersFound(clusterSizes); if (allClustersFound != true) { throw new Exception("Something went wrong - all clusters should be set in phase 2!!"); } } } return activeLearningPoints; } /** Returns true if all clusterSizes are non-zero */ boolean setAllClustersFound(int [] clusterSizes) { boolean found = true; for (int i=0; i<m_NumClusters; i++) { if (clusterSizes[i] == 0) { found = false; } //if (m_Verbose) { System.out.println("Cluster " + i + " has size: " + clusterSizes[i]); //} } return found; } /** Finds the sum of instance sum with instance inst */ Instance sumWithInstance(Instance sum, Instance inst) throws Exception { Instance newSum; if (sum == null) { if (isSparseInstance) { newSum = new SparseInstance(inst); newSum.setDataset(m_Instances); } else { newSum = new Instance(inst); newSum.setDataset(m_Instances); } } else { newSum = sumInstances(sum, inst); } return newSum; } /** Finds point which has max min-distance from set visitedPoints */ int farthestFromSet(HashSet visitedPoints, HashSet eliminationSet) throws Exception { // implements farthest-first search algorithm: /* for (each datapoint x not in visitedPoints) { distance of x to visitedPoints = min{d(x,f):f \in visitedPoints} } select the point x with maximum distance as new center; */ if (visitedPoints.size() == 0) { Random rand = new Random(m_randomSeed); int point = rand.nextInt(m_StartingIndexOfTest); // Note - no need to check for labeled data now, since we have no visitedPoints // => no labeled data System.out.println("First point selected: " + point); return point; } else { if (m_Verbose) { Iterator iter = visitedPoints.iterator(); while(iter.hasNext()) { System.out.println("In visitedPoints set: " + ((Integer) iter.next()).intValue()); } if (eliminationSet != null) { iter = eliminationSet.iterator(); while(iter.hasNext()) { System.out.println("In elimination set: " + ((Integer) iter.next()).intValue()); } } } } double minSimilaritySoFar = Double.POSITIVE_INFINITY; double maxDistanceSoFar = Double.NEGATIVE_INFINITY; ArrayList bestPoints = new ArrayList(); for (int i=0; i<m_Instances.numInstances() && i<m_StartingIndexOfTest; i++) { // point should not belong to test set if (!visitedPoints.contains(new Integer(i))) { if (eliminationSet == null || !eliminationSet.contains(new Integer(i))) { // point should not belong to visitedPoints Instance inst = m_Instances.instance(i); Iterator iter = visitedPoints.iterator(); double minDistanceFromSet = Double.POSITIVE_INFINITY; double maxSimilarityFromSet = Double.NEGATIVE_INFINITY; while (iter.hasNext()) { Instance pointInSet = m_Instances.instance(((Integer) iter.next()).intValue()); if (!m_objFunDecreasing) { double sim = m_metric.similarity(inst, pointInSet); if (sim > maxSimilarityFromSet) { maxSimilarityFromSet = sim; // if (m_Verbose) { // System.out.println("Max similarity of " + i + " from set is: " + maxSimilarityFromSet); // } } } else { double dist = m_metric.distance(inst, pointInSet); if (dist < minDistanceFromSet) { minDistanceFromSet = dist; // if (m_Verbose) { // System.out.println("Min distance of " + i + " from set is: " + minDistanceFromSet); // } } } } if (m_Verbose) { System.out.println(i + " has sim: " + maxSimilarityFromSet + ", best: " + minSimilaritySoFar); } if (!m_objFunDecreasing) { if (maxSimilarityFromSet == minSimilaritySoFar) { minSimilaritySoFar = maxSimilarityFromSet; bestPoints.add(new Integer(i)); if (m_Verbose) { System.out.println("Additional point added: " + i + " with similarity: " + minSimilaritySoFar); } } else if (maxSimilarityFromSet < minSimilaritySoFar) { minSimilaritySoFar = maxSimilarityFromSet; bestPoints.clear(); bestPoints.add(new Integer(i)); if (m_Verbose) { System.out.println("Farthest point from set is: " + i + " with similarity: " + minSimilaritySoFar); } } } else { if (minDistanceFromSet == maxDistanceSoFar) { minDistanceFromSet = maxDistanceSoFar; bestPoints.add(new Integer(i)); if (m_Verbose) { System.out.println("Additional point added: " + i + " with similarity: " + minSimilaritySoFar); } } else if (minDistanceFromSet > maxDistanceSoFar) { maxDistanceSoFar = minDistanceFromSet; bestPoints.clear(); bestPoints.add(new Integer(i)); if (m_Verbose) { System.out.println("Farthest point from set is: " + i + " with distance: " + maxDistanceSoFar); } } } } } } int bestPoint = -1; if (bestPoints.size() > 1) { // multiple points, get random from whole set Random random = new Random(m_randomSeed); bestPoint = random.nextInt(m_StartingIndexOfTest); while ((visitedPoints != null && visitedPoints.contains(new Integer(bestPoint))) || (eliminationSet != null && eliminationSet.contains(new Integer(bestPoint)))) { bestPoint = random.nextInt(m_StartingIndexOfTest); } System.out.println("Randomly selected " + bestPoint + " with similarity: " + minSimilaritySoFar); } else { // only 1 point, fine bestPoint = ((Integer)bestPoints.get(0)).intValue(); System.out.println("Deterministically selected " + bestPoint + " with similarity: " + minSimilaritySoFar); } if (m_Verbose) { if (!m_objFunDecreasing) { System.out.println("Randomly selected " + bestPoint + " with similarity: " + minSimilaritySoFar); } else { System.out.println("Randomly selected " + bestPoint + " with similarity: " + maxDistanceSoFar); } } return bestPoint; } /** Finds point which is nearest to center. This point should not be * a test point and should not belong to visitedPoints */ int nearestFromPoint(Instance center, HashSet visitedPoints) throws Exception { double maxSimilarity = Double.NEGATIVE_INFINITY; double minDistance = Double.POSITIVE_INFINITY; int bestPoint = -1; for (int i=0; i<m_Instances.numInstances() && i<m_StartingIndexOfTest; i++) { // bestPoint should not be a test point if (!visitedPoints.contains(new Integer(i))) { // bestPoint should not belong to visitedPoints Instance inst = m_Instances.instance(i); if (!m_objFunDecreasing) { double sim = m_metric.similarity(inst, center); if (sim > maxSimilarity) { bestPoint = i; maxSimilarity = sim; if (m_Verbose) { System.out.println("Nearest point is: " + bestPoint + " with sim: " + maxSimilarity); } } } else { double dist = m_metric.distance(inst, center); if (dist < minDistance) { bestPoint = i; minDistance = dist; if (m_Verbose) { System.out.println("Nearest point is: " + bestPoint + " with dist: " + minDistance); } } } } } return bestPoint; } /** Finds sum of instances (handles sparse and non-sparse) */ protected Instance sumInstances(Instance inst1, Instance inst2) throws Exception { int numAttributes = inst1.numAttributes(); if (inst2.numAttributes() != numAttributes) { throw new Exception ("Error!! inst1 and inst2 should have same number of attributes."); } if (m_Verbose) { // System.out.println("Instance 1 is: " + inst1 + ", instance 2 is: " + inst2); } double weight1 = inst1.weight(), weight2 = inst2.weight(); double [] values = new double[numAttributes]; for (int i=0; i<numAttributes; i++) { values[i] = 0; } if (inst1 instanceof SparseInstance && inst2 instanceof SparseInstance) { for (int i=0; i<inst1.numValues(); i++) { int indexOfIndex = inst1.index(i); values[indexOfIndex] = inst1.valueSparse(i); } for (int i=0; i<inst2.numValues(); i++) { int indexOfIndex = inst2.index(i); values[indexOfIndex] += inst2.valueSparse(i); } SparseInstance newInst = new SparseInstance(weight1+weight2, values); newInst.setDataset(m_Instances); if (m_Verbose) { // System.out.println("Sum instance is: " + newInst); } return newInst; } else if (!(inst1 instanceof SparseInstance) && !(inst2 instanceof SparseInstance)){ for (int i=0; i<numAttributes; i++) { values[i] = inst1.value(i) + inst2.value(i); } } else { throw new Exception ("Error!! inst1 and inst2 should be both of same type -- sparse or non-sparse"); } Instance newInst = new Instance(weight1+weight2, values); newInst.setDataset(m_Instances); if (m_Verbose) { // System.out.println("Sum instance is: " + newInst); } return newInst; } /** This function divides every attribute value in an instance by * the instance weight -- useful to find the mean of a cluster in * Euclidean space * @param inst Instance passed in for normalization (destructive update) */ protected void normalizeByWeight(Instance inst) { double weight = inst.weight(); if (m_Verbose) { // System.out.println("Before weight normalization: " + inst); } if (inst instanceof SparseInstance) { for (int i=0; i<inst.numValues(); i++) { inst.setValueSparse(i, inst.valueSparse(i)/weight); } } else if (!(inst instanceof SparseInstance)) { for (int i=0; i<inst.numAttributes(); i++) { inst.setValue(i, inst.value(i)/weight); } } if (m_Verbose) { // System.out.println("After weight normalization: " + inst); } } public int[] oldBestInstancesForActiveLearning(int numActive) throws Exception{ int numInstances = m_Instances.numInstances(); double [] scores = new double [numInstances]; int numLabeledData = 0; if (m_SeedHash != null) { numLabeledData = m_SeedHash.size(); } // Remember: order of data -- labeled, then unlabeled, then test for (int i=0; i<numLabeledData; i++) { scores[i] = -1; } for (int i=numLabeledData; i<numInstances; i++) { double score = 0, normalizer = 0; Instance inst = m_Instances.instance(i); double[] prob = new double[m_NumClusters]; for (int j=0; j<m_NumClusters; j++) { if (!m_objFunDecreasing) { double sim = m_metric.similarity(inst, m_ClusterCentroids.instance(j)); prob[j] = Math.exp(sim * m_Concentration); // P(x|h) } else { double dist = m_metric.distance(inst, m_ClusterCentroids.instance(j)); prob[j] = Math.exp(-dist*dist * m_Concentration); // P(x|h) } normalizer += prob[j]; // P(x)/P(h) = Sum_h P(x|h) [uniform priors P(h)] } for (int j=0; j<m_NumClusters; j++) { prob[j] /= normalizer; // P(h|x) = P(x|h)*P(h)/P(x) score -= prob[j] * Math.log(prob[j]); } scores[i] = score * normalizer; // InfoGain = H(C|x).P(x) [with a constant factor of 1/P(h)] } System.out.println("NumInstances: "+ numInstances + ", starting index of unlabeled train: " + numLabeledData + ", starting index of test: " + m_StartingIndexOfTest); int [] indices = Utils.sort(scores); int [] mostConfused = new int [numActive]; for (int i=0,num=0; i<numInstances && num<numActive; i++) { int index = numInstances-1-i; if ((indices[index]<m_StartingIndexOfTest) && (scores[indices[index]]!=-1)) { // makes sure that labeled or test instances are not asked to be active labeled mostConfused[num] = (indices[index]); num++; } } for (int i=0; i<numActive; i++) { // System.out.println("Value: " + scores[mostConfused[i]] + ", index: " + mostConfused[i]); } return mostConfused; } /** * Checks if instance has to be normalized and classifies the * instance using the current clustering * * @param instance the instance to be assigned to a cluster * @return the number of the assigned cluster as an integer * if the class is enumerated, otherwise the predicted value * @exception Exception if instance could not be classified * successfully */ public int clusterInstance(Instance instance) throws Exception { if (m_Algorithm == ALGORITHM_SPHERICAL) { // check here, since evaluateModel calls this function on test data normalize(instance); } return assignClusterToInstance(instance); } /** * Classifies the instance using the current clustering * * @param instance the instance to be assigned to a cluster * @return the number of the assigned cluster as an integer * if the class is enumerated, otherwise the predicted value * @exception Exception if instance could not be classified * successfully */ public int assignClusterToInstance(Instance instance) throws Exception { int bestCluster = 0; double bestDistance = Double.POSITIVE_INFINITY; double bestSimilarity = Double.NEGATIVE_INFINITY; for (int i = 0; i < m_NumClusters; i++) { double distance = 0, similarity = 0; if (!m_objFunDecreasing) { similarity = m_metric.similarity(instance, m_ClusterCentroids.instance(i)); if (similarity > bestSimilarity) { bestSimilarity = similarity; bestCluster = i; } } else { distance = m_metric.distance(instance, m_ClusterCentroids.instance(i)); if (distance < bestDistance) { bestDistance = distance; bestCluster = i; } } } if (bestSimilarity == 0) { System.out.println("Note!! bestSimilarity is 0 for instance " + m_currIdx + ", assigned to cluster: " + bestCluster + " ... instance is: " + instance); } return bestCluster; } /** Return the number of clusters */
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?