glearner.h

来自「一个由Mike Gashler完成的机器学习方面的includes neural」· C头文件 代码 · 共 81 行

H
81
字号
/*	Copyright (C) 2006, Mike Gashler	This library is free software; you can redistribute it and/or	modify it under the terms of the GNU Lesser General Public	License as published by the Free Software Foundation; either	version 2.1 of the License, or (at your option) any later version.	see http://www.gnu.org/copyleft/lesser.html*/#ifndef __GLEARNER_H__#define __GLEARNER_H__class GArffRelation;class GArffData;class GSupervisedLearner{protected:	GArffRelation* m_pRelation;public:	GSupervisedLearner(GArffRelation* pRelation);	virtual ~GSupervisedLearner();	// Returns the relation used to construct this learner	GArffRelation* GetRelation() { return m_pRelation; }	// Discard any training (but not any settings) so it can be trained again	virtual void Reset() = 0;	// Train with the provided data	virtual void Train(GArffData* pData) = 0;	// Evaluates the input values in the provided vector and	// deduce the output values	virtual void Eval(double* pVector) = 0;	// Computes predictive accuracy (the ratio of samples that	// are correctly classified to total samples). If there is	// more than one output attribute, each output attribute	// is evaluated independently. If there are continuous output	// values, it uses 1-1/(1+(squared error)) as an estimate	// so that a small squared error will be close to 1 (correct)	// and a large squared error will be close to 0 (incorrect).	double MeasurePredictiveAccuracy(GArffData* pData);	// Computes the mean squared error. If there are multiple	// output attributes, each one is considered independently.	// If there are discreet output attributes, a correct	// classification is considered to be a squared error of	// 0 and an incorrect classification is a squared error of 1.	double MeasureMeanSquaredError(GArffData* pData);	// Perform n-fold cross validation on pData. If bRegression is true,	// it will return the average mean squared error. If bRegression is	// false, it will return the average predictive accuracy.	double CrossValidate(GArffData* pData, int nFolds, bool bRegression);};// Always outputs the mean (for continuous values) and the most common// class (for enumerated values)class GBaselineLearner : public GSupervisedLearner{protected:	double* m_pOutputs;public:	GBaselineLearner(GArffRelation* pRelation);	virtual ~GBaselineLearner();	virtual void Reset();	virtual void Train(GArffData* pData);	virtual void Eval(double* pVector);};#endif // __GLEARNER_H__

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?