📄 trainer.java
字号:
package com.digiburo.demo2;import java.io.File;import java.io.IOException; import java.io.FileNotFoundException;import com.digiburo.backprop1.Pattern;import com.digiburo.backprop1.PatternList;/** * Train for demo2, XOR pattern. * Since there are only four patterns, I generate the training * patterns manually. * * @author G.S. Cole (gsc@acm.org) * @version $Id: Trainer.java,v 1.4 2002/02/03 20:31:41 gsc Exp $ *//* * Development Environment: * Linux 2.2.14-5.0 (Red Hat 6.2) * Java Developers Kit 1.3.1 * * Legalise: * Copyright (C) 2002 Digital Burro, INC. * * Maintenance History: * $Log: Trainer.java,v $ * Revision 1.4 2002/02/03 20:31:41 gsc * Format tweaks * * Revision 1.3 2002/02/02 20:53:53 gsc * More testing tweaks * * Revision 1.2 2002/02/01 06:14:07 gsc * Work In Progress * * Revision 1.1 2002/02/01 02:48:56 gsc * Initial Check In */public class Trainer { public static final double ONE = 0.9999999999; public static final double ZERO = 0.0000000001; private static final String NETWORK_FILENAME = "demo2.serial"; private BpDemo2 bp; private PatternList pl; /** * Create network */ public Trainer() { bp = new BpDemo2(2, 3, 1, 0.25, 0.5); } /** * Create training datum */ public int loadTraining() throws IOException, FileNotFoundException, ClassNotFoundException { pl = new PatternList(); double[] input = new double[2]; double[] output = new double[1]; input[0] = ZERO; input[1] = ZERO; output[0] = ZERO; pl.add(input, output); input = new double[2]; output = new double[1]; input[0] = ZERO; input[1] = ONE; output[0] = ONE; pl.add(input, output); input = new double[2]; output = new double[1]; input[0] = ONE; input[1] = ZERO; output[0] = ONE; pl.add(input, output); input = new double[2]; output = new double[1]; input[0] = ONE; input[1] = ONE; output[0] = ZERO; pl.add(input, output); return(pl.size()); } /** * Train the network on these patterns */ public void performTraining() { int counter = 0; int success = 0; do { success = 0; for (int ii = 0; ii < pl.size(); ii++) { Pattern pn = pl.get(ii); // System.out.println(pn); bp.runNetwork(pn); bp.trainNetwork(pn); double truth[] = pn.getOutput(); double results[] = bp.getOutputPattern(); boolean failed = false; for (int jj = 0; jj < results.length; jj++) { if (round1(results[jj]) != round2(truth[jj])) { failed = true; } } if (!failed) { ++success; } } if ((++counter % 100) == 0) { System.out.println(counter + " success:" + success + " needed:" + pl.size()); } } while (success < pl.size()); System.out.println("Training complete in " + counter + " cycles"); } /** * Map an answer from the network to a value suitable for truth comparison * @param candidate value from network * @return value for comparison w/truth */ private int round1(double candidate) { if (candidate > 0.95) { return(1); } else if (candidate < 0.05) { return(0); } return(-1); } /** * Map a truth value to a value suitable for comparison * @param candidate value from truth pattern * @return value for comparison w/truth */ private int round2(double candidate) { if (candidate > 0.5) { return(1); } return(0); } /** * Save this network for later use. * @param datum file to save as */ public void saveTraining(File datum) throws IOException, FileNotFoundException { bp.writer(datum); } /** * */ public static void main(String args[]) throws Exception { System.out.println("begin"); Trainer tr = new Trainer(); int population = tr.loadTraining(); System.out.println("PatternList loaded w/" + population + " patterns"); tr.performTraining(); tr.saveTraining(new File(NETWORK_FILENAME)); System.out.println("end"); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -