glearner.h
来自「一个由Mike Gashler完成的机器学习方面的includes neural」· C头文件 代码 · 共 81 行
H
81 行
/* Copyright (C) 2006, Mike Gashler This library is free software; you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation; either version 2.1 of the License, or (at your option) any later version. see http://www.gnu.org/copyleft/lesser.html*/#ifndef __GLEARNER_H__#define __GLEARNER_H__class GArffRelation;class GArffData;class GSupervisedLearner{protected: GArffRelation* m_pRelation;public: GSupervisedLearner(GArffRelation* pRelation); virtual ~GSupervisedLearner(); // Returns the relation used to construct this learner GArffRelation* GetRelation() { return m_pRelation; } // Discard any training (but not any settings) so it can be trained again virtual void Reset() = 0; // Train with the provided data virtual void Train(GArffData* pData) = 0; // Evaluates the input values in the provided vector and // deduce the output values virtual void Eval(double* pVector) = 0; // Computes predictive accuracy (the ratio of samples that // are correctly classified to total samples). If there is // more than one output attribute, each output attribute // is evaluated independently. If there are continuous output // values, it uses 1-1/(1+(squared error)) as an estimate // so that a small squared error will be close to 1 (correct) // and a large squared error will be close to 0 (incorrect). double MeasurePredictiveAccuracy(GArffData* pData); // Computes the mean squared error. If there are multiple // output attributes, each one is considered independently. // If there are discreet output attributes, a correct // classification is considered to be a squared error of // 0 and an incorrect classification is a squared error of 1. double MeasureMeanSquaredError(GArffData* pData); // Perform n-fold cross validation on pData. If bRegression is true, // it will return the average mean squared error. If bRegression is // false, it will return the average predictive accuracy. double CrossValidate(GArffData* pData, int nFolds, bool bRegression);};// Always outputs the mean (for continuous values) and the most common// class (for enumerated values)class GBaselineLearner : public GSupervisedLearner{protected: double* m_pOutputs;public: GBaselineLearner(GArffRelation* pRelation); virtual ~GBaselineLearner(); virtual void Reset(); virtual void Train(GArffData* pData); virtual void Eval(double* pVector);};#endif // __GLEARNER_H__
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?