⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 extendablelearner.java

📁 一个纯java写的神经网络源代码
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
/* * ExtendableLearner.java * * Created on September 14, 2004, 8:30 AM */package org.joone.engine;import java.util.*;import org.joone.engine.extenders.*;/** * Learners that extend this class are forced to implement certain functions, a * so-called skeleton. The good thing is, because learners extend this class * certain plug-ins can be added. For example, plug ins that change the objective * function, or the delta-update rule. Still learners that do not fit into this * skeleton have to opportunity to implement Learner directly (or extend * AbstractLearner), but it won't be able to use the extra plug-ins (unless it * is build in the learner by the programmer itself). * * Basically, this class is the BasicLearner, but by adding extenders it can * provide totally different learning algoriths. * * @author Boris Jansen */public class ExtendableLearner extends AbstractLearner {        /** The list with delta rule extenders, extenders that change the     * delta w, e.g. momentum term, etc. */    protected List theDeltaRuleExtenders = new ArrayList();        /** The list with gradient extenders, extenders that change the gradient. */    protected List theGradientExtenders = new ArrayList();        /** The update weight extender, that is, the way to update     * the weights, online, batch mode, etc. */    protected UpdateWeightExtender theUpdateWeightExtender;        /** Creates a new instance of ExtendableLearner */    public ExtendableLearner() {    }        public final void requestBiasUpdate(double[] currentGradientOuts) {        double myDelta;        preBiasUpdate(currentGradientOuts);                for(int x = 0; x < getLayer().getRows(); x++) {            myDelta = getDelta(currentGradientOuts, x);            updateBias(x, myDelta);        }                postBiasUpdate(currentGradientOuts);    }        public final void requestWeightUpdate(double[] currentPattern, double[] currentInps) {        double myDelta;        preWeightUpdate(currentPattern, currentInps);        boolean[][] isEnabled = getSynapse().getWeights().getEnabled();        boolean[][] isFixed = getSynapse().getWeights().getFixed();        for(int x = 0; x < getSynapse().getInputDimension(); x++) {            for(int y = 0; y < getSynapse().getOutputDimension(); y++) {                if (!isFixed[x][y] && isEnabled[x][y]) {                    myDelta = getDelta(currentInps, x, currentPattern, y);                    updateWeight(x, y, myDelta);                }            }        }                postWeightUpdate(currentPattern, currentInps);    }        /**     * Updates a bias with the calculated delta value.     *     * @param j the index of the bias to update.     * @param aDelta the calculated delta value.     */    protected void updateBias(int j, double aDelta) {        theUpdateWeightExtender.updateBias(j, aDelta);    }        /**     * Updates a weight with the calculated delta value.     *     * @param j the input index of the weight to update.     * @param k the output index of the weight to update.     * @param aDelta the calculated delta value.     */    protected void updateWeight(int j, int k, double aDelta) {        theUpdateWeightExtender.updateWeight(j, k, aDelta);    }        /**     * Computes the delta value for a bias.     *     * @param currentGradientOuts the back propagated gradients.     * @param j the index of the bias.     */    protected double getDelta(double[] currentGradientOuts, int j) {        // if this method is overwritten, make sure that no delta extenders can be set        // by throwing an exception from setDeltaExtender()                // more than one delta extender might be set, this variable is used to pass on        // the delta value calculated by the previous delta extender to the next one        double myDelta = getDefaultDelta(currentGradientOuts, j);                for(int i = 0; i < theDeltaRuleExtenders.size(); i++) {            if(((DeltaRuleExtender)theDeltaRuleExtenders.get(i)).isEnabled()) {                myDelta = ((DeltaRuleExtender)theDeltaRuleExtenders.get(i)).                        getDelta(currentGradientOuts, j, myDelta);            }        }        return myDelta;    }        /**     * Gets the default (normal calculation of) delta.     *     * @param currentGradientOuts the back propagated gradients.     * @param j the index of the bias.     */    public double getDefaultDelta(double[] currentGradientOuts, int j) {        return getLearningRate(j) * getGradientBias(currentGradientOuts, j);    }        /**     * Computes the delta value for a weight.     *     * @param currentInps the forwarded input.     * @param j the input index of the weight.     * @param currentPattern the back propagated gradients.     * @param k the output index of the weight.     */    protected double getDelta(double[] currentInps, int j, double[] currentPattern, int k) {        // if this method is overwritten, make sure that no delta extenders can be set        // by throwing an exception from setDeltaExtender()                // more than one delta extender might be set, this variable is used to pass on        // the delta value calculated by the previous delta extender to the next one        double myDelta = getDefaultDelta(currentInps, j, currentPattern, k);                for(int i = 0; i < theDeltaRuleExtenders.size(); i++) {            if(((DeltaRuleExtender)theDeltaRuleExtenders.get(i)).isEnabled()) {                myDelta = ((DeltaRuleExtender)theDeltaRuleExtenders.get(i)).                        getDelta(currentInps, j, currentPattern, k, myDelta);            }        }        return myDelta;    }        /**     * Gets the default (normal calculation of) delta.     *     * @param currentInps the forwarded input.     * @param j the input index of the weight.     * @param currentPattern the back propagated gradients.     * @param k the output index of the weight.     */    public double getDefaultDelta(double[] currentInps, int j, double[] currentPattern, int k) {        return getLearningRate(j, k) * getGradientWeight(currentInps, j, currentPattern, k);    }        /**     * Gets the learning rate.     *     * @param j the index of the bias (for which we should get the learning rate).     * @return the learning rate for a bias.     */    protected double getLearningRate(int j) {        // in future we could add learning rate extenders...                return getMonitor().getLearningRate();    }        /**     * Gets the learning rate.     *     * @param j the input index of the weight (for which we should get the learning rate).     * @param k the output index of the weight (for which we should get the learning rate).     * @return the learning rate for a weight.     */    protected double getLearningRate(int j, int k) {        // in future we could add learning rate extenders...                return getMonitor().getLearningRate();    }        /**     * Gets the gradient for biases.     *     * @param currentGradientOuts the back protected gradients.     * @param j the index of the bias.     * @return the gradient for bias b_i.     */    public double getGradientBias(double[] currentGradientOuts, int j) {        double myGradient = getDefaultGradientBias(currentGradientOuts, j);                for(int i = 0; i < theGradientExtenders.size(); i++) {            if(((GradientExtender)theGradientExtenders.get(i)).isEnabled()) {                myGradient = ((GradientExtender)theGradientExtenders.get(i)).                        getGradientBias(currentGradientOuts, j, myGradient);            }        }        return myGradient;    }        /**     * Gets the default (normal calculation of the) gradient for biases.     *     * @param currentGradientOuts the back protected gradients.     * @param j the index of the bias.     * @return the gradient for bias b_i.     */    public double getDefaultGradientBias(double[] currentGradientOuts, int j) {        return currentGradientOuts[j];    }        /**     * Gets the gradient for weights.     *     * @param aCurrentInps the forwarded input.     * @param j the input index of the weight.     * @param currentPattern the back propagated gradients.     * @param k the output index of the weight.     *     * @return the gradient for the weight w_j_k     */    public double getGradientWeight(double[] currentInps, int j, double[] currentPattern, int k) {        double myGradient = getDefaultGradientWeight(currentInps, j, currentPattern, k);                for(int i = 0; i < theGradientExtenders.size(); i++) {            if(((GradientExtender)theGradientExtenders.get(i)).isEnabled()) {                myGradient = ((GradientExtender)theGradientExtenders.get(i)).                        getGradientWeight(currentInps, j, currentPattern, k, myGradient);            }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -