📄 mm_02.cc
字号:
// reset indentation // if (level_a > Integral::NONE) { Console::decreaseIndention(); } //--------------------------------------------------------------------------- // // 3. class-specific public methods: // computation methods // //--------------------------------------------------------------------------- // set indentation // if (level_a > Integral::NONE) { Console::put(L"testing class-specific public methods: computation methods...\n"); Console::increaseIndention(); } // test the set and get methods // VectorFloat new_weights(L"79, 5"); mix7.setWeights(new_weights); VectorFloat test_weights; if (!mix7.getWeights(test_weights) || !test_weights.almostEqual(new_weights)) { new_weights.debug(L"expected weights"); test_weights.debug(L"actual weights"); return Error::handle(name(), L"setWeights/getWeights", Error::TEST, __FILE__, __LINE__); } // now change mix7 to NONE mode and make sure the results hold // mix7.setMode(NONE); if (!mix7.getWeights(test_weights) || !test_weights.almostEqual(new_weights)) { new_weights.debug(L"expected weights"); test_weights.debug(L"actual weights"); return Error::handle(name(), L"setWeights/getWeights", Error::TEST, __FILE__, __LINE__); } // test the isNormalized method // if (mix7.isNormalized()) { return Error::handle(name(), L"isNormalized", Error::TEST, __FILE__, __LINE__); } // now normalize the data // mix7.normalizeWeights(); if (!mix7.isNormalized()) { mix7.debug(L"mix7 not normalized"); return Error::handle(name(), L"normalizeWeights", Error::TEST, __FILE__, __LINE__); } // now normalize to a different normalization factor which should // give us the original weights back // mix7.normalizeWeights(84); if (!mix7.isNormalized(84) || !mix7.getWeights(test_weights) || !test_weights.almostEqual(new_weights)) { mix7.debug(L"mix7 not normalized to 84"); return Error::handle(name(), L"normalizeWeights", Error::TEST, __FILE__, __LINE__); } // test all of the normalization in PRE_COMPUTE mode // mix7.setMode(PRECOMPUTE); if (mix7.isNormalized()) { return Error::handle(name(), L"isNormalized", Error::TEST, __FILE__, __LINE__); } // now normalize the data // mix7.normalizeWeights(); if (!mix7.isNormalized()) { mix7.debug(L"mix7 not normalized"); return Error::handle(name(), L"normalizeWeights", Error::TEST, __FILE__, __LINE__); } // now normalize to a different normalization factor which should // give us the original weights back // mix7.normalizeWeights(84); if (!mix7.isNormalized(84) || !mix7.getWeights(test_weights) || !test_weights.almostEqual(new_weights)) { mix7.debug(L"mix7 not normalized to 84"); return Error::handle(name(), L"normalizeWeights", Error::TEST, __FILE__, __LINE__); } // test the getModels and setModels methods // MixtureModel mix8; mix8.setModels(mix7.getModels()); mix7.getWeights(test_weights); mix8.setWeights(test_weights); if (!mix7.eq(mix8)) { mix7.debug(L"expected mixture"); mix8.debug(L"actual mixture"); return Error::handle(name(), L"setModels/getModels", Error::TEST, __FILE__, __LINE__); } // test the likelihood methods // VectorFloat mix_mean0(L"-3.639462, 1.653617, 1.787559, 4.262538, 1.369591, 2.418731, 0.2277095, 1.344989, 0.7442688, 0.8819447, 0.5029327, 0.6623194, -7.385917, -0.01989461, 0.03300723, 3.038317e-04, 7.731465e-02, 2.036626e-02, 1.392890e-02, 1.471747e-02, 2.386167e-02, -1.337681e-02, 1.569910e-02, -7.796203e-05, 9.159852e-03, -3.887357e-02, 1.891592e-02, 3.389170e-03, 6.476890e-03, -9.673509e-03, -3.964308e-03, -1.162331e-02, -5.141579e-03, -8.838848e-03, -3.854857e-03, -2.433208e-05, -3.092395e-03, -5.406926e-04, 1.159509e-02 "); MatrixFloat mix_cov0(39, 39, L"1.433564e+01, 1.975537e+01, 1.671448e+01, 2.977089e+01, 2.330977e+01, 2.712510e+01, 2.764725e+01, 3.032939e+01, 3.006948e+01, 2.652304e+01, 2.420893e+01, 2.028848e+01, 5.845720e+00, 2.440532e-01, 4.785362e-01, 7.022072e-01, 1.055732e+00, 1.272656e+00, 1.545274e+00, 1.733616e+00, 1.836289e+00, 1.909358e+00, 1.861139e+00, 1.752789e+00, 1.551828e+00, 2.726359e-02, 4.366880e-02, 8.581409e-02, 1.317306e-01, 1.942055e-01, 2.418336e-01, 2.950188e-01, 3.333847e-01, 3.536546e-01, 3.684548e-01, 3.604752e-01, 3.396829e-01, 3.016612e-01, 4.171541e-03", Integral::DIAGONAL); VectorFloat in_vec0(L"-1.05023813258004e+00,-7.77674896710878e+00,1.43041098327926e+01,-6.21318461516028e+00,-8.64903872637757e+00,1.02604097385083e+01,-5.57177352692909e+00,5.09477660171566e+00,-4.09609275292737e+00,5.96100219381628e+00,3.76151348252777e-01,-2.97277346115335e+00,-2.30355781300591e+00,7.66914363607964e-01,1.27304216160438e-01,-3.20960700374856e-01,-1.93367994699623e+00,-4.25686341458929e-01,4.12861672859121e-01,-3.16990107434180e-01,1.81895735084098e+00,1.97179456310064e+00,-5.23509954876174e-01,-2.63924777075424e-01,3.76853360820597e-03,6.07853048112972e-04,9.75499698349679e-02,5.80349863491104e-01,-4.96041065023962e-01,-5.45823432246360e-02,-1.61513044481162e-01,2.12215102626400e-01,2.76088189007932e-01,4.25378696913581e-01,-3.04615852820205e-02,-2.56536949288540e-02,2.44496596888947e-01,6.41653297530483e-01,-6.79771984752767e-02"); Float res_score0 = -67.155225; GaussianModel gauss4; gauss4.setMean(mix_mean0); gauss4.setCovariance(mix_cov0); // add the gaussian to the mixture model // MixtureModel mix_model0; mix_model0.add(gauss4); mix_model0.add(gauss4); // set the mixture model weights // // new_weights.assign(L"2, 2"); new_weights.assign(L"7.389056099, 7.389056099"); mix_model0.setWeights(new_weights); // generate the log-likelihood // Float score0 = mix_model0.getLogLikelihood(in_vec0); if (!score0.almostEqual(res_score0)) { mix_model0.debug(L"mix_model0"); res_score0.debug(L"expected score"); score0.debug(L"actual score"); return Error::handle(name(), L"getLogLikelihood", Error::TEST, __FILE__, __LINE__); } // change to NONE mode and normalize the weights so that they sum to 2.0. // since the weights are equal, this should put them both at 1.0 // mix_model0.setMode(NONE); mix_model0.normalizeWeights(2.0); Float score1 = mix_model0.getLogLikelihood(in_vec0); Float res_score1 = -69.155225; if (!score1.almostEqual(res_score1)) { mix_model0.debug(L"mix_model0"); res_score1.debug(L"expected score"); score1.debug(L"actual score"); return Error::handle(name(), L"getLogLikelihood", Error::TEST, __FILE__, __LINE__); } Float score2 = mix_model0.getLikelihood(in_vec0); if (!score2.almostEqual(Integral::exp(res_score1))) { mix_model0.debug(L"mix_model0"); res_score1.exp(); res_score1.debug(L"expected score"); score1.debug(L"actual score"); return Error::handle(name(), L"getLogLikelihood", Error::TEST, __FILE__, __LINE__); } // test the getMean // VectorFloat input_mean_mix1(L"2, 4"); VectorFloat input_mean_mix2(L"-2, -4"); VectorFloat exp_mean(L"0, 0"); VectorFloat mix_weight(L"0.5, 0.5"); VectorFloat mean; MatrixFloat input_cov_mix1(2, 2, L"1, 0.5", Integral::DIAGONAL); MatrixFloat input_cov_mix2(2, 2, L"1, 0.5", Integral::DIAGONAL); GaussianModel gauss5; gauss5.setMean(input_mean_mix1); gauss5.setCovariance(input_cov_mix1); GaussianModel gauss6; gauss6.setMean(input_mean_mix2); gauss6.setCovariance(input_cov_mix2); // add the gaussian to the mixture model // MixtureModel mix_model9; mix_model9.add(gauss5); mix_model9.add(gauss6); mix_model9.setWeights(mix_weight); mix_model9.getMean(mean); if (!mean.almostEqual(exp_mean)) { exp_mean.debug(L"expected mean"); mean.debug(L"actual mean"); return Error::handle(name(), L"getMean", Error::TEST, __FILE__, __LINE__); } // test the getCovariance for single mixture Gaussian distribution // VectorFloat input_mean_mix3(L"2, 3"); MatrixFloat input_cov_mix3(2, 2, L"1.33333, 2.444", Integral::DIAGONAL); MatrixFloat exp_cov(2, 2, L"1.33333, 2.444", Integral::DIAGONAL); MatrixFloat cov; GaussianModel gauss7; gauss7.setMean(input_mean_mix3); gauss7.setCovariance(input_cov_mix3); // add the gaussian to the mixture model // MixtureModel mix_model10; mix_model10.add(gauss7); mix_model10.getCovariance(cov); if (!cov.almostEqual(exp_cov)) { exp_cov.debug(L"expected covariance"); cov.debug(L"actual covariance"); return Error::handle(name(), L"getCovariance", Error::TEST, __FILE__, __LINE__); } // reset indentation // if (level_a > Integral::NONE) { Console::decreaseIndention(); } //--------------------------------------------------------------------------- // // 4. print completion message // //--------------------------------------------------------------------------- // reset indentation // if (level_a > Integral::NONE) { Console::decreaseIndention(); } if (level_a > Integral::NONE) { SysString output(L"diagnostics passed for class "); output.concat(name()); output.concat(L"\n"); Console::put(output); } // exit gracefully // return true;}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -