⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 artmap.h

📁 ARTMAP source code. include: Fuzzy ARTMAP, Default ARTMAP, Instance-Counting ARTMAP, Distributed ART
💻 H
字号:
/**
 * \file
 * artmap.h
 *
 * Copyleft (C) - Boston University CELEST Technology Laboratory
 */

#ifndef ARTMAP_H
#define ARTMAP_H

#include <iostream>
#include <vector>
#include <fstream>
#include <valarray>

#include "Logger.h"
#include "MsgException.h"
#include "util.h"

using std::ifstream;
using std::vector;
using std::string;
using std::istringstream;

#ifndef ARTMAP_DLL
#define ARTMAP_DECLSPEC
#else
#ifdef ARTMAP_IMPORT
#define ARTMAP_DECLSPEC __declspec(dllimport)
#else
#define ARTMAP_DECLSPEC __declspec(dllexport)
#endif
#endif

/**
  Artmap class - Implements the Distributed ARTMAP model.
	         Depending on the NetworkType setting, can emulate
					 fuzzy ARTMAP, Default ARTMAP, or the instance
					 counting and distributed varieties. A flowchart 
					 of the training process is shown below:
					 \image html TrainingFlowchart.png "ARTMAP Training Flowchart"

 */

class ARTMAP_DECLSPEC artmap {
 public:
	 /** The available ARTMAP models (See main page for differences between models). */
  typedef enum RunModeType { FUZZY, DEFAULT, IC, DISTRIB };

 private:
  RunModeType NetworkType;  ///< Controls the algorithm used

	int     M;          ///< Number of inputs (before complement-coding) 
  int     L;	        ///< Number of output classes ([1-L], not [0-(L-1)]) 
  
  float   RhoBar;     ///< Baseline vigilance - training
  float   RhoBarTest; ///< Baseline vigilance - testing
  float   Alpha;      ///< Signal rule parameter
  float   Beta;       ///< Learning rate
  float   Eps;        ///< Match tracking parameter
  float   P;          ///< CAM rule power parameter
  
  int     C;          ///< Number of committed nodes 
  
  int     J;          ///< In WTA mode, index of the winning node 
  int     K;          ///< The target class (1-L, not 0-(L-1)) 

  float   rho;        ///< Current vigilance 

  /** Index ranges - i: 1-M, j: 1-C; k: 1-L */
  float * A;          ///< Indexed by i - Complement-coded input 
  float * x;          ///< Indexed by i - F1, matching 
  float * y;          ///< Indexed by j - F2, coding 
  float * Y;          ///< Indexed by j - F3, counting 
  float * T;          ///< Indexed by j - Total F0->F2 
  float * S;          ///< Indexed by j - Phasic F0->F2 
  float * H;          ///< Indexed by j - Tonic F0->F2 (Capital Theta) 
  float * c;          ///< Indexed by j - F2->F3 
  bool  * lambda;     ///< Indexed by j - T if node is eligible, F otherwise 
  float * sigma_i;    ///< Indexed by i - F3->F1 
  float * sigma_k;    ///< Indexed by k - F3->F0ab 
  int   * kap;        ///< Indexed by j - F3->Fab (small kappa) 
	float * dKap;				///< Distributed version of kap 
  float * tIj;        ///< Indexed by i&j - F0->F2 (tau sub ij) 
  float * tJi;        ///< Indexed by j&i - F3->F1 (tau sub ji) 

	bool    dMapWeights;  ///< if true, use dKap, else use kap 

  /* Utility variables (not in algorithm) */
  float  Tu;          ///< Uncommitted node activation 
  float  sum_x;       ///< To avoid recomputing norm 
  int    _2M;         ///< To keep from repeatedly calculating 2*M 
  int    N;           ///< Growable upper bound on coding nodes 
  int    i, j, k;     ///< Indices i, j and k, so we don't have to declare 'em everywhere 

  void complementCode(float  *a);
  int  F0_to_F2_signal();
  void newNode();
  void CAM_distrib();
  void CAM_WTA();
  void F1signal_WTA();
  void F1signal_distrib();
  bool passesVigilance();
  int  prediction_distrib();
  int  prediction_WTA();
  void matchTracking();
  void creditAssignment();
  void resonance_distrib();
  void resonance_WTA();
  void growF2 (float  factor);

	/** 
	 This cost function takes the input signal \f$T_j\f$ to an F2 node,
	 and rescales the metric so that nodes that match the training/test 
	 sample being evaluated well have low cost. It reaches a minimum of zero 
	 when the argument \f$T_j\f$ is equal to \f$(2-\alpha)M\f$, which 
	 corresponds to the training/test sample falling within a point category box.
	 @param x The input signal \f$T_j\f$ to a category node.
	 @return A measure of the 'cost' of the category node with respect to a particular training/test sample.
	 */
  float cost(float x) { return ((2-Alpha)*M - x); }
	ofstream *ostCategoryActivations;
  void toStr();
  void toStr_dimensions();
  void toStr_A();
  void toStr_nodeJTSH(int j);
  void toStr_nodeJdetails(int j);
  void toStr_nodeJtauIj(int j);
  void toStr_nodeJtauJi(int j);
  void toStr_x();
  void toStr_sigma_i();
  void toStr_sigma_k();
		
 public:
				 artmap    (int M, int L);
		    ~artmap();
  void   train     (float  *a, int K);
  void   test      (float  *a);
	/** Returns the k-th output (distributed prediction).
	    @param k The index of the output to retrieve 
			@return The predicted likelihood that the input is of class k
			        If no choice is appropriate, then all output values are 1.0
	 */
  float  getOutput (int k) { return sigma_k[k]; }
	/** 
   Returns the index of the largest output prediction, which
	 in a winner-take-all situation (fuzzy ARTMAP) is the predicted class.
	 @return The index of the predicted class, or -1 if there's a tie.
	 */
	int    getMaxOutputIndex () { 
		std::valarray<float> outs = std::valarray<float> (sigma_k, L);
		return getIndexOfMaxElt (outs);
	}

  void   fwrite    (ofstream &ofs);
  void   fread     (ifstream &ifs, string &specialRequest);

  void setParam (const string &name, const string &value);
	/** Returns the number of category nodes (aka templates learned by the network) */
  int         getC()            { return C;           }
	/** Returns the output class associated with a category node with the given index */
	int         getNodeClass (int j) { if ((j < 0) || (j > C) || dMapWeights) { return -1; } else { return kap[j]; } }
	/** Returns the number of bytes required to store the weights for the network */
  int         getLtmRequired () { return C * M * 2 * sizeof (float ); }
  float  &tauIj (int i, int j);
  float  &tauJi (int i, int j);
  int     getOutputType (const string &name);
  int     getInt        (const string &name);
  float   getFloat      (const string &name);
  string &getString     (const string &name);

	void    requestOutput (const string &name, ofstream *ost);
	void    closeStreams  ();

  void setNetworkType (RunModeType v) { NetworkType = v; } ///< Accessor method
  void setM           (int         v) { M	    = v; }       ///< Accessor method
  void setL           (int         v) { L	    = v; }       ///< Accessor method 
  void setRhoBar      (float       v) { RhoBar      = v; } ///< Accessor method
  void setRhoBarTest  (float       v) { RhoBarTest  = v; } ///< Accessor method
  void setAlpha       (float       v) { Alpha       = v; } ///< Accessor method
  void setBeta        (float       v) { Beta        = v; } ///< Accessor method
  void setEps         (float       v) { Eps         = v; } ///< Accessor method
  void setP           (float       v) { P           = v; } ///< Accessor method

  RunModeType getNetworkType() { return NetworkType; }     ///< Accessor method
  int 	      getM()           { return M;           }     ///< Accessor method
  int 	      getL()           { return L;           }     ///< Accessor method
  float       getRhoBar()      { return RhoBar;      }     ///< Accessor method
  float       getRhoBarTest()  { return RhoBarTest;  }     ///< Accessor method
  float       getAlpha()       { return Alpha;       }     ///< Accessor method
  float       getBeta()        { return Beta;        }     ///< Accessor method
  float       getEps()         { return Eps;         }     ///< Accessor method
  float       getP()           { return P;           }     ///< Accessor method
};

/* Used to iterate the 'official' indices */
#define foreach_i for (i = 0; i < _2M; i++)
#define foreach_j for (j = 0; j <   C; j++) 
#define foreach_k for (k = 0; k <   L; k++)

#define forall_j  for (j = 0; j < N; j++)

#endif


⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -