predictionappender.java

来自「Weka」· Java 代码 · 共 857 行 · 第 1/2 页

JAVA
857
字号
          }	  return;	} catch (Exception ex) {	  ex.printStackTrace();	}      }      if (m_appendProbabilities) {	try {	  Instances newTestSetInstances = 	    makeDataSetProbabilities(testSet,				     classifier,relationNameModifier);	  Instances newTrainingSetInstances = 	    makeDataSetProbabilities(trainSet,				     classifier,relationNameModifier);	  if (m_trainingSetListeners.size() > 0) {	    TrainingSetEvent tse = new TrainingSetEvent(this,		new Instances(newTrainingSetInstances, 0));	    tse.m_setNumber = setNum;	    tse.m_maxSetNumber = maxNum;	    notifyTrainingSetAvailable(tse);//	    fill in predicted probabilities	    for (int i = 0; i < trainSet.numInstances(); i++) {	      double [] preds = classifier.	      distributionForInstance(trainSet.instance(i));	      for (int j = 0; j < trainSet.classAttribute().numValues(); j++) {		newTrainingSetInstances.instance(i).setValue(trainSet.numAttributes()+j,		    preds[j]);	      }	    }	    tse = new TrainingSetEvent(this,        	newTrainingSetInstances);            tse.m_setNumber = setNum;            tse.m_maxSetNumber = maxNum;            notifyTrainingSetAvailable(tse);	  }	  if (m_testSetListeners.size() > 0) {	    TestSetEvent tse = new TestSetEvent(this,		new Instances(newTestSetInstances, 0));	    tse.m_setNumber = setNum;	    tse.m_maxSetNumber = maxNum;	    notifyTestSetAvailable(tse);	  }	  if (m_dataSourceListeners.size() > 0) {	    notifyDataSetAvailable(new DataSetEvent(this, new Instances(newTestSetInstances,0)));	  }          if (e.getTestSet().isStructureOnly()) {	    m_format = newTestSetInstances;	  }          if (m_dataSourceListeners.size() > 0 || m_testSetListeners.size() > 0) {            // fill in predicted probabilities            for (int i = 0; i < testSet.numInstances(); i++) {              double [] preds = classifier.              distributionForInstance(testSet.instance(i));              for (int j = 0; j < testSet.classAttribute().numValues(); j++) {        	newTestSetInstances.instance(i).setValue(testSet.numAttributes()+j,        	    preds[j]);              }            }          }                    // notify listeners          if (m_testSetListeners.size() > 0) {            TestSetEvent tse = new TestSetEvent(this, newTestSetInstances);            tse.m_setNumber = setNum;            tse.m_maxSetNumber = maxNum;            notifyTestSetAvailable(tse);          }          if (m_dataSourceListeners.size() > 0) {            notifyDataSetAvailable(new DataSetEvent(this, newTestSetInstances));          }	} catch (Exception ex) {	  ex.printStackTrace();	}      }    }  }        /**   * Accept and process a batch classifier event   *   * @param e a <code>BatchClassifierEvent</code> value   */  public void acceptClusterer(BatchClustererEvent e) {    if (m_dataSourceListeners.size() > 0) {      if(e.getTestSet().isStructureOnly())          return;      Instances testSet = e.getTestSet().getDataSet();      weka.clusterers.Clusterer clusterer = e.getClusterer();      String test;      if(e.getTestOrTrain()==0)          test = "test";      else          test = "training";      String relationNameModifier = "_"+test+"_"+e.getSetNumber()+"_of_"	+e.getMaxSetNumber();      if (!m_appendProbabilities || !(clusterer instanceof DensityBasedClusterer)) {	if(m_appendProbabilities && !(clusterer instanceof DensityBasedClusterer)){            System.err.println("Only density based clusterers can append probabilities. Instead cluster will be assigned for each instance.");            if (m_logger != null) {                m_logger.logMessage("Only density based clusterers can append probabilities. Instead cluster will be assigned for each instance.");            }        }        try {	  Instances newInstances = makeClusterDataSetClass(testSet, clusterer,						    relationNameModifier);	  notifyDataSetAvailable(new DataSetEvent(this, new Instances(newInstances,0)));          	  // fill in predicted values	  for (int i = 0; i < testSet.numInstances(); i++) {	    double predCluster = 	      clusterer.clusterInstance(testSet.instance(i));	    newInstances.instance(i).setValue(newInstances.numAttributes()-1,					      predCluster);	  }	  // notify listeners	  notifyDataSetAvailable(new DataSetEvent(this, newInstances));	  return;	} catch (Exception ex) {	  ex.printStackTrace();	}      }      else{	try {	  Instances newInstances = 	    makeClusterDataSetProbabilities(testSet,				     clusterer,relationNameModifier);	  notifyDataSetAvailable(new DataSetEvent(this, new Instances(newInstances,0)));          	  // fill in predicted probabilities	  for (int i = 0; i < testSet.numInstances(); i++) {	    double [] probs = clusterer.	      distributionForInstance(testSet.instance(i));	    for (int j = 0; j < clusterer.numberOfClusters(); j++) {	      newInstances.instance(i).setValue(testSet.numAttributes()+j,						probs[j]);	    }	  }	  // notify listeners	  notifyDataSetAvailable(new DataSetEvent(this, newInstances));	} catch (Exception ex) {	  ex.printStackTrace();	}      }    }  }  private Instances     makeDataSetProbabilities(Instances format,			     weka.classifiers.Classifier classifier,			     String relationNameModifier)   throws Exception {    String classifierName = classifier.getClass().getName();    classifierName = classifierName.      substring(classifierName.lastIndexOf('.')+1, classifierName.length());    int numOrigAtts = format.numAttributes();    Instances newInstances = new Instances(format);    for (int i = 0; i < format.classAttribute().numValues(); i++) {      weka.filters.unsupervised.attribute.Add addF = new	weka.filters.unsupervised.attribute.Add();      addF.setAttributeIndex("last");      addF.setAttributeName(classifierName+"_prob_"+format.classAttribute().value(i));      addF.setInputFormat(newInstances);      newInstances = weka.filters.Filter.useFilter(newInstances, addF);    }    newInstances.setRelationName(format.relationName()+relationNameModifier);    return newInstances;  }  private Instances makeDataSetClass(Instances format,				     weka.classifiers.Classifier classifier,				     String relationNameModifier)   throws Exception {        weka.filters.unsupervised.attribute.Add addF = new      weka.filters.unsupervised.attribute.Add();    addF.setAttributeIndex("last");    String classifierName = classifier.getClass().getName();    classifierName = classifierName.      substring(classifierName.lastIndexOf('.')+1, classifierName.length());    addF.setAttributeName("class_predicted_by: "+classifierName);    if (format.classAttribute().isNominal()) {      String classLabels = "";      Enumeration enu = format.classAttribute().enumerateValues();      classLabels += (String)enu.nextElement();      while (enu.hasMoreElements()) {	classLabels += ","+(String)enu.nextElement();      }      addF.setNominalLabels(classLabels);    }    addF.setInputFormat(format);    Instances newInstances =       weka.filters.Filter.useFilter(format, addF);    newInstances.setRelationName(format.relationName()+relationNameModifier);    return newInstances;  }    private Instances     makeClusterDataSetProbabilities(Instances format,			     weka.clusterers.Clusterer clusterer,			     String relationNameModifier)   throws Exception {    int numOrigAtts = format.numAttributes();    Instances newInstances = new Instances(format);    for (int i = 0; i < clusterer.numberOfClusters(); i++) {      weka.filters.unsupervised.attribute.Add addF = new	weka.filters.unsupervised.attribute.Add();      addF.setAttributeIndex("last");      addF.setAttributeName("prob_cluster"+i);      addF.setInputFormat(newInstances);      newInstances = weka.filters.Filter.useFilter(newInstances, addF);    }    newInstances.setRelationName(format.relationName()+relationNameModifier);    return newInstances;  }  private Instances makeClusterDataSetClass(Instances format,				     weka.clusterers.Clusterer clusterer,				     String relationNameModifier)   throws Exception {        weka.filters.unsupervised.attribute.Add addF = new      weka.filters.unsupervised.attribute.Add();    addF.setAttributeIndex("last");    String clustererName = clusterer.getClass().getName();    clustererName = clustererName.      substring(clustererName.lastIndexOf('.')+1, clustererName.length());    addF.setAttributeName("assigned_cluster: "+clustererName);    //if (format.classAttribute().isNominal()) {    String clusterLabels = "0";      /*Enumeration enu = format.classAttribute().enumerateValues();      clusterLabels += (String)enu.nextElement();      while (enu.hasMoreElements()) {	clusterLabels += ","+(String)enu.nextElement();      }*/    for(int i = 1; i <= clusterer.numberOfClusters()-1; i++)        clusterLabels += ","+i;    addF.setNominalLabels(clusterLabels);    //}    addF.setInputFormat(format);    Instances newInstances =       weka.filters.Filter.useFilter(format, addF);    newInstances.setRelationName(format.relationName()+relationNameModifier);    return newInstances;  }  /**   * Notify all instance listeners that an instance is available   *   * @param e an <code>InstanceEvent</code> value   */  protected void notifyInstanceAvailable(InstanceEvent e) {    Vector l;    synchronized (this) {      l = (Vector)m_instanceListeners.clone();    }        if (l.size() > 0) {      for(int i = 0; i < l.size(); i++) {	((InstanceListener)l.elementAt(i)).acceptInstance(e);      }    }  }  /**   * Notify all Data source listeners that a data set is available   *   * @param e a <code>DataSetEvent</code> value   */  protected void notifyDataSetAvailable(DataSetEvent e) {    Vector l;    synchronized (this) {      l = (Vector)m_dataSourceListeners.clone();    }        if (l.size() > 0) {      for(int i = 0; i < l.size(); i++) {	((DataSourceListener)l.elementAt(i)).acceptDataSet(e);      }    }  }    /**   * Notify all test set listeners that a test set is available   *   * @param e a <code>TestSetEvent</code> value   */  protected void notifyTestSetAvailable(TestSetEvent e) {    Vector l;    synchronized (this) {      l = (Vector)m_testSetListeners.clone();    }        if (l.size() > 0) {      for(int i = 0; i < l.size(); i++) {	((TestSetListener)l.elementAt(i)).acceptTestSet(e);      }    }  }    /**   * Notify all test set listeners that a test set is available   *   * @param e a <code>TestSetEvent</code> value   */  protected void notifyTrainingSetAvailable(TrainingSetEvent e) {    Vector l;    synchronized (this) {      l = (Vector)m_trainingSetListeners.clone();    }        if (l.size() > 0) {      for(int i = 0; i < l.size(); i++) {	((TrainingSetListener)l.elementAt(i)).acceptTrainingSet(e);      }    }  }  /**   * Set a logger   *   * @param logger a <code>weka.gui.Logger</code> value   */  public void setLog(weka.gui.Logger logger) {    m_logger = logger;  }  public void stop() {    // cant really do anything meaningful here  }  /**   * Returns true if, at this time,    * the object will accept a connection according to the supplied   * event name   *   * @param eventName the event   * @return true if the object will accept a connection   */  public boolean connectionAllowed(String eventName) {    return (m_listenee == null);  }  /**   * Returns true if, at this time,    * the object will accept a connection according to the supplied   * EventSetDescriptor   *   * @param esd the EventSetDescriptor   * @return true if the object will accept a connection   */  public boolean connectionAllowed(EventSetDescriptor esd) {    return connectionAllowed(esd.getName());  }  /**   * Notify this object that it has been registered as a listener with   * a source with respect to the supplied event name   *   * @param eventName   * @param source the source with which this object has been registered as   * a listener   */  public synchronized void connectionNotification(String eventName,						  Object source) {    if (connectionAllowed(eventName)) {      m_listenee = source;    }  }  /**   * Notify this object that it has been deregistered as a listener with   * a source with respect to the supplied event name   *   * @param eventName the event name   * @param source the source with which this object has been registered as   * a listener   */  public synchronized void disconnectionNotification(String eventName,						     Object source) {    if (m_listenee == source) {      m_listenee = null;      m_format = null; // assume any calculated instance format if now invalid    }  }  /**   * Returns true, if at the current time, the named event could   * be generated. Assumes that supplied event names are names of   * events that could be generated by this bean.   *   * @param eventName the name of the event in question   * @return true if the named event could be generated at this point in   * time   */  public boolean eventGeneratable(String eventName) {    if (m_listenee == null) {      return false;    }    if (m_listenee instanceof EventConstraints) {      if (eventName.equals("instance")) {	if (!((EventConstraints)m_listenee).	    eventGeneratable("incrementalClassifier")) {	  return false;	}      }      if (eventName.equals("dataSet") 	  || eventName.equals("trainingSet") 	  || eventName.equals("testSet")) {	if (((EventConstraints)m_listenee).	    eventGeneratable("batchClassifier")) {	  return true;	}	if (((EventConstraints)m_listenee).eventGeneratable("batchClusterer")) {	  return true;	}	return false;      }    }    return true;  }}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?