📄 ebm.h
字号:
/*************************************************************************** * Copyright (C) 2008 by Yann LeCun * * yann@cs.nyu.edu * * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Redistribution under a license not approved by the Open Source * Initiative (http://www.opensource.org) must display the * following acknowledgement in all advertising material: * This product includes software developed at the Courant * Institute of Mathematical Sciences (http://cims.nyu.edu). * * The names of the authors may not be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL ThE AUTHORS BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ***************************************************************************/#ifndef EBM_H_#define EBM_H_#include "Idx.h"#include "Blas.h"namespace ebl {//! see numerics.h for descriptionextern bool drand_ini;void err_not_implemented();class infer_param {};//! class that contains all the parameters//! for a stochastic gradient descent update,//! including step sizes, regularizer coefficients...class gd_param: public infer_param {public: //! global step size double eta; //! time at which to start using decay values int decay_time; //! L2 regularizer coefficient double decay_l2; //! L1 regularizer coefficient double decay_l1; //! stopping criterion threshold double n; //! momentum term double inertia; //! annealing coefficient for the learning rate double anneal_value; //! number of iteration beetween two annealings double anneal_time; //! threshold on square norm of gradient for stopping double gradient_threshold; //! for debugging purpose int niter_done; gd_param(double leta, double ln, double l1, double l2, int dtime, double iner, double a_v, double a_t, double g_t);};//////////////////////////////////////////////////////////////////! abstract class for randomization parametersclass forget_param {};class forget_param_linear: public forget_param {public: //! each random value will be drawn uniformly //! from [-value/(fanin**exponent), +value/(fanin**exponent)] double value; double exponent; //! constructor. //! each random value will be drawn uniformly //! from [-v/(fanin**exponent), +v/(fanin**exponent)] forget_param_linear(double v, double e);};//////////////////////////////////////////////////////////////////! abstract class that stores a state.//! it must support the following methods//! clear (clear values), clear_dx (clear gradients),//! clear_ddx (clear hessian), and update_gd(arg) (update//! with gradient descent.class state {public: virtual void clear(); virtual void clear_dx(); virtual void clear_ddx(); virtual void update_gd(gd_param &arg); state(); virtual ~state();};class parameter;//! class that stores a vector/tensor stateclass state_idx: public state {public: //! state itself Idx<double> x; //! gradient of loss with respect to state Idx<double> dx; //! diag hessian of loss with respect to state Idx<double> ddx; //! Constructs a state_idx of order 0 state_idx(); //! Constructs a state_idx of order 1 state_idx(intg s0); //! Constructs a state_idx of order 2 state_idx(intg s0, intg s1); //! Constructs a state_idx of order 3 state_idx(intg s0, intg s1, intg s2); //! Constructor. A state_idx can have up to 8 dimensions. state_idx(intg s0, intg s1, intg s2, intg s3, intg s4 = -1, intg s5 = -1, intg s6 = -1, intg s7 = -1); //! this appends the state_idx into the same Srg as the //! state_idx passed as argument. This is useful for //! allocating multiple state_idx inside a parameter. //! This replaces the Lush function alloc_state_idx. state_idx(parameter *st); state_idx(parameter *st, intg s0); state_idx(parameter *st, intg s0, intg s1); state_idx(parameter *st, intg s0, intg s1, intg s2); state_idx(parameter *st, intg s0, intg s1, intg s2, intg s3, intg s4 = -1, intg s5 = -1, intg s6 = -1, intg s7 = -1); virtual ~state_idx(); //! clear x virtual void clear(); //! clear gradients dx virtual void clear_dx(); //! clear diag hessians ddx virtual void clear_ddx(); //! return number of elements virtual intg nelements(); //! return footprint in storages virtual intg footprint(); //! same as footprint virtual intg size(); //! update with gradient descent virtual void update_gd(gd_param &arg); //! resize. The order cannot be changed with this. virtual void resize(intg s0 = -1, intg s1 = -1, intg s2 = -1, intg s3 = -1, intg s4 = -1, intg s5 = -1, intg s6 = -1, intg s7 = -1); virtual void resizeAs(state_idx &s); virtual void resize(const intg* dimsBegin, const intg* dimsEnd); //! make a new copy of self virtual state_idx make_copy();};//////////////////////////////////////////////////////////////////! parameter: the main class for a trainable//! parameter vector.class parameter: public state_idx {public: Idx<double> gradient; Idx<double> deltax; Idx<double> epsilons; Idx<double> ddeltax; //! constructor parameter(intg initial_size = 100); virtual ~parameter(); virtual void resize(intg s0); void update_gd(gd_param &arg); virtual void update(gd_param &arg); void clear_deltax(); void update_deltax(double knew, double kold); void clear_ddeltax(); void update_ddeltax(double knew, double kold); void set_epsilons(double m); void compute_epsilons(double mu); bool load(const char *s); void save(const char *s);};////////////////////////////////////////////////////////////////// templates for generic modules//! abstract class for a module with one input and one output.template<class Tin, class Tout> class module_1_1 {public: virtual ~module_1_1() { } virtual void fprop(Tin *in, Tout *out); virtual void bprop(Tin *in, Tout *out); virtual void bbprop(Tin *in, Tout *out); virtual void forget(forget_param_linear& fp); virtual void normalize();};//////////////////////////////////////////////////////////////////! abstract class for a module with two inputs and one output.template<class Tin1, class Tin2, class Tout> class module_2_1 {public: virtual ~module_2_1() { } ; virtual void fprop(Tin1 *in1, Tin2 *in2, Tout *out); virtual void bprop(Tin1 *in1, Tin2 *in2, Tout *out); virtual void bbprop(Tin1 *in1, Tin2 *in2, Tout *out); virtual void forget(forget_param &fp); virtual void normalize();};//////////////////////////////////////////////////////////////////! abstract class for a module with one inputs and one energy output.template<class Tin> class ebm_1 {public: virtual ~ebm_1() { } ; virtual void fprop(Tin *in, state_idx *energy); virtual void bprop(Tin *in, state_idx *energy); virtual void bbprop(Tin *in, state_idx *energy); virtual void forget(forget_param &fp); virtual void normalize();};//////////////////////////////////////////////////////////////////! abstract class for a module with two inputs and one energy output.template<class Tin1, class Tin2> class ebm_2 {public: virtual ~ebm_2() { } ; //! fprop: compute output from input virtual void fprop(Tin1 *i1, Tin2 *i2, state_idx *energy); //! bprop: compute gradient wrt inputs, given gradient wrt output virtual void bprop(Tin1 *i1, Tin2 *i2, state_idx *energy); //! bprop: compute diaghession wrt inputs, given diaghessian wrt output virtual void bbprop(Tin1 *i1, Tin2 *i2, state_idx *energy); virtual void bprop1_copy(Tin1 *i1, Tin2 *i2, state_idx *energy); virtual void bprop2_copy(Tin1 *i1, Tin2 *i2, state_idx *energy); virtual void bbprop1_copy(Tin1 *i1, Tin2 *i2, state_idx *energy); virtual void bbprop2_copy(Tin1 *i1, Tin2 *i2, state_idx *energy); virtual void forget(forget_param &fp); virtual void normalize(); //! compute value of in1 that minimizes the energy, given in2 virtual double infer1(Tin1 *i1, Tin2 *i2, state_idx *energy, infer_param *ip) { return 0; } //! compute value of in2 that minimizes the energy, given in1 virtual double infer2(Tin1 *i1, Tin2 *i2, state_idx *energy, infer_param *ip) { return 0; }};////////////////////////////////////////////////////////////////// generic architecturestemplate<class Tin, class Thid, class Tout> class layers_2: public module_1_1< Tin, Tout> {public: module_1_1<Tin, Thid> *layer1; Thid *hidden; module_1_1<Thid, Tout> *layer2; layers_2(module_1_1<Tin, Thid> *l1, Thid *h, module_1_1<Thid, Tout> *l2); virtual ~layers_2(); void fprop(Tin *in, Tout *out); void bprop(Tin *in, Tout *out); void bbprop(Tin *in, Tout *out); void forget(forget_param &fp); void normalize();};template<class T> class layers_n: public module_1_1<T, T> {public: std::vector< module_1_1<T, T>* > *modules; std::vector< T* > *hiddens; layers_n(); layers_n(bool oc); virtual ~layers_n(); void addModule(module_1_1 <T, T>* module, T* hidden); void fprop(T *in, T *out); void bprop(T *in, T *out); void bbprop(T *in, T *out); void forget(forget_param_linear &fp); void normalize();private: bool own_contents;};////////////////////////////////////////////////////////////////////! standard 1 input EBM with one module-1-1, and one ebm-1 on top.//! fc stands for "function+cost".template<class Tin, class Thid> class fc_ebm1: public ebm_1<Tin> {public: module_1_1<Tin, Thid> *fmod; Thid *fout; ebm_1<Thid> *fcost; fc_ebm1(module_1_1<Tin, Thid> *fm, Thid *fo, ebm_1<Thid> *fc); virtual ~fc_ebm1(); void fprop(Tin *in, state_idx *energy); void bprop(Tin *in, state_idx *energy); void bbprop(Tin *in, state_idx *energy); void forget(forget_param &fp);};//////////////////////////////////////////////////////////////////! standard 2 input EBM with one module-1-1, and one ebm-2 on top.//! fc stands for "function+cost".template<class Tin1, class Tin2, class Thid> class fc_ebm2: public ebm_2<Tin1, Tin2> {public: module_1_1<Tin1, Thid> *fmod; Thid *fout; ebm_2<Thid, Tin2> *fcost; fc_ebm2(module_1_1<Tin1, Thid> *fm, Thid *fo, ebm_2<Thid, Tin2> *fc); virtual ~fc_ebm2(); void fprop(Tin1 *in1, Tin2 *in2, state_idx *energy); void bprop(Tin1 *in1, Tin2 *in2, state_idx *energy); void bbprop(Tin1 *in1, Tin2 *in2, state_idx *energy); void forget(forget_param &fp);};////////////////////////////////////////////////////////////////// linear module// It's different from f_layer in that it is// not spatially replicable and does not operate// on 3D state_idx.class linear_module: public module_1_1<state_idx, state_idx> {public: state_idx *w; virtual ~linear_module(); linear_module(parameter *p, intg in0, intg out0); void fprop(state_idx *in, state_idx *out); void bprop(state_idx *in, state_idx *out); void bbprop(state_idx *in, state_idx *out); void forget(forget_param_linear &fp); void normalize();};//////////////////////////////////////////////////////////////////! a slab of standard Lush sigmoidsclass stdsigmoid_module: public module_1_1<state_idx, state_idx> {public: //! empty constructor stdsigmoid_module(); virtual ~stdsigmoid_module(); //! fprop from in to out virtual void fprop(state_idx *in, state_idx *out); //! bprop virtual void bprop(state_idx *in, state_idx *out); //! bbprop virtual void bbprop(state_idx *in, state_idx *out);};//////////////////////////////////////////////////////////////////! a slab of tanhclass tanh_module: public module_1_1<state_idx, state_idx> {public: //! fprop from in to out void fprop(state_idx *in, state_idx *out); //! bprop void bprop(state_idx *in, state_idx *out); //! bbprop void bbprop(state_idx *in, state_idx *out); void forget(forget_param_linear &fp); void normalize();};//////////////////////////////////////////////////////////////////! constant addclass addc_module: public module_1_1<state_idx, state_idx> {public: // coefficients state_idx* bias; addc_module(parameter *p, intg size); ~addc_module(); //! fprop from in to out void fprop(state_idx *in, state_idx *out); //! bprop void bprop(state_idx *in, state_idx *out); //! bbprop void bbprop(state_idx *in, state_idx *out); void forget(forget_param_linear &fp); void normalize();};//////////////////////////////////////////////////////////////////! a simple fully-connected neural net layer: linear + tanh non-linearity.//! Unlike the f-layer class, this one is not spatially replicable.class nn_layer_full: public module_1_1<state_idx, state_idx> {public: //! linear module for weight matrix linear_module *linear; //! bias vector state_idx *bias; //! weighted sum state_idx *sum; //! the non-linear function tanh_module *sigmoid; //! constructor. Arguments are a pointer to a parameter //! in which the trainable weights will be appended, //! the number of inputs, and the number of outputs. nn_layer_full(parameter *p, intg ninputs, intg noutputs); virtual ~nn_layer_full();
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -