⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 rakel.java

📁 Multi-label classification 和weka集成
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
		for(int i = 0; i < testData.numInstances(); i++)
		{
			Instance instance = testData.instance(i);
			Prediction result = updatePrediction(instance, i, model);
//			Prediction result = makePrediction(instance);
			//System.out.println(java.util.Arrays.toString(result.getConfidences()));
			for(int j = 0; j < numLabels; j++)
			{
				int classIdx = testData.numAttributes() - numLabels + j;
				boolean actual = Utils.eq(1, instance.value(classIdx));
				predictions[i][j] = new BinaryPrediction(
							result.getPrediction(j), 
							actual, 
							result.getConfidence(j));
			}
		}		
	}
	
        
        
	public void buildClassifier(Instances trainData) throws Exception {
		if (cvParamSelection) {
                    paramSelectionViaCV(trainData);
                    System.out.println("Selected Parameters\n" +
                                       "Subset size     : " + getSizeOfSubset() + 
                                       "Number of models: " + getNumModels() +
                                       "Threshold       : " + getThreshold());
                }
                
                // need a structure to hold different combinations
		combinations = new HashSet<String>();		
	
		for (int i=0; i<numOfModels; i++)
			updateClassifier(trainData, i);		
	}
	
	public void updateClassifier(Instances trainData, int model) throws Exception {
		if (combinations == null)
			combinations = new HashSet<String>();
		
		Random rnd = new Random();	

		// --select a random subset of classes not seen before
		boolean[] selected;
		do {
			selected = new boolean[numLabels];
			for (int j=0; j<sizeOfSubset; j++) {
				int randomLabel;
	           	randomLabel = Math.abs(rnd.nextInt() % numLabels);
	            while (selected[randomLabel] != false) {
	            	randomLabel = Math.abs(rnd.nextInt() % numLabels);
	            }
				selected[randomLabel] = true;
				//System.out.println("label: " + randomLabel);
				classIndicesPerSubset[model][j] = randomLabel;
			}
			Arrays.sort(classIndicesPerSubset[model]);
		} while (combinations.add(Arrays.toString(classIndicesPerSubset[model])) == false);
		System.out.println("Building model " + model + ", subset: " + Arrays.toString(classIndicesPerSubset[model]));	
		
		// --remove the unselected labels
		int numPredictors = trainData.numAttributes()-numLabels;
		absoluteIndicesToRemove[model] = new int[numLabels-sizeOfSubset]; 
		int k=0;
		for (int j=0; j<numLabels; j++) 
			if (selected[j] == false) {
				absoluteIndicesToRemove[model][k] = numPredictors+j;
				k++;					
			}				                     
		Remove remove = new Remove();
		remove.setAttributeIndicesArray(absoluteIndicesToRemove[model]);
		remove.setInputFormat(trainData);
		remove.setInvertSelection(false);
		Instances trainSubset = Filter.useFilter(trainData, remove);
		//System.out.println(trainSubset.toSummaryString());
			
		// build a LabelPowersetClassifier for the selected label subset;
		subsetClassifiers[model] = new LabelPowersetClassifier(Classifier.makeCopy(getBaseClassifier()), sizeOfSubset);
		subsetClassifiers[model].buildClassifier(trainSubset);

		// keep the header of the training data for testing
		trainSubset.delete();
		metadataTest[model] = trainSubset;
	}
	
	public Prediction updatePrediction(Instance instance, int instanceNumber, int model) throws Exception {	
		int numPredictors = instance.numAttributes()-numLabels;

		// transform instance
		//// new2 solution
		
		Instance newInstance;
		if (instance instanceof SparseInstance) {
			newInstance = new SparseInstance(instance);
			for (int i=1; i<numLabels-sizeOfSubset; i++)
				newInstance.deleteAttributeAt(newInstance.numAttributes());
		} else {
			double[] vals = new double[numPredictors+sizeOfSubset];
			for (int j=0; j<vals.length-sizeOfSubset; j++)
				vals[j] = instance.value(j);
			newInstance = new Instance(instance.weight(), vals);			
		}
		
		
		//// new solution
		/*
		double[] vals = new double[numPredictors+sizeOfSubset];
		for (int j=0; j<vals.length-sizeOfSubset; j++)
			vals[j] = instance.value(j);
		Instance newInstance = (instance instanceof SparseInstance)
		? new SparseInstance(instance.weight(), vals)
		: new Instance(instance.weight(), vals);
		*/
		
		//// old solution
		/*
		Instance newInstance = new Instance(numPredictors+sizeOfSubset);
		for (int j=0; j<newInstance.numAttributes(); j++)
			newInstance.setValue(j, instance.value(j));
		*/
		
		newInstance.setDataset(metadataTest[model]);
			
		double[] predictions = subsetClassifiers[model].makePrediction(newInstance).getPredictedLabels();
		for (int j=0; j<sizeOfSubset; j++) {
			sumVotesIncremental[instanceNumber][classIndicesPerSubset[model][j]] += predictions[j];
			lengthVotesIncremental[instanceNumber][classIndicesPerSubset[model][j]]++;
		}
		/*
		for (int i=0; i<numLabels; i++)
			System.out.print(instance.value(numPredictors+i) + " ");
		System.out.println("");
		System.out.println(Arrays.toString(sumVotesIncremental[instanceNumber]));
		System.out.println(Arrays.toString(lengthVotesIncremental[instanceNumber]));
		//*/
		
		double[] confidence = new double[numLabels];
		double[] labels = new double[numLabels];
		for (int i=0; i<numLabels; i++) {
			confidence[i] = sumVotesIncremental[instanceNumber][i]/lengthVotesIncremental[instanceNumber][i];
			if (confidence[i] >= 0.5)
				labels[i] = 1;
			else
				labels[i] = 0;
		}
		
		Prediction pred = new Prediction(labels, confidence);

		return pred;
	}
	
	
	public Prediction makePrediction(Instance instance) throws Exception {		
		int numPredictors = instance.numAttributes()-numLabels;
		Arrays.fill(sumVotes, 0);
		Arrays.fill(lengthVotes, 0);
		for (int i=0; i<numOfModels; i++) {
			if (subsetClassifiers[i] == null)
				continue;
			
			// transform instance
			//// new solution
			double[] vals = new double[numPredictors+sizeOfSubset];
			for (int j=0; j<vals.length-sizeOfSubset; j++)
				vals[j] = instance.value(j);
			Instance newInstance = (instance instanceof SparseInstance)
			? new SparseInstance(instance.weight(), vals)
			: new Instance(instance.weight(), vals);
                         			
			                         
			//// old solution 
			/*                         
			//System.out.println("old instance: " + instance.toString());
			Instance newInstance = new Instance(numPredictors+sizeOfSubset);
			for (int j=0; j<newInstance.numAttributes(); j++)
				newInstance.setValue(j, instance.value(j));
			//*/
			
			newInstance.setDataset(metadataTest[i]);
			//System.out.println("new instance: " + newInstance.toString());
			
			double[] predictions = subsetClassifiers[i].makePrediction(newInstance).getPredictedLabels();
			for (int j=0; j<sizeOfSubset; j++) {
				sumVotes[classIndicesPerSubset[i][j]] += predictions[j];
				lengthVotes[classIndicesPerSubset[i][j]]++;
			}
		}
		/*
		for (int i=0; i<numLabels; i++)
			System.out.print(instance.value(numPredictors+i) + " ");
		System.out.println("");
		System.out.println(Arrays.toString(sumVotes));
		System.out.println(Arrays.toString(lengthVotes));
		//*/
		
		double[] confidence = new double[numLabels];
		double[] labels = new double[numLabels];
		for (int i=0; i<numLabels; i++) {
			confidence[i] = sumVotes[i]/lengthVotes[i];
			if (confidence[i] >= 0.5)
				labels[i] = 1;
			else
				labels[i] = 0;
		}
		
		Prediction pred = new Prediction(labels, confidence);
		
		return pred;
	}
        
        public void nullSubsetClassifier(int i) {
            subsetClassifiers[i] = null;
        }

    public String getRevision() {
        throw new UnsupportedOperationException("Not supported yet.");
    }
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -