📄 network.java
字号:
package com.digiburo.backprop1;import java.io.Serializable;/** * Memento class to capture state of network. * * @author G.S. Cole (gsc@acm.org) * @version $Id: Network.java,v 1.2 2002/02/01 02:49:20 gsc Exp $ *//* * Development Environment: * Linux 2.2.14-5.0 (Red Hat 6.2) * Java Developers Kit 1.3.1 * * Legalise: * Copyright (C) 2002 Digital Burro, INC. * * Maintenance History: * $Log: Network.java,v $ * Revision 1.2 2002/02/01 02:49:20 gsc * Work In Progress * * Revision 1.1 2002/01/22 06:34:58 gsc * Work In Progress */public class Network implements Serializable { /** * Index of first input node (in Nodez) */ private final int first_input_node_ndx; /** * Index of first middle node (in Nodez) */ private final int first_middle_node_ndx; /** * Index of first output node (in Nodez) */ private final int first_output_node_ndx; /** * learning rate */ private double learning_rate; /** * momentum term */ private double momentum; /** * Arcs connect nodes */ private final Arc[] arcz; /** * Nodes contain values */ private final AbstractNode[] nodez; /** * Constructor for standard 3 layer backpropagation network. * * @param input_population input node count * @param middle_population middle node count * @param output_population output node count * @param learning_rate * @param momentum */ public Network(int input_population, int middle_population, int output_population, double learning_rate, double momentum) { this.momentum = momentum; this.learning_rate = learning_rate; // // create nodez // int total_node_population = input_population + middle_population + output_population; nodez = new AbstractNode[total_node_population]; int node_ndx = 0; // // define input layer // first_input_node_ndx = node_ndx; for (int ii = 0; ii < input_population; ii++) { nodez[node_ndx++] = new InputNode(); } // // define middle layer // first_middle_node_ndx = node_ndx; for (int ii = 0; ii < middle_population; ii++) { nodez[node_ndx++] = new MiddleNode(learning_rate, momentum); } // // define output layer // first_output_node_ndx = node_ndx; for (int ii = 0; ii < output_population; ii++) { nodez[node_ndx++] = new OutputNode(learning_rate, momentum); } // // create arcs // int total_arc_population = input_population * middle_population; total_arc_population += middle_population * output_population; arcz = new Arc[total_arc_population]; for (int ii = 0; ii < total_arc_population; ii++) { arcz[ii] = new Arc(); } // // connect nodes and arcs // int arc_ndx = 0; for (int ii = 0; ii < input_population; ii++) { for (int jj = 0; jj < middle_population; jj++) { nodez[ii].connect(nodez[first_middle_node_ndx+jj], arcz[arc_ndx++]); } } for (int ii = 0; ii < middle_population; ii++) { for (int jj = 0; jj < output_population; jj++) { nodez[first_middle_node_ndx+ii].connect(nodez[first_output_node_ndx+jj], arcz[arc_ndx++]); } } } /** * Return index of first input node (in Nodez) * @return index of first input node (in Nodez) */ public int getFirstInputNodeIndex() { return(first_input_node_ndx); } /** * Return index of first middle node (in Nodez) * @return index of first middle node (in Nodez) */ public int getFirstMiddleNodeIndex() { return(first_middle_node_ndx); } /** * Return index of first output node (in Nodez) * @return index of first output node (in Nodez) */ public int getFirstOutputNodeIndex() { return(first_output_node_ndx); } /** * Return network learning rate * @return network learning rate */ public double getLearningRate() { return(learning_rate); } /** * Return network momentum * @return network momentum */ public double getMomentum() { return(momentum); } /** * Return all arcs * @return all arcs */ public Arc[] getArcs() { return(arcz); } /** * Return all nodes * @return all nodes */ public AbstractNode[] getNodes() { return(nodez); } /** * Write a report about node connections */ public void dumpState() { System.out.println("Total Nodes:" + nodez.length); System.out.println("Input Node Index:" + first_input_node_ndx); System.out.println("Middle Node Index:" + first_middle_node_ndx); System.out.println("Output Node Index:" + first_output_node_ndx); for (int ii = 0; ii < nodez.length; ii++) { System.out.println(ii + " " + nodez[ii]); } } /** * Write a report about arc connections */ public void dumpConnection() { for (int ii = 0; ii < arcz.length; ii++) { System.out.println(ii + " " + arcz[ii]); } } /** * Driver */ public static void main(String[] args) { System.out.println("begin"); Network nn = new Network(2, 3, 1, 1.23, 3.21); nn.dumpState(); nn.dumpConnection(); System.out.println("end"); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -