📄 weightedmahalanobis.java
字号:
} } if (maxDistance == -Double.MIN_VALUE) { // System.out.println("ACTUAL weights det=" + ((new Matrix(m_weightsMatrix)).det()));// for (int i = 0; i < m_weightsMatrix.length; i++) {// for (int j = 0; j < m_weightsMatrix[i].length; j++) {// System.out.print(((float)m_weightsMatrix[i][j]) + "\t");// }// System.out.println();// }// System.out.println("\n\nsq weights");// for (int i = 0; i < m_weightsMatrixSquare.length; i++) {// for (int j = 0; j < m_weightsMatrixSquare[i].length; j++) {// System.out.print(((float)m_weightsMatrixSquare[i][j]) + "\t");// }// System.out.println();// } if (m_weightsMatrixSquare != null) { m_weightsMatrixSquare = null; System.out.println("recursing"); return getMaxPoints(constraintMap, instances); } else { iterator = constraintMap.entrySet().iterator(); while (iterator.hasNext()) { Map.Entry entry = (Map.Entry) iterator.next(); int type = ((Integer) entry.getValue()).intValue(); if (type == InstancePair.CANNOT_LINK) { maxConstraint = (InstancePair) entry.getKey(); break; } } } } int firstIdx = maxConstraint.first; int secondIdx = maxConstraint.second; Instance instance1 = instances.instance(firstIdx); Instance instance2 = instances.instance(secondIdx); for (int i = 0; i < m_weightsMatrix.length; i++) { m_maxPoints[0][i] = instance1.value(i); m_maxPoints[1][i] = instance2.value(i); } return m_maxPoints; } /** * Returns a non-weighted distance value between two instances. * @param instance1 First instance. * @param instance2 Second instance. * @exception Exception if distance could not be estimated. */ public double distanceNonWeighted(Instance instance1, Instance instance2) throws Exception { double value1, value2, diff, distance = 0; double [] values1 = instance1.toDoubleArray(); double [] values2 = instance2.toDoubleArray(); // Go through all attributes for (int i = 0; i < values1.length; i++) { if (i != m_classIndex) { diff = values1[i] - values2[i]; distance += diff * diff; } } distance = Math.sqrt(distance); return distance; }; /** * Returns a similarity estimate between two instances. Similarity is obtained by * inverting the distance value using one of three methods: * CONVERSION_LAPLACIAN, CONVERSION_EXPONENTIAL, CONVERSION_UNIT. * @param instance1 First instance. * @param instance2 Second instance. * @exception Exception if similarity could not be estimated. */ public double similarity(Instance instance1, Instance instance2) throws Exception { switch (m_conversionType) { case CONVERSION_LAPLACIAN: return 1 / (1 + distance(instance1, instance2)); case CONVERSION_UNIT: return 2 * (1 - distance(instance1, instance2)); case CONVERSION_EXPONENTIAL: return Math.exp(-distance(instance1, instance2)); default: throw new Exception ("Unknown distance to similarity conversion method"); } } /** * Returns a similarity estimate between two instances without using the weights. * @param instance1 First instance. * @param instance2 Second instance. * @exception Exception if similarity could not be estimated. */ public double similarityNonWeighted(Instance instance1, Instance instance2) throws Exception { switch (m_conversionType) { case CONVERSION_LAPLACIAN: return 1 / (1 + distanceNonWeighted(instance1, instance2)); case CONVERSION_UNIT: return 2 * (1 - distanceNonWeighted(instance1, instance2)); case CONVERSION_EXPONENTIAL: return Math.exp(-distanceNonWeighted(instance1, instance2)); default: throw new Exception ("Unknown distance to similarity conversion method"); } } /** Get the values of the partial derivates for the metric components * for a particular instance pair @param instance1 the first instance @param instance2 the first instance */ public double[] getGradients(Instance instance1, Instance instance2) throws Exception { double[] gradients = new double[m_numAttributes]; double distance = distance(instance1, instance2); // gradients are zero for 0-distance instances if (distance == 0) { return gradients; } // take care of SparseInstances by enumerating over the values of the first instance for (int i = 0; i < m_numAttributes; i++) { // get the values double val1 = instance1.valueSparse(i); Attribute attr = instance1.attributeSparse(i); double val2 = instance2.value(attr); gradients[i] = 1.0 / (2*distance) * (val2 - val1) * (val2 - val1); } return gradients; } /** Train the metric */ public void learnMetric (Instances data) throws Exception { System.err.println("WeightedMahalanobis can only be learned by the outside algorithm"); } /** Set the weights */ public void setWeights(Matrix weights) {// System.out.println("Setting weights: ");// for (int i = 0; i < weights.getArray().length; i++) {// for (int j = 0; j < weights.getArray()[i].length; j++) {// System.out.print((float)weights.getArray()[i][j] + "\t");// }// System.out.println();// } // check that the matrix is positive semi-definite boolean isPSD = true; EigenvalueDecomposition ed = weights.eig(); Matrix eigenVectorsMatrix = ed.getV(); double[][] eigenVectors = eigenVectorsMatrix.getArray(); double[] evalues = ed.getRealEigenvalues(); double [][] evaluesM = new double[evalues.length][evalues.length]; for (int i = 0; i < evalues.length; i++) { if (evalues[i] < 0) { evaluesM[i][i] = 0; isPSD = false; } else { evaluesM[i][i] = evalues[i]; } } if (!isPSD) { // update the weights: A = V * E * V' Matrix eigenValuesMatrix = new Matrix(evaluesM); weights = (eigenVectorsMatrix.times(eigenValuesMatrix)).times(eigenVectorsMatrix.transpose()); // System.out.println("NON-PSD MATRIX! After zeroing negative eigenvalues got: ");// for (int i = 0; i < weights.getArray().length; i++) {// for (int j = 0; j < weights.getArray()[i].length; j++) {// System.out.print((float)weights.getArray()[i][j] + "\t");// }// System.out.println();// } } m_weightsMatrix = weights.getArray(); // if the matrix is singular, need to be careful with m_weightsMatrixSquare and not use Cholesky if (!isPSD || weights.det() < Math.pow(10, -2* m_weightsMatrix.length)) { System.out.println("Singular weight matrix! det=" + weights.det()); int maxIterations = 1000; int currIteration = 0; double det = weights.det(); while (Math.abs(det) < 1.0e-8 && currIteration++ < maxIterations) { Matrix regularizer = Matrix.identity(m_weightsMatrix.length, m_weightsMatrix.length); regularizer = regularizer.times(weights.trace() * 0.01); weights = weights.plus(regularizer); // W = W + 0.01tr(W) * I System.out.println("\t" + currIteration + ". det=" + ((float)det) + "\tafter FIXING AND REGULARIZATION det=" + weights.det()); det = weights.det(); } // if the matrix is irrepairable, use alternate factorization if (currIteration >= maxIterations) { System.out.println("IRREPAIRABLE MATRIX, using an alternate factorization:"); // sqWeights = (Lambda+)^.5 * Q^T for (int i = 0; i < evaluesM.length; i++) { evaluesM[i][i] = Math.sqrt(evaluesM[i][i]); } m_weightsMatrixSquare = new double[m_weightsMatrix.length][m_weightsMatrix.length]; for (int i = 0; i < m_weightsMatrixSquare.length; i++) { for (int j = 0; j < m_weightsMatrixSquare[i].length; j++) { m_weightsMatrixSquare[i][j] += eigenVectors[i][j] * evaluesM[j][j]; } } // m_weightsMatrixSquare = null; // m_weightsMatrix = Matrix.identity(m_weightsMatrix.length, m_weightsMatrix.length).getArray();// m_weightsMatrixSquare = Matrix.identity(m_weightsMatrix.length, m_weightsMatrix.length).getArray(); } else { // the matrix is positive definite, can do Cholesky m_weightsMatrix = weights.getArray(); m_weightsMatrixSquare = weights.chol().getL().getArray(); } // System.out.println("\nsq weights: ");// for (int i = 0; i < m_weightsMatrixSquare.length; i++) {// for (int j = 0; j < m_weightsMatrixSquare[i].length; j++) {// System.out.print(((float)m_weightsMatrixSquare[i][j]) + "\t");// }// System.out.println();// }// System.out.println("sqWeights*sqWeights'");// Matrix sqWeights = new Matrix(m_weightsMatrixSquare);// double[][] sanity = sqWeights.times(sqWeights.transpose()).getArray();// for (int i = 0; i < sanity.length; i++) {// for (int j = 0; j < sanity[i].length; j++) {// System.out.print(((float)sanity[i][j]) + "\t");// }// System.out.println();// } // System.out.println("ACTUAL weights: ");// for (int i = 0; i < m_weightsMatrix.length; i++) {// for (int j = 0; j < m_weightsMatrix[i].length; j++) {// System.out.print(((float)m_weightsMatrix[i][j]) + "\t");// }// System.out.println();// } } else { m_weightsMatrix = weights.getArray(); m_weightsMatrixSquare = weights.chol().getL().getArray(); } m_projectedInstanceHash = new HashMap(); m_maxPoints = null; m_maxProjPoints = null; recomputeNormalizer(); recomputeRegularizer(); } /** override the parent class methods */ public void setWeights(double[] weights) { int numAttributes = weights.length; m_weightsMatrix = new double[numAttributes][numAttributes]; for (int i = 0; i < numAttributes; i++) { m_weightsMatrix[i][i] = weights[i]; } setWeights(new Matrix(m_weightsMatrix)); } /** override the parent class methods */ public double[] getWeights() { double [] weights = new double[m_weightsMatrix.length]; for (int i = 0; i < weights.length; i++) { weights[i] = m_weightsMatrix[i][i]; } return weights; } /** override the parent class methods */ public Matrix getWeightsMatrix() { return new Matrix(m_weightsMatrix); } /** Computes the regularizer */ public void recomputeRegularizer() { Matrix weightMatrix = new Matrix(m_weightsMatrix); // TODONOW: implement regularization m_regularizerVal = weightMatrix.norm1(); }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -