⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 weuclideanlearner.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
//        System.out.println();//      } else {//        System.out.println("Constraints satisfied");//      }    return raphsonWeights;  }// OLD CODE FOR MULTIPLE://    /** M-step of the KMeans clustering algorithm -- updates metric//     *  weights for the individual metrics. Invoked only whe metric is trainable//     *///    protected boolean updateMultipleMetricWeightsEuclidean() throws Exception {//      if (m_regularizeWeights) {//        System.out.println("Regularized version, calling GD version of updateMultipleMetricWeightsEuclidean!");//        updateMultipleMetricWeightsEuclideanGD();//      }//      int numAttributes = m_Instances.numAttributes();//      double[][] weights = new double[m_NumClusters][numAttributes];//      int []counts = new int[m_NumClusters]; // count how many instances are in each cluster//      Instance diffInstance;//      //begin debugging variance//      boolean debugVariance = true; //      double[][] trueWeights = new double[m_NumClusters][numAttributes];//      int [] majorityClasses = new int[m_NumClusters];//      int [][] classCounts  = new int[m_NumClusters][m_TotalTrainWithLabels.numClasses()];//      // get the majority counts//      // NB:  m_TotalTrainWithLabels does *not* include unlabeled data, counts here are undersampled!//      // assuming unlabeled data came from same distribution as m_TotalTrainWithLabels, counts are still valid...//      for (int instIdx=0; instIdx<m_TotalTrainWithLabels.numInstances(); instIdx++) {//        Instance fullInstance = m_TotalTrainWithLabels.instance(instIdx);//        classCounts[m_ClusterAssignments[instIdx]][(int)(fullInstance.classValue())]++;//      }//      for (int i = 0; i < m_NumClusters; i++){//        int majorityClass = 0;//        System.out.print("Cluster" + i + "\t" + classCounts[i][0]);//        for (int j = 1; j < m_TotalTrainWithLabels.numClasses(); j++) {//  	System.out.print("\t" + classCounts[i][j]);//  	if (classCounts[i][j] > classCounts[i][majorityClass]) {//  	  majorityClass = j;//  	}//        }//        System.out.println();//        majorityClasses[i] = majorityClass;//      }//      class MajorityChecker {//        int [] m_majorityClasses  = null; //        public MajorityChecker(int [] majClasses) { m_majorityClasses = majClasses;}//        public  boolean belongsToMajority(Instances instances, int instIdx, int centroidIdx) {//  	// silly, must pass instance since can't access outer class fields otherwise from a local inner class//  	Instance fullInstance = instances.instance(instIdx); //  	int classValue = (int) fullInstance.classValue();//  	if (classValue == m_majorityClasses[centroidIdx]) {//  	  return true;//  	} else {//  	  return false;//  	}//        }//      }//      MajorityChecker majChecker = new MajorityChecker(majorityClasses);//      //end debugging variance    //      int violatedConstraints = 0; //      for (int instIdx=0; instIdx<m_Instances.numInstances(); instIdx++) {//        int centroidIdx = m_ClusterAssignments[instIdx];//        diffInstance = m_metrics[centroidIdx].createDiffInstance(m_Instances.instance(instIdx), m_ClusterCentroids.instance(centroidIdx));//        for (int attr=0; attr<numAttributes; attr++) {//  	weights[centroidIdx][attr] += diffInstance.value(attr); // Mahalanobis components//  	if (debugVariance && instIdx < m_TotalTrainWithLabels.numInstances()) {//  	  if (majChecker.belongsToMajority(m_TotalTrainWithLabels, instIdx, centroidIdx)) {//  	    trueWeights[centroidIdx][attr] += diffInstance.value(attr);//  	  } //  	}//        }//        counts[centroidIdx]++;//        Object list =  m_instanceConstraintHash.get(new Integer(instIdx));//        if (list != null) {   // there are constraints associated with this instance//  	ArrayList constraintList = (ArrayList) list;//  	for (int i = 0; i < constraintList.size(); i++) {//  	  InstancePair pair = (InstancePair) constraintList.get(i);//  	  int firstIdx = pair.first;//  	  int secondIdx = pair.second;//  	  double cost = 0;//  	  if (pair.linkType == InstancePair.MUST_LINK) {//  	    cost = m_MLweight;//  	  } else if (pair.linkType == InstancePair.CANNOT_LINK) {//  	    cost = m_CLweight;//  	  }//  	  Instance instance1 = m_Instances.instance(firstIdx);//  	  Instance instance2 = m_Instances.instance(secondIdx);//  	  int otherIdx = (firstIdx == instIdx) ? m_ClusterAssignments[secondIdx] : m_ClusterAssignments[firstIdx];//  	  // check whether the constraint is violated//  	  if (otherIdx != -1) {  //  	    if (otherIdx != centroidIdx && pair.linkType == InstancePair.MUST_LINK) { // violated must-link//  	      if (m_verbose) {//  		System.out.println("Found violated must link between: " + firstIdx + " and " + secondIdx);//  	      }//  	      // we penalize weights for both clusters involved, splitting the penalty in half//  	      Instance diffInstance1 = m_metrics[otherIdx].createDiffInstance(instance1, instance2);//  	      Instance diffInstance2 = m_metrics[centroidIdx].createDiffInstance(instance1, instance2);	      //  	      for (int attr=0; attr<numAttributes; attr++) {  // double-counting constraints, hence 0.5*0.5//  		weights[otherIdx][attr] += 0.25 * cost * diffInstance1.value(attr);//  		weights[centroidIdx][attr] += 0.25 * cost * diffInstance2.value(attr);//  	      }//  	      violatedConstraints++; //  	    }	      //  	    else if (otherIdx == centroidIdx && pair.linkType == InstancePair.CANNOT_LINK) { //violated cannot-link//  	      if (m_verbose) {//  		System.out.println("Found violated cannot link between: " + firstIdx + " and " + secondIdx);//  	      }//  	      // we penalize weights for just one cluster involved//  	      diffInstance = m_metrics[centroidIdx].createDiffInstance(instance1, instance2);//  	      Instance cannotDiffInstance = m_metrics[otherIdx].createDiffInstance(m_maxCLPoints[centroidIdx][0],//  										   m_maxCLPoints[centroidIdx][1]);//  	      for (int attr=0; attr<numAttributes; attr++) {  // double-counting constraints, hence 0.5//  		weights[centroidIdx][attr] += 0.5 * cost * cannotDiffInstance.value(attr);//  		weights[centroidIdx][attr] -= 0.5 * cost * diffInstance.value(attr); //  	      }//  	      violatedConstraints++; //  	    }//  	  } // end while//  	}//        }//      }//      System.out.println("   Total constraints violated: " + violatedConstraints/2 + "; per-cluster weights are:");    //      // check if NR needed//      double [][] newWeights = new double[m_NumClusters][numAttributes];//      double [][] currentWeights = new double[m_NumClusters][numAttributes];//      for (int i=0; i<m_NumClusters; i++) {//        currentWeights[i] = ((LearnableMetric) m_metrics[i]).getWeights();//      }//      for (int i=0; i<m_NumClusters; i++) {//        boolean needNewtonRaphson = false;//        for (int attr=0; attr<numAttributes; attr++) {//  	if (weights[i][attr] < 0) { // check to avoid divide by 0//  	  System.out.println("WARNING!  Cluster " + i + ", attribute " + attr + " weight=" + weights[i][attr]);//  	  Cluster currentCluster = (Cluster) getClusters().get(i);//  	  System.out.println("\nCluster " + i + ": " + currentCluster.size() + " instances");//  	  if (currentCluster == null) {//  	    System.out.println("(empty)");//  	  }//  	  else {//  	    for (int j=0; j<currentCluster.size(); j++) {//  	      Instance instance = (Instance) currentCluster.get(j);	//  	      System.out.println("Instance: " + instance);//  	    }//  	  }	  //  	  needNewtonRaphson = true;//  	  break;//  	} else if (weights[i][attr] == 0) {//  	  newWeights[i][attr] = currentWeights[i][attr];//  	  System.out.println("WARNING!  Cluster " + i + ", attribute " + attr + " has 0 weight; keeping it as " + weights[i][attr]);//  	} else {//  	  newWeights[i][attr] = m_logTermWeight * counts[i]/weights[i][attr]; // invert weights//  	  if (debugVariance) {//  	    trueWeights[i][attr] = counts[i]/trueWeights[i][attr];//  	  }//  	}//        }      //        // uncomment next line for debugging NR//        // needNewtonRaphson = true;//        // do NR if needed//        if (needNewtonRaphson) {//  	// weights not inverted here -- done in NR routine//  	newWeights[i] = updateWeightsUsingNewtonRaphson(currentWeights[i], weights[i]);	//  	System.out.println(" (NR) ");//        } //        // PRINT routine//        //        System.out.print("\t" + i + "(" + counts[i] + "): ");//        //        for (int attr=0; attr<numAttributes; attr++) {//        //  	if (debugVariance) {//        //  	  System.out.print(((float)trueWeights[i][attr]) + "/~/");//        //  	} //        //  	System.out.print(((float)newWeights[i][attr]) + "\t");//        //        }//        //        System.out.println();//        //        System.out.println("\t\tMean: " + m_ClusterCentroids.instance(i));//        // end PRINT routine//        ((LearnableMetric) m_metrics[i]).setWeights(newWeights[i]);//      }//      return true;//    }  /**   * Gets the current settings of KL   *   * @return an array of strings suitable for passing to setOptions()   */  public String [] getOptions() {    String [] options = new String [1];    int current = 0;    while (current < options.length) {      options[current++] = "";    }    return options;  }  public void setOptions(String[] options) throws Exception {    // TODO: add later   }  public Enumeration listOptions() {    // TODO: add later     return null;  }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -