⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 mahalanobislearner.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
		//      // store the constant part of the gradient://      double[][] gradientConst = new double[numAttributes][numAttributes];//      for (int instIdx = 0; instIdx < m_Instances.numInstances(); instIdx++) {//        // the (x-m)(x-m)' part//        int centroidIdx = m_ClusterAssignments[instIdx];//        Instance centroid = m_ClusterCentroids.instance(centroidIdx);//        diffInstance = metric.createDiffInstance(m_Instances.instance(instIdx),//  					       centroid);//        for (int i = 0; i < numAttributes; i++) {//  	for (int j = 0; j <= i; j++) {//  	  gradientConst[i][j] =//  	    gradientConst[j][i] =	diffInstance.value(i) * diffInstance.value(j); //  	}//        }//        // the violated constraints//        Object list =  m_instanceConstraintHash.get(new Integer(instIdx));//        if (list != null) {   // there are constraints associated with this instance//  	ArrayList constraintList = (ArrayList) list;//  	for (int constrIdx = 0; constrIdx < constraintList.size(); constrIdx++) {//  	  InstancePair pair = (InstancePair) constraintList.get(constrIdx);//  	  int firstIdx = pair.first;//  	  int secondIdx = pair.second;//  	  double cost = 0;//  	  if (pair.linkType == InstancePair.MUST_LINK) {//  	    cost = m_MLweight;//  	  } else if (pair.linkType == InstancePair.CANNOT_LINK) {//  	    cost = m_CLweight;//  	  }//  	  Instance instance1 = m_Instances.instance(firstIdx);//  	  Instance instance2 = m_Instances.instance(secondIdx);//  	  int otherIdx = (firstIdx == instIdx) ? m_ClusterAssignments[secondIdx]//  	    : m_ClusterAssignments[firstIdx];//  	  if (otherIdx == -1) {//  	    throw new Exception("One of the instances is unassigned in "//  				+ "updateMetricWeightsMahalanobisGD"); //  	  }//  	  // check whether the constraint is violated//  	  if (otherIdx != centroidIdx &&//  	      pair.linkType == InstancePair.MUST_LINK) {//  	    diffInstance = metric.createDiffInstance(instance1, instance2);//  	    for (int i = 0; i < numAttributes; i++) {//  	      for (int j = 0; j <= i; j++) {//  		gradientConst[i][j] =//  		  gradientConst[j][i] =//  		  0.5 * cost * diffInstance.value(i) * diffInstance.value(j);//  	      }//  	    }//  	    violatedConstraints++; //  	  } else if (otherIdx == centroidIdx &&//  		     pair.linkType == InstancePair.CANNOT_LINK) {//  	    diffInstance = metric.createDiffInstance(instance1, instance2);//  	    for (int i = 0; i < numAttributes; i++) {//  	      for (int j = 0; j <= i; j++) {//  		gradientConst[i][j] =//  		  gradientConst[j][i] =//  		  0.5 * cost *//  		  (maxCLUpdate[i][j] -//  		   diffInstance.value(i) * diffInstance.value(j)); //  	      }//  	    }//  	    violatedConstraints++; //  	  }//  	}//        }//      }//      Matrix constUpdate = new Matrix(gradientConst); 		//      while (iteration < m_maxGDIterations && !converged) {//        // calculate the gradient//        Matrix update =  constUpdate.copy(); //        // factor in the A^-1 //        Matrix Ai = newWeights.inverse();//        Ai.timesEquals(m_logTermWeight); //        update.minusEquals(Ai); //        // regularization  (-1/sum(a_ij)^2)//        double regularizer = 0; //        for (int i = 0; i < numAttributes; i++) {//  	for (int j = 0; j <= i; j++) {//  	  regularizer += 2.0/(newWeights.get(i, j) * newWeights.get(i, j));//  	}//        }//        // correct for double-counted diagonal//        for (int i = 0; i < numAttributes; i++) {//  	regularizer -= 1.0/newWeights.get(i, i);//        }//        regularizer *= m_currregularizerTermWeight; //        for (int i = 0; i < numAttributes; i++) {//  	for (int j = 0; j < numAttributes; j++) {//  	  update.set(i, j, update.get(i,j) - regularizer);//  	}//        }//        // update//        update.timesEquals(m_currEta); //        newWeights.minusEquals(update);//        // anneal if necessary and check for convergence//        m_currEta = m_currEta * m_etaDecayRate;//        // check for convergence//        double norm = update.norm1();//        System.out.println(iteration + ":  norm=" + norm); //        if (norm < 0.0001) {//  	converged = true;//        }//        iteration++; //      }//      // We're done, set the weights to newWeights		//    }// MULTIPLE://    /** M-step of the KMeans clustering algorithm -- updates metric//     *  weights. Invoked only when metric is an instance of Mahalanobis//     * @return value true if everything was alright; false if there was//     miserable failure and clustering needs to be restarted *///    protected boolean updateMultipleMetricWeightsMahalanobis() throws Exception {//      if (m_regularizeWeights) {//        System.out.println("Regularized version, calling GD version of updateMultipleMetricWeightsMahalanobisGD!");//        updateMultipleMetricWeightsMahalanobisGD();//      }//      int numAttributes = m_Instances.numAttributes();//      if (m_Instances.classIndex() >= 0) {//        numAttributes--;//      }//      Matrix [] updateMatrices = new Matrix[m_metrics.length];//      for (int i = 0; i < updateMatrices.length; i++) { //        updateMatrices[i] = new Matrix(numAttributes, numAttributes);//      }//      int violatedConstraints = 0;//      int [] counts = new int[updateMatrices.length];//      for (int instIdx=0; instIdx<m_Instances.numInstances(); instIdx++) {//        int centroidIdx = m_ClusterAssignments[instIdx];//        Matrix diffMatrix = ((WeightedMahalanobis) m_metrics[centroidIdx]).createDiffMatrix(m_Instances.instance(instIdx),//  											  m_ClusterCentroids.instance(centroidIdx));//        updateMatrices[centroidIdx] = updateMatrices[centroidIdx].plus(diffMatrix);//        counts[centroidIdx]++;//        // go through violated constraints//        Object list =  m_instanceConstraintHash.get(new Integer(instIdx));//        if (list != null) {   // there are constraints associated with this instance//  	ArrayList constraintList = (ArrayList) list;//  	for (int i = 0; i < constraintList.size(); i++) {//  	  InstancePair pair = (InstancePair) constraintList.get(i);//  	  int firstIdx = pair.first;//  	  int secondIdx = pair.second;//  	  Instance instance1 = m_Instances.instance(firstIdx);//  	  Instance instance2 = m_Instances.instance(secondIdx);//  	  int otherIdx = (firstIdx == instIdx) ? m_ClusterAssignments[secondIdx] : m_ClusterAssignments[firstIdx];//  	  // check whether the constraint is violated//  	  if (otherIdx != -1) {  //  	    if (otherIdx != centroidIdx && pair.linkType == InstancePair.MUST_LINK) {//  	      Matrix diffMatrix1 = ((WeightedMahalanobis) m_metrics[centroidIdx]).createDiffMatrix(instance1, instance2);//  	      diffMatrix1 = diffMatrix1.times(0.25);//  	      Matrix diffMatrix2 = ((WeightedMahalanobis) m_metrics[otherIdx]).createDiffMatrix(instance1, instance2);//  	      diffMatrix2 = diffMatrix2.times(0.25);//  	      updateMatrices[centroidIdx] = updateMatrices[centroidIdx].plus(diffMatrix1); //  	      updateMatrices[otherIdx] = updateMatrices[otherIdx].plus(diffMatrix2);//  	      violatedConstraints++; //  	    } else if (otherIdx == centroidIdx && pair.linkType == InstancePair.CANNOT_LINK) {//  	      diffMatrix = ((WeightedMahalanobis) m_metrics[centroidIdx]).createDiffMatrix(instance1, instance2);//  	      Matrix maxMatrix = ((WeightedMahalanobis) m_metrics[centroidIdx]).createDiffMatrix(m_maxCLPoints[centroidIdx][0],//  												 m_maxCLPoints[centroidIdx][1]);//  	      diffMatrix = diffMatrix.times(0.5);//  	      maxMatrix = maxMatrix.times(0.5);//  	      updateMatrices[centroidIdx] = updateMatrices[centroidIdx].plus(maxMatrix); //  	      updateMatrices[centroidIdx] = updateMatrices[centroidIdx].minus(diffMatrix);//  	      violatedConstraints++; //  	    }//  	  } // end while//  	}//        }//      }   //      int [][] classCounts  = new int[m_NumClusters][m_TotalTrainWithLabels.numClasses()];//      // NB:  m_TotalTrainWithLabels does *not* include unlabeled data, counts here are undersampled!//      // assuming unlabeled data came from same distribution as m_TotalTrainWithLabels, counts are still valid...//      for (int instIdx=0; instIdx<m_TotalTrainWithLabels.numInstances(); instIdx++) {//        Instance fullInstance = m_TotalTrainWithLabels.instance(instIdx);//        classCounts[m_ClusterAssignments[instIdx]][(int)(fullInstance.classValue())]++;//      }//      for (int i = 0; i < m_NumClusters; i++){//        System.out.print("Cluster " + i + "(" + counts[i] + ")\t" + classCounts[i][0]);//        for (int j = 1; j < m_TotalTrainWithLabels.numClasses(); j++) {//  	System.out.print("\t" + classCounts[i][j]);//        }//        System.out.println();//      }//      // now update the actual weight matrices//      for (int i = 0; i < updateMatrices.length; i++) {//        int maxIterations = 100;//        if (counts[i] == 0) {//  	//System.out.println("Cluster " + i + " has lost all instances; leaving weights as is");//  	updateMatrices[i] = Matrix.identity(numAttributes, numAttributes);//  	counts[i] = 1;//  	//System.err.println("IRREPAIRABLE COVARIANCE MATRIX, RESTARTING");//  	//return false;//        }//        updateMatrices[i] = updateMatrices[i].times(1.0/counts[i]);//        double updateDet = updateMatrices[i].det();//        int currIteration = 0;//        Matrix newWeights = null; //        // check that the update matrix is non-singular//        while (Math.abs(updateDet) < m_NRConvergenceDifference && currIteration++ < maxIterations) {//  	Matrix regularizer = Matrix.identity(numAttributes, numAttributes);//  	regularizer = regularizer.times(updateMatrices[i].trace() * 0.01);//  	updateMatrices[i] = updateMatrices[i].plus(regularizer);//  	System.out.print(i + "\tsingular UPDATE matrix, DET=" + ((float)updateDet));//  	updateDet = updateMatrices[i].det();//  	System.out.println("; after regularization DET=" + ((float)updateDet));//  	//  	System.out.println("ACTUAL weights: ");//  	//  	double[][] m_weights = updateMatrices[i].getArray();//  	//  	for (int l = 0; l < m_weights.length; l++) {//  	//  	  for (int j = 0; j < m_weights[l].length; j++) {//  	//  	    System.out.print(((float)m_weights[l][j]) + "\t");//  	//  	}//  	//  	  System.out.println();//  	//  	}//        }//        if (currIteration >= maxIterations) {      // if the matrix is irrepairable, return to identity matrix//  	newWeights = Matrix.identity(numAttributes, numAttributes);//  	System.err.println("IRREPAIRABLE UPDATE MATRIX, RESTARTING");//        } else { //  	newWeights = updateMatrices[i].inverse();//        } //        ((WeightedMahalanobis) m_metrics[i]).setWeights(newWeights);//        // project all the instances for subsequent calculation of max-points for cannot-link penalties//        // TODO:  we are projecting ALL instances just in case...  possibly can optimize in the future//        for (int instIdx=0; instIdx<m_Instances.numInstances(); instIdx++) {//  	((WeightedMahalanobis) m_metrics[i]).projectInstance(m_Instances.instance(instIdx));//        }//      }//      return true; //    }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -