⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 neuralnetwork.java

📁 :<<数据挖掘--实用机器学习技术及java实现>>一书的配套源程序
💻 JAVA
📖 第 1 页 / 共 5 页
字号:
      //get the values.       m_outputs[noc].outputValue(true);    }  }  /**   * This will cause the error values to be calculated for all nodes.   * Note that the m_currentInstance is used to calculate these values.   * Also the output values should have been calculated first.   * @return The squared error.   */  private double calculateErrors() throws Exception {    double ret = 0, temp = 0;     for (int noc = 0; noc < m_numAttributes; noc++) {      //get the errors.      m_inputs[noc].errorValue(true);          }    for (int noc = 0; noc < m_numClasses; noc++) {      temp = m_outputs[noc].errorValue(false);      ret += temp * temp;    }        return ret;      }  /**   * This will cause the weight values to be updated based on the learning   * rate, momentum and the errors that have been calculated for each node.   * @param l The learning rate to update with.   * @param m The momentum to update with.   */  private void updateNetworkWeights(double l, double m) {    for (int noc = 0; noc < m_numClasses; noc++) {      //update weights      m_outputs[noc].updateWeights(l, m);    }  }    /**   * This creates the required input units.   */  private void setupInputs() throws Exception {    m_inputs = new NeuralEnd[m_numAttributes];    int now = 0;    for (int noa = 0; noa < m_numAttributes+1; noa++) {      if (m_instances.classIndex() != noa) {	m_inputs[noa - now] = new NeuralEnd(m_instances.attribute(noa).name());		m_inputs[noa - now].setX(.1);	m_inputs[noa - now].setY((noa - now + 1.0) / (m_numAttributes + 1));	m_inputs[noa - now].setLink(true, noa);      }          else {	now = 1;      }    }  }  /**   * This creates the required output units.   */  private void setupOutputs() throws Exception {      m_outputs = new NeuralEnd[m_numClasses];    for (int noa = 0; noa < m_numClasses; noa++) {      if (m_numeric) {	m_outputs[noa] = new NeuralEnd(m_instances.classAttribute().name());      }      else {	m_outputs[noa]= new NeuralEnd(m_instances.classAttribute().value(noa));      }            m_outputs[noa].setX(.9);      m_outputs[noa].setY((noa + 1.0) / (m_numClasses + 1));      m_outputs[noa].setLink(false, noa);      NeuralNode temp = new NeuralNode(String.valueOf(m_nextId), m_random,				       m_sigmoidUnit);      m_nextId++;      temp.setX(.75);      temp.setY((noa + 1.0) / (m_numClasses + 1));      addNode(temp);      NeuralConnection.connect(temp, m_outputs[noa]);    }   }    /**   * Call this function to automatically generate the hidden units   */  private void setupHiddenLayer()  {    StringTokenizer tok = new StringTokenizer(m_hiddenLayers, ",");    int val = 0;  //num of nodes in a layer    int prev = 0; //used to remember the previous layer    int num = tok.countTokens(); //number of layers    String c;    for (int noa = 0; noa < num; noa++) {      //note that I am using the Double to get the value rather than the      //Integer class, because for some reason the Double implementation can      //handle leading white space and the integer version can't!?!      c = tok.nextToken().trim();      if (c.equals("a")) {	val = (m_numAttributes + m_numClasses) / 2;      }      else if (c.equals("i")) {	val = m_numAttributes;      }      else if (c.equals("o")) {	val = m_numClasses;      }      else if (c.equals("t")) {	val = m_numAttributes + m_numClasses;      }      else {	val = Double.valueOf(c).intValue();      }      for (int nob = 0; nob < val; nob++) {	NeuralNode temp = new NeuralNode(String.valueOf(m_nextId), m_random,					 m_sigmoidUnit);	m_nextId++;	temp.setX(.5 / (num) * noa + .25);	temp.setY((nob + 1.0) / (val + 1));	addNode(temp);	if (noa > 0) {	  //then do connections	  for (int noc = m_neuralNodes.length - nob - 1 - prev;	       noc < m_neuralNodes.length - nob - 1; noc++) {	    NeuralConnection.connect(m_neuralNodes[noc], temp);	  }	}      }            prev = val;    }    tok = new StringTokenizer(m_hiddenLayers, ",");    c = tok.nextToken();    if (c.equals("a")) {      val = (m_numAttributes + m_numClasses) / 2;    }    else if (c.equals("i")) {      val = m_numAttributes;    }    else if (c.equals("o")) {      val = m_numClasses;    }    else if (c.equals("t")) {      val = m_numAttributes + m_numClasses;    }    else {      val = Double.valueOf(c).intValue();    }        if (val == 0) {      for (int noa = 0; noa < m_numAttributes; noa++) {	for (int nob = 0; nob < m_numClasses; nob++) {	  NeuralConnection.connect(m_inputs[noa], m_neuralNodes[nob]);	}      }    }    else {      for (int noa = 0; noa < m_numAttributes; noa++) {	for (int nob = m_numClasses; nob < m_numClasses + val; nob++) {	  NeuralConnection.connect(m_inputs[noa], m_neuralNodes[nob]);	}      }      for (int noa = m_neuralNodes.length - prev; noa < m_neuralNodes.length;	   noa++) {	for (int nob = 0; nob < m_numClasses; nob++) {	  NeuralConnection.connect(m_neuralNodes[noa], m_neuralNodes[nob]);	}      }    }      }    /**   * This will go through all the nodes and check if they are connected   * to a pure output unit. If so they will be set to be linear units.   * If not they will be set to be sigmoid units.   */  private void setEndsToLinear() {    for (int noa = 0; noa < m_neuralNodes.length; noa++) {      if ((m_neuralNodes[noa].getType() & NeuralConnection.OUTPUT) ==	  NeuralConnection.OUTPUT) {	((NeuralNode)m_neuralNodes[noa]).setMethod(m_linearUnit);      }      else {	((NeuralNode)m_neuralNodes[noa]).setMethod(m_sigmoidUnit);      }    }  }    /**   * Call this function to build and train a neural network for the training   * data provided.   * @param i The training data.   * @exception Throws exception if can't build classification properly.   */  public void buildClassifier(Instances i) throws Exception {    if (i.checkForStringAttributes()) {      throw new IllegalArgumentException("Can't handle string attributes!");    }    if (i.numInstances() == 0) {      throw new IllegalArgumentException("No training instances.");    }    m_epoch = 0;    m_error = 0;    m_instances = null;    m_currentInstance = null;    m_controlPanel = null;    m_nodePanel = null;            m_outputs = new NeuralEnd[0];    m_inputs = new NeuralEnd[0];    m_numAttributes = 0;    m_numClasses = 0;    m_neuralNodes = new NeuralConnection[0];        m_selected = new FastVector(4);    m_graphers = new FastVector(2);    m_nextId = 0;    m_stopIt = true;    m_stopped = true;    m_accepted = false;        m_instances = new Instances(i);    m_instances.deleteWithMissingClass();    if (m_instances.numInstances() == 0) {      m_instances = null;      throw new IllegalArgumentException("All class values missing.");    }    m_random = new Random(m_randomSeed);    m_instances.randomize(m_random);    if (m_useNomToBin) {      m_nominalToBinaryFilter = new NominalToBinaryFilter();      m_nominalToBinaryFilter.setInputFormat(m_instances);      m_instances = Filter.useFilter(m_instances,				     m_nominalToBinaryFilter);    }    m_numAttributes = m_instances.numAttributes() - 1;    m_numClasses = m_instances.numClasses();         setClassType(m_instances);           //this sets up the validation set.    Instances valSet = null;    //numinval is needed later    int numInVal = (int)(m_valSize / 100.0 * m_instances.numInstances());    if (m_valSize > 0) {      if (numInVal == 0) {	numInVal = 1;      }      valSet = new Instances(m_instances, 0, numInVal);    }    ///////////    setupInputs();          setupOutputs();        if (m_autoBuild) {      setupHiddenLayer();    }        /////////////////////////////    //this sets up the gui for usage    if (m_gui) {      m_win = new JFrame();            m_win.addWindowListener(new WindowAdapter() {	  public void windowClosing(WindowEvent e) {	    boolean k = m_stopIt;	    m_stopIt = true;	    int well =JOptionPane.showConfirmDialog(m_win, 						    "Are You Sure...\n"						    + "Click Yes To Accept"						    + " The Neural Network" 						    + "\n Click No To Return",						    "Accept Neural Network", 						    JOptionPane.YES_NO_OPTION);	    	    if (well == 0) {	      m_win.setDefaultCloseOperation(JFrame.DISPOSE_ON_CLOSE);	      m_accepted = true;	      blocker(false);	    }	    else {	      m_win.setDefaultCloseOperation(JFrame.DO_NOTHING_ON_CLOSE);	    }	    m_stopIt = k;	  }	});            m_win.getContentPane().setLayout(new BorderLayout());      m_win.setTitle("Neural Network");      m_nodePanel = new NodePanel();      JScrollPane sp = new JScrollPane(m_nodePanel,				       JScrollPane.VERTICAL_SCROLLBAR_ALWAYS, 				       JScrollPane.HORIZONTAL_SCROLLBAR_NEVER);      m_controlPanel = new ControlPanel();                 m_win.getContentPane().add(sp, BorderLayout.CENTER);      m_win.getContentPane().add(m_controlPanel, BorderLayout.SOUTH);      m_win.setSize(640, 480);      m_win.show();    }       //This sets up the initial state of the gui    if (m_gui) {      blocker(true);      m_controlPanel.m_changeEpochs.setEnabled(false);      m_controlPanel.m_changeLearning.setEnabled(false);      m_controlPanel.m_changeMomentum.setEnabled(false);    }         //For silly situations in which the network gets accepted before training    //commenses    if (m_numeric) {      setEndsToLinear();    }    if (m_accepted) {      m_win.dispose();      m_controlPanel = null;      m_nodePanel = null;      m_instances = new Instances(m_instances, 0);      return;    }    //connections done.    double right = 0;    double driftOff = 0;    double lastRight = Double.POSITIVE_INFINITY;    double tempRate;    double totalWeight = 0;    double totalValWeight = 0;    double origRate = m_learningRate; //only used for when reset        //ensure that at least 1 instance is trained through.    if (numInVal == m_instances.numInstances()) {      numInVal--;    }    if (numInVal < 0) {      numInVal = 0;    }    for (int noa = numInVal; noa < m_instances.numInstances(); noa++) {      if (!m_instances.instance(noa).classIsMissing()) {	totalWeight += m_instances.instance(noa).weight();      }    }    if (m_valSize != 0) {      for (int noa = 0; noa < valSet.numInstances(); noa++) {	if (!valSet.instance(noa).classIsMissing()) {	  totalValWeight += valSet.instance(noa).weight();	}      }    }    m_stopped = false;         for (int noa = 1; noa < m_numEpochs + 1; noa++) {      right = 0;      for (int nob = numInVal; nob < m_instances.numInstances(); nob++) {	m_currentInstance = m_instances.instance(nob);		if (!m_currentInstance.classIsMissing()) {	   	  //this is where the network updating (and training occurs, for the	  //training set	  resetNetwork();	  calculateOutputs();	  tempRate = m_learningRate * m_currentInstance.weight();  	  if (m_decay) {	    tempRate /= noa;	  }	  right += (calculateErrors() / m_instances.numClasses()) *	    m_currentInstance.weight();	  updateNetworkWeights(tempRate, m_momentum);	  	}	      }      right /= totalWeight;      if (Double.isInfinite(right) || Double.isNaN(right)) {	if (!m_reset) {	  m_instances = null;	  throw new Exception("Network cannot train. Try restarting with a" +			      " smaller learning rate.");	}	else {	  //reset the network	  m_learningRate /= 2;	  buildClassifier(i);	  m_learningRate = origRate;	  m_instances = new Instances(m_instances, 0);	  	  return;	}      }      ////////////////////////do validation testing if applicable      if (m_valSize != 0) {	right = 0;	for (int nob = 0; nob < valSet.numInstances(); nob++) {	  m_currentInstance = valSet.instance(nob);	  if (!m_currentInstance.classIsMissing()) {	    //this is where the network updating occurs, for the validation set	    resetNetwork();	    calculateOutputs();	    right += (calculateErrors() / valSet.numClasses()) 	      * m_currentInstance.weight();	    //note 'right' could be calculated here just using	    //the calculate output values. This would be faster.	    //be less modular	  }	  	}		if (right < lastRight) {	  driftOff = 0;	}	else {	  driftOff++;	}	lastRight = right;	if (driftOff > m_driftThreshold || noa + 1 >= m_numEpochs) {	  m_accepted = true;	}	right /= totalValWeight;      }      m_epoch = noa;      m_error = right;      //shows what the neuralnet is upto if a gui exists.       updateDisplay();      //This junction controls what state the gui is in at the end of each      //epoch, Such as if it is paused, if it is resumable etc...      if (m_gui) {	while ((m_stopIt || (m_epoch >= m_numEpochs && m_valSize == 0)) && 		!m_accepted) {	  m_stopIt = true;	  m_stopped = true;	  if (m_epoch >= m_numEpochs && m_valSize == 0) {	    	    m_controlPanel.m_startStop.setEnabled(false);	  }	  else {	    m_controlPanel.m_startStop.setEnabled(true);	  }	  m_controlPanel.m_startStop.setText("Start");

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -