📄 multilayerperceptron.java
字号:
selection(tmp, (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK , true); return; } } for (int noa = 0; noa < m_numClasses; noa++) { if (m_outputs[noa].onUnit(g, x, y, w, h)) { tmp.addElement(m_outputs[noa]); selection(tmp, (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK , true); return; } } for (int noa = 0; noa < m_neuralNodes.length; noa++) { if (m_neuralNodes[noa].onUnit(g, x, y, w, h)) { tmp.addElement(m_neuralNodes[noa]); selection(tmp, (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK , true); return; } } NeuralNode temp = new NeuralNode(String.valueOf(m_nextId), m_random, m_sigmoidUnit); m_nextId++; temp.setX((double)e.getX() / w); temp.setY((double)e.getY() / h); tmp.addElement(temp); addNode(temp); selection(tmp, (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK , true); } else { //then right click Graphics g = NodePanel.this.getGraphics(); int x = e.getX(); int y = e.getY(); int w = NodePanel.this.getWidth(); int h = NodePanel.this.getHeight(); FastVector tmp = new FastVector(4); for (int noa = 0; noa < m_numAttributes; noa++) { if (m_inputs[noa].onUnit(g, x, y, w, h)) { tmp.addElement(m_inputs[noa]); selection(tmp, (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK , false); return; } } for (int noa = 0; noa < m_numClasses; noa++) { if (m_outputs[noa].onUnit(g, x, y, w, h)) { tmp.addElement(m_outputs[noa]); selection(tmp, (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK , false); return; } } for (int noa = 0; noa < m_neuralNodes.length; noa++) { if (m_neuralNodes[noa].onUnit(g, x, y, w, h)) { tmp.addElement(m_neuralNodes[noa]); selection(tmp, (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK , false); return; } } selection(null, (e.getModifiers() & MouseEvent.CTRL_MASK) == MouseEvent.CTRL_MASK , false); } } }); } /** * This function gets called when the user has clicked something * It will amend the current selection or connect the current selection * to the new selection. * Or if nothing was selected and the right button was used it will * delete the node. * @param v The units that were selected. * @param ctrl True if ctrl was held down. * @param left True if it was the left mouse button. */ private void selection(FastVector v, boolean ctrl, boolean left) { if (v == null) { //then unselect all. m_selected.removeAllElements(); repaint(); return; } //then exclusive or the new selection with the current one. if ((ctrl || m_selected.size() == 0) && left) { boolean removed = false; for (int noa = 0; noa < v.size(); noa++) { removed = false; for (int nob = 0; nob < m_selected.size(); nob++) { if (v.elementAt(noa) == m_selected.elementAt(nob)) { //then remove that element m_selected.removeElementAt(nob); removed = true; break; } } if (!removed) { m_selected.addElement(v.elementAt(noa)); } } repaint(); return; } if (left) { //then connect the current selection to the new one. for (int noa = 0; noa < m_selected.size(); noa++) { for (int nob = 0; nob < v.size(); nob++) { NeuralConnection .connect((NeuralConnection)m_selected.elementAt(noa) , (NeuralConnection)v.elementAt(nob)); } } } else if (m_selected.size() > 0) { //then disconnect the current selection from the new one. for (int noa = 0; noa < m_selected.size(); noa++) { for (int nob = 0; nob < v.size(); nob++) { NeuralConnection .disconnect((NeuralConnection)m_selected.elementAt(noa) , (NeuralConnection)v.elementAt(nob)); NeuralConnection .disconnect((NeuralConnection)v.elementAt(nob) , (NeuralConnection)m_selected.elementAt(noa)); } } } else { //then remove the selected node. (it was right clicked while //no other units were selected for (int noa = 0; noa < v.size(); noa++) { ((NeuralConnection)v.elementAt(noa)).removeAllInputs(); ((NeuralConnection)v.elementAt(noa)).removeAllOutputs(); removeNode((NeuralConnection)v.elementAt(noa)); } } repaint(); } /** * This will paint the nodes ontot the panel. * @param g The graphics context. */ public void paintComponent(Graphics g) { super.paintComponent(g); int x = getWidth(); int y = getHeight(); if (25 * m_numAttributes > 25 * m_numClasses && 25 * m_numAttributes > y) { setSize(x, 25 * m_numAttributes); } else if (25 * m_numClasses > y) { setSize(x, 25 * m_numClasses); } else { setSize(x, y); } y = getHeight(); for (int noa = 0; noa < m_numAttributes; noa++) { m_inputs[noa].drawInputLines(g, x, y); } for (int noa = 0; noa < m_numClasses; noa++) { m_outputs[noa].drawInputLines(g, x, y); m_outputs[noa].drawOutputLines(g, x, y); } for (int noa = 0; noa < m_neuralNodes.length; noa++) { m_neuralNodes[noa].drawInputLines(g, x, y); } for (int noa = 0; noa < m_numAttributes; noa++) { m_inputs[noa].drawNode(g, x, y); } for (int noa = 0; noa < m_numClasses; noa++) { m_outputs[noa].drawNode(g, x, y); } for (int noa = 0; noa < m_neuralNodes.length; noa++) { m_neuralNodes[noa].drawNode(g, x, y); } for (int noa = 0; noa < m_selected.size(); noa++) { ((NeuralConnection)m_selected.elementAt(noa)).drawHighlight(g, x, y); } } } /** * This provides the basic controls for working with the neuralnetwork * @author Malcolm Ware (mfw4@cs.waikato.ac.nz) * @version $Revision: 1.4 $ */ class ControlPanel extends JPanel { /** for serialization */ static final long serialVersionUID = 7393543302294142271L; /** The start stop button. */ public JButton m_startStop; /** The button to accept the network (even if it hasn't done all epochs. */ public JButton m_acceptButton; /** A label to state the number of epochs processed so far. */ public JPanel m_epochsLabel; /** A label to state the total number of epochs to be processed. */ public JLabel m_totalEpochsLabel; /** A text field to allow the changing of the total number of epochs. */ public JTextField m_changeEpochs; /** A label to state the learning rate. */ public JLabel m_learningLabel; /** A label to state the momentum. */ public JLabel m_momentumLabel; /** A text field to allow the changing of the learning rate. */ public JTextField m_changeLearning; /** A text field to allow the changing of the momentum. */ public JTextField m_changeMomentum; /** A label to state roughly the accuracy of the network.(because the accuracy is calculated per epoch, but the network is changing throughout each epoch train). */ public JPanel m_errorLabel; /** The constructor. */ public ControlPanel() { setBorder(BorderFactory.createTitledBorder("Controls")); m_totalEpochsLabel = new JLabel("Num Of Epochs "); m_epochsLabel = new JPanel(){ public void paintComponent(Graphics g) { super.paintComponent(g); g.setColor(m_controlPanel.m_totalEpochsLabel.getForeground()); g.drawString("Epoch " + m_epoch, 0, 10); } }; m_epochsLabel.setFont(m_totalEpochsLabel.getFont()); m_changeEpochs = new JTextField(); m_changeEpochs.setText("" + m_numEpochs); m_errorLabel = new JPanel(){ public void paintComponent(Graphics g) { super.paintComponent(g); g.setColor(m_controlPanel.m_totalEpochsLabel.getForeground()); if (m_valSize == 0) { g.drawString("Error per Epoch = " + Utils.doubleToString(m_error, 7), 0, 10); } else { g.drawString("Validation Error per Epoch = " + Utils.doubleToString(m_error, 7), 0, 10); } } }; m_errorLabel.setFont(m_epochsLabel.getFont()); m_learningLabel = new JLabel("Learning Rate = "); m_momentumLabel = new JLabel("Momentum = "); m_changeLearning = new JTextField(); m_changeMomentum = new JTextField(); m_changeLearning.setText("" + m_learningRate); m_changeMomentum.setText("" + m_momentum); setLayout(new BorderLayout(15, 10)); m_stopIt = true; m_accepted = false; m_startStop = new JButton("Start"); m_startStop.setActionCommand("Start"); m_acceptButton = new JButton("Accept"); m_acceptButton.setActionCommand("Accept"); JPanel buttons = new JPanel(); buttons.setLayout(new BoxLayout(buttons, BoxLayout.Y_AXIS)); buttons.add(m_startStop); buttons.add(m_acceptButton); add(buttons, BorderLayout.WEST); JPanel data = new JPanel(); data.setLayout(new BoxLayout(data, BoxLayout.Y_AXIS)); Box ab = new Box(BoxLayout.X_AXIS); ab.add(m_epochsLabel); data.add(ab); ab = new Box(BoxLayout.X_AXIS); Component b = Box.createGlue(); ab.add(m_totalEpochsLabel); ab.add(m_changeEpochs); m_changeEpochs.setMaximumSize(new Dimension(200, 20)); ab.add(b); data.add(ab); ab = new Box(BoxLayout.X_AXIS); ab.add(m_errorLabel); data.add(ab); add(data, BorderLayout.CENTER); data = new JPanel(); data.setLayout(new BoxLayout(data, BoxLayout.Y_AXIS)); ab = new Box(BoxLayout.X_AXIS); b = Box.createGlue(); ab.add(m_learningLabel); ab.add(m_changeLearning); m_changeLearning.setMaximumSize(new Dimension(200, 20)); ab.add(b); data.add(ab); ab = new Box(BoxLayout.X_AXIS); b = Box.createGlue(); ab.add(m_momentumLabel); ab.add(m_changeMomentum); m_changeMomentum.setMaximumSize(new Dimension(200, 20)); ab.add(b); data.add(ab); add(data, BorderLayout.EAST); m_startStop.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { if (e.getActionCommand().equals("Start")) { m_stopIt = false; m_startStop.setText("Stop"); m_startStop.setActionCommand("Stop"); int n = Integer.valueOf(m_changeEpochs.getText()).intValue(); m_numEpochs = n; m_changeEpochs.setText("" + m_numEpochs); double m=Double.valueOf(m_changeLearning.getText()). doubleValue(); setLearningRate(m); m_changeLearning.setText("" + m_learningRate); m = Double.valueOf(m_changeMomentum.getText()).doubleValue(); setMomentum(m); m_changeMomentum.setText("" + m_momentum); blocker(false); } else if (e.getActionCommand().equals("Stop")) { m_stopIt = true; m_startStop.setText("Start"); m_startStop.setActionCommand("Start"); } } }); m_acceptButton.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { m_accepted = true; blocker(false); } }); m_changeEpochs.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { int n = Integer.valueOf(m_changeEpochs.getText()).intValue(); if (n > 0) { m_numEpochs = n; blocker(false); } } }); } } /** The training instances. */ private Instances m_instances; /** The current instance running through the network. */ private Instance m_currentInstance; /** A flag to say that it's a numeric class. */ private boolean m_numeric; /** The ranges for all the attributes. */ private double[] m_attributeRanges; /** The base values for all the attributes. */ private double[] m_attributeBases; /** The output units.(only feeds the errors, does no calcs) */ private NeuralEnd[] m_outputs; /** The input units.(only feeds the inputs does no calcs) */ private NeuralEnd[] m_inputs; /** All the nodes that actually comprise the logical neural net. */ private NeuralConnection[] m_neuralNodes; /** The number of classes. */ private int m_numClasses = 0; /** The number of attributes. */ private int m_numAttributes = 0; //note the number doesn't include the class. /** The panel the nodes are displayed on. */ private NodePanel m_nodePanel; /** The control panel. */ private ControlPanel m_controlPanel; /** The next id number available for default naming. */ private int m_nextId; /** A Vector list of the units currently selected. */ private FastVector m_selected; /** A Vector list of the graphers. */ private FastVector m_graphers;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -