crossvalidationexperiment.java
来自「Multi-label classification 和weka集成」· Java 代码 · 共 82 行
JAVA
82 行
package mulan.examples;/** * * @author greg */import mulan.classifier.BinaryRelevanceClassifier;import mulan.classifier.AbstractMultiLabelClassifier.*;import mulan.evaluation.Evaluator;import weka.core.Instances;import java.io.*;import mulan.*;import mulan.classifier.LabelPowersetClassifier;import mulan.classifier.MLkNN;import mulan.classifier.RAKEL;import mulan.evaluation.IntegratedCrossvalidation;import weka.classifiers.trees.J48;public class CrossValidationExperiment { /** * Creates a new instance of this class */ public CrossValidationExperiment() { } public static void main(String[] args) throws Exception { String path = "d:/work/datasets/multilabel/scene/"; String filename = "scene.arff"; int numLabels = 6; FileReader frData = new FileReader(path + filename); Instances data = new Instances(frData); Evaluator eval = new Evaluator(5); IntegratedCrossvalidation results; //* Binary Relevance Classifier System.out.println("BR"); BinaryRelevanceClassifier br = new BinaryRelevanceClassifier(); J48 brBaseClassifier = new J48(); br.setBaseClassifier(brBaseClassifier); br.setNumLabels(numLabels); results = eval.crossValidateAll(br, data, 10); System.out.println(results.toString()); System.gc(); //*/ //* Label Powerset Classifier System.out.println("LP"); J48 lpBaseClassifier = new J48(); LabelPowersetClassifier lp = new LabelPowersetClassifier(lpBaseClassifier, numLabels); results = eval.crossValidateAll(lp, data, 10); System.out.println(results.toString()); System.gc(); //*/ //* RAKEL System.out.println("RAKEL"); RAKEL rakel = new RAKEL(numLabels, 10, 3); J48 rakelBaseClassifier = new J48(); rakel.setBaseClassifier(rakelBaseClassifier); rakel.setParamSelectionViaCV(true); rakel.setParamSets(3, 2, numLabels-1, 1, 500, 0.1, 0.1, 9); results = eval.crossValidateAll(rakel, data, 10); System.out.println(results.toString()); System.gc(); //*/ //* ML-kNN System.out.println("ML-kNN"); int numNeighbours = 10; MLkNN mlknn = new MLkNN(numLabels, numNeighbours, 1); results = eval.crossValidateAll(mlknn, data, 10); System.out.println(results.toString()); System.gc(); //*/ } }
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?