⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 ann1dn.cpp

📁 该程序是在vc环境下编写的bp神经网络c++类库
💻 CPP
📖 第 1 页 / 共 3 页
字号:
                if (vldrec.entries.size()) {
                        validate(&vldrec, TH, &acur, &pacur);
                        wprintf(L" \n validation set: %d %d\n   sensitivity: %.2f\n   specificity: %.2f\n   +predictive: %.2f\n   -predictive: %.2f\n      accuracy: %.2f\n", vldrec.indices[0].size(), vldrec.indices[1].size(), pacur.se, pacur.sp, pacur.pp, pacur.np, pacur.ac);
                }
                if (tstrec.entries.size()) {
                        validate(&tstrec, TH, &acur, &pacur);
                        wprintf(L" \n test set: %d %d\n   sensitivity: %.2f\n   specificity: %.2f\n   +predictive: %.2f\n   -predictive: %.2f\n      accuracy: %.2f\n", tstrec.indices[0].size(), tstrec.indices[1].size(), pacur.se, pacur.sp, pacur.pp, pacur.np, pacur.ac);
                }
        } else
                wprintf(L"failed to load maxacur.nn for classification\n");

        ann = new ANNetwork(argv[2]);         //validate(...)  uses *ann network
        if (!ann->status()) {                 //classification results for trained network
                wprintf(L"\nclassification results: %s\n", argv[2]);
                validate(&trnrec, TH, &acur, &pacur);
                wprintf(L" \n train set: %d %d\n   sensitivity: %.2f\n   specificity: %.2f\n   +predictive: %.2f\n   -predictive: %.2f\n      accuracy: %.2f\n", trnrec.indices[0].size(), trnrec.indices[1].size(), pacur.se, pacur.sp, pacur.pp, pacur.np, pacur.ac);
                if (vldrec.entries.size()) {
                        validate(&vldrec, TH, &acur, &pacur);
                        wprintf(L" \n validation set: %d %d\n   sensitivity: %.2f\n   specificity: %.2f\n   +predictive: %.2f\n   -predictive: %.2f\n      accuracy: %.2f\n", vldrec.indices[0].size(), vldrec.indices[1].size(), pacur.se, pacur.sp, pacur.pp, pacur.np, pacur.ac);
                }
                if (tstrec.entries.size()) {
                        validate(&tstrec, TH, &acur, &pacur);
                        wprintf(L" \n test set: %d %d\n   sensitivity: %.2f\n   specificity: %.2f\n   +predictive: %.2f\n   -predictive: %.2f\n      accuracy: %.2f\n", tstrec.indices[0].size(), tstrec.indices[1].size(), pacur.se, pacur.sp, pacur.pp, pacur.np, pacur.ac);
                }
        } else
                wprintf(L"failed to load %s for classification\n", argv[2]);


}
////////////////////////////////////////////////////////////////////////////////////////////////

////////////////////////////////////////////////////////////////////////////////////////////////
// for 2 classes only    for N classes in SOM project
void set_validation(PREC vld, PREC trn, float p)
{
        int c1 = int((p / 100.0f) * (float)trn->indices[0].size());
        int c2 = int((p / 100.0f) * (float)trn->indices[1].size());
        wprintf(L" validaton size: %d %d\n", c1, c2);
        if (c1 < 1 || c2 < 1) {
                wprintf(L" validaton is not set, one of the vld class of 0 lenght\n");
                return;
        }

        vld->entries.resize(c1 + c2);

        vld->clsnum.push_back(trn->clsnum[0]);
        vld->clsnum.push_back(trn->clsnum[1]);


        //random shuffle indeces and take first c1,c2 ones//////////////////////
        random_shuffle(trn->indices[0].begin(), trn->indices[0].end());
        random_shuffle(trn->indices[1].begin(), trn->indices[1].end());
        ////////////////////////////////////////////////////////////////////////


        //class1////////////////////////////////////////////
        vector<int> indices;
        indices.resize(c1);
        vld->indices.push_back(indices);
        //get random % from trn set
        for (int i = 0; i < c1; i++) {
                int ind = trn->indices[0].at(i);

                vld->indices[0].at(i) = i;
                vld->entries[i] = trn->entries[ ind ];
                trn->entries[ ind ] = 0;
        }
        trn->indices[0].erase(trn->indices[0].begin(), trn->indices[0].begin() + c1);

        //class2////////////////////////////////////////////
        indices.resize(c2);
        vld->indices.push_back(indices);
        //get random % from trn set
        for (int i = 0; i < c2; i++) {
                int ind = trn->indices[1].at(i);

                vld->indices[1].at(i) = i + c1;
                vld->entries[i+c1] = trn->entries[ ind ];
                trn->entries[ ind ] = 0;
        }
        trn->indices[1].erase(trn->indices[1].begin(), trn->indices[1].begin() + c2);
}

////////////////////////////////////////////////////////////////////////////////////////////////

////////////////////////////////////////////////////////////////////////////////////////////////
/*
    acur - metric results
    pacur - classification results Se,Sp,...
*/
void validate(PREC rec, float TH, float *acur, PACUR pacur)
{
        float mse = 0.0f, se = 0.0f, sp = 0.0f, pp = 0.0f, np = 0.0f, ac = 0.0f, b;
        float *ivec;
        float ovec[1] = {0.0f};


        int TP = 0, FN = 0, TN = 0, FP = 0;
        int size = (int)rec->entries.size();
        ////////////////////TESTING////////////////////////////////////////////////////////////////
        for (int i = 0; i < (int)rec->entries.size(); i++) {
                //in train set might be 0 entries after set_validation()
                if (rec->entries[i] == 0) {
                        size--;
                        continue;
                }

                ivec = rec->entries[i]->vec;
                ann->classify(ivec, ovec);

                int clstype = (ovec[0] > TH) ? 1 : 2;
                int vcls = rec->entries[i]->cls;

                if (vcls) { //if 1 or 2, 0 if no class info
                        //mse
                        if (vcls == 1)
                                mse += (0.9f - ovec[0]) * (0.9f - ovec[0]);
                        else if (vcls == 2)
                                mse += (0.1f - ovec[0]) * (0.1f - ovec[0]);

                        //se,sp,...
                        if (vcls == clstype) {
                                if (clstype == 1)
                                        TP++;
                                else if (clstype == 2)
                                        TN++;
                        } else { /////error//////////
                                if (clstype == 2 && vcls == 1)  //ill defined as healthy
                                        FN++;
                                else if (clstype == 1 && vcls == 2)  //healthy defined as ill
                                        FP++;
                        }
                }
        }
        mse /= (float)size;   // - 0 marked classes nums
        ///////////////////////////////////////////////////////////////////////////////////////////

        if (TP)
                se = float(TP) / float(TP + FN);
        if (TN)
                sp = float(TN) / float(TN + FP);
        if (TP)
                pp = float(TP) / float(TP + FP);
        if (TN)
                np = float(TN) / float(TN + FN);
        if (TP || FP || FN || TN)
                ac = float(TP + TN) / float(TP + FN + TN + FP);

        pacur->se = 100.0f * se;
        pacur->sp = 100.0f * sp;
        pacur->pp = 100.0f * pp;
        pacur->np = 100.0f * np;
        pacur->ac = 100.0f * ac;

        switch (validation_type) {
        default:
        case 0:
                *acur = 1.0f / mse;    //mse
                break;
        case 1:
                *acur = ac;            //acur
                break;
        case 2:                 //geometric mean se,sp
                *acur = gmean(se * sp, 2);
                break;
        case 3:                 //geometric mean se,pp
                *acur = gmean(se * pp, 2);
                break;
        case 4:                 //geometric mean se,sp,ac
                *acur = gmean(se * sp * ac, 3);
                break;
        case 5:                 //geometric mean se,sp,pp,np,ac
                *acur = gmean(se * sp * pp * np * ac, 5);
                break;
        case 6:                 //F-measure  b=1
                b = 1.0f;
                if (pp && se)
                        *acur = ((b * b + 1) * se * pp) / (b * b * pp + se);
                else
                        *acur = 0;
                break;
        case 7:                 //F-measure  b=1.5
                b = 1.5f;
                if (pp && se)
                        *acur = ((b * b + 1) * se * pp) / (b * b * pp + se);
                else
                        *acur = 0;
                break;
        case 8:                 //F-measure  b=3
                b = 3.0f;
                if (pp && se)
                        *acur = ((b * b + 1) * se * pp) / (b * b * pp + se);
                else
                        *acur = 0;
                break;
        }

}
////////////////////////////////////////////////////////////////////////////////////////////////
float gmean(float m, int n)
{
        return pow(m, 1.0f / (float)n);
}
////////////////////////////////////////////////////////////////////////////////////////////////

////////////////////////////////////////////////////////////////////////////////////////////////
void dump_sets(PREC trn, PREC vld, PREC tst)
{
        wchar_t name[_MAX_PATH] = L"";
        wchar_t dir[_MAX_PATH] = L"";

        FILE* fp = _wfopen(L"dbgsets.txt", L"wt");

        if (trn != 0) {
                size_t s = 0;
                for (size_t i = 0; i < trn->entries.size(); i++) {
                        if (trn->entries[i] != 0) s++;
                }

                fwprintf(fp, L"TRAINING SET: %d\n", s);
                if (trn->entries.size() < 1000) {
                        for (size_t i = 0; i < trn->entries.size(); i++) {
                                if (trn->entries[i] != 0)  //in train set might be 0 entries after setvld()
                                        fwprintf(fp, L"%s  %d\n", trn->entries[i]->fname, trn->entries[i]->cls);
                        }
                }
        }

        if (vld != 0) {
                fwprintf(fp, L"\n\nVALIDATION SET: %d\n", vld->entries.size());
                if (vld->entries.size() < 1000) {
                        for (size_t i = 0; i < vld->entries.size(); i++)
                                fwprintf(fp, L"%s  %d\n", vld->entries[i]->fname, vld->entries[i]->cls);
                }
        }

        if (tst != 0) {
                fwprintf(fp, L"\n\nTEST SET: %d\n", tst->entries.size());
                if (tst->entries.size() < 1000) {
                        for (size_t i = 0; i < tst->entries.size(); i++)
                                fwprintf(fp, L"%s  %d\n", tst->entries[i]->fname, tst->entries[i]->cls);
                }
        }

        fclose(fp);
}
////////////////////////////////////////////////////////////////////////////////////////////////






////////////////////////////////////////////////////////////////////////////////////////////////
///////////////////////////TESTING//////////////////////////////////////////////////////////////
void test(int argc, wchar_t* argv[])
{

        REC tstrec;


////parse optional arguments 4,5////////////////////////////////////
        float TH = 0.5f;
        if (argc >= 4 + 1)
                TH = (float)_wtof(argv[4]);
        if (argc >= 5 + 1)
                normalization = _wtoi(argv[5]);
////////////////////////////////////////////////////////////////////


        wprintf(L"loading data...\n");
        FILE *cls1 = _wfopen(argv[3], L"rt");

        if (!cls1) {
                wprintf(L"failed to open file: %s\n", argv[3]);
                exit(1);
        } else
                read_class(cls1, &tstrec);



        if (!tstrec.entries.size()) {
                wprintf(L" no files loaded from: %s.\n", argv[3]);
                exit(1);
        } else
                wprintf(L" %d  files loaded.  size: %d samples\n", tstrec.entries.size(), tstrec.entries[0]->size);



        ann = new ANNetwork(argv[2]);
        if (ann->status()) {
                wprintf(L"failed to load network: %s\n", argv[2]);
                exit(1);
        }
        if (ann->get_layer(0)->get_neurons_number() != vector_length) {
                if (ann->get_layer(0)->get_neurons_number() > vector_length) {
                        wprintf(L" input layer neurons %d are more than data dimension %d", ann->get_layer(0)->get_neurons_number(), vector_length);
                        exit(1);
                } else
                        wprintf(L" input layer neurons %d are less than data dimension %d\n", ann->get_layer(0)->get_neurons_number(), vector_length);
        }
        wprintf(L"%s\n", argv[2]);


        wchar_t name[_MAX_PATH] = L"";
        wchar_t dir[_MAX_PATH] = L"";
        float *ivec;
        float ovec[1] = {0.0f};

        int TP = 0, FN = 0, TN = 0, FP = 0;
        ////////////////////TESTING////////////////////////////////////////////////////////////////
        for (int i = 0; i < (int)tstrec.entries.size(); i++) {

                ivec = tstrec.entries[i]->vec;
                ann->classify(ivec, ovec);


                if (parse_path(tstrec.entries[i]->fname, dir, name))
                        wprintf(L" %s\n", dir);
                int clstype = (ovec[0] > TH) ? 1 : 2 ;
                wprintf(L"  %s   %f   cls %d  ", name, ovec[0], clstype);

                if (tstrec.entries[i]->cls) { //if 1 or 2, 0 if no class info
                        if (tstrec.entries[i]->cls == clstype) {
                                if (clstype == 1)
                                        TP++;
                                else if (clstype == 2)
                                        TN++;
                                wprintf(L"+\n");
                        } else { /////error//////////
                                wprintf(L"-\n");

                                if (clstype == 2 && tstrec.entries[i]->cls == 1)  //ill defined as healthy
                                        FN++;
                                else if (clstype == 1 && tstrec.entries[i]->cls == 2)  //healthy defined as ill
                                        FP++;
                        }
                } else
                        wprintf(L"\n");
        }
        ///////////////////////////////////////////////////////////////////////////////////////////
        if (TP)
                wprintf(L"   sensitivity: %.2f\n", 100.0f * float(TP) / float(TP + FN));
        if (TN)
                wprintf(L"   specificity: %.2f\n", 100.0f * float(TN) / float(TN + FP));
        if (TP)
                wprintf(L"   +predictive: %.2f\n", 100.0f * float(TP) / float(TP + FP));
        if (TN)
                wprintf(L"   -predictive: %.2f\n", 100.0f * float(TN) / float(TN + FN));
        if (TP || FP || FN || TN)
                wprintf(L"      accuracy: %.2f\n", 100.0f * float(TP + TN) / float(TP + FN + TN + FP));

}
////////////////////////////////////////////////////////////////////////////////////////////////





/////////////////////get normalization params////////////////////////////////////////////////////////////
void set_normalization(REC *rec, ANNetwork *pann)
{
        int N = int(rec->entries.size());
        int I = -1;  //first nonzero entry
        for (int i = 0; i < (int)rec->entries.size(); i++) {
                if (rec->entries[i] == 0)

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -