📄 lwl.java
字号:
return options; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String KNNTipText() { return "How many neighbours are used to determine the width of the " + "weighting function (<= 0 means all neighbours)."; } /** * Sets the number of neighbours used for kernel bandwidth setting. * The bandwidth is taken as the distance to the kth neighbour. * * @param knn the number of neighbours included inside the kernel * bandwidth, or 0 to specify using all neighbors. */ public void setKNN(int knn) { m_kNN = knn; if (knn <= 0) { m_kNN = 0; m_UseAllK = true; } else { m_UseAllK = false; } } /** * Gets the number of neighbours used for kernel bandwidth setting. * The bandwidth is taken as the distance to the kth neighbour. * * @return the number of neighbours included inside the kernel * bandwidth, or 0 for all neighbours */ public int getKNN() { return m_kNN; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String weightingKernelTipText() { return "Determines weighting function. [0 = Linear, 1 = Epnechnikov,"+ "2 = Tricube, 3 = Inverse, 4 = Gaussian and 5 = Constant. "+ "(default 0 = Linear)]."; } /** * Sets the kernel weighting method to use. Must be one of LINEAR, * EPANECHNIKOV, TRICUBE, INVERSE, GAUSS or CONSTANT, other values * are ignored. * * @param kernel the new kernel method to use. Must be one of LINEAR, * EPANECHNIKOV, TRICUBE, INVERSE, GAUSS or CONSTANT. */ public void setWeightingKernel(int kernel) { if ((kernel != LINEAR) && (kernel != EPANECHNIKOV) && (kernel != TRICUBE) && (kernel != INVERSE) && (kernel != GAUSS) && (kernel != CONSTANT)) { return; } m_WeightKernel = kernel; } /** * Gets the kernel weighting method to use. * * @return the new kernel method to use. Will be one of LINEAR, * EPANECHNIKOV, TRICUBE, INVERSE, GAUSS or CONSTANT. */ public int getWeightingKernel() { return m_WeightKernel; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String nearestNeighbourSearchAlgorithmTipText() { return "The nearest neighbour search algorithm to use (Default: LinearNN)."; } /** * Returns the current nearestNeighbourSearch algorithm in use. * @return the NearestNeighbourSearch algorithm currently in use. */ public NearestNeighbourSearch getNearestNeighbourSearchAlgorithm() { return m_NNSearch; } /** * Sets the nearestNeighbourSearch algorithm to be used for finding nearest * neighbour(s). * @param nearestNeighbourSearchAlgorithm - The NearestNeighbourSearch class. */ public void setNearestNeighbourSearchAlgorithm(NearestNeighbourSearch nearestNeighbourSearchAlgorithm) { m_NNSearch = nearestNeighbourSearchAlgorithm; } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { Capabilities result; if (m_Classifier != null) result = m_Classifier.getCapabilities(); else result = super.getCapabilities(); result.setMinimumNumberInstances(0); // set dependencies for (Capability cap: Capability.values()) result.enableDependency(cap); return result; } /** * Generates the classifier. * * @param instances set of instances serving as training data * @throws Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { if (!(m_Classifier instanceof WeightedInstancesHandler)) { throw new IllegalArgumentException("Classifier must be a " + "WeightedInstancesHandler!"); } // can classifier handle the data? getCapabilities().testWithFail(instances); // remove instances with missing class instances = new Instances(instances); instances.deleteWithMissingClass(); m_Train = new Instances(instances, 0, instances.numInstances()); m_NNSearch.setInstances(m_Train); } /** * Adds the supplied instance to the training set * * @param instance the instance to add * @throws Exception if instance could not be incorporated * successfully */ public void updateClassifier(Instance instance) throws Exception { if (m_Train.numInstances() == 0) { throw new Exception("No training instances!"); } else if (m_Train.equalHeaders(instance.dataset()) == false) { throw new Exception("Incompatible instance types"); } if (!instance.classIsMissing()) { m_NNSearch.update(instance); m_Train.add(instance); } } /** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return preedicted class probability distribution * @throws Exception if distribution can't be computed successfully */ public double[] distributionForInstance(Instance instance) throws Exception { if (m_Train.numInstances() == 0) { throw new Exception("No training instances!"); } m_NNSearch.addInstanceInfo(instance); int k = m_Train.numInstances(); if( (!m_UseAllK && (m_kNN < k)) && !(m_WeightKernel==INVERSE || m_WeightKernel==GAUSS) ) { k = m_kNN; } Instances neighbours = m_NNSearch.kNearestNeighbours(instance, k); double distances[] = m_NNSearch.getDistances(); if (m_Debug) { System.out.println("Kept " + neighbours.numInstances() + " out of " + m_Train.numInstances() + " instances"); } if (m_Debug) { System.out.println("Instance Distances"); for (int i = 0; i < distances.length; i++) { System.out.println("" + distances[i]); } } // Determine the bandwidth double bandwidth = distances[k-1]; // Check for bandwidth zero if (bandwidth <= 0) { //if the kth distance is zero than give all instances the same weight for(int i=0; i < distances.length; i++) distances[i] = 1; } else { // Rescale the distances by the bandwidth for (int i = 0; i < distances.length; i++) distances[i] = distances[i] / bandwidth; } // Pass the distances through a weighting kernel for (int i = 0; i < distances.length; i++) { switch (m_WeightKernel) { case LINEAR: distances[i] = 1.0001 - distances[i]; break; case EPANECHNIKOV: distances[i] = 3/4D*(1.0001 - distances[i]*distances[i]); break; case TRICUBE: distances[i] = Math.pow( (1.0001 - Math.pow(distances[i], 3)), 3 ); break; case CONSTANT: //System.err.println("using constant kernel"); distances[i] = 1; break; case INVERSE: distances[i] = 1.0 / (1.0 + distances[i]); break; case GAUSS: distances[i] = Math.exp(-distances[i] * distances[i]); break; } } if (m_Debug) { System.out.println("Instance Weights"); for (int i = 0; i < distances.length; i++) { System.out.println("" + distances[i]); } } // Set the weights on the training data double sumOfWeights = 0, newSumOfWeights = 0; for (int i = 0; i < distances.length; i++) { double weight = distances[i]; Instance inst = (Instance) neighbours.instance(i); sumOfWeights += inst.weight(); newSumOfWeights += inst.weight() * weight; inst.setWeight(inst.weight() * weight); //weightedTrain.add(newInst); } // Rescale weights for (int i = 0; i < neighbours.numInstances(); i++) { Instance inst = neighbours.instance(i); inst.setWeight(inst.weight() * sumOfWeights / newSumOfWeights); } // Create a weighted classifier m_Classifier.buildClassifier(neighbours); if (m_Debug) { System.out.println("Classifying test instance: " + instance); System.out.println("Built base classifier:\n" + m_Classifier.toString()); } // Return the classifier's predictions return m_Classifier.distributionForInstance(instance); } /** * Returns a description of this classifier. * * @return a description of this classifier as a string. */ public String toString() { if (m_Train == null) { return "Locally weighted learning: No model built yet."; } String result = "Locally weighted learning\n" + "===========================\n"; result += "Using classifier: " + m_Classifier.getClass().getName() + "\n"; switch (m_WeightKernel) { case LINEAR: result += "Using linear weighting kernels\n"; break; case EPANECHNIKOV: result += "Using epanechnikov weighting kernels\n"; break; case TRICUBE: result += "Using tricube weighting kernels\n"; break; case INVERSE: result += "Using inverse-distance weighting kernels\n"; break; case GAUSS: result += "Using gaussian weighting kernels\n"; break; case CONSTANT: result += "Using constant weighting kernels\n"; break; } result += "Using " + (m_UseAllK ? "all" : "" + m_kNN) + " neighbours"; return result; } /** * Main method for testing this class. * * @param argv the options */ public static void main(String [] argv) { runClassifier(new LWL(), argv); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -