📄 weuclideanlearner.java
字号:
// System.out.println();// } else {// System.out.println("Constraints satisfied");// } return raphsonWeights; }// OLD CODE FOR MULTIPLE:// /** M-step of the KMeans clustering algorithm -- updates metric// * weights for the individual metrics. Invoked only whe metric is trainable// */// protected boolean updateMultipleMetricWeightsEuclidean() throws Exception {// if (m_regularizeWeights) {// System.out.println("Regularized version, calling GD version of updateMultipleMetricWeightsEuclidean!");// updateMultipleMetricWeightsEuclideanGD();// }// int numAttributes = m_Instances.numAttributes();// double[][] weights = new double[m_NumClusters][numAttributes];// int []counts = new int[m_NumClusters]; // count how many instances are in each cluster// Instance diffInstance;// //begin debugging variance// boolean debugVariance = true; // double[][] trueWeights = new double[m_NumClusters][numAttributes];// int [] majorityClasses = new int[m_NumClusters];// int [][] classCounts = new int[m_NumClusters][m_TotalTrainWithLabels.numClasses()];// // get the majority counts// // NB: m_TotalTrainWithLabels does *not* include unlabeled data, counts here are undersampled!// // assuming unlabeled data came from same distribution as m_TotalTrainWithLabels, counts are still valid...// for (int instIdx=0; instIdx<m_TotalTrainWithLabels.numInstances(); instIdx++) {// Instance fullInstance = m_TotalTrainWithLabels.instance(instIdx);// classCounts[m_ClusterAssignments[instIdx]][(int)(fullInstance.classValue())]++;// }// for (int i = 0; i < m_NumClusters; i++){// int majorityClass = 0;// System.out.print("Cluster" + i + "\t" + classCounts[i][0]);// for (int j = 1; j < m_TotalTrainWithLabels.numClasses(); j++) {// System.out.print("\t" + classCounts[i][j]);// if (classCounts[i][j] > classCounts[i][majorityClass]) {// majorityClass = j;// }// }// System.out.println();// majorityClasses[i] = majorityClass;// }// class MajorityChecker {// int [] m_majorityClasses = null; // public MajorityChecker(int [] majClasses) { m_majorityClasses = majClasses;}// public boolean belongsToMajority(Instances instances, int instIdx, int centroidIdx) {// // silly, must pass instance since can't access outer class fields otherwise from a local inner class// Instance fullInstance = instances.instance(instIdx); // int classValue = (int) fullInstance.classValue();// if (classValue == m_majorityClasses[centroidIdx]) {// return true;// } else {// return false;// }// }// }// MajorityChecker majChecker = new MajorityChecker(majorityClasses);// //end debugging variance // int violatedConstraints = 0; // for (int instIdx=0; instIdx<m_Instances.numInstances(); instIdx++) {// int centroidIdx = m_ClusterAssignments[instIdx];// diffInstance = m_metrics[centroidIdx].createDiffInstance(m_Instances.instance(instIdx), m_ClusterCentroids.instance(centroidIdx));// for (int attr=0; attr<numAttributes; attr++) {// weights[centroidIdx][attr] += diffInstance.value(attr); // Mahalanobis components// if (debugVariance && instIdx < m_TotalTrainWithLabels.numInstances()) {// if (majChecker.belongsToMajority(m_TotalTrainWithLabels, instIdx, centroidIdx)) {// trueWeights[centroidIdx][attr] += diffInstance.value(attr);// } // }// }// counts[centroidIdx]++;// Object list = m_instanceConstraintHash.get(new Integer(instIdx));// if (list != null) { // there are constraints associated with this instance// ArrayList constraintList = (ArrayList) list;// for (int i = 0; i < constraintList.size(); i++) {// InstancePair pair = (InstancePair) constraintList.get(i);// int firstIdx = pair.first;// int secondIdx = pair.second;// double cost = 0;// if (pair.linkType == InstancePair.MUST_LINK) {// cost = m_MLweight;// } else if (pair.linkType == InstancePair.CANNOT_LINK) {// cost = m_CLweight;// }// Instance instance1 = m_Instances.instance(firstIdx);// Instance instance2 = m_Instances.instance(secondIdx);// int otherIdx = (firstIdx == instIdx) ? m_ClusterAssignments[secondIdx] : m_ClusterAssignments[firstIdx];// // check whether the constraint is violated// if (otherIdx != -1) { // if (otherIdx != centroidIdx && pair.linkType == InstancePair.MUST_LINK) { // violated must-link// if (m_verbose) {// System.out.println("Found violated must link between: " + firstIdx + " and " + secondIdx);// }// // we penalize weights for both clusters involved, splitting the penalty in half// Instance diffInstance1 = m_metrics[otherIdx].createDiffInstance(instance1, instance2);// Instance diffInstance2 = m_metrics[centroidIdx].createDiffInstance(instance1, instance2); // for (int attr=0; attr<numAttributes; attr++) { // double-counting constraints, hence 0.5*0.5// weights[otherIdx][attr] += 0.25 * cost * diffInstance1.value(attr);// weights[centroidIdx][attr] += 0.25 * cost * diffInstance2.value(attr);// }// violatedConstraints++; // } // else if (otherIdx == centroidIdx && pair.linkType == InstancePair.CANNOT_LINK) { //violated cannot-link// if (m_verbose) {// System.out.println("Found violated cannot link between: " + firstIdx + " and " + secondIdx);// }// // we penalize weights for just one cluster involved// diffInstance = m_metrics[centroidIdx].createDiffInstance(instance1, instance2);// Instance cannotDiffInstance = m_metrics[otherIdx].createDiffInstance(m_maxCLPoints[centroidIdx][0],// m_maxCLPoints[centroidIdx][1]);// for (int attr=0; attr<numAttributes; attr++) { // double-counting constraints, hence 0.5// weights[centroidIdx][attr] += 0.5 * cost * cannotDiffInstance.value(attr);// weights[centroidIdx][attr] -= 0.5 * cost * diffInstance.value(attr); // }// violatedConstraints++; // }// } // end while// }// }// }// System.out.println(" Total constraints violated: " + violatedConstraints/2 + "; per-cluster weights are:"); // // check if NR needed// double [][] newWeights = new double[m_NumClusters][numAttributes];// double [][] currentWeights = new double[m_NumClusters][numAttributes];// for (int i=0; i<m_NumClusters; i++) {// currentWeights[i] = ((LearnableMetric) m_metrics[i]).getWeights();// }// for (int i=0; i<m_NumClusters; i++) {// boolean needNewtonRaphson = false;// for (int attr=0; attr<numAttributes; attr++) {// if (weights[i][attr] < 0) { // check to avoid divide by 0// System.out.println("WARNING! Cluster " + i + ", attribute " + attr + " weight=" + weights[i][attr]);// Cluster currentCluster = (Cluster) getClusters().get(i);// System.out.println("\nCluster " + i + ": " + currentCluster.size() + " instances");// if (currentCluster == null) {// System.out.println("(empty)");// }// else {// for (int j=0; j<currentCluster.size(); j++) {// Instance instance = (Instance) currentCluster.get(j); // System.out.println("Instance: " + instance);// }// } // needNewtonRaphson = true;// break;// } else if (weights[i][attr] == 0) {// newWeights[i][attr] = currentWeights[i][attr];// System.out.println("WARNING! Cluster " + i + ", attribute " + attr + " has 0 weight; keeping it as " + weights[i][attr]);// } else {// newWeights[i][attr] = m_logTermWeight * counts[i]/weights[i][attr]; // invert weights// if (debugVariance) {// trueWeights[i][attr] = counts[i]/trueWeights[i][attr];// }// }// } // // uncomment next line for debugging NR// // needNewtonRaphson = true;// // do NR if needed// if (needNewtonRaphson) {// // weights not inverted here -- done in NR routine// newWeights[i] = updateWeightsUsingNewtonRaphson(currentWeights[i], weights[i]); // System.out.println(" (NR) ");// } // // PRINT routine// // System.out.print("\t" + i + "(" + counts[i] + "): ");// // for (int attr=0; attr<numAttributes; attr++) {// // if (debugVariance) {// // System.out.print(((float)trueWeights[i][attr]) + "/~/");// // } // // System.out.print(((float)newWeights[i][attr]) + "\t");// // }// // System.out.println();// // System.out.println("\t\tMean: " + m_ClusterCentroids.instance(i));// // end PRINT routine// ((LearnableMetric) m_metrics[i]).setWeights(newWeights[i]);// }// return true;// } /** * Gets the current settings of KL * * @return an array of strings suitable for passing to setOptions() */ public String [] getOptions() { String [] options = new String [1]; int current = 0; while (current < options.length) { options[current++] = ""; } return options; } public void setOptions(String[] options) throws Exception { // TODO: add later } public Enumeration listOptions() { // TODO: add later return null; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -