⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 mahalanobislearner.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    MahalanobisLearner.java *    Copyright (C) 2004 Mikhail Bilenko and Sugato Basu * */package weka.clusterers.metriclearners; import java.util.*;import weka.core.*;import weka.core.metrics.*;import weka.clusterers.MPCKMeans;import weka.clusterers.InstancePair;import Jama.Matrix; /**  * A closed-form based learner for Mahalanobis * * @author Mikhail Bilenko (mbilenko@cs.utexas.edu) and Sugato Basu * (sugato@cs.utexas.edu) * @version $Revision: 1.5 $ */public class MahalanobisLearner extends MPCKMeansMetricLearner {  /** min difference of objective function values for convergence*/  protected double m_minDet = 1e-5;  public void resetLearner() {  }   /** if clusterIdx is -1, all instances are used   * (a single metric for all clusters is used) */     public boolean trainMetric(int clusterIdx) throws Exception {    Init(clusterIdx);    Matrix updateMatrix = new Matrix(m_numAttributes, m_numAttributes);    int violatedConstraints = 0;    int numInstances = 0;    WeightedMahalanobis metric = (WeightedMahalanobis) m_metric;    Matrix maxMatrix = null;    if (m_instanceConstraintMap.size() > 0) {      if (clusterIdx == -1) { 	maxMatrix = metric.createDiffMatrix(m_kmeans.m_maxCLPoints[0][0],					    m_kmeans.m_maxCLPoints[0][1]);      } else {	maxMatrix = metric.createDiffMatrix(m_kmeans.m_maxCLPoints[clusterIdx][0],					    m_kmeans.m_maxCLPoints[clusterIdx][1]);      }       maxMatrix = maxMatrix.times(0.5);    }    for (int instIdx = 0; instIdx < m_instances.numInstances(); instIdx++) {      int assignment = m_clusterAssignments[instIdx];      // only instances assigned to this cluster are of importance      if (assignment == clusterIdx || clusterIdx == -1) {	numInstances++;	if (clusterIdx < 0) {	  m_centroid = m_kmeans.getClusterCentroids().instance(assignment); 	}	Instance instance = m_instances.instance(instIdx); 	Matrix diffMatrix = metric.createDiffMatrix(instance, m_centroid); 	updateMatrix = updateMatrix.plus(diffMatrix);	// go through violated constraints	Object list =  m_instanceConstraintMap.get(new Integer(instIdx));	if (list != null) {   // there are constraints associated with this instance	  ArrayList constraintList = (ArrayList) list;	  for (int i = 0; i < constraintList.size(); i++) {	    InstancePair pair = (InstancePair) constraintList.get(i);	    int linkType = pair.linkType;	    int firstIdx = pair.first;	    int secondIdx = pair.second;	    Instance instance1 = m_instances.instance(firstIdx);	    Instance instance2 = m_instances.instance(secondIdx);	    int otherIdx = (firstIdx == instIdx) ? m_clusterAssignments[secondIdx]	      : m_clusterAssignments[firstIdx];	    // check whether the constraint is violated	    if (otherIdx != -1 ) {  	      if (otherIdx != assignment && linkType == InstancePair.MUST_LINK) {		diffMatrix = metric.createDiffMatrix(instance1, instance2);		diffMatrix = diffMatrix.times(0.5);		updateMatrix = updateMatrix.plus(diffMatrix); 		violatedConstraints++; 	      } else if (otherIdx == assignment && linkType == InstancePair.CANNOT_LINK) {		diffMatrix = metric.createDiffMatrix(instance1, instance2);		diffMatrix = diffMatrix.times(0.5);		updateMatrix = updateMatrix.plus(maxMatrix); 		updateMatrix = updateMatrix.minus(diffMatrix);		violatedConstraints++; 	      }	    } // end while	  }	}      }    }    updateMatrix = updateMatrix.times(1.0/numInstances);    double updateDet = updateMatrix.det();    int maxIterations = 1000;    int currIteration = 1;    Matrix newWeights = null;//      System.out.println("UPDATE weights: " + " (violated constraints: " + violatedConstraints + ")");//      for (int i = 0; i < updateMatrix.getArray().length; i++) {//        for (int j = 0; j < updateMatrix.getArray()[i].length; j++) {//    	System.out.print((float)updateMatrix.getArray()[i][j] + "\t");//        }//        System.out.println();//      }    // check that the update matrix is non-singular    while (Math.abs(updateDet) < m_minDet && currIteration++ < maxIterations) {      Matrix regularizer = Matrix.identity(m_numAttributes, m_numAttributes);      regularizer = regularizer.times(updateMatrix.trace() * 0.01);      updateMatrix = updateMatrix.plus(regularizer);      System.out.print("\t" + currIteration + ". Singular update matrix, DET=" + (float)updateDet);      updateDet = updateMatrix.det();      System.out.println("; after regularization DET=" + (float)updateDet);    }        if (currIteration >= maxIterations) {      // if the matrix is irrepairable, return to identity matrix      System.out.println("\n\nCOULDN'T REGULARIZE; GOING TO IDENTITY\n\n");      newWeights = Matrix.identity(m_numAttributes, m_numAttributes);    } else {       newWeights = updateMatrix.inverse();    }     //      // check that matrix is positive semi-definite    //      currIteration = 0;    //      double det = newWeights.det();    //      Matrix weightsSquare = newWeights.chol().getL();    //      double sqDet = weightsSquare.det();    //      while ((det < 0 || Math.abs(det) < m_ObjFunConvergenceDifference    //  	    || Math.abs(sqDet) < m_ObjFunConvergenceDifference || Double.isNaN(sqDet))    //  	   && currIteration++ < maxIterations) {    //        // make sure the the matrix is symmetric positive definite    //        if (det < 0) {    //  	EigenvalueDecomposition ed = newWeights.eig();    //  	Matrix eigenVectorsMatrix = ed.getV();    //  	double[] evalues = ed.getRealEigenvalues();    //  	double [][] evaluesM = new double[evalues.length][evalues.length];    //  	for (int i = 0; i < evalues.length; i++) {    //  	  if (evalues[i] < 0) {    //  	    evalues[i] = -evalues[i];    //  	  } else {    //  	    evaluesM[i][i] = evalues[i];    //  	  }    //  	}    //  	Matrix eigenValuesMatrix = new Matrix(evaluesM);     //  	// update the weights:  A' = V' * E * V    //  	newWeights = ((eigenVectorsMatrix.transpose()).times(eigenValuesMatrix)).times(eigenVectorsMatrix);    //  	System.out.println("\tNegative determinant; projecting for subsequent regularization");    //        }    //        // the weights matrix may end up singular (if determinant was negative, or det(updateMatrix) was very large    //        sqDet = newWeights.chol().getL().det();    //        det = newWeights.det();    //        if (Math.abs(det) < m_ObjFunConvergenceDifference || Math.abs(sqDet) < m_ObjFunConvergenceDifference    //  	  || Double.isNaN(sqDet)) {    //  	Matrix regularizer = Matrix.identity(m_numAttributes, m_numAttributes);    //  	regularizer = regularizer.times(newWeights.trace() * 0.01);    //  	newWeights = newWeights.plus(regularizer);  // W = W + 0.01tr(W) * I    //  	System.out.println("\tsingular matrix, det=" + ((float)det) + ", sqDet=" + ((float)sqDet) +    //     "\tafter FIXING AND REGULARIZATION det=" + newWeights.det());    //  	det = newWeights.det();    //  	sqDet = newWeights.chol().getL().det();    //        }    //      }    //      // if the matrix is irrepairable, return to identity matrix    //      if (currIteration >= maxIterations) {     //        newWeights = Matrix.identity(m_numAttributes, m_numAttributes);    //      }        metric.setWeights(newWeights);    // project all the instances for subsequent calculation of max-points for cannot-link penalties    for (int instIdx=0; instIdx<m_instances.numInstances(); instIdx++) {      if (clusterIdx < 0 || m_clusterAssignments[instIdx] == clusterIdx) { 	metric.projectInstance(m_instances.instance(instIdx));      }    }    return true;   }  /**   * Gets the current settings of KL   *   * @return an array of strings suitable for passing to setOptions()   */  public String [] getOptions() {    String [] options = new String [1];    int current = 0;    while (current < options.length) {      options[current++] = "";    }    return options;  }  public void setOptions(String[] options) throws Exception {    // TODO: add later   }  public Enumeration listOptions() {    // TODO: add later     return null;  }}//    protected void updateMetricWeightsMahalanobisGD() throws Exception {//      WeightedMahalanobis metric = (WeightedMahalanobis) m_metric;		//      int numAttributes = m_Instances.numAttributes();//      Instance diffInstance;//      int violatedConstraints = 0;//      Matrix newWeights = metric.getWeightsMatrix().copy();//      // Do the GD//      int iteration = 0;//      boolean converged = false;//      // precompute the update matrix for maxCannotLinkInstance//      double[][] maxCLUpdate = new double[numAttributes][numAttributes];//      Instance maxCLDiffInstance = null; //      if (m_maxCLPoints != null) { //        maxCLDiffInstance = metric.createDiffInstance(m_maxCLPoints[0][0],//  						    m_maxCLPoints[0][1]);//        for (int i = 0; i < numAttributes; i++) {//  	for (int j = 0; j <=i; j++) {//  	  maxCLUpdate[i][j] =//  	    maxCLUpdate[j][i] =//  	    maxCLDiffInstance.value(i) *maxCLDiffInstance.value(j);//  	}//        }//      }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -