📄 classifiers.cpp
字号:
//====// Classifiers.cpp// - implements a few methods defined in Classifiers.hpp// Notes// - this package is for experiment and demo purposes only,// because it operates on full feature-document matrices,// which can be very inefficient in production especially// when the matrices are sparse and of high dimensions,// in which case software engineering techniques such as// inverted index posting or CLUTO form may be considered// - this package is provided as is with no warranty// - the authors are not responsible for any damage caused// either directly or indirectly by using this package// - anybody is free to do whatever he/she wants with this// package as long as this header section is preserved// Created on 2005-07-18 by// - Caliope Sandiford// - Chitra Attaluri// - Naureen Nizam// - Roger Zhang (rogerz@cs.dal.ca)// Modifications// - Roger Zhang on 2005-07-25// - added rule compounding in the label function// - Roger Zhang on 2005-07-29// - added case based classifier functions// -// Last compiled under Linux with gcc 3.4//====#include "Classifiers.hpp"#include <fstream>#include <sstream>#include <cassert>#include <map>namespace CNR{ //==== // local helper functions static std::string numToStr(double x) // atof reversed :) { std::ostringstream oss; oss << x; return oss.str(); } static double distance(double *u, double *v, int n) // squared distance { assert(u && v && n > 0); double d = 0; while (n--) { double x = u[n] - v[n]; d += x * x; } return d; } //==== // rule basec classifier functions void RuleBasedClassifier::load(char *file) { rules.clear(); std::ifstream fin(file); if (!(fin >> dim)) { if (buf) { *buf = "Loading classifier failed\n"; } return; } for (Rule r; fin >> r.i >> r.t >> r.p >> r.c; rules.push_back(r)); fin.close(); if (rules.empty()) { if (buf) { *buf = "Loading classifier failed\n"; } dim = 0; return; } if (buf) { *buf = "Rule based classifier loaded\n"; *buf += " number of rules = " + numToStr(rules.size()) + "\n"; *buf += " vector dimension = " + numToStr(dim) + "\n"; } } void RuleBasedClassifier::save(char *file) const { std::ofstream fout(file, std::ios::out | std::ios::trunc); fout << dim << "\n"; for (int k = 0; k < rules.size(); k++) { fout << rules[k].i << " " << rules[k].t << " " << rules[k].p << " " << rules[k].c << "\n"; } fout.close(); if (buf) { *buf = std::string("Rule based classifier saved to ") + file + "\n"; } } void RuleBasedClassifier::train( double **items, int n, int d, int *labels, int size) { assert(items && labels && n > 0 && d > 0 && size > 0); dim = d; rules.clear(); // delete existing rules if any std::map<int /* class label */, int /* item count */> classSizes; std::map<int, std::pair<int /* DF */, double /* weight */>*> concepts; std::map<int, std::pair<int, double>*>::iterator iter; for (int j = 0; j < n; j++) { // build the concept table if (!classSizes[labels[j]]++) { // initialize an array of pairs concepts[labels[j]] = new std::pair<int, double>[d]; } for (int k = 0; k < d; k++) { if (items[j][k]) { // item j has feature k concepts[labels[j]][k].first++; concepts[labels[j]][k].second += items[j][k]; } } } for (int j = 0; j < d; j++) { // find the top 2 classes where feature j has largest DF int c1 = UNKNOWN_CLASS, c2 = UNKNOWN_CLASS; for (iter = concepts.begin(); iter != concepts.end(); iter++) { int count = (iter->second)[j].first; if (c1 == UNKNOWN_CLASS || count > concepts[c1][j].first) { c2 = c1; c1 = iter->first; } else if (c2 == UNKNOWN_CLASS || count > concepts[c2][j].first) { c2 = iter->first; } } Rule r; // create a new rule and add it to the database r.i = j, r.p = c1; r.t = concepts[c1][j].second / concepts[c1][j].first; // take average // r.c = (DF(t, c1) - DF(t, c2)) / DF(t, c1) * DF(t, c1) / size of c1 r.c = double(concepts[c1][j].first - concepts[c2][j].first) / classSizes[c1]; rules.push_back(r); } for (iter = concepts.begin(); iter != concepts.end(); iter++) { delete [] iter->second; // free up memory as soon as possible } std::sort(rules.begin(), rules.end()); // higher confidence rules first while (rules.size() > size) { rules.pop_back(); // if there are more rules than needed, remove them } if (buf) { *buf = "Rule based classifier trained\n"; *buf += " size of training corpus = " + numToStr(n) + "\n"; *buf += " dimension of data items = " + numToStr(d) + "\n"; *buf += " desired number of rules = " + numToStr(size) + "\n"; *buf += " number of rules created = " + numToStr(rules.size()) + "\n"; } } // RuleBasedClassifier::train int RuleBasedClassifier::label(double *item, int d, double *conf) const { assert(item && d > 0); if (dim != d) { if (buf) { if (!dim) { *buf = "Classifier not trained\n"; } else { *buf = "Target item incompatible\n"; *buf += " classifier dimension = " + numToStr(dim) + "\n"; *buf += " target item dimension = " + numToStr(d) + "\n"; } } if (conf && (*conf = 0)); return UNKNOWN_CLASS; } double confidence = .5; int prediction = UNKNOWN_CLASS; for (int k = 0; k < rules.size(); k++) { if (item[rules[k].i] >= rules[k].t) { // a rule is met if (prediction == UNKNOWN_CLASS) { // the first rule prediction = rules[k].p; confidence *= (1 + rules[k].c); if (buf) { *buf = "Target item passed rule test\n"; } } else { if (rules[k].p == prediction) { // a supporting rule confidence *= (1 + rules[k].c); } else if ((confidence *= (1 - rules[k].c)) < .5) { prediction = rules[k].p; // because of too many conflicts } if (buf) { *buf += " &&\n"; } } if (buf) { *buf += " feature index = " + numToStr(rules[k].i) + "\n"; *buf += " min threshold = " + numToStr(rules[k].t) + "\n"; } } } if (conf && (*conf = confidence > 1 ? 1 : confidence)); if (buf) { if (prediction == UNKNOWN_CLASS) { *buf = "target item failed all rule tests\n"; } else { *buf += " =>\n"; *buf += " prediction = " + numToStr(prediction) + "\n"; *buf += " confidence = " + numToStr(confidence) + "\n"; } } return prediction; } // RuleBasedClassifier::label //==== // case based classifier functions void CaseBasedClassifier::load(char *file) { clear(); std::ifstream fin(file); if (!(fin >> dim)) { // what else if (buf) { *buf = "Loading classifier failed\n"; } return; } for (Case c; fin >> c.c && (c.v = new double[dim]); cases.push_back(c)) { // note: memory leak will occur if the following fails for (int i = 0; i < dim; fin >> c.v[i++]); } fin.close(); if (cases.empty()) { // no records read if (buf) { *buf = "Loading classifier failed\n"; } dim = 0; return; } if (buf) { *buf = "Case based classifier loaded\n"; *buf += " number of classes = " + numToStr(cases.size()) + "\n"; *buf += " vector dimension = " + numToStr(dim) + "\n"; } } void CaseBasedClassifier::save(char *file) const { std::ofstream fout(file, std::ios::out | std::ios::trunc); fout << dim << "\n"; for (int i = 0; i < cases.size(); i++) { fout << cases[i].c; for (int j = 0; j < dim; fout << " " << cases[i].v[j++]); fout << "\n"; } fout.close(); if (buf) { *buf = std::string("Case based classifier saved to ") + file + "\n"; } } void CaseBasedClassifier::train(double **items, int n, int d, int *labels) { assert(items && labels && n > 0 && d > 0); clear(); dim = d; // (class label => (member count => position index)) as in perl syntax std::map<int, std::pair<int, int> > m; for (int i = 0; i < n; i++) { if (m[labels[i]].first++ == 0) { // first member of a new class Case c; c.c = labels[i]; c.v = new double[d]; for (int j = d; j-- > 0; c.v[j] = items[i][j]); m[c.c].second = cases.size(); cases.push_back(c); } else { // member of an existing class int j = m[labels[i]].second; for (int k = d; k-- > 0; cases[j].v[k] += items[i][k]); } } for (int i = cases.size() - 1; i >= 0; i--) { // calculate centroids for (int j = d; j-- > 0; cases[i].v[j] /= m[cases[i].c].first); } if (buf) { *buf = "Case based classifier trained\n"; *buf += " training corpus size = " + numToStr(n) + "\n"; *buf += " number of classes = " + numToStr(cases.size()) + "\n"; *buf += " vector dimension = " + numToStr(d) + "\n"; } } // CaseBasedClassifier::train int CaseBasedClassifier::label(double *item, int d) const { assert(item && d > 0); if (dim != d) { if (buf) { if (!dim) { *buf = "Classifier not trained\n"; } else { *buf = "Target item incompatible\n"; *buf += " classifier dimension = " + numToStr(dim) + "\n"; *buf += " target item dimension = " + numToStr(d) + "\n"; } } return UNKNOWN_CLASS; } double min; int pre = UNKNOWN_CLASS; if (buf) { *buf = "Searching for nearest centroid\n"; } for (int i = 0; i < cases.size(); i++) { double x = distance(item, cases[i].v, d); if (pre == UNKNOWN_CLASS || x < min) { min = x, pre = cases[i].c; if (buf) { *buf += " class " + numToStr(pre) + " has average "; *buf += "squared distance " + numToStr(min) + " to target\n"; } } } if (buf) { *buf += " Searching finished (prediction = " + numToStr(pre) + ")\n"; } return pre; } // CaseBasedClassifier::label} // namespace CNR
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -