📄 neuralnetwork.java
字号:
} } repaint(); return; } if (left) { //then connect the current selection to the new one. for (int noa = 0; noa < m_selected.size(); noa++) { for (int nob = 0; nob < v.size(); nob++) { NeuralConnection .connect((NeuralConnection)m_selected.elementAt(noa) , (NeuralConnection)v.elementAt(nob)); } } } else if (m_selected.size() > 0) { //then disconnect the current selection from the new one. for (int noa = 0; noa < m_selected.size(); noa++) { for (int nob = 0; nob < v.size(); nob++) { NeuralConnection .disconnect((NeuralConnection)m_selected.elementAt(noa) , (NeuralConnection)v.elementAt(nob)); NeuralConnection .disconnect((NeuralConnection)v.elementAt(nob) , (NeuralConnection)m_selected.elementAt(noa)); } } } else { //then remove the selected node. (it was right clicked while //no other units were selected for (int noa = 0; noa < v.size(); noa++) { ((NeuralConnection)v.elementAt(noa)).removeAllInputs(); ((NeuralConnection)v.elementAt(noa)).removeAllOutputs(); removeNode((NeuralConnection)v.elementAt(noa)); } } repaint(); } /** * This will paint the nodes ontot the panel. * @param g The graphics context. */ public void paintComponent(Graphics g) { super.paintComponent(g); int x = getWidth(); int y = getHeight(); if (25 * m_numAttributes > 25 * m_numClasses && 25 * m_numAttributes > y) { setSize(x, 25 * m_numAttributes); } else if (25 * m_numClasses > y) { setSize(x, 25 * m_numClasses); } else { setSize(x, y); } y = getHeight(); for (int noa = 0; noa < m_numAttributes; noa++) { m_inputs[noa].drawInputLines(g, x, y); } for (int noa = 0; noa < m_numClasses; noa++) { m_outputs[noa].drawInputLines(g, x, y); m_outputs[noa].drawOutputLines(g, x, y); } for (int noa = 0; noa < m_neuralNodes.length; noa++) { m_neuralNodes[noa].drawInputLines(g, x, y); } for (int noa = 0; noa < m_numAttributes; noa++) { m_inputs[noa].drawNode(g, x, y); } for (int noa = 0; noa < m_numClasses; noa++) { m_outputs[noa].drawNode(g, x, y); } for (int noa = 0; noa < m_neuralNodes.length; noa++) { m_neuralNodes[noa].drawNode(g, x, y); } for (int noa = 0; noa < m_selected.size(); noa++) { ((NeuralConnection)m_selected.elementAt(noa)).drawHighlight(g, x, y); } } } /** * This provides the basic controls for working with the neuralnetwork * @author Malcolm Ware (mfw4@cs.waikato.ac.nz) * @version $Revision: 1.4.2.2 $ */ class ControlPanel extends JPanel { /** The start stop button. */ public JButton m_startStop; /** The button to accept the network (even if it hasn't done all epochs. */ public JButton m_acceptButton; /** A label to state the number of epochs processed so far. */ public JPanel m_epochsLabel; /** A label to state the total number of epochs to be processed. */ public JLabel m_totalEpochsLabel; /** A text field to allow the changing of the total number of epochs. */ public JTextField m_changeEpochs; /** A label to state the learning rate. */ public JLabel m_learningLabel; /** A label to state the momentum. */ public JLabel m_momentumLabel; /** A text field to allow the changing of the learning rate. */ public JTextField m_changeLearning; /** A text field to allow the changing of the momentum. */ public JTextField m_changeMomentum; /** A label to state roughly the accuracy of the network.(because the accuracy is calculated per epoch, but the network is changing throughout each epoch train). */ public JPanel m_errorLabel; /** The constructor. */ public ControlPanel() { setBorder(BorderFactory.createTitledBorder("Controls")); m_totalEpochsLabel = new JLabel("Num Of Epochs "); m_epochsLabel = new JPanel(){ public void paintComponent(Graphics g) { super.paintComponent(g); g.setColor(m_controlPanel.m_totalEpochsLabel.getForeground()); g.drawString("Epoch " + m_epoch, 0, 10); } }; m_epochsLabel.setFont(m_totalEpochsLabel.getFont()); m_changeEpochs = new JTextField(); m_changeEpochs.setText("" + m_numEpochs); m_errorLabel = new JPanel(){ public void paintComponent(Graphics g) { super.paintComponent(g); g.setColor(m_controlPanel.m_totalEpochsLabel.getForeground()); if (m_valSize == 0) { g.drawString("Error per Epoch = " + Utils.doubleToString(m_error, 7), 0, 10); } else { g.drawString("Validation Error per Epoch = " + Utils.doubleToString(m_error, 7), 0, 10); } } }; m_errorLabel.setFont(m_epochsLabel.getFont()); m_learningLabel = new JLabel("Learning Rate = "); m_momentumLabel = new JLabel("Momentum = "); m_changeLearning = new JTextField(); m_changeMomentum = new JTextField(); m_changeLearning.setText("" + m_learningRate); m_changeMomentum.setText("" + m_momentum); setLayout(new BorderLayout(15, 10)); m_stopIt = true; m_accepted = false; m_startStop = new JButton("Start"); m_startStop.setActionCommand("Start"); m_acceptButton = new JButton("Accept"); m_acceptButton.setActionCommand("Accept"); JPanel buttons = new JPanel(); buttons.setLayout(new BoxLayout(buttons, BoxLayout.Y_AXIS)); buttons.add(m_startStop); buttons.add(m_acceptButton); add(buttons, BorderLayout.WEST); JPanel data = new JPanel(); data.setLayout(new BoxLayout(data, BoxLayout.Y_AXIS)); Box ab = new Box(BoxLayout.X_AXIS); ab.add(m_epochsLabel); data.add(ab); ab = new Box(BoxLayout.X_AXIS); Component b = Box.createGlue(); ab.add(m_totalEpochsLabel); ab.add(m_changeEpochs); m_changeEpochs.setMaximumSize(new Dimension(200, 20)); ab.add(b); data.add(ab); ab = new Box(BoxLayout.X_AXIS); ab.add(m_errorLabel); data.add(ab); add(data, BorderLayout.CENTER); data = new JPanel(); data.setLayout(new BoxLayout(data, BoxLayout.Y_AXIS)); ab = new Box(BoxLayout.X_AXIS); b = Box.createGlue(); ab.add(m_learningLabel); ab.add(m_changeLearning); m_changeLearning.setMaximumSize(new Dimension(200, 20)); ab.add(b); data.add(ab); ab = new Box(BoxLayout.X_AXIS); b = Box.createGlue(); ab.add(m_momentumLabel); ab.add(m_changeMomentum); m_changeMomentum.setMaximumSize(new Dimension(200, 20)); ab.add(b); data.add(ab); add(data, BorderLayout.EAST); m_startStop.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { if (e.getActionCommand().equals("Start")) { m_stopIt = false; m_startStop.setText("Stop"); m_startStop.setActionCommand("Stop"); int n = Integer.valueOf(m_changeEpochs.getText()).intValue(); m_numEpochs = n; m_changeEpochs.setText("" + m_numEpochs); double m=Double.valueOf(m_changeLearning.getText()). doubleValue(); setLearningRate(m); m_changeLearning.setText("" + m_learningRate); m = Double.valueOf(m_changeMomentum.getText()).doubleValue(); setMomentum(m); m_changeMomentum.setText("" + m_momentum); blocker(false); } else if (e.getActionCommand().equals("Stop")) { m_stopIt = true; m_startStop.setText("Start"); m_startStop.setActionCommand("Start"); } } }); m_acceptButton.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { m_accepted = true; blocker(false); } }); m_changeEpochs.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { int n = Integer.valueOf(m_changeEpochs.getText()).intValue(); if (n > 0) { m_numEpochs = n; blocker(false); } } }); } } /** The training instances. */ private Instances m_instances; /** The current instance running through the network. */ private Instance m_currentInstance; /** A flag to say that it's a numeric class. */ private boolean m_numeric; /** The ranges for all the attributes. */ private double[] m_attributeRanges; /** The base values for all the attributes. */ private double[] m_attributeBases; /** The output units.(only feeds the errors, does no calcs) */ private NeuralEnd[] m_outputs; /** The input units.(only feeds the inputs does no calcs) */ private NeuralEnd[] m_inputs; /** All the nodes that actually comprise the logical neural net. */ private NeuralConnection[] m_neuralNodes; /** The number of classes. */ private int m_numClasses = 0; /** The number of attributes. */ private int m_numAttributes = 0; //note the number doesn't include the class. /** The panel the nodes are displayed on. */ private NodePanel m_nodePanel; /** The control panel. */ private ControlPanel m_controlPanel; /** The next id number available for default naming. */ private int m_nextId; /** A Vector list of the units currently selected. */ private FastVector m_selected; /** A Vector list of the graphers. */ private FastVector m_graphers; /** The number of epochs to train through. */ private int m_numEpochs; /** a flag to state if the network should be running, or stopped. */ private boolean m_stopIt; /** a flag to state that the network has in fact stopped. */ private boolean m_stopped; /** a flag to state that the network should be accepted the way it is. */ private boolean m_accepted; /** The window for the network. */ private JFrame m_win; /** A flag to tell the build classifier to automatically build a neural net. */ private boolean m_autoBuild; /** A flag to state that the gui for the network should be brought up. To allow interaction while training. */ private boolean m_gui; /** An int to say how big the validation set should be. */ private int m_valSize; /** The number to to use to quit on validation testing. */ private int m_driftThreshold; /** The number used to seed the random number generator. */ private long m_randomSeed; /** The actual random number generator. */ private Random m_random; /** A flag to state that a nominal to binary filter should be used. */ private boolean m_useNomToBin; /** The actual filter. */ private NominalToBinaryFilter m_nominalToBinaryFilter; /** The string that defines the hidden layers */ private String m_hiddenLayers; /** This flag states that the user wants the input values normalized. */ private boolean m_normalizeAttributes; /** This flag states that the user wants the learning rate to decay. */ private boolean m_decay; /** This is the learning rate for the network. */ private double m_learningRate; /** This is the momentum for the network. */ private double m_momentum; /** Shows the number of the epoch that the network just finished. */ private int m_epoch; /** Shows the error of the epoch that the network just finished. */ private double m_error; /** This flag states that the user wants the network to restart if it * is found to be generating infinity or NaN for the error value. This * would restart the network with the current options except that the * learning rate would be smaller than before, (perhaps half of its current * value). This option will not be available if the gui is chosen (if the * gui is open the user can fix the network themselves, it is an * architectural minefield for the network to be reset with the gui open). */ private boolean m_reset; /** This flag states that the user wants the class to be normalized while * processing in the network is done. (the final answer will be in the * original range regardless). This option will only be used when the class * is numeric. */ private boolean m_normalizeClass; /** * this is a sigmoid unit. */ private SigmoidUnit m_sigmoidUnit; /** * This is a linear unit. */ private LinearUnit m_linearUnit; /** * The constructor. */ public NeuralNetwork() { m_instances = null; m_currentInstance = null; m_controlPanel = null; m_nodePanel = null; m_epoch = 0; m_error = 0; m_outputs = new NeuralEnd[0]; m_inputs = new NeuralEnd[0]; m_numAttributes = 0; m_numClasses = 0; m_neuralNodes = new NeuralConnection[0]; m_selected = new FastVector(4); m_graphers = new FastVector(2); m_nextId = 0; m_stopIt = true; m_stopped = true; m_accepted = false; m_numeric = false; m_random = null; m_nominalToBinaryFilter = new NominalToBinaryFilter(); m_sigmoidUnit = new SigmoidUnit(); m_linearUnit = new LinearUnit(); //setting all the options to their defaults. To completely change these //defaults they will also need to be changed down the bottom in the //setoptions function (the text info in the accompanying functions should //also be changed to reflect the new defaults m_normalizeClass = true; m_normalizeAttributes = true; m_autoBuild = true; m_gui = false; m_useNomToBin = true; m_driftThreshold = 20; m_numEpochs = 500; m_valSize = 0; m_randomSeed = 0; m_hiddenLayers = "a"; m_learningRate = .3; m_momentum = .2;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -