⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 weuclideanlearner.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    WEuclideanLearner.java *    Copyright (C) 2004 Mikhail Bilenko and Sugato Basu * */package weka.clusterers.metriclearners; import java.util.*;import weka.core.*;import weka.core.metrics.*;import weka.clusterers.MPCKMeans;import weka.clusterers.InstancePair;/**  * A closed-form learner for WeightedEuclidean * * @author Mikhail Bilenko (mbilenko@cs.utexas.edu) and Sugato Basu * (sugato@cs.utexas.edu * @version $Revision: 1.5 $ */public class WEuclideanLearner extends  MPCKMeansMetricLearner {    public void resetLearner() {  }     /** if clusterIdx is -1, all instances are used   * (a single metric for all clusters is used) */     public boolean trainMetric(int clusterIdx) throws Exception {    Init(clusterIdx);     double[] weights = new double[m_numAttributes];    int violatedConstraints = 0;    int numInstances = 0;     for (int instIdx = 0; instIdx < m_instances.numInstances(); instIdx++) {      int assignment = m_clusterAssignments[instIdx];      // only instances assigned to this cluster are of importance      if (assignment == clusterIdx || clusterIdx == -1) {	numInstances++;	if (clusterIdx < 0) {	  m_centroid = m_kmeans.getClusterCentroids().instance(assignment); 	} 		// accumulate variance	Instance instance = m_instances.instance(instIdx);	Instance diffInstance = m_metric.createDiffInstance(instance, m_centroid); 	for (int attr = 0; attr < m_numAttributes; attr++) {	  weights[attr] += diffInstance.value(attr); 	}	// check all constraints for this instance	Object list =  m_instanceConstraintMap.get(new Integer(instIdx));	if (list != null) {   // there are constraints associated with this instance	  ArrayList constraintList = (ArrayList) list;	  for (int i = 0; i < constraintList.size(); i++) {	    InstancePair pair = (InstancePair) constraintList.get(i);	    int linkType = pair.linkType;	    int firstIdx = pair.first;	    int secondIdx = pair.second;	    Instance instance1 = m_instances.instance(firstIdx);	    Instance instance2 = m_instances.instance(secondIdx);	    int otherIdx = (firstIdx == instIdx) ? m_clusterAssignments[secondIdx]	      : m_clusterAssignments[firstIdx];	    if (otherIdx != -1) {  // check whether the constraint is violated	      if (otherIdx != assignment && linkType == InstancePair.MUST_LINK) {		diffInstance = m_metric.createDiffInstance(instance1, instance2);		for (int attr = 0; attr < m_numAttributes; attr++) {  		  weights[attr] += 0.5 * m_MLweight * diffInstance.value(attr);		}	      }	      else if (otherIdx == assignment && linkType == InstancePair.CANNOT_LINK){ 		diffInstance = m_metric.createDiffInstance(instance1, instance2);		for (int attr = 0; attr < m_numAttributes; attr++) {		  // this constraint will be counted twice, hence 0.5		  weights[attr] += 0.5 * m_CLweight * m_maxCLDiffInstance.value(attr);		  weights[attr] -= 0.5 * m_CLweight * diffInstance.value(attr); 		}	      }	    } 	  }	}      }    }//      System.out.println("Updating cluster " + clusterIdx//  		       + " containing " + numInstances);     // check the weights    double [] newWeights = new double[m_numAttributes];    double [] currentWeights = m_metric.getWeights();    boolean needNewtonRaphson = false;    for (int attr = 0; attr < m_numAttributes; attr++) {      if (weights[attr] <= 0) { // check to avoid divide by 0 - TODO!	System.out.println("Negative weight " + weights[attr]			   + " for clusterIdx=" + clusterIdx			   + "; using prev value=" + currentWeights[attr]); 	newWeights[attr] = currentWeights[attr];	//  	needNewtonRaphson = true;	//	break;      } else {	if (m_regularize) { // solution of quadratic equation - TODO!	  int n = m_instances.numInstances();	  double ratio =  (m_logTermWeight * n) / (2 * weights[attr]);	  newWeights[attr] = ratio + Math.sqrt(ratio*ratio +					       (m_regularizerTermWeight*n)					       /weights[attr]);	} else {	  newWeights[attr] = m_logTermWeight * numInstances  / weights[attr];	}      }    }        // do NR if needed    if (needNewtonRaphson) {      System.out.println("GOING TO NEWTON-RAPHSON!!!\n");       newWeights = updateWeightsUsingNewtonRaphson(currentWeights, weights);    }     // PRINT routine    //      System.out.println("Total constraints violated: " + violatedConstraints/2 + "; weights are:");     //      for (int attr=0; attr<numAttributes; attr++) {    //        System.out.print(newWeights[attr] + "\t");    //      }    //      System.out.println();    // end PRINT routine    m_metric.setWeights(newWeights);    return true;  }    /** calculates weights using Newton Raphson, to satisfy the      positivity constraint of each attribute weight, returns learned      attribute weights. Note: currentAttrWeights is the inverted version      of the current m_metric weights.  */  protected double [] updateWeightsUsingNewtonRaphson    (double [] currentAttrWeights, double [] invUnconstrainedAttrWeights)    throws Exception {    int numAttributes = currentAttrWeights.length;     double [] iterAttrWeights = currentAttrWeights;    //      System.out.println("Updating Weights Using NewtonRaphson");//      do {//        // sets new attribute weights using NR with line search for alpha//        iterAttrWeights = nrWithLineSearchForAlpha(iterAttrWeights,//  						 invUnconstrainedAttrWeights); //        // set current attribute weight to m_metric, recalculate obj. fn.//        m_OldObjective = m_Objective;//        ((LearnableM_metric) m_m_metric).setWeights(iterAttrWeights);//        calculateObjectiveFunction();//      } while (!convergenceCheck(m_OldObjective, m_Objective, false)); // objective function not guaranteed to monotonically decrease across NR iterations, so don't do convergence check    return iterAttrWeights;  }  /** Does one NR step, calculates the alpha (using line search) that      does not violate positivity constraint of each attribute weight,      returns new values of attribute weights */  protected double [] nrWithLineSearchForAlpha(double [] currAttrWeights,					       double [] invUnconstrainedAttrWeights)    throws Exception {    int numAttributes = currAttrWeights.length;    double [] raphsonWeights = new double[numAttributes];    double top = 1, bottom = 0, alpha = 1;    boolean satisfiesConstraints = true;        //      // initial check for alpha = top//      System.out.println("Evaluating at alpha=1");//      for (int attr = 0; attr < numAttributes; attr++) {//        raphsonWeights[attr] = currAttrWeights[attr] * (1 - alpha * (currAttrWeights[attr] * invUnconstrainedAttrWeights[attr] - 1));//        if (raphsonWeights[attr] < 0) {//  	satisfiesConstraints = false;//  	System.out.println("Negative raphsonWeight for attr: " + attr + ", exiting loop");//  	break;//        }//        //        System.out.println("Curr weights: " + currAttrWeights[attr] + ", alpha: " + alpha + ", m_Objective: " + m_Objective);//        //        System.out.println("Raphson weights[" + attr +"] = " + raphsonWeights[attr]);//      }//      if (!satisfiesConstraints) {//        // line search for alpha between bottom and top//        // satisfiesConstraints is false at top, true at bottom//        // we want max. alpha in [0,1] for which satisfiesConstraints is true//        System.out.println("Starting line search for alpha");//        while ((top-bottom) > m_NRConvergenceDifference && bottom <= top) {//  	alpha = (bottom + top)/2;//  	satisfiesConstraints = true;//  	for (int attr = 0; attr < numAttributes; attr++) {//  	  raphsonWeights[attr] = currAttrWeights[attr] * (1 - alpha * (currAttrWeights[attr] * invUnconstrainedAttrWeights[attr] - 1));//  	  if (raphsonWeights[attr] < 0) {//  	    satisfiesConstraints = false;//  	    System.out.println("Negative raphsonWeight for attr: " + attr + ", exiting loop");//  	    break;//  	  }//  	  //  	  System.out.println("In line search ... curr weights: " + currAttrWeights[attr] + ", alpha: " + alpha + ", m_Objective: " + m_Objective);//  	  //  	  System.out.println("In line search ... raphson weights[" + attr +"] = " + raphsonWeights[attr]);//  	}//  	if (!satisfiesConstraints) {//  	  top = alpha;//  	} else {//  	  bottom = alpha;//  	}//  	System.out.println("Top: " + top + ", Bottom: " + bottom);//        }//        alpha = bottom;    //        System.out.println("Final alpha: " + alpha + ", final objective: " + m_Objective);//        System.out.print("Final weights: ");//        for (int attr = 0; attr < numAttributes; attr++) {//  	raphsonWeights[attr] = currAttrWeights[attr] * (1 - alpha * (currAttrWeights[attr] * invUnconstrainedAttrWeights[attr] - 1));//  	System.out.print(raphsonWeights[attr] + "\t");//        }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -