📄 trainer.java
字号:
package com.digiburo.demo1;import java.io.File;import java.io.IOException; import java.io.FileNotFoundException;import com.digiburo.backprop1.Pattern;import com.digiburo.backprop1.PatternList;/** * Train a backpropagation network for demo1. * * @author G.S. Cole (gsc@acm.org) * @version $Id: Trainer.java,v 1.5 2002/02/02 20:53:53 gsc Exp $ *//* * Development Environment: * Linux 2.2.14-5.0 (Red Hat 6.2) * Java Developers Kit 1.3.1 * * Legalise: * Copyright (C) 2002 Digital Burro, INC. * * Maintenance History: * $Log: Trainer.java,v $ * Revision 1.5 2002/02/02 20:53:53 gsc * More testing tweaks * * Revision 1.4 2002/02/02 08:27:27 gsc * Work In Progress * * Revision 1.3 2002/02/01 05:09:59 gsc * Tweaks from Unit Testing * * Revision 1.2 2002/02/01 02:48:08 gsc * Work In Progress * * Revision 1.1 2002/01/22 08:19:35 gsc * Initial Check In */public class Trainer { private static final String TRAIN_FILENAME = "demo1.trn"; private static final String NETWORK_FILENAME = "demo1.serial"; private BpDemo1 bp; private boolean[] flags; private PatternList pl; /** * Create network */ public Trainer() { //bp = new BpDemo1(2, 7, 1, 0.45, 0.9); bp = new BpDemo1(2, 7, 1, 0.25, 0.9); } /** * Load training datum * @param datum training file */ public int loadTraining(File datum) throws IOException, FileNotFoundException, ClassNotFoundException { pl = new PatternList(); pl.reader(datum); flags = new boolean[pl.size()]; return(pl.size()); } /** * Train the network on these patterns */ public void performTraining() { int counter = 0; int success = 0; int max_success = 0; do { if (success > max_success) { max_success = success; } success = 0; for (int ii = 0; ii < pl.size(); ii++) { Pattern pn = pl.get(ii); bp.runNetwork(pn); bp.trainNetwork(pn); double[] truth = pn.getOutput(); double[] results = bp.getOutputPattern(); boolean failed = false; for (int jj = 0; jj < results.length; jj++) {// System.out.print(ii + " results:" + round1(results[jj]) + " truth:" + truth[jj]);// if (round1(results[jj]) == round2(truth[jj])) {// System.out.println(" true");// } else {// System.out.println(" false");// } if (round1(results[jj]) == round2(truth[jj])) { flags[ii] = true; } else { flags[ii] = false; failed = true; break; } } if (!failed) { ++success; } } if ((++counter % 10000) == 0) { System.out.println(counter + " success:" + success + " needed:" + pl.size() + " best run:" + max_success); for (int jj = 0; jj < flags.length; jj++) { if (flags[jj] == false) { System.out.print(jj + " "); } } System.out.println(); max_success = 0; } } while (success < pl.size()); System.out.println("Training complete in " + counter + " cycles"); } /** * Map an answer from the network to a value suitable for truth comparison * @param candidate value from network * @return value for comparison w/truth */ private int round1(double candidate) { if (candidate > 0.8) { return(1); } else if (candidate < 0.2) { return(0); } return(-1); } /** * Map a truth value to a value suitable for comparison * @param candidate value from truth pattern * @return value for comparison w/truth */ private int round2(double candidate) { if (candidate > 0.5) { return(1); } return(0); } /** * Save this network for later use. * @param datum file to save as */ public void saveTraining(File datum) throws IOException, FileNotFoundException { bp.writer(datum); } /** * Driver. */ public static void main(String args[]) throws Exception { System.out.println("begin"); int population = 0; Trainer tr = new Trainer(); if (args.length == 0) { population = tr.loadTraining(new File(TRAIN_FILENAME)); } else { population = tr.loadTraining(new File(args[0])); } System.out.println("PatternList loaded w/" + population + " patterns"); tr.performTraining(); tr.saveTraining(new File(NETWORK_FILENAME)); System.out.println("end"); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -