📄 citationknn.java
字号:
train = new Instances(train); train.deleteWithMissingClass(); m_TrainBags = train; m_ClassIndex = train.classIndex(); m_IdIndex = 0; m_NumClasses = train.numClasses(); m_Classes = new int [train.numInstances()]; // Class values m_Attributes = train.instance(0).relationalValue(1).stringFreeStructure(); m_Citers = new int[train.numClasses()]; m_References = new int[train.numClasses()]; m_Diffs = new double[m_Attributes.numAttributes()]; m_Min = new double[m_Attributes.numAttributes()]; m_Max = new double[m_Attributes.numAttributes()]; preprocessData(); buildCNN(); if(m_CNNDebug){ System.out.println("########################################### "); System.out.println("###########CITATION######################## "); System.out.println("########################################### "); for(int i = 0; i < m_CNN.length; i++){ System.out.println("Bag: " + i); m_CNN[i].printReducedList(); } } } /** * generates all the variables associated to the citation * classifier * * @throws Exception if generation fails */ public void buildCNN() throws Exception { int numCiters = 0; if((m_NumCiters >= m_TrainBags.numInstances()) || (m_NumCiters < 0)) throw new Exception("Number of citers is out of the range [0, numInstances)"); else numCiters = m_NumCiters; m_CNN = new NeighborList[m_TrainBags.numInstances()]; Instance bag; for(int i = 0; i< m_TrainBags.numInstances(); i++){ bag = m_TrainBags.instance(i); //first we find its neighbors NeighborList neighborList = findNeighbors(bag, numCiters, m_TrainBags); m_CNN[i] = neighborList; } } /** * calculates the citers associated to a bag * @param bag the bag cited */ public void countBagCiters(Instance bag){ //Initialization of the vector for(int i = 0; i < m_TrainBags.numClasses(); i++) m_Citers[i] = 0; // if(m_CitersDebug == true) System.out.println("-------CITERS--------"); NeighborList neighborList; NeighborNode current; boolean stopSearch = false; int index; // compute the distance between the test bag and each training bag. Update // the bagCiter count in case it be a neighbour double bagDistance = 0; for(int i = 0; i < m_TrainBags.numInstances(); i++){ //measure the distance bagDistance = distanceSet(bag, m_TrainBags.instance(i)); if(m_CitersDebug == true){ System.out.print("bag - bag(" + i + "): " + bagDistance); System.out.println(" <" + m_TrainBags.instance(i).classValue() + ">"); } //compare the distance to see if it would belong to the // neighborhood of each training exemplar neighborList = m_CNN[i]; current = neighborList.mFirst; while((current != null) && (!stopSearch)) { if(m_CitersDebug == true) System.out.println("\t\tciter Distance: " + current.mDistance); if(current.mDistance < bagDistance){ current = current.mNext; } else{ stopSearch = true; if(m_CitersDebug == true){ System.out.println("\t***"); } } } if(stopSearch == true){ stopSearch = false; index = (int)(m_TrainBags.instance(i)).classValue(); m_Citers[index] += 1; } } if(m_CitersDebug == true){ for(int i= 0; i < m_Citers.length; i++){ System.out.println("[" + i + "]: " + m_Citers[i]); } } } /** * Calculates the references of the exemplar bag * @param bag the exemplar to which the nearest references * will be calculated */ public void countBagReferences(Instance bag){ int index = 0, referencesIndex = 0; if(m_TrainBags.numInstances() < m_NumReferences) referencesIndex = m_TrainBags.numInstances() - 1; else referencesIndex = m_NumReferences; if(m_CitersDebug == true){ System.out.println("-------References (" + referencesIndex+ ")--------"); } //Initialization of the vector for(int i = 0; i < m_References.length; i++) m_References[i] = 0; if(referencesIndex > 0){ //first we find its neighbors NeighborList neighborList = findNeighbors(bag, referencesIndex, m_TrainBags); if(m_ReferencesDebug == true){ System.out.println("Bag: " + bag + " Neighbors: "); neighborList.printReducedList(); } NeighborNode current = neighborList.mFirst; while(current != null){ index = (int) current.mBag.classValue(); m_References[index] += 1; current = current.mNext; } } if(m_ReferencesDebug == true){ System.out.println("References:"); for(int j = 0; j < m_References.length; j++) System.out.println("[" + j + "]: " + m_References[j]); } } /** * Build the list of nearest k neighbors to the given test instance. * @param bag the bag to search for neighbors of * @param kNN the number of nearest neighbors * @param bags the data * @return a list of neighbors */ protected NeighborList findNeighbors(Instance bag, int kNN, Instances bags){ double distance; int index = 0; if(kNN > bags.numInstances()) kNN = bags.numInstances() - 1; NeighborList neighborList = new NeighborList(kNN); for(int i = 0; i < bags.numInstances(); i++){ if(bag != bags.instance(i)){ // for hold-one-out cross-validation distance = distanceSet(bag, bags.instance(i)) ; //mDistanceSet.distance(bag, mInstances, bags.exemplar(i), mInstances); if(m_NeighborListDebug) System.out.println("distance(bag, " + i + "): " + distance); if(neighborList.isEmpty() || (index < kNN) || (distance <= neighborList.mLast.mDistance)) neighborList.insertSorted(distance, bags.instance(i), i); index++; } } if(m_NeighborListDebug){ System.out.println("bag neighbors:"); neighborList.printReducedList(); } return neighborList; } /** * Calculates the distance between two instances * @param first instance * @param second instance * @return the distance value */ public double distanceSet(Instance first, Instance second){ double[] h_f = new double[first.relationalValue(1).numInstances()]; double distance; //initilization for(int i = 0; i < h_f.length; i++) h_f[i] = Double.MAX_VALUE; int rank; if(m_HDRank >= first.relationalValue(1).numInstances()) rank = first.relationalValue(1).numInstances(); else if(m_HDRank < 1) rank = 1; else rank = m_HDRank; if(m_HDistanceDebug){ System.out.println("-------HAUSDORFF DISTANCE--------"); System.out.println("rank: " + rank + "\nset of instances:"); System.out.println("\tset 1:"); for(int i = 0; i < first.relationalValue(1).numInstances(); i++) System.out.println(first.relationalValue(1).instance(i)); System.out.println("\n\tset 2:"); for(int i = 0; i < second.relationalValue(1).numInstances(); i++) System.out.println(second.relationalValue(1).instance(i)); System.out.println("\n"); } //for each instance in bag first for(int i = 0; i < first.relationalValue(1).numInstances(); i++){ // calculate the distance to each instance in // bag second if(m_HDistanceDebug){ System.out.println("\nDistances:"); } for(int j = 0; j < second.relationalValue(1).numInstances(); j++){ distance = distance(first.relationalValue(1).instance(i), second.relationalValue(1).instance(j)); if(distance < h_f[i]) h_f[i] = distance; if(m_HDistanceDebug){ System.out.println("\tdist(" + i + ", "+ j + "): " + distance + " --> h_f[" + i + "]: " + h_f[i]); } } } int[] index_f = Utils.stableSort(h_f); if(m_HDistanceDebug){ System.out.println("\nRanks:\n"); for(int i = 0; i < index_f.length; i++) System.out.println("\trank " + (i + 1) + ": " + h_f[index_f[i]]); System.out.println("\n\t\t>>>>> rank " + rank + ": " + h_f[index_f[rank - 1]] + " <<<<<"); } return h_f[index_f[rank - 1]]; } /** * distance between two instances * @param first the first instance * @param second the other instance * @return the distance in double precision */ public double distance(Instance first, Instance second){ double sum = 0, diff; for(int i = 0; i < m_Attributes.numAttributes(); i++){ diff = (first.value(i) - m_Min[i])/ m_Diffs[i] - (second.value(i) - m_Min[i])/ m_Diffs[i]; sum += diff * diff; } return sum = Math.sqrt(sum); } /** * Computes the distribution for a given exemplar * * @param bag the exemplar for which distribution is computed * @return the distribution * @throws Exception if the distribution can't be computed successfully */ public double[] distributionForInstance(Instance bag) throws Exception { if(m_TrainBags.numInstances() == 0) throw new Exception("No training bags!"); updateNormalization(bag); //build references (R nearest neighbors) countBagReferences(bag); //build citers countBagCiters(bag); return makeDistribution(); } /** * Updates the normalization of each attribute. * * @param bag the exemplar to update the normalization for */ public void updateNormalization(Instance bag){ int i, k; double min, max; Instances instances; Instance instance; // compute the min/max of each feature for (i = 0; i < m_TrainBags.attribute(1).relation().numAttributes(); i++) { min = m_Min[i] / m_MinNorm; max = m_Max[i] / m_MaxNorm; instances = bag.relationalValue(1); for (k=0;k<instances.numInstances();k++) { instance = instances.instance(k); if(instance.value(i) < min) min = instance.value(i); if(instance.value(i) > max) max = instance.value(i); } m_Min[i] = min * m_MinNorm; m_Max[i] = max * m_MaxNorm; m_Diffs[i]= max * m_MaxNorm - min * m_MinNorm; } } /** * Wether the instances of two exemplars are or are not equal * @param exemplar1 first exemplar * @param exemplar2 second exemplar * @return if the instances of the exemplars are equal or not */ public boolean equalExemplars(Instance exemplar1, Instance exemplar2){ if(exemplar1.relationalValue(1).numInstances() == exemplar2.relationalValue(1).numInstances()){ Instances instances1 = exemplar1.relationalValue(1); Instances instances2 = exemplar2.relationalValue(1); for(int i = 0; i < instances1.numInstances(); i++){ Instance instance1 = instances1.instance(i); Instance instance2 = instances2.instance(i); for(int j = 0; j < instance1.numAttributes(); j++){ if(instance1.value(j) != instance2.value(j)){ return false; } } } return true; } return false; }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -