📄 mahalanobislearner.java
字号:
// // store the constant part of the gradient:// double[][] gradientConst = new double[numAttributes][numAttributes];// for (int instIdx = 0; instIdx < m_Instances.numInstances(); instIdx++) {// // the (x-m)(x-m)' part// int centroidIdx = m_ClusterAssignments[instIdx];// Instance centroid = m_ClusterCentroids.instance(centroidIdx);// diffInstance = metric.createDiffInstance(m_Instances.instance(instIdx),// centroid);// for (int i = 0; i < numAttributes; i++) {// for (int j = 0; j <= i; j++) {// gradientConst[i][j] =// gradientConst[j][i] = diffInstance.value(i) * diffInstance.value(j); // }// }// // the violated constraints// Object list = m_instanceConstraintHash.get(new Integer(instIdx));// if (list != null) { // there are constraints associated with this instance// ArrayList constraintList = (ArrayList) list;// for (int constrIdx = 0; constrIdx < constraintList.size(); constrIdx++) {// InstancePair pair = (InstancePair) constraintList.get(constrIdx);// int firstIdx = pair.first;// int secondIdx = pair.second;// double cost = 0;// if (pair.linkType == InstancePair.MUST_LINK) {// cost = m_MLweight;// } else if (pair.linkType == InstancePair.CANNOT_LINK) {// cost = m_CLweight;// }// Instance instance1 = m_Instances.instance(firstIdx);// Instance instance2 = m_Instances.instance(secondIdx);// int otherIdx = (firstIdx == instIdx) ? m_ClusterAssignments[secondIdx]// : m_ClusterAssignments[firstIdx];// if (otherIdx == -1) {// throw new Exception("One of the instances is unassigned in "// + "updateMetricWeightsMahalanobisGD"); // }// // check whether the constraint is violated// if (otherIdx != centroidIdx &&// pair.linkType == InstancePair.MUST_LINK) {// diffInstance = metric.createDiffInstance(instance1, instance2);// for (int i = 0; i < numAttributes; i++) {// for (int j = 0; j <= i; j++) {// gradientConst[i][j] =// gradientConst[j][i] =// 0.5 * cost * diffInstance.value(i) * diffInstance.value(j);// }// }// violatedConstraints++; // } else if (otherIdx == centroidIdx &&// pair.linkType == InstancePair.CANNOT_LINK) {// diffInstance = metric.createDiffInstance(instance1, instance2);// for (int i = 0; i < numAttributes; i++) {// for (int j = 0; j <= i; j++) {// gradientConst[i][j] =// gradientConst[j][i] =// 0.5 * cost *// (maxCLUpdate[i][j] -// diffInstance.value(i) * diffInstance.value(j)); // }// }// violatedConstraints++; // }// }// }// }// Matrix constUpdate = new Matrix(gradientConst); // while (iteration < m_maxGDIterations && !converged) {// // calculate the gradient// Matrix update = constUpdate.copy(); // // factor in the A^-1 // Matrix Ai = newWeights.inverse();// Ai.timesEquals(m_logTermWeight); // update.minusEquals(Ai); // // regularization (-1/sum(a_ij)^2)// double regularizer = 0; // for (int i = 0; i < numAttributes; i++) {// for (int j = 0; j <= i; j++) {// regularizer += 2.0/(newWeights.get(i, j) * newWeights.get(i, j));// }// }// // correct for double-counted diagonal// for (int i = 0; i < numAttributes; i++) {// regularizer -= 1.0/newWeights.get(i, i);// }// regularizer *= m_currregularizerTermWeight; // for (int i = 0; i < numAttributes; i++) {// for (int j = 0; j < numAttributes; j++) {// update.set(i, j, update.get(i,j) - regularizer);// }// }// // update// update.timesEquals(m_currEta); // newWeights.minusEquals(update);// // anneal if necessary and check for convergence// m_currEta = m_currEta * m_etaDecayRate;// // check for convergence// double norm = update.norm1();// System.out.println(iteration + ": norm=" + norm); // if (norm < 0.0001) {// converged = true;// }// iteration++; // }// // We're done, set the weights to newWeights // }// MULTIPLE:// /** M-step of the KMeans clustering algorithm -- updates metric// * weights. Invoked only when metric is an instance of Mahalanobis// * @return value true if everything was alright; false if there was// miserable failure and clustering needs to be restarted */// protected boolean updateMultipleMetricWeightsMahalanobis() throws Exception {// if (m_regularizeWeights) {// System.out.println("Regularized version, calling GD version of updateMultipleMetricWeightsMahalanobisGD!");// updateMultipleMetricWeightsMahalanobisGD();// }// int numAttributes = m_Instances.numAttributes();// if (m_Instances.classIndex() >= 0) {// numAttributes--;// }// Matrix [] updateMatrices = new Matrix[m_metrics.length];// for (int i = 0; i < updateMatrices.length; i++) { // updateMatrices[i] = new Matrix(numAttributes, numAttributes);// }// int violatedConstraints = 0;// int [] counts = new int[updateMatrices.length];// for (int instIdx=0; instIdx<m_Instances.numInstances(); instIdx++) {// int centroidIdx = m_ClusterAssignments[instIdx];// Matrix diffMatrix = ((WeightedMahalanobis) m_metrics[centroidIdx]).createDiffMatrix(m_Instances.instance(instIdx),// m_ClusterCentroids.instance(centroidIdx));// updateMatrices[centroidIdx] = updateMatrices[centroidIdx].plus(diffMatrix);// counts[centroidIdx]++;// // go through violated constraints// Object list = m_instanceConstraintHash.get(new Integer(instIdx));// if (list != null) { // there are constraints associated with this instance// ArrayList constraintList = (ArrayList) list;// for (int i = 0; i < constraintList.size(); i++) {// InstancePair pair = (InstancePair) constraintList.get(i);// int firstIdx = pair.first;// int secondIdx = pair.second;// Instance instance1 = m_Instances.instance(firstIdx);// Instance instance2 = m_Instances.instance(secondIdx);// int otherIdx = (firstIdx == instIdx) ? m_ClusterAssignments[secondIdx] : m_ClusterAssignments[firstIdx];// // check whether the constraint is violated// if (otherIdx != -1) { // if (otherIdx != centroidIdx && pair.linkType == InstancePair.MUST_LINK) {// Matrix diffMatrix1 = ((WeightedMahalanobis) m_metrics[centroidIdx]).createDiffMatrix(instance1, instance2);// diffMatrix1 = diffMatrix1.times(0.25);// Matrix diffMatrix2 = ((WeightedMahalanobis) m_metrics[otherIdx]).createDiffMatrix(instance1, instance2);// diffMatrix2 = diffMatrix2.times(0.25);// updateMatrices[centroidIdx] = updateMatrices[centroidIdx].plus(diffMatrix1); // updateMatrices[otherIdx] = updateMatrices[otherIdx].plus(diffMatrix2);// violatedConstraints++; // } else if (otherIdx == centroidIdx && pair.linkType == InstancePair.CANNOT_LINK) {// diffMatrix = ((WeightedMahalanobis) m_metrics[centroidIdx]).createDiffMatrix(instance1, instance2);// Matrix maxMatrix = ((WeightedMahalanobis) m_metrics[centroidIdx]).createDiffMatrix(m_maxCLPoints[centroidIdx][0],// m_maxCLPoints[centroidIdx][1]);// diffMatrix = diffMatrix.times(0.5);// maxMatrix = maxMatrix.times(0.5);// updateMatrices[centroidIdx] = updateMatrices[centroidIdx].plus(maxMatrix); // updateMatrices[centroidIdx] = updateMatrices[centroidIdx].minus(diffMatrix);// violatedConstraints++; // }// } // end while// }// }// } // int [][] classCounts = new int[m_NumClusters][m_TotalTrainWithLabels.numClasses()];// // NB: m_TotalTrainWithLabels does *not* include unlabeled data, counts here are undersampled!// // assuming unlabeled data came from same distribution as m_TotalTrainWithLabels, counts are still valid...// for (int instIdx=0; instIdx<m_TotalTrainWithLabels.numInstances(); instIdx++) {// Instance fullInstance = m_TotalTrainWithLabels.instance(instIdx);// classCounts[m_ClusterAssignments[instIdx]][(int)(fullInstance.classValue())]++;// }// for (int i = 0; i < m_NumClusters; i++){// System.out.print("Cluster " + i + "(" + counts[i] + ")\t" + classCounts[i][0]);// for (int j = 1; j < m_TotalTrainWithLabels.numClasses(); j++) {// System.out.print("\t" + classCounts[i][j]);// }// System.out.println();// }// // now update the actual weight matrices// for (int i = 0; i < updateMatrices.length; i++) {// int maxIterations = 100;// if (counts[i] == 0) {// //System.out.println("Cluster " + i + " has lost all instances; leaving weights as is");// updateMatrices[i] = Matrix.identity(numAttributes, numAttributes);// counts[i] = 1;// //System.err.println("IRREPAIRABLE COVARIANCE MATRIX, RESTARTING");// //return false;// }// updateMatrices[i] = updateMatrices[i].times(1.0/counts[i]);// double updateDet = updateMatrices[i].det();// int currIteration = 0;// Matrix newWeights = null; // // check that the update matrix is non-singular// while (Math.abs(updateDet) < m_NRConvergenceDifference && currIteration++ < maxIterations) {// Matrix regularizer = Matrix.identity(numAttributes, numAttributes);// regularizer = regularizer.times(updateMatrices[i].trace() * 0.01);// updateMatrices[i] = updateMatrices[i].plus(regularizer);// System.out.print(i + "\tsingular UPDATE matrix, DET=" + ((float)updateDet));// updateDet = updateMatrices[i].det();// System.out.println("; after regularization DET=" + ((float)updateDet));// // System.out.println("ACTUAL weights: ");// // double[][] m_weights = updateMatrices[i].getArray();// // for (int l = 0; l < m_weights.length; l++) {// // for (int j = 0; j < m_weights[l].length; j++) {// // System.out.print(((float)m_weights[l][j]) + "\t");// // }// // System.out.println();// // }// }// if (currIteration >= maxIterations) { // if the matrix is irrepairable, return to identity matrix// newWeights = Matrix.identity(numAttributes, numAttributes);// System.err.println("IRREPAIRABLE UPDATE MATRIX, RESTARTING");// } else { // newWeights = updateMatrices[i].inverse();// } // ((WeightedMahalanobis) m_metrics[i]).setWeights(newWeights);// // project all the instances for subsequent calculation of max-points for cannot-link penalties// // TODO: we are projecting ALL instances just in case... possibly can optimize in the future// for (int instIdx=0; instIdx<m_Instances.numInstances(); instIdx++) {// ((WeightedMahalanobis) m_metrics[i]).projectInstance(m_Instances.instance(instIdx));// }// }// return true; // }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -