📄 extendablelearner.java
字号:
} return myGradient; } /** * Gets the default (normal calculation of the) gradient for weights. * * @param aCurrentInps the forwarded input. * @param j the input index of the weight. * @param currentPattern the back propagated gradients. * @param k the output index of the weight. * * @return the gradient for the weight w_j_k */ public double getDefaultGradientWeight(double[] currentInps, int j, double[] currentPattern, int k) { return currentInps[j] * currentPattern[k]; } /** * Gives learners and extenders a change to do some pre-computing before the * biases are updated. * * @param currentGradientOuts the back propagated gradients. */ protected final void preBiasUpdate(double[] currentGradientOuts) { preBiasUpdateImpl(currentGradientOuts); // update weight extender... if(theUpdateWeightExtender != null && theUpdateWeightExtender.isEnabled()) { theUpdateWeightExtender.preBiasUpdate(currentGradientOuts); } // delta rule extenders... for(int i = 0; i < theDeltaRuleExtenders.size(); i++) { if(((DeltaRuleExtender)theDeltaRuleExtenders.get(i)).isEnabled()) { ((DeltaRuleExtender)theDeltaRuleExtenders.get(i)). preBiasUpdate(currentGradientOuts); } } // gradient extenders... for(int i = 0; i < theGradientExtenders.size(); i++) { if(((GradientExtender)theGradientExtenders.get(i)).isEnabled()) { ((GradientExtender)theGradientExtenders.get(i)). preBiasUpdate(currentGradientOuts); } } } /** * Gives learners a change to do some pre-computing before the biases are * updated. * * @param currentGradientOuts */ protected void preBiasUpdateImpl(double[] currentGradientOuts) { } /** * Gives learners and extenders a change to do some pre-computing before the * weights are updated. * * @param currentPattern the back propagated gradients. * @param currentInps the forwarded input. */ protected final void preWeightUpdate(double[] currentPattern, double[] currentInps) { preWeightUpdateImpl(currentPattern, currentInps); // update weight extender... if(theUpdateWeightExtender != null && theUpdateWeightExtender.isEnabled()) { theUpdateWeightExtender.preWeightUpdate(currentInps, currentPattern); } // delta rule extenders... for(int i = 0; i < theDeltaRuleExtenders.size(); i++) { if(((DeltaRuleExtender)theDeltaRuleExtenders.get(i)).isEnabled()) { ((DeltaRuleExtender)theDeltaRuleExtenders.get(i)). preWeightUpdate(currentInps, currentPattern); } } // gradient extenders... for(int i = 0; i < theGradientExtenders.size(); i++) { if(((GradientExtender)theGradientExtenders.get(i)).isEnabled()) { ((GradientExtender)theGradientExtenders.get(i)). preWeightUpdate(currentInps, currentPattern); } } } /** * Gives learners a change to do some pre-computing before the weights are * updated. * * @param currentPattern the back propagated gradients. * @param currentInps the forwarded input. */ protected void preWeightUpdateImpl(double[] currentPattern, double[] currentInps) { } /** * Gives learners and extenders a change to do some post-computing after the * biases are updated. * * @param currentGradientOuts the back propagated gradients. */ protected final void postBiasUpdate(double[] currentGradientOuts) { // gradient extenders... for(int i = 0; i < theGradientExtenders.size(); i++) { if(((GradientExtender)theGradientExtenders.get(i)).isEnabled()) { ((GradientExtender)theGradientExtenders.get(i)). postBiasUpdate(currentGradientOuts); } } // delta rule extenders... for(int i = 0; i < theDeltaRuleExtenders.size(); i++) { if(((DeltaRuleExtender)theDeltaRuleExtenders.get(i)).isEnabled()) { ((DeltaRuleExtender)theDeltaRuleExtenders.get(i)). postBiasUpdate(currentGradientOuts); } } // update weight extenders... if(theUpdateWeightExtender != null && theUpdateWeightExtender.isEnabled()) { theUpdateWeightExtender.postBiasUpdate(currentGradientOuts); } postBiasUpdateImpl(currentGradientOuts); } /** * Gives learners a change to do some post-computing after the biases are * updated. * * @param currentGradientOuts the back propagated gradients. */ protected void postBiasUpdateImpl(double[] currentGradientOuts) { } /** * Gives learners and extenders a change to do some post-computing after the * weights are updated. * * @param currentPattern the back propagated gradients. * @param currentInps the forwarded input. */ protected final void postWeightUpdate(double[] currentPattern, double[] currentInps) { // gradient extenders... for(int i = 0; i < theGradientExtenders.size(); i++) { if(((GradientExtender)theGradientExtenders.get(i)).isEnabled()) { ((GradientExtender)theGradientExtenders.get(i)). postWeightUpdate(currentInps, currentPattern); } } // delta extenders... for(int i = 0; i < theDeltaRuleExtenders.size(); i++) { if(((DeltaRuleExtender)theDeltaRuleExtenders.get(i)).isEnabled()) { ((DeltaRuleExtender)theDeltaRuleExtenders.get(i)). postWeightUpdate(currentInps, currentPattern); } } // update weight extenders... if(theUpdateWeightExtender != null && theUpdateWeightExtender.isEnabled()) { theUpdateWeightExtender.postWeightUpdate(currentInps, currentPattern); } postWeightUpdateImpl(currentInps, currentInps); } /** * Gives learners a change to do some post-computing after the weights are * updated. * * @param currentPattern the back propagated gradients. * @param currentInps the forwarded input. */ protected void postWeightUpdateImpl(double[] currentPattern, double[] currentInps) { } /** * Adds a delta extender. * * @param aDeltaRuleExtender the delta rule extender to add. */ public void addDeltaRuleExtender(DeltaRuleExtender aDeltaRuleExtender) { // Note one needs to be careful to the order of the extenders, // also note that basic and batch learner add a delta (momentum) // extender in their constructor theDeltaRuleExtenders.add(aDeltaRuleExtender); aDeltaRuleExtender.setLearner(this); } /** * Adds a gradient extender. * * @param aGradientExtender the gradient extender to add. */ public void addGradientExtender(GradientExtender aGradientExtender) { theGradientExtenders.add(aGradientExtender); aGradientExtender.setLearner(this); } /** * Sets an update weight extender. * * @param anUpdateWeightExtender the update weight extender to set. */ public void setUpdateWeightExtender(UpdateWeightExtender anUpdateWeightExtender) { theUpdateWeightExtender = anUpdateWeightExtender; theUpdateWeightExtender.setLearner(this); } /** * Gets the update weight extender. * * @return the update weight extender. */ public UpdateWeightExtender getUpdateWeightExtender() { return theUpdateWeightExtender; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -