📄 ann1dn.cpp
字号:
// ann1d.cpp : Defines the entry point for the console application.
//
#include "stdafx.h"
#include "trace.h"
#include "lib\signal.h"
#include "libnn\network.h"
#include "libnn\neuron.h"
typedef struct _acur {
float se;
float sp;
float pp;
float np;
float ac;
} ACUR, *PACUR;
enum NORMALIZATION {NONE, MINMAX, ZSCORE, SIGMOIDAL, ENERGY};
wchar_t normalization_type[][20] = {L"minmax", L"zscore", L"sigmoidal", L"energy"};
int normalization = 0;
int validation_type = 0;
int vector_length = 0; //size of the first vector read from read_class() trn,vld or tst set
class ANNetwork *ann = 0;
vector<CSignal *> signals; //file mapping classes of signals
//if used energy or minmax for a vector, filemapping is overwritten with new data
void read_class(FILE *fp, PREC rec, int c = 0); //closes fp handle
int read_line(FILE *f, wchar_t *buff, int *c = 0);
void get_file_name(wchar_t *path, wchar_t *name);
int parse_path(wchar_t *path, wchar_t *dir, wchar_t *name);
void msec_to_time(int msec, int& h, int& m, int& s, int& ms);
void train(int argc, wchar_t* argv[]);
void validate(PREC rec, float TH, float *acur, PACUR pacur);
float gmean(float m, int n); //geometric mean pow(m,1/n);
void set_validation(PREC vld, PREC trn, float p);
void dump_sets(PREC trn, PREC vld, PREC tst);
void test(int argc, wchar_t* argv[]);
void set_normalization(REC *rec, ANNetwork *pann); //save normalization params to ANN input layer
/*
//training mode
1 t //train
2 cnf.nn //network conf file
3 cls1.txt //files for class1 [0.9]
4 cls2.txt //files for class2 [0.1]
5 epochs //epochs num
6 [val.txt] //validation set
7 [tst.txt] //test set
8 [val TH] //validation threshold decision
9 [val type] //validation type (mse/ac/...)
10 [norm] //normilize input data [0-none],1-minmax,2-zscore,3-softmax,4-energy,5-minimaxscale
11 [error] //default 0.05
//test mode
1 r //run
2 cnf.nn //network file
3 cls.txt //files to test
4 [TH 0.5] //threshold optional only for 4-energy norm
5 [norm] //normilize input data [0-none],1-minmax,2-zscore,3-softmax,4-energy
*/
int _tmain(int argc, wchar_t* argv[])
{
srand((unsigned int)time(0));
if (argc == 1) {
wprintf(L"\n argv[1] t-train\n");
wprintf(L" argv[2] network conf file\n");
wprintf(L" argv[3] cls1 files [0.9]\n");
wprintf(L" argv[4] cls2 files [0.1]\n");
wprintf(L" argv[5] epochs num\n");
wprintf(L" argv[6] [validation class]\n");
wprintf(L" argv[7] [test class]\n");
wprintf(L" argv[8] [validation TH 0.5]\n");
wprintf(L" argv[9] [vld metric mse]\n");
wprintf(L" argv[10] [norm]: [0-no], 1-minmax, 2-zscore, 3-softmax, 4-energy\n");
wprintf(L" argv[11] [error tolerance cls] +- 0.05 default\n\n");
wprintf(L" argv[1] r-run\n");
wprintf(L" argv[2] network conf file\n");
wprintf(L" argv[3] cls files\n");
wprintf(L" argv[4] [validation TH 0.5]\n");
wprintf(L" argv[5] [norm]: [0-no], 1-minmax, 2-zscore, 3-softmax, 4-energy\n\n");
wprintf(L" ann1dn.exe t net.nn cls1 cls2 3000 [tst.txt][val.txt][TH [0.5]][val type [mse]] [norm [0]] [err [0.05]] \n");
wprintf(L" ann1dn.exe r net.nn testcls [TH [0.5]] [norm [0]]\n\n");
wprintf(L" metrics: [0 - mse]\n");
wprintf(L" 1 - AC\n");
wprintf(L" 2 - sqrt(SE*SP)\n");
wprintf(L" 3 - sqrt(SE*PP)\n");
wprintf(L" 4 - sqrt(SE*SP*AC)\n");
wprintf(L" 5 - sqrt(SE*SP*PP*NP*AC)\n");
wprintf(L" 6 - F-measure b=1\n");
wprintf(L" 7 - F-measure b=1.5\n");
wprintf(L" 8 - F-measure b=3\n");
} else if (!wcscmp(argv[1], L"t"))
train(argc, argv);
else if (!wcscmp(argv[1], L"r"))
test(argc, argv);
else
wprintf(L"argv[1] t-train, r-run\n");
return 0;
}
///////////////////////////TRAINING/////////////////////////////////////////////////////////////
void train(int argc, wchar_t* argv[])
{
//REC trncopy; //for overall classification accuracy
REC trnrec; //class 1,2 records
REC vldrec; //validation records
REC tstrec; //test records
bool vld = false; //is there validation file?
bool tst = false; //is there test file
////parse optional arguments 6,7,8,9////////////////////////////////
float TH = 0.5f;
float error = 0.05f;
if (argc >= 6 + 1) {
if (wcslen(argv[6]) > 1) { // 6,7,8,9 test,validation class; TH; validation_type
if (argc >= 10 + 1)
normalization = _wtoi(argv[10]);
if (argc >= 11 + 1)
error = float(_wtof(argv[11]));
TH = float(_wtof(argv[8]));
validation_type = _wtoi(argv[9]);
//check validation set
FILE *val = _wfopen(argv[6], L"rt");
if (val) {
read_class(val, &vldrec);
if (vldrec.entries.size())
wprintf(L" validation size: %d files, TH = %.2f\n", vldrec.entries.size(), TH);
else
vld = true;
} else {
wprintf(L" failed to open %s\n", argv[6]);
exit(1);
}
//check test set
FILE *test = _wfopen(argv[7], L"rt");
if (test) {
read_class(test, &tstrec);
if (tstrec.entries.size())
wprintf(L" test size: %d files\n", tstrec.entries.size());
else
tst = true;
} else {
wprintf(L" failed to open %s\n", argv[7]);
exit(1);
}
} else {
normalization = _wtoi(argv[6]);
if (argc >= 7 + 1)
error = float(_wtof(argv[7]));
}
}
////////////////////////////////////////////////////////////////////
wprintf(L"loading data...\n");
FILE *cls1 = _wfopen(argv[3], L"rt");
FILE *cls2 = _wfopen(argv[4], L"rt");
if (!cls1 || !cls2) {
wprintf(L"failed to open files %s %s\n", argv[3], argv[4]);
exit(1);
} else {
read_class(cls1, &trnrec, 1); //by default put 1 class mark
read_class(cls2, &trnrec, 2); //by default put 2 class mark
}
if (!trnrec.entries.size()) {
wprintf(L"no files loaded to training set.\n");
exit(1);
} else if (trnrec.clsnum.size() != 2) {
wprintf(L"%d classes loaded. works only for 2 classes.\n", trnrec.clsnum.size());
exit(1);
} else
wprintf(L" cls%d: %d cls%d: %d files loaded. size: %d samples\n", trnrec.clsnum[0], trnrec.indices[0].size(), trnrec.clsnum[1], trnrec.indices[1].size(), trnrec.entries[0]->size);
//arrange 25% from train set to validation/test sets
if (vld && tst) {
set_validation(&vldrec, &trnrec, 25.0f);
set_validation(&tstrec, &trnrec, 35.0f);
} else if (vld && !tst)
set_validation(&vldrec, &trnrec, 50.0f);
else if (!vld && tst)
set_validation(&tstrec, &trnrec, 50.0f);
dump_sets(&trnrec, &vldrec, &tstrec);
//load network
ann = new ANNetwork(argv[2]);
if (ann->status() < 0) {
wprintf(L"failed to load network: %s", argv[2]);
exit(1);
}
if (ann->get_layer(0)->get_neurons_number() != vector_length) {
if (ann->get_layer(0)->get_neurons_number() > vector_length) {
wprintf(L" input layer neurons %d are more than data dimension %d", ann->get_layer(0)->get_neurons_number(), vector_length);
exit(1);
} else
wprintf(L" input layer neurons %d are less than data dimension %d\n", ann->get_layer(0)->get_neurons_number(), vector_length);
}
if (normalization && normalization != 4) { //energy normalization per vector
wprintf(L"normalizing %s...\n", normalization_type[normalization-1]);
set_normalization(&trnrec, ann); //get normalization params add,mult to ANN from training set
}
int msecs = GetTickCount();
wprintf(L"training...\n");
float dvec[1] = {0.0f};
float *ivec;
float ovec[1] = {0.0f};
float ovec1[1] = {0.0f}, ovec2[1] = {0.0f};
bool prv = false;
int x = 0, y = 0, ii = 0;
int quit = 0;
float acur = 0.0f, tmpacur;
ACUR pacur, tmppacur;
memset(&pacur, 0, sizeof(ACUR));
memset(&tmppacur, 0, sizeof(ACUR));
CTRACE trc(L"macurtrace.txt");
////////////////////TRAINING////////////////////////////////////////////////////////////////
int step = (trnrec.indices[0].size() > trnrec.indices[1].size()) ? 2 * (int)trnrec.indices[0].size() : 2 * (int)trnrec.indices[1].size();
int EPOCHS = _wtoi(argv[5]);
int e = EPOCHS * step;
int maxepoch = 0;
while (e) {
if (x > 1) {
x = 0;
ii++;
}
if (x == 0) //1st class
y = ii % (int)trnrec.indices[x].size();
else if (x == 1) //2nd class
y = ii % (int)trnrec.indices[x].size();
int ind = trnrec.indices[x].at(y);
ivec = trnrec.entries[ind]->vec;
int cls = trnrec.entries[ind]->cls;
if (cls == 1) dvec[0] = 0.9f;
else if (cls == 2) dvec[0] = 0.1f;
ann->train(ivec, ovec, dvec, error);
if (cls == 1) ovec1[0] += ovec[0];
else if (cls == 2) ovec2[0] += ovec[0];
x++;
e--;
////////////////////////////////////////////////////////////////////////////////////////////////
if (!(e % step)) { //one epochs is expired
float mout1 = ovec1[0] / (float(step) / 2.0f); //mean out1
float mout2 = ovec2[0] / (float(step) / 2.0f); //mean out2
ovec1[0] = 0.0f;
ovec2[0] = 0.0f;
if (quit == 10) //no more error
break;
if (fabsl(mout1 - 0.9f) > error || fabsl(mout2 - 0.1f) > error)
quit = 0;
else
quit++;
//ann->save(L"temp.nn");
wprintf(L" epoch: %d out: %f %f ", EPOCHS - (e / step), mout1, mout2);
////validate/////////////////////////////////////////////////////////////////////////////////////////////
if (vldrec.entries.size() && (mout1 > TH && mout2 < TH)) {
validate(&vldrec, TH, &tmpacur, &tmppacur);
trc.dump(tmpacur);
if (tmpacur >= acur) {
maxepoch = EPOCHS - (e / step);
acur = tmpacur;
memcpy(&pacur, &tmppacur, sizeof(ACUR));
if (!ann->save(L"maxacur.nn"))
wprintf(L" failed to save maxacur.nn ");
}
wprintf(L" max acur: %.2f (epoch %d) se:%.2f sp:%.2f ac:%.2f\n", acur, maxepoch, pacur.se, pacur.sp, pacur.ac);
} else
wprintf(L"\n");
//////////////////////////////////////////////////////////////////////////////////////////////////////////
for (int i = 0; i < (int)trnrec.indices.size(); i++) //shuffle indices to entries array
random_shuffle(trnrec.indices[i].begin(), trnrec.indices[i].end());
if (kbhit() && _getwch() == 'q') //quit program ?
e = 0;
}//one epoch is expired/////////////////////////////////////////////////////////
}
////////while(epochs)/////////////////////////////////////////////////////////////////////////////////
if (e)
wprintf(L"training done.\n");
int hour, min, sec, msec;
msec_to_time(GetTickCount() - msecs, hour, min, sec, msec);
wprintf(L"training time: %02d:%02d:%02d:%03d\n", hour, min, sec, msec);
if (!ann->save(argv[2]))
wprintf(L"failed to save %s\n", argv[2]);
//testing on maxacur and trained network/////////////////////////////////////////////////////////////////////////////////////////////
ann = new ANNetwork(L"maxacur.nn"); //validate(...) uses *ann network
if (!ann->status()) { //classification results for maxacur.nn network
wprintf(L"\nclassification results: maxacur.nn\n");
validate(&trnrec, TH, &acur, &pacur);
wprintf(L" \n train set: %d %d\n sensitivity: %.2f\n specificity: %.2f\n +predictive: %.2f\n -predictive: %.2f\n accuracy: %.2f\n", trnrec.indices[0].size(), trnrec.indices[1].size(), pacur.se, pacur.sp, pacur.pp, pacur.np, pacur.ac);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -