⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 ann1dn.cpp

📁 该程序是在vc环境下编写的bp神经网络c++类库
💻 CPP
📖 第 1 页 / 共 3 页
字号:

// ann1d.cpp : Defines the entry point for the console application.
//


#include "stdafx.h"
#include "trace.h"

#include "lib\signal.h"
#include "libnn\network.h"
#include "libnn\neuron.h"



typedef struct _acur {
        float se;
        float sp;
        float pp;
        float np;
        float ac;
} ACUR, *PACUR;


enum NORMALIZATION {NONE, MINMAX, ZSCORE, SIGMOIDAL, ENERGY};
wchar_t normalization_type[][20] = {L"minmax", L"zscore", L"sigmoidal", L"energy"};


int normalization = 0;        
int validation_type = 0;
int vector_length = 0;     //size of the first vector read from read_class() trn,vld or tst set

class ANNetwork *ann = 0; 


vector<CSignal *> signals;     //file mapping classes of signals
//if used energy or minmax for a vector, filemapping is overwritten with new data





void read_class(FILE *fp, PREC rec, int c = 0);   //closes fp handle
int read_line(FILE *f, wchar_t *buff, int *c = 0);
void get_file_name(wchar_t *path, wchar_t *name);
int parse_path(wchar_t *path, wchar_t *dir, wchar_t *name);
void msec_to_time(int msec, int& h, int& m, int& s, int& ms);


void train(int argc, wchar_t* argv[]);
void validate(PREC rec, float TH, float *acur, PACUR pacur);
float gmean(float m, int n);    //geometric mean pow(m,1/n);
void set_validation(PREC vld, PREC trn, float p);
void dump_sets(PREC trn, PREC vld, PREC tst);
void test(int argc, wchar_t* argv[]);
void set_normalization(REC *rec, ANNetwork *pann);             //save normalization params to ANN input layer




/*
  //training mode
   1 t          //train
   2 cnf.nn     //network conf file
   3 cls1.txt   //files for class1  [0.9]
   4 cls2.txt   //files for class2  [0.1]
   5 epochs     //epochs num
    6 [val.txt]  //validation set
    7 [tst.txt]  //test set
    8 [val TH]   //validation threshold decision
    9 [val type] //validation type (mse/ac/...)
   10 [norm]     //normilize input data [0-none],1-minmax,2-zscore,3-softmax,4-energy,5-minimaxscale
   11 [error]    //default 0.05

  //test mode
   1 r          //run
   2 cnf.nn     //network file
   3 cls.txt    //files to test
   4 [TH 0.5]   //threshold optional only for 4-energy norm
   5 [norm]     //normilize input data [0-none],1-minmax,2-zscore,3-softmax,4-energy

                                                                                              */
int _tmain(int argc, wchar_t* argv[])
{
        srand((unsigned int)time(0));

        if (argc == 1) {
                wprintf(L"\n argv[1] t-train\n");
                wprintf(L" argv[2] network conf file\n");
                wprintf(L" argv[3] cls1 files [0.9]\n");
                wprintf(L" argv[4] cls2 files [0.1]\n");
                wprintf(L" argv[5] epochs num\n");
                wprintf(L"  argv[6] [validation class]\n");
                wprintf(L"  argv[7] [test class]\n");
                wprintf(L"  argv[8] [validation TH 0.5]\n");
                wprintf(L"  argv[9] [vld metric mse]\n");
                wprintf(L" argv[10] [norm]: [0-no], 1-minmax, 2-zscore, 3-softmax, 4-energy\n");
                wprintf(L" argv[11] [error tolerance cls] +- 0.05 default\n\n");

                wprintf(L" argv[1] r-run\n");
                wprintf(L" argv[2] network conf file\n");
                wprintf(L" argv[3] cls files\n");
                wprintf(L" argv[4] [validation TH 0.5]\n");
                wprintf(L" argv[5] [norm]: [0-no], 1-minmax, 2-zscore, 3-softmax, 4-energy\n\n");

                wprintf(L"  ann1dn.exe t net.nn cls1 cls2 3000 [tst.txt][val.txt][TH [0.5]][val type [mse]] [norm [0]] [err [0.05]] \n");
                wprintf(L"  ann1dn.exe r net.nn testcls [TH [0.5]] [norm [0]]\n\n");

                wprintf(L" metrics: [0 - mse]\n");
                wprintf(L"           1 - AC\n");
                wprintf(L"           2 - sqrt(SE*SP)\n");
                wprintf(L"           3 - sqrt(SE*PP)\n");
                wprintf(L"           4 - sqrt(SE*SP*AC)\n");
                wprintf(L"           5 - sqrt(SE*SP*PP*NP*AC)\n");
                wprintf(L"           6 - F-measure b=1\n");
                wprintf(L"           7 - F-measure b=1.5\n");
                wprintf(L"           8 - F-measure b=3\n");
        } else if (!wcscmp(argv[1], L"t"))
                train(argc, argv);
        else if (!wcscmp(argv[1], L"r"))
                test(argc, argv);
        else
                wprintf(L"argv[1] t-train, r-run\n");

        return 0;
}




///////////////////////////TRAINING/////////////////////////////////////////////////////////////
void train(int argc, wchar_t* argv[])
{
        //REC trncopy;          //for overall classification accuracy

        REC trnrec;           //class 1,2 records
        REC vldrec;           //validation records
        REC tstrec;           //test records
        bool vld = false;     //is there validation file?
        bool tst = false;     //is there test file


////parse optional arguments 6,7,8,9////////////////////////////////
        float TH = 0.5f;
        float error = 0.05f;
        if (argc >= 6 + 1) {
                if (wcslen(argv[6]) > 1) { // 6,7,8,9  test,validation class; TH; validation_type
                        if (argc >= 10 + 1)
                                normalization = _wtoi(argv[10]);
                        if (argc >= 11 + 1)
                                error = float(_wtof(argv[11]));

                        TH = float(_wtof(argv[8]));
                        validation_type = _wtoi(argv[9]);


                        //check validation set
                        FILE *val = _wfopen(argv[6], L"rt");
                        if (val) {
                                read_class(val, &vldrec);
                                if (vldrec.entries.size())
                                        wprintf(L" validation size: %d files, TH = %.2f\n", vldrec.entries.size(), TH);
                                else
                                        vld = true;
                        } else {
                                wprintf(L" failed to open %s\n", argv[6]);
                                exit(1);
                        }

                        //check test set
                        FILE *test = _wfopen(argv[7], L"rt");
                        if (test) {
                                read_class(test, &tstrec);
                                if (tstrec.entries.size())
                                        wprintf(L" test size: %d files\n", tstrec.entries.size());
                                else
                                        tst = true;
                        } else {
                                wprintf(L" failed to open %s\n", argv[7]);
                                exit(1);
                        }
                } else {
                        normalization = _wtoi(argv[6]);
                        if (argc >= 7 + 1)
                                error = float(_wtof(argv[7]));
                }
        }
////////////////////////////////////////////////////////////////////




        wprintf(L"loading data...\n");
        FILE *cls1 = _wfopen(argv[3], L"rt");
        FILE *cls2 = _wfopen(argv[4], L"rt");

        if (!cls1 || !cls2) {
                wprintf(L"failed to open files %s %s\n", argv[3], argv[4]);
                exit(1);
        } else {
                read_class(cls1, &trnrec, 1);   //by default put 1 class mark
                read_class(cls2, &trnrec, 2);   //by default put 2 class mark
        }

        if (!trnrec.entries.size()) {
                wprintf(L"no files loaded to training set.\n");
                exit(1);
        } else if (trnrec.clsnum.size() != 2) {
                wprintf(L"%d classes loaded. works only for 2 classes.\n", trnrec.clsnum.size());
                exit(1);
        } else
                wprintf(L" cls%d: %d  cls%d: %d  files loaded.  size: %d samples\n", trnrec.clsnum[0], trnrec.indices[0].size(), trnrec.clsnum[1], trnrec.indices[1].size(), trnrec.entries[0]->size);


        //arrange 25% from train set to validation/test sets
        if (vld && tst) {
                set_validation(&vldrec, &trnrec, 25.0f);
                set_validation(&tstrec, &trnrec, 35.0f);
        } else if (vld && !tst)
                set_validation(&vldrec, &trnrec, 50.0f);
        else if (!vld && tst)
                set_validation(&tstrec, &trnrec, 50.0f);

        dump_sets(&trnrec, &vldrec, &tstrec);



        //load network
        ann = new ANNetwork(argv[2]);
        if (ann->status() < 0) {
                wprintf(L"failed to load network: %s", argv[2]);
                exit(1);
        }
        if (ann->get_layer(0)->get_neurons_number() != vector_length) {
                if (ann->get_layer(0)->get_neurons_number() > vector_length) {
                        wprintf(L" input layer neurons %d are more than data dimension %d", ann->get_layer(0)->get_neurons_number(), vector_length);
                        exit(1);
                } else
                        wprintf(L" input layer neurons %d are less than data dimension %d\n", ann->get_layer(0)->get_neurons_number(), vector_length);
        }


        if (normalization && normalization != 4) { //energy normalization per vector
                wprintf(L"normalizing %s...\n", normalization_type[normalization-1]);
                set_normalization(&trnrec, ann);   //get normalization params  add,mult to ANN  from training set
        }



        int msecs = GetTickCount();
        wprintf(L"training...\n");

        float dvec[1] = {0.0f};
        float *ivec;
        float ovec[1] = {0.0f};
        float ovec1[1] = {0.0f}, ovec2[1] = {0.0f};
        bool prv = false;
        int x = 0, y = 0, ii = 0;
        int quit = 0;

        float acur = 0.0f, tmpacur;
        ACUR pacur, tmppacur;
        memset(&pacur, 0, sizeof(ACUR));
        memset(&tmppacur, 0, sizeof(ACUR));

        CTRACE trc(L"macurtrace.txt");
        ////////////////////TRAINING////////////////////////////////////////////////////////////////
        int step = (trnrec.indices[0].size() > trnrec.indices[1].size()) ? 2 * (int)trnrec.indices[0].size() : 2 * (int)trnrec.indices[1].size();
        int EPOCHS = _wtoi(argv[5]);
        int e = EPOCHS * step;
        int maxepoch = 0;
        while (e) {
                if (x > 1) {
                        x = 0;
                        ii++;
                }
                if (x == 0) //1st class
                        y = ii % (int)trnrec.indices[x].size();
                else if (x == 1) //2nd class
                        y = ii % (int)trnrec.indices[x].size();


                int ind = trnrec.indices[x].at(y);
                ivec = trnrec.entries[ind]->vec;

                int cls = trnrec.entries[ind]->cls;
                if (cls == 1) dvec[0] = 0.9f;
                else if (cls == 2) dvec[0] = 0.1f;

                ann->train(ivec, ovec, dvec, error);

                if (cls == 1) ovec1[0] += ovec[0];
                else if (cls == 2) ovec2[0] += ovec[0];

                x++;
                e--;

                ////////////////////////////////////////////////////////////////////////////////////////////////
                if (!(e % step)) { //one epochs is expired
                        float mout1 = ovec1[0] / (float(step) / 2.0f);   //mean out1
                        float mout2 = ovec2[0] / (float(step) / 2.0f);   //mean out2
                        ovec1[0] = 0.0f;
                        ovec2[0] = 0.0f;

                        if (quit == 10)  //no more error
                                break;

                        if (fabsl(mout1 - 0.9f) > error || fabsl(mout2 - 0.1f) > error)
                                quit = 0;
                        else
                                quit++;


                        //ann->save(L"temp.nn");
                        wprintf(L"  epoch: %d   out: %f %f ", EPOCHS - (e / step), mout1, mout2);

                        ////validate/////////////////////////////////////////////////////////////////////////////////////////////
                        if (vldrec.entries.size() && (mout1 > TH && mout2 < TH)) {
                                validate(&vldrec, TH, &tmpacur, &tmppacur);

                                trc.dump(tmpacur);
                                if (tmpacur >= acur) {
                                        maxepoch = EPOCHS - (e / step);
                                        acur = tmpacur;
                                        memcpy(&pacur, &tmppacur, sizeof(ACUR));
                                        if (!ann->save(L"maxacur.nn"))
                                                wprintf(L"  failed to save maxacur.nn  ");
                                }

                                wprintf(L"  max acur: %.2f (epoch %d)   se:%.2f sp:%.2f ac:%.2f\n", acur, maxepoch, pacur.se, pacur.sp, pacur.ac);
                        } else
                                wprintf(L"\n");
                        //////////////////////////////////////////////////////////////////////////////////////////////////////////


                        for (int i = 0; i < (int)trnrec.indices.size(); i++)  //shuffle indices to entries array
                                random_shuffle(trnrec.indices[i].begin(), trnrec.indices[i].end());


                        if (kbhit() && _getwch() == 'q') //quit program ?
                                e = 0;
                }//one epoch is expired/////////////////////////////////////////////////////////


        }
        ////////while(epochs)/////////////////////////////////////////////////////////////////////////////////

        if (e)
                wprintf(L"training done.\n");

        int hour, min, sec, msec;
        msec_to_time(GetTickCount() - msecs, hour, min, sec, msec);
        wprintf(L"training time: %02d:%02d:%02d:%03d\n", hour, min, sec, msec);

        if (!ann->save(argv[2]))
                wprintf(L"failed to save %s\n", argv[2]);        





//testing on maxacur and trained network/////////////////////////////////////////////////////////////////////////////////////////////
        ann = new ANNetwork(L"maxacur.nn");       //validate(...)  uses *ann network
        if (!ann->status()) {                     //classification results for maxacur.nn network
                wprintf(L"\nclassification results: maxacur.nn\n");
                validate(&trnrec, TH, &acur, &pacur);
                wprintf(L" \n train set: %d %d\n   sensitivity: %.2f\n   specificity: %.2f\n   +predictive: %.2f\n   -predictive: %.2f\n      accuracy: %.2f\n", trnrec.indices[0].size(), trnrec.indices[1].size(), pacur.se, pacur.sp, pacur.pp, pacur.np, pacur.ac);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -