📄 ann1dn.cpp
字号:
if (vldrec.entries.size()) {
validate(&vldrec, TH, &acur, &pacur);
wprintf(L" \n validation set: %d %d\n sensitivity: %.2f\n specificity: %.2f\n +predictive: %.2f\n -predictive: %.2f\n accuracy: %.2f\n", vldrec.indices[0].size(), vldrec.indices[1].size(), pacur.se, pacur.sp, pacur.pp, pacur.np, pacur.ac);
}
if (tstrec.entries.size()) {
validate(&tstrec, TH, &acur, &pacur);
wprintf(L" \n test set: %d %d\n sensitivity: %.2f\n specificity: %.2f\n +predictive: %.2f\n -predictive: %.2f\n accuracy: %.2f\n", tstrec.indices[0].size(), tstrec.indices[1].size(), pacur.se, pacur.sp, pacur.pp, pacur.np, pacur.ac);
}
} else
wprintf(L"failed to load maxacur.nn for classification\n");
ann = new ANNetwork(argv[2]); //validate(...) uses *ann network
if (!ann->status()) { //classification results for trained network
wprintf(L"\nclassification results: %s\n", argv[2]);
validate(&trnrec, TH, &acur, &pacur);
wprintf(L" \n train set: %d %d\n sensitivity: %.2f\n specificity: %.2f\n +predictive: %.2f\n -predictive: %.2f\n accuracy: %.2f\n", trnrec.indices[0].size(), trnrec.indices[1].size(), pacur.se, pacur.sp, pacur.pp, pacur.np, pacur.ac);
if (vldrec.entries.size()) {
validate(&vldrec, TH, &acur, &pacur);
wprintf(L" \n validation set: %d %d\n sensitivity: %.2f\n specificity: %.2f\n +predictive: %.2f\n -predictive: %.2f\n accuracy: %.2f\n", vldrec.indices[0].size(), vldrec.indices[1].size(), pacur.se, pacur.sp, pacur.pp, pacur.np, pacur.ac);
}
if (tstrec.entries.size()) {
validate(&tstrec, TH, &acur, &pacur);
wprintf(L" \n test set: %d %d\n sensitivity: %.2f\n specificity: %.2f\n +predictive: %.2f\n -predictive: %.2f\n accuracy: %.2f\n", tstrec.indices[0].size(), tstrec.indices[1].size(), pacur.se, pacur.sp, pacur.pp, pacur.np, pacur.ac);
}
} else
wprintf(L"failed to load %s for classification\n", argv[2]);
}
////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////
// for 2 classes only for N classes in SOM project
void set_validation(PREC vld, PREC trn, float p)
{
int c1 = int((p / 100.0f) * (float)trn->indices[0].size());
int c2 = int((p / 100.0f) * (float)trn->indices[1].size());
wprintf(L" validaton size: %d %d\n", c1, c2);
if (c1 < 1 || c2 < 1) {
wprintf(L" validaton is not set, one of the vld class of 0 lenght\n");
return;
}
vld->entries.resize(c1 + c2);
vld->clsnum.push_back(trn->clsnum[0]);
vld->clsnum.push_back(trn->clsnum[1]);
//random shuffle indeces and take first c1,c2 ones//////////////////////
random_shuffle(trn->indices[0].begin(), trn->indices[0].end());
random_shuffle(trn->indices[1].begin(), trn->indices[1].end());
////////////////////////////////////////////////////////////////////////
//class1////////////////////////////////////////////
vector<int> indices;
indices.resize(c1);
vld->indices.push_back(indices);
//get random % from trn set
for (int i = 0; i < c1; i++) {
int ind = trn->indices[0].at(i);
vld->indices[0].at(i) = i;
vld->entries[i] = trn->entries[ ind ];
trn->entries[ ind ] = 0;
}
trn->indices[0].erase(trn->indices[0].begin(), trn->indices[0].begin() + c1);
//class2////////////////////////////////////////////
indices.resize(c2);
vld->indices.push_back(indices);
//get random % from trn set
for (int i = 0; i < c2; i++) {
int ind = trn->indices[1].at(i);
vld->indices[1].at(i) = i + c1;
vld->entries[i+c1] = trn->entries[ ind ];
trn->entries[ ind ] = 0;
}
trn->indices[1].erase(trn->indices[1].begin(), trn->indices[1].begin() + c2);
}
////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////
/*
acur - metric results
pacur - classification results Se,Sp,...
*/
void validate(PREC rec, float TH, float *acur, PACUR pacur)
{
float mse = 0.0f, se = 0.0f, sp = 0.0f, pp = 0.0f, np = 0.0f, ac = 0.0f, b;
float *ivec;
float ovec[1] = {0.0f};
int TP = 0, FN = 0, TN = 0, FP = 0;
int size = (int)rec->entries.size();
////////////////////TESTING////////////////////////////////////////////////////////////////
for (int i = 0; i < (int)rec->entries.size(); i++) {
//in train set might be 0 entries after set_validation()
if (rec->entries[i] == 0) {
size--;
continue;
}
ivec = rec->entries[i]->vec;
ann->classify(ivec, ovec);
int clstype = (ovec[0] > TH) ? 1 : 2;
int vcls = rec->entries[i]->cls;
if (vcls) { //if 1 or 2, 0 if no class info
//mse
if (vcls == 1)
mse += (0.9f - ovec[0]) * (0.9f - ovec[0]);
else if (vcls == 2)
mse += (0.1f - ovec[0]) * (0.1f - ovec[0]);
//se,sp,...
if (vcls == clstype) {
if (clstype == 1)
TP++;
else if (clstype == 2)
TN++;
} else { /////error//////////
if (clstype == 2 && vcls == 1) //ill defined as healthy
FN++;
else if (clstype == 1 && vcls == 2) //healthy defined as ill
FP++;
}
}
}
mse /= (float)size; // - 0 marked classes nums
///////////////////////////////////////////////////////////////////////////////////////////
if (TP)
se = float(TP) / float(TP + FN);
if (TN)
sp = float(TN) / float(TN + FP);
if (TP)
pp = float(TP) / float(TP + FP);
if (TN)
np = float(TN) / float(TN + FN);
if (TP || FP || FN || TN)
ac = float(TP + TN) / float(TP + FN + TN + FP);
pacur->se = 100.0f * se;
pacur->sp = 100.0f * sp;
pacur->pp = 100.0f * pp;
pacur->np = 100.0f * np;
pacur->ac = 100.0f * ac;
switch (validation_type) {
default:
case 0:
*acur = 1.0f / mse; //mse
break;
case 1:
*acur = ac; //acur
break;
case 2: //geometric mean se,sp
*acur = gmean(se * sp, 2);
break;
case 3: //geometric mean se,pp
*acur = gmean(se * pp, 2);
break;
case 4: //geometric mean se,sp,ac
*acur = gmean(se * sp * ac, 3);
break;
case 5: //geometric mean se,sp,pp,np,ac
*acur = gmean(se * sp * pp * np * ac, 5);
break;
case 6: //F-measure b=1
b = 1.0f;
if (pp && se)
*acur = ((b * b + 1) * se * pp) / (b * b * pp + se);
else
*acur = 0;
break;
case 7: //F-measure b=1.5
b = 1.5f;
if (pp && se)
*acur = ((b * b + 1) * se * pp) / (b * b * pp + se);
else
*acur = 0;
break;
case 8: //F-measure b=3
b = 3.0f;
if (pp && se)
*acur = ((b * b + 1) * se * pp) / (b * b * pp + se);
else
*acur = 0;
break;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////
float gmean(float m, int n)
{
return pow(m, 1.0f / (float)n);
}
////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////
void dump_sets(PREC trn, PREC vld, PREC tst)
{
wchar_t name[_MAX_PATH] = L"";
wchar_t dir[_MAX_PATH] = L"";
FILE* fp = _wfopen(L"dbgsets.txt", L"wt");
if (trn != 0) {
size_t s = 0;
for (size_t i = 0; i < trn->entries.size(); i++) {
if (trn->entries[i] != 0) s++;
}
fwprintf(fp, L"TRAINING SET: %d\n", s);
if (trn->entries.size() < 1000) {
for (size_t i = 0; i < trn->entries.size(); i++) {
if (trn->entries[i] != 0) //in train set might be 0 entries after setvld()
fwprintf(fp, L"%s %d\n", trn->entries[i]->fname, trn->entries[i]->cls);
}
}
}
if (vld != 0) {
fwprintf(fp, L"\n\nVALIDATION SET: %d\n", vld->entries.size());
if (vld->entries.size() < 1000) {
for (size_t i = 0; i < vld->entries.size(); i++)
fwprintf(fp, L"%s %d\n", vld->entries[i]->fname, vld->entries[i]->cls);
}
}
if (tst != 0) {
fwprintf(fp, L"\n\nTEST SET: %d\n", tst->entries.size());
if (tst->entries.size() < 1000) {
for (size_t i = 0; i < tst->entries.size(); i++)
fwprintf(fp, L"%s %d\n", tst->entries[i]->fname, tst->entries[i]->cls);
}
}
fclose(fp);
}
////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////
///////////////////////////TESTING//////////////////////////////////////////////////////////////
void test(int argc, wchar_t* argv[])
{
REC tstrec;
////parse optional arguments 4,5////////////////////////////////////
float TH = 0.5f;
if (argc >= 4 + 1)
TH = (float)_wtof(argv[4]);
if (argc >= 5 + 1)
normalization = _wtoi(argv[5]);
////////////////////////////////////////////////////////////////////
wprintf(L"loading data...\n");
FILE *cls1 = _wfopen(argv[3], L"rt");
if (!cls1) {
wprintf(L"failed to open file: %s\n", argv[3]);
exit(1);
} else
read_class(cls1, &tstrec);
if (!tstrec.entries.size()) {
wprintf(L" no files loaded from: %s.\n", argv[3]);
exit(1);
} else
wprintf(L" %d files loaded. size: %d samples\n", tstrec.entries.size(), tstrec.entries[0]->size);
ann = new ANNetwork(argv[2]);
if (ann->status()) {
wprintf(L"failed to load network: %s\n", argv[2]);
exit(1);
}
if (ann->get_layer(0)->get_neurons_number() != vector_length) {
if (ann->get_layer(0)->get_neurons_number() > vector_length) {
wprintf(L" input layer neurons %d are more than data dimension %d", ann->get_layer(0)->get_neurons_number(), vector_length);
exit(1);
} else
wprintf(L" input layer neurons %d are less than data dimension %d\n", ann->get_layer(0)->get_neurons_number(), vector_length);
}
wprintf(L"%s\n", argv[2]);
wchar_t name[_MAX_PATH] = L"";
wchar_t dir[_MAX_PATH] = L"";
float *ivec;
float ovec[1] = {0.0f};
int TP = 0, FN = 0, TN = 0, FP = 0;
////////////////////TESTING////////////////////////////////////////////////////////////////
for (int i = 0; i < (int)tstrec.entries.size(); i++) {
ivec = tstrec.entries[i]->vec;
ann->classify(ivec, ovec);
if (parse_path(tstrec.entries[i]->fname, dir, name))
wprintf(L" %s\n", dir);
int clstype = (ovec[0] > TH) ? 1 : 2 ;
wprintf(L" %s %f cls %d ", name, ovec[0], clstype);
if (tstrec.entries[i]->cls) { //if 1 or 2, 0 if no class info
if (tstrec.entries[i]->cls == clstype) {
if (clstype == 1)
TP++;
else if (clstype == 2)
TN++;
wprintf(L"+\n");
} else { /////error//////////
wprintf(L"-\n");
if (clstype == 2 && tstrec.entries[i]->cls == 1) //ill defined as healthy
FN++;
else if (clstype == 1 && tstrec.entries[i]->cls == 2) //healthy defined as ill
FP++;
}
} else
wprintf(L"\n");
}
///////////////////////////////////////////////////////////////////////////////////////////
if (TP)
wprintf(L" sensitivity: %.2f\n", 100.0f * float(TP) / float(TP + FN));
if (TN)
wprintf(L" specificity: %.2f\n", 100.0f * float(TN) / float(TN + FP));
if (TP)
wprintf(L" +predictive: %.2f\n", 100.0f * float(TP) / float(TP + FP));
if (TN)
wprintf(L" -predictive: %.2f\n", 100.0f * float(TN) / float(TN + FN));
if (TP || FP || FN || TN)
wprintf(L" accuracy: %.2f\n", 100.0f * float(TP + TN) / float(TP + FN + TN + FP));
}
////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////get normalization params////////////////////////////////////////////////////////////
void set_normalization(REC *rec, ANNetwork *pann)
{
int N = int(rec->entries.size());
int I = -1; //first nonzero entry
for (int i = 0; i < (int)rec->entries.size(); i++) {
if (rec->entries[i] == 0)
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -