⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 trainer.java

📁 BP算法JAVA源程序
💻 JAVA
字号:
package com.digiburo.demo1;import java.io.File;import java.io.IOException; import java.io.FileNotFoundException;import com.digiburo.backprop1.Pattern;import com.digiburo.backprop1.PatternList;/** * Train a backpropagation network for demo1. * * @author G.S. Cole (gsc@acm.org) * @version $Id: Trainer.java,v 1.5 2002/02/02 20:53:53 gsc Exp $ *//* * Development Environment: *   Linux 2.2.14-5.0 (Red Hat 6.2) *   Java Developers Kit 1.3.1 * * Legalise:   *   Copyright (C) 2002 Digital Burro, INC. * * Maintenance History: *   $Log: Trainer.java,v $ *   Revision 1.5  2002/02/02 20:53:53  gsc *   More testing tweaks * *   Revision 1.4  2002/02/02 08:27:27  gsc *   Work In Progress * *   Revision 1.3  2002/02/01 05:09:59  gsc *   Tweaks from Unit Testing * *   Revision 1.2  2002/02/01 02:48:08  gsc *   Work In Progress * *   Revision 1.1  2002/01/22 08:19:35  gsc *   Initial Check In */public class Trainer {    private static final String TRAIN_FILENAME = "demo1.trn";    private static final String NETWORK_FILENAME = "demo1.serial";    private BpDemo1 bp;    private boolean[] flags;    private PatternList pl;    /**     * Create network     */    public Trainer() {	//bp = new BpDemo1(2, 7, 1, 0.45, 0.9);	bp = new BpDemo1(2, 7, 1, 0.25, 0.9);    }    /**     * Load training datum     * @param datum training file     */    public int loadTraining(File datum) throws IOException, FileNotFoundException, ClassNotFoundException {	pl = new PatternList();	pl.reader(datum);	flags = new boolean[pl.size()];	return(pl.size());    }    /**     * Train the network on these patterns     */    public void performTraining() {	int counter = 0;	int success = 0;	int max_success = 0;		do {	    if (success > max_success) {		max_success = success;	    }	    success = 0;	    	    for (int ii = 0; ii < pl.size(); ii++) {		Pattern pn = pl.get(ii);		bp.runNetwork(pn);		bp.trainNetwork(pn);				double[] truth = pn.getOutput();				double[] results = bp.getOutputPattern();				boolean failed = false;		for (int jj = 0; jj < results.length; jj++) {//  		    System.out.print(ii + " results:" + round1(results[jj]) + " truth:" + truth[jj]);//  		    if (round1(results[jj]) == round2(truth[jj])) {//  			System.out.println(" true");//  		    } else {//  			System.out.println(" false");//  		    }		    if (round1(results[jj]) == round2(truth[jj])) {			flags[ii] = true;		    } else {			flags[ii] = false;			failed = true;			break;		    }		}				if (!failed) {		    ++success;		} 	    }	    	    if ((++counter % 10000) == 0) {		System.out.println(counter + " success:" + success + " needed:" + pl.size() + " best run:" + max_success);		for (int jj = 0; jj < flags.length; jj++) {		    if (flags[jj] == false) {			System.out.print(jj + " ");		    }		}		System.out.println();		max_success = 0;	    }	} while (success < pl.size());		System.out.println("Training complete in " + counter + " cycles");    }    /**     * Map an answer from the network to a value suitable for truth comparison     * @param candidate value from network     * @return value for comparison w/truth     */    private int round1(double candidate) {	if (candidate > 0.8) {	    return(1);	} else if (candidate < 0.2) {	    return(0);	}	        return(-1);    }    /**     * Map a truth value to a value suitable for comparison     * @param candidate value from truth pattern     * @return value for comparison w/truth     */    private int round2(double candidate) {	if (candidate > 0.5) {	    return(1);	}	return(0);    }    /**     * Save this network for later use.     * @param datum file to save as     */    public void saveTraining(File datum) throws IOException, FileNotFoundException {	bp.writer(datum);    }    /**     * Driver.     */    public static void main(String args[]) throws Exception {	System.out.println("begin");	int population = 0;	Trainer tr = new Trainer();	if (args.length == 0) {	    population = tr.loadTraining(new File(TRAIN_FILENAME));	} else {	    population = tr.loadTraining(new File(args[0]));	}	System.out.println("PatternList loaded w/" + population + " patterns");	tr.performTraining();	tr.saveTraining(new File(NETWORK_FILENAME));	System.out.println("end");    }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -