📄 lookup.cpp
字号:
const TVarList &attributes = ogen->domain->attributes.getReference();
const int nattrs = attributes.size();
PVariable classVar = ogen->domain->classVar;
bool alreadyWarned = false;
// we shall use ClassifierByLookupTable if the number of attributes
// is <= 3 and the are all discrete
if (allowFastLookups && (nattrs <= 3)) {
TVarList::const_iterator vi(attributes.begin()), ve(attributes.end());
for(; (vi!=ve) && ((*vi)->varType == TValue::INTVAR); vi++);
if (vi==ve) {
if (!nattrs) {
PDistribution classDist = getClassDistribution(ogen, weightID);
return mlnew TDefaultClassifier(classVar, classDist->highestProbValue(), classDist);
}
else if (nattrs == 1) {
TClassifierByLookupTable1 *cblt = mlnew TClassifierByLookupTable1(classVar, attributes[0]);
PClassifier wcblt = cblt;
TDiscDistribution valDist(attributes[0]);
TDiscDistribution unkDist(attributes[0]);
PEITERATE(ei, ogen) {
if ((*ei).getClass().isSpecial())
UNKNOWN_CLASS_WARNING
else {
const TValue val = (*ei)[0];
const float weight = WEIGHT(*ei);
if (val.isSpecial()) {
if (unknownsHandling)
unkDist.addint((*ei)[1], weight);
}
else {
cblt->distributions->at(val.intV)->addint((*ei)[1], weight);
valDist.addint(val.intV, weight);
}
}
}
if (unkDist.abs && valDist.abs) {
TDistributionList::iterator dli(cblt->distributions->begin());
TDiscDistribution::const_iterator vdi(valDist.begin()), vde(valDist.end());
for(; vdi!=vde; (dynamic_cast<TDiscDistribution &>((*dli++).getReference())).adddist(unkDist, *vdi++));
}
cblt->replaceDKs(valDist);
cblt->valuesFromDistributions();
return wcblt;
}
else {
TClassifierByLookupTable *cblt =
nattrs == 2 ? (TClassifierByLookupTable *)mlnew TClassifierByLookupTable2(classVar, attributes[0], attributes[1])
: (TClassifierByLookupTable *)mlnew TClassifierByLookupTable3(classVar, attributes[0], attributes[1], attributes[2]);
PClassifier wcblt = cblt;
TExampleIterator ei(ogen->begin());
for(; ei; ++ei) {
if ((*ei).getClass().isSpecial())
UNKNOWN_CLASS_WARNING
else {
const int idx = cblt->getIndex(*ei);
if (idx<0) {
raiseWarning("unknown attribute values detected: constructing ClassifierByExampleTable instead of LookupClassifier");
break;
}
cblt->distributions->at(idx)->addint((*ei)[nattrs].intV, WEIGHT(*ei));
}
}
if (!ei) { // have we finished prematurely due to unknown values?
if (nattrs==2)
dynamic_cast<TClassifierByLookupTable2 *>(cblt)->replaceDKs(ogen);
else
dynamic_cast<TClassifierByLookupTable3 *>(cblt)->replaceDKs(ogen);
cblt->valuesFromDistributions();
return wcblt;
}
// else fallthrough
}
}
}
PExampleGenerator gen = fixedExamples(ogen);
TExampleTable examplePtrs(gen, false);
examplePtrs.sort();
TExampleTable unknowns(gen->domain);
TEFMDataDescription *efmdata = mlnew TEFMDataDescription(gen->domain, mlnew TDomainDistributions(gen), weightID, getMetaID());
PEFMDataDescription wefmdata = efmdata;
TClassifierByExampleTable *classifier = mlnew TClassifierByExampleTable(examplePtrs.domain);
PClassifier wclassifier = PClassifier(classifier);
classifier->dataDescription = wefmdata;
TFilter_hasSpecial hasSpecial;
for (TExampleIterator bi(examplePtrs.begin()), bbi(bi); bi; bi = bbi) {
PDistribution classDist = TDistribution::create(examplePtrs.domain->classVar);
TDistribution &tcv = classDist.getReference();
if ((*bbi).getClass().isSpecial()) {
UNKNOWN_CLASS_WARNING
continue;
}
int diff;
do {
tcv.add((*bbi).getClass(), WEIGHT2(*bbi, weightID));
if (!++bbi)
break;
TExample::iterator bii((*bi).begin()), bbii((*bbi).begin());
for(diff = nattrs; diff && (*(bii++)==*(bbii++)); diff--);
} while (!diff);
bool hasUnknowns = hasSpecial(*bi);
if (classDist->abs == 0.0 || hasUnknowns && !unknownsHandling)
continue;
TExample ex = *bi;
ex.setClass(classVar->DK());
ex.getClass().svalV = classDist;
if (hasUnknowns) {
if (unknownsHandling == UnknownsDistribute) {
unknowns.addExample(ex);
dynamic_cast<TDistribution &>(ex.getClass().svalV.getReference()) *= efmdata->getExampleWeight(ex);
continue;
}
else
classifier->containsUnknowns = true;
}
classifier->sortedExamples->addExample(ex);
}
if (unknowns.size()) {
const int missWeight = getMetaID();
efmdata = mlnew TEFMDataDescription(gen->domain, mlnew TDomainDistributions(gen), weightID, missWeight);
wefmdata = efmdata;
TExampleTable additionalExamples(gen->domain);
EITERATE(ui, unknowns) {
TExampleForMissing imputedExample(*ui, wefmdata);
imputedExample.resetExample();
do {
additionalExamples.addExample(imputedExample);
TExample &justAdded = additionalExamples.back();
dynamic_cast<TDistribution &>(justAdded.getClass().svalV.getReference()) *= imputedExample.getMeta(missWeight).floatV;
justAdded.removeMeta(missWeight);
}
while (imputedExample.nextExample());
}
PExampleGenerator wadde = PExampleGenerator(additionalExamples);
TExampleTable sortedAdd(wadde, false);
sortedAdd.sort();
PExampleGenerator oldSortedExamples = classifier->sortedExamples;
TExampleTable *sortedExamples = mlnew TExampleTable(gen->domain);
classifier->sortedExamples = sortedExamples;
for(TExampleIterator osi(oldSortedExamples->begin()), nsi(sortedAdd.begin()); osi && nsi; ) {
int cmp = (*osi).compare(*nsi);
if (cmp <= 0) {
sortedExamples->addExample(*osi);
++osi;
}
else {
TExample *lastAdded = sortedExamples->size() ? &sortedExamples->back() : NULL;
if (lastAdded && !(*nsi).compare(*lastAdded))
dynamic_cast<TDistribution &>(lastAdded->getClass().svalV.getReference()) += dynamic_cast<TDistribution &>((*nsi).getClass().svalV.getReference());
else
sortedExamples->addExample(*nsi);
++nsi;
}
}
}
if (learnerForUnknown)
classifier->classifierForUnknown = learnerForUnknown->operator()(ogen, weightID);
return wclassifier;
}
TClassifierByExampleTable::TClassifierByExampleTable(PDomain dom)
: TClassifierFD(dom),
sortedExamples(mlnew TExampleTable(dom))
{}
TClassifierByExampleTable::TClassifierByExampleTable(PExampleGenerator gen, PClassifier unk)
: TClassifierFD(gen->domain),
sortedExamples(mlnew TExampleTable(gen)),
containsUnknowns(false),
classifierForUnknown(unk)
{
TFilter_hasSpecial hasSpecial;
for(TExampleIterator ei(sortedExamples->begin()); ei && !containsUnknowns; containsUnknowns = hasSpecial(*ei), ++ei);
}
PDistribution TClassifierByExampleTable::classDistributionLow(const TExample &exam)
{
TExample convertedEx(domain, exam);
if (containsUnknowns || TFilter_hasSpecial()(convertedEx)) {
bool weightUnknowns = dataDescription && (classVar->varType == TValue::INTVAR);
TDistribution *distsum = TDistribution::create(classVar);
PDistribution res = distsum;
PEITERATE(ei, sortedExamples)
if (convertedEx.compatible(*ei)) {
TValue &classVal = (*ei).getClass();
TDistribution *dist = classVal.svalV.AS(TDistribution);
if (dist)
if (weightUnknowns)
((TDiscDistribution *)(distsum))->adddist(*dist, dataDescription->getExampleWeight(*ei));
else
*distsum += *dist;
else if (!classVal.isSpecial())
distsum->addint(classVal.intV);
}
if (distsum->abs) {
distsum->normalize();
return res;
}
else
return PDistribution();
}
int L = 0, H = sortedExamples->size();
while(L<H) {
const int M = (L+H)/2;
int cmp = convertedEx.compare(sortedExamples->at(M));
if (cmp > 0)
L = M+1;
else if (cmp > 0)
H = M;
else {
TValue &classVal = sortedExamples->at(M).getClass();
TDistribution *dist = classVal.svalV.AS(TDistribution);
if (dist)
return CLONE(TDistribution, dist);
else {
dist = TDistribution::create(classVar);
dist->add(classVal);
return dist;
}
}
}
return PDistribution();
}
TValue TClassifierByExampleTable::operator()(const TExample &exam)
{
PDistribution probs = classDistributionLow(exam);
if (probs)
return probs->highestProbValue(exam);
else
return classifierForUnknown ? classifierForUnknown->operator()(exam) : domain->classVar->DK();
}
PDistribution TClassifierByExampleTable::classDistribution(const TExample &exam)
{
PDistribution dval = classDistributionLow(exam);
if (dval) {
PDistribution dd = CLONE(TDistribution, dval);
dval->normalize();
return dval;
}
if (classifierForUnknown)
return classifierForUnknown->classDistribution(exam);
dval = TDistribution::create(domain->classVar);
dval->normalize();
return PDistribution();
}
void TClassifierByExampleTable::predictionAndDistribution(const TExample &exam, TValue &pred, PDistribution &dist)
{
PDistribution dval = classDistributionLow(exam);
if (dval) {
pred = dval->highestProbValue(exam);
dist = CLONE(TDistribution, dval);
dist->normalize();
}
else if (classifierForUnknown)
classifierForUnknown->predictionAndDistribution(exam, pred, dist);
else {
pred = domain->classVar->DK();
dval = TDistribution::create(domain->classVar);
dval->normalize();
}
}
void TClassifierByExampleTable::afterSet(const char *name)
{
if (!strcmp(name, "sortedExamples")) {
domain = sortedExamples->domain;
classVar = sortedExamples->domain->classVar;
}
TClassifierFD::afterSet(name);
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -