📄 rulelearner.cpp
字号:
while (nRules < width && ri != re) {
if (!inRules(wFilteredRules,*ri)) {
wFilteredRules->push_back(*ri);
nRules++;
}
ri++;
}
rules = wFilteredRules;
}
}
inline void _selectBestRule(PRule &rule, PRule &bestRule, int &wins, TRandomGenerator &rgen)
{
if ((rule->quality > bestRule->quality) || (rule->complexity < bestRule->complexity)) {
bestRule = rule;
wins = 1;
}
else if ((rule->complexity == bestRule->complexity) && rgen.randbool(++wins))
bestRule = rule;
}
PRuleList TRuleBeamInitializer_Default::operator()(PExampleTable data, const int &weightID, const int &targetClass, PRuleList baseRules, PRuleEvaluator evaluator, PDistribution apriori, PRule &bestRule)
{
checkProperty(evaluator);
TRuleList *ruleList = mlnew TRuleList();
PRuleList wruleList = ruleList;
TRandomGenerator rgen(data->numberOfExamples());
int wins;
if (baseRules && baseRules->size())
PITERATE(TRuleList, ri, baseRules) {
TRule *newRule = mlnew TRule((*ri).getReference(), true);
PRule wNewRule = newRule;
ruleList->push_back(wNewRule);
if (!newRule->examples)
newRule->filterAndStore(data,weightID,targetClass);
newRule->quality = evaluator->call(wNewRule, data, weightID, targetClass, apriori);
if (!bestRule || (newRule->quality > bestRule->quality)) {
bestRule = wNewRule;
wins = 1;
}
else
if (newRule->quality == bestRule->quality)
_selectBestRule(wNewRule, bestRule, wins, rgen);
}
else {
TRule *ubestRule = mlnew TRule();
bestRule = ubestRule;
ruleList->push_back(bestRule);
ubestRule->filter = new TFilter_values();
ubestRule->filter->domain = data->domain;
ubestRule->filterAndStore(data, weightID,targetClass);
ubestRule->complexity = 0;
}
return wruleList;
}
PRuleList TRuleBeamRefiner_Selector::operator()(PRule wrule, PExampleTable data, const int &weightID, const int &targetClass)
{
if (!discretization) {
discretization = mlnew TEntropyDiscretization();
dynamic_cast<TEntropyDiscretization *>(discretization.getUnwrappedPtr())->forceAttribute = true;
}
TRule &rule = wrule.getReference();
TFilter_values *filter = wrule->filter.AS(TFilter_values);
if (!filter)
raiseError("a filter of type 'Filter_values' expected");
TRuleList *ruleList = mlnew TRuleList;
PRuleList wRuleList = ruleList;
TDomainDistributions ddist(wrule->examples, wrule->weightID);
const TVarList &attributes = rule.examples->domain->attributes.getReference();
vector<bool> used(attributes.size(), false);
PITERATE(TValueFilterList, vfi, filter->conditions)
used[(*vfi)->position] = true;
vector<bool>::const_iterator ui(used.begin());
TDomainDistributions::const_iterator di(ddist.begin());
TVarList::const_iterator vi(attributes.begin()), ve(attributes.end());
int pos = 0;
for(; vi != ve; vi++, ui++, pos++, di++) {
if ((*vi)->varType == TValue::INTVAR) {
if (!*ui) {
vector<float>::const_iterator idi((*di).AS(TDiscDistribution)->begin());
for(int v = 0, e = (*vi)->noOfValues(); v != e; v++)
if (*idi>0) {
TRule *newRule = mlnew TRule(rule, false);
ruleList->push_back(newRule);
newRule->complexity++;
filter = newRule->filter.AS(TFilter_values);
TValueFilter_discrete *newCondition = mlnew TValueFilter_discrete(pos, *vi, 0);
filter->conditions->push_back(newCondition);
TValue value = TValue(v);
newCondition->values->push_back(value);
newRule->filterAndStore(rule.examples, rule.weightID,targetClass);
newRule->parentRule = wrule;
}
}
}
else if (((*vi)->varType == TValue::FLOATVAR)) {
if (discretization) {
PVariable discretized;
try {
discretized = discretization->call(rule.examples, *vi, weightID);
} catch(...) {
continue;
}
TClassifierFromVar *cfv = discretized->getValueFrom.AS(TClassifierFromVar);
TDiscretizer *discretizer = cfv ? cfv->transformer.AS(TDiscretizer) : NULL;
if (!discretizer)
raiseError("invalid or unrecognized discretizer");
vector<float> cutoffs;
discretizer->getCutoffs(cutoffs);
if (cutoffs.size()) {
TRule *newRule;
newRule = mlnew TRule(rule, false);
PRule wnewRule = newRule;
newRule->complexity++;
newRule->parentRule = wrule;
newRule->filter.AS(TFilter_values)->conditions->push_back(mlnew TValueFilter_continuous(pos, TValueFilter_continuous::LessEqual, cutoffs.front(), 0, 0));
newRule->filterAndStore(rule.examples, rule.weightID,targetClass);
if (wrule->classDistribution->cases > wnewRule->classDistribution->cases)
ruleList->push_back(newRule);
for(vector<float>::const_iterator ci(cutoffs.begin()), ce(cutoffs.end()-1); ci != ce; ci++) {
newRule = mlnew TRule(rule, false);
wnewRule = newRule;
newRule->complexity++;
newRule->parentRule = wrule;
filter = newRule->filter.AS(TFilter_values);
filter->conditions->push_back(mlnew TValueFilter_continuous(pos, TValueFilter_continuous::Greater, *ci, 0, 0));
newRule->filterAndStore(rule.examples, rule.weightID,targetClass);
if (wrule->classDistribution->cases > wnewRule->classDistribution->cases)
ruleList->push_back(newRule);
newRule = mlnew TRule(rule, false);
wnewRule = newRule;
newRule->complexity++;
newRule->parentRule = wrule;
filter = newRule->filter.AS(TFilter_values);
filter->conditions->push_back(mlnew TValueFilter_continuous(pos, TValueFilter_continuous::LessEqual, *(ci+1), 0, 0));
newRule->filterAndStore(rule.examples, rule.weightID,targetClass);
if (wrule->classDistribution->cases > wnewRule->classDistribution->cases)
ruleList->push_back(newRule);
}
newRule = mlnew TRule(rule, false);
ruleList->push_back(newRule);
newRule->complexity++;
newRule->filter.AS(TFilter_values)->conditions->push_back(mlnew TValueFilter_continuous(pos, TValueFilter_continuous::Greater, cutoffs.back(), 0, 0));
newRule->filterAndStore(rule.examples, rule.weightID,targetClass);
newRule->parentRule = wrule;
}
}
else
raiseWarning("discretizer not given, continuous attributes will be skipped");
}
}
if (!discretization)
discretization = PDiscretization();
return wRuleList;
}
PRuleList TRuleBeamCandidateSelector_TakeAll::operator()(PRuleList &existingRules, PExampleTable, const int &)
{
PRuleList candidates = mlnew TRuleList(existingRules.getReference());
// existingRules->clear();
existingRules->erase(existingRules->begin(), existingRules->end());
return candidates;
}
PRule TRuleBeamFinder::operator()(PExampleTable data, const int &weightID, const int &targetClass, PRuleList baseRules)
{
// set default values if value not set
bool tempInitializer = !initializer;
if (tempInitializer)
initializer = mlnew TRuleBeamInitializer_Default;
bool tempCandidateSelector = !candidateSelector;
if (tempCandidateSelector)
candidateSelector = mlnew TRuleBeamCandidateSelector_TakeAll;
bool tempRefiner = !refiner;
if (tempRefiner)
refiner = mlnew TRuleBeamRefiner_Selector;
/* bool tempValidator = !validator;
if (tempValidator)
validator = mlnew TRuleValidator_LRS((float)0.01);
bool tempRuleStoppingValidator = !ruleStoppingValidator;
if (tempRuleStoppingValidator)
ruleStoppingValidator = mlnew TRuleValidator_LRS((float)0.05); */
bool tempEvaluator = !evaluator;
if (tempEvaluator)
evaluator = mlnew TRuleEvaluator_Entropy;
bool tempRuleFilter = !ruleFilter;
if (tempRuleFilter)
ruleFilter = mlnew TRuleBeamFilter_Width;
checkProperty(initializer);
checkProperty(candidateSelector);
checkProperty(refiner);
checkProperty(evaluator);
checkProperty(ruleFilter);
PDistribution apriori = getClassDistribution(data, weightID);
TRandomGenerator rgen(data->numberOfExamples());
int wins = 1;
PRule bestRule;
PRuleList ruleList = initializer->call(data, weightID, targetClass, baseRules, evaluator, apriori, bestRule);
{
PITERATE(TRuleList, ri, ruleList) {
if (!(*ri)->examples)
(*ri)->filterAndStore(data, weightID,targetClass);
if ((*ri)->quality == ILLEGAL_FLOAT)
(*ri)->quality = evaluator->call(*ri, data, weightID, targetClass, apriori);
}
}
if (!bestRule->examples)
bestRule->filterAndStore(data, weightID,targetClass);
if (bestRule->quality == ILLEGAL_FLOAT)
bestRule->quality = evaluator->call(bestRule, data, weightID, targetClass, apriori);
int bestRuleLength = 0;
while(ruleList->size()) {
PRuleList candidateRules = candidateSelector->call(ruleList, data, weightID);
PITERATE(TRuleList, ri, candidateRules) {
PRuleList newRules = refiner->call(*ri, data, weightID, targetClass);
PITERATE(TRuleList, ni, newRules) {
(*ni)->quality = evaluator->call(*ni, data, weightID, targetClass, apriori);
if (!ruleStoppingValidator || ruleStoppingValidator->call(*ni, (*ri)->examples, weightID, targetClass, (*ri)->classDistribution)) {
ruleList->push_back(*ni);
if ((*ni)->quality >= bestRule->quality && (!validator || validator->call(*ni, data, weightID, targetClass, apriori)))
_selectBestRule(*ni, bestRule, wins, rgen);
}
}
}
ruleFilter->call(ruleList,data,weightID);
}
// set empty values if value was not set (used default)
if (tempInitializer)
initializer = PRuleBeamInitializer();
if (tempCandidateSelector)
candidateSelector = PRuleBeamCandidateSelector();
if (tempRefiner)
refiner = PRuleBeamRefiner();
/* if (tempValidator)
validator = PRuleValidator();
if (tempRuleStoppingValidator)
ruleStoppingValidator = PRuleValidator(); */
if (tempEvaluator)
evaluator = PRuleEvaluator();
if (tempRuleFilter)
ruleFilter = PRuleBeamFilter();
return bestRule;
}
TRuleLearner::TRuleLearner(bool se, int tc, PRuleList rl)
: storeExamples(se),
targetClass(tc),
baseRules(rl)
{}
PClassifier TRuleLearner::operator()(PExampleGenerator gen, const int &weightID)
{
return this->call(gen,weightID,targetClass,baseRules);
}
PClassifier TRuleLearner::operator()(PExampleGenerator gen, const int &weightID, const int &targetClass, PRuleList baseRules)
{
// Initialize default values if values not set
bool tempDataStopping = !dataStopping && !ruleStopping;
if (tempDataStopping)
dataStopping = mlnew TRuleDataStoppingCriteria_NoPositives;
bool tempRuleFinder = !ruleFinder;
if (tempRuleFinder)
ruleFinder = mlnew TRuleBeamFinder;
bool tempCoverAndRemove = !coverAndRemove;
if (tempCoverAndRemove)
coverAndRemove = mlnew TRuleCovererAndRemover_Default;
checkProperty(ruleFinder);
checkProperty(coverAndRemove);
TExampleTable *data = mlnew TExampleTable(gen);
PExampleTable wdata = data;
if (!dataStopping && !ruleStopping)
raiseError("no stopping criteria; set 'dataStopping' and/or 'ruleStopping'");
TRuleList *ruleList = mlnew TRuleList;
PRuleList wruleList = ruleList;
int currWeightID = weightID;
float beginwe=0.0, currentwe;
if (progressCallback) {
if (targetClass==-1)
beginwe = wdata->weightOfExamples(weightID);
else {
PDistribution classDist = getClassDistribution(wdata, weightID);
TDiscDistribution *ddist = classDist.AS(TDiscDistribution);
beginwe = ddist->atint(targetClass);
}
progressCallback->call(0.0);
}
while (!dataStopping || !dataStopping->call(wdata, currWeightID, targetClass)) {
PRule rule = ruleFinder->call(wdata, currWeightID, targetClass, baseRules);
if (!rule)
raiseError("'ruleFinder' didn't return a rule");
if (ruleStopping && ruleStopping->call(ruleList, rule, wdata, currWeightID))
break;
wdata = coverAndRemove->call(rule, wdata, currWeightID, currWeightID, targetClass);
ruleList->push_back(rule);
if (progressCallback) {
if (targetClass==-1)
currentwe = wdata->weightOfExamples(weightID);
else {
PDistribution classDist = getClassDistribution(wdata, currWeightID);
TDiscDistribution *ddist = classDist.AS(TDiscDistribution);
currentwe = ddist->atint(targetClass);
}
progressCallback->call(1-currentwe/beginwe);
}
}
if (progressCallback)
progressCallback->call(1.0);
// Restore values
if (tempDataStopping)
dataStopping = PRuleDataStoppingCriteria();
if (tempRuleFinder)
ruleFinder = PRuleFinder();
if (tempCoverAndRemove)
coverAndRemove = PRuleCovererAndRemover();
PRuleClassifierConstructor clConstructor =
classifierConstructor ? classifierConstructor :
PRuleClassifierConstructor(mlnew TRuleClassifierConstructor_firstRule());
return clConstructor->call(ruleList, gen, weightID);
};
bool TRuleDataStoppingCriteria_NoPositives::operator()(PExampleTable data, const int &weightID, const int &targetClass) const
{
PDistribution classDist = getClassDistribution(data, weightID);
TDiscDistribution *ddist = classDist.AS(TDiscDistribution);
return (targetClass >= 0 ? ddist->atint(targetClass) : ddist->abs) == 0.0;
}
bool TRuleStoppingCriteria_NegativeDistribution::operator()(PRuleList ruleList, PRule rule, PExampleTable data, const int &weightID) const
{
if (rule && rule->classifier)
{
PDistribution aprioriDist = getClassDistribution(data, weightID);
TDiscDistribution *apriori = aprioriDist.AS(TDiscDistribution);
const TDefaultClassifier *clsf = rule->classifier.AS(TDefaultClassifier);
if (!clsf)
return false;
const TDiscDistribution *dist = dynamic_cast<const TDiscDistribution *>(clsf->defaultDistribution.getUnwrappedPtr());
const int classVal = clsf->defaultVal.intV;
if (classVal<0 || classVal>=dist->size())
return false;
float acc = dist->atint(clsf->defaultVal.intV)/dist->abs;
float accApriori = apriori->atint(clsf->defaultVal.intV)/apriori->abs;
if (accApriori>acc)
return true;
}
return false;
}
PExampleTable TRuleCovererAndRemover_Default::operator()(PRule rule, PExampleTable data, const int &weightID, int &newWeight, const int &targetClass) const
{
TExampleTable *table = mlnew TExampleTable(data, 1);
PExampleGenerator wtable = table;
TFilter &filter = rule->filter.getReference();
if (targetClass < 0)
{
PEITERATE(ei, data)
if (!filter(*ei))
table->addExample(*ei);
}
else
PEITERATE(ei, data)
if (!filter(*ei) || (*ei).getClass().intV!=targetClass)
table->addExample(*ei);
newWeight = weightID;
return wtable;
}
// classifiers
PRuleClassifier TRuleClassifierConstructor_firstRule::operator ()(PRuleList rules, PExampleTable table, const int &weightID)
{
return mlnew TRuleClassifier_firstRule(rules, table, weightID);
}
TRuleClassifier::TRuleClassifier(PRuleList arules, PExampleTable anexamples, const int &aweightID)
: rules(arules),
examples(anexamples),
weightID(aweightID)
{}
TRuleClassifier::TRuleClassifier()
{}
TRuleClassifier_firstRule::TRuleClassifier_firstRule(PRuleList arules, PExampleTable anexamples, const int &aweightID)
: TRuleClassifier(arules, anexamples, aweightID)
{
prior = getClassDistribution(examples, weightID);
}
TRuleClassifier_firstRule::TRuleClassifier_firstRule()
: TRuleClassifier()
{}
PDistribution TRuleClassifier_firstRule::classDistribution(const TExample &ex)
{
checkProperty(rules);
checkProperty(prior);
PITERATE(TRuleList, ri, rules) {
if ((*ri)->call(ex))
return (*ri)->classDistribution;
}
return prior;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -