📄 classifiers.hpp
字号:
//====// Classifiers.hpp// - defines a few popular vector-space model classifiers// - Classifier (abstract base class)// - RuleBasedClassifier// - CaseBasedClassifier// Notes// - this package is for experiment and demo purposes only,// because it operates on full feature-document matrices,// which can be very inefficient in production especially// when the matrices are sparse and of high dimensions,// in which case software engineering techniques such as// inverted index posting or CLUTO form may be considered// - this package is provided as is with no warranty// - the authors are not responsible for any damage caused// either directly or indirectly by using this package// - anybody is free to do whatever he/she wants with this// package as long as this header section is preserved// Created on 2005-07-18 by// - Caliope Sandiford// - Chitra Attaluri// - Naureen Nizam// - Roger Zhang (rogerz@cs.dal.ca)// Modifications// - Roger Zhang on 2005-07-29// - added storage for classifier dimension// - added case based classifier definition// -// Last compiled under Linux with gcc 3.4//====#ifndef _VSM_CLASSIFIERS_#define _VSM_CLASSIFIERS_#include <string>#include <vector>namespace CNR{ class Classifier // abstract base { //==== // data members protected: int dim; // vector dimension (maintained by the load and train routines) mutable std::string *buf; // system memo buffer (recording is optional) public: static const int UNKNOWN_CLASS = -1; // change the number if conflicting //==== // construction and destruction // use the "fast" switch to turn off memo recording Classifier(bool fast) : dim(0), buf(fast ? NULL : new std::string("")) {} virtual ~Classifier() { if (buf) delete buf; } //==== // accessors bool trained() const { return dim > 0; } int dimension() const { return dim; } const char *memo() const // for operation info or error logging { return buf ? buf->c_str() : "Classifier memo turned off"; } //==== // a child class must implement the following interface virtual void load(char *file) = 0; // load a previously saved classifier virtual void save(char *file) const = 0; // store a trained classifier virtual void train(double **items, int n, int d, int *labels) = 0; virtual int label(double *item, int d) const = 0; // classify an item }; // Classifier //============================================================ class RuleBasedClassifier : public Classifier { class Rule // internal representation of a rule { public: // for example, if the 3rd feature has a weight greater than .768 // then with 95% confidence this item should belong to category 2 int i; // feature index = 2 (starting from 0) double t; // threshold = .768 int p; // prediction = 2 double c; // confidence = .95 bool operator <(const Rule &r) const // for sorting support { return c > r.c || c == r.c && t > r.t; } }; std::vector<Rule> rules; public: explicit RuleBasedClassifier(bool fast = false) : Classifier(fast) {} int size() const { return rules.size(); } // total number of rules void load(char *file); void save(char *file) const; void train(double **items, int n, int d, int *labels) { train(items, n, d, labels, 100 /* arbitrary number */); } // the additional size parameter below provides a desired number // of rules for the training routine to consider, although the // actual number generated may not be the same given that for // certain training corpora it is simply not possible/realistic void train(double **items, int n, int d, int *labels, int size); int label(double *item, int d) const { return label(item, d, NULL); } // the additional conf parameter below provides a return address // for the confidence level of a classification when interested int label(double *item, int d, double *conf) const; }; // RuleBasedClassifier //============================================================ class CaseBasedClassifier : public Classifier { struct Case { int c; // class label double *v; // centroid (arithmetic mean of all class c vectors) }; std::vector<Case> cases; void clear() { if (dim) { for (int i = cases.size(); i-- > 0; cases.pop_back()) { delete [] cases[i].v; } } } public: explicit CaseBasedClassifier(bool fast = false) : Classifier(fast) {} ~CaseBasedClassifier() { clear(); } void load(char *file); void save(char *file) const; void train(double **items, int n, int d, int *labels); int label(double *item, int d) const; };} // namespace CNR#endif // _VSM_CLASSIFIERS_
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -