📄 lpboost.cpp
字号:
/** @file * $Id: lpboost.cpp 2537 2006-01-08 08:40:36Z ling $ */#include <assert.h>#include <cmath>#include <iostream>#include "lpboost.h"extern "C"{#include <glpk.h>}REGISTER_CREATOR(lemga::LPBoost);namespace lemga {#define U(i) ((i)+1) //U(0) to U(n_samples-1)#define R(t) ((t)+1) //R(0) to R(T-1)REAL LPBoost::train () { assert(n_in_agg == 0 && empty()); assert(ptd != 0 && ptw != 0); assert(lm_base != 0); // we need lm_base to create new hypotheses assert(!grad_desc_view); // Construct inner problem LPX* lp = lpx_create_prob(); lpx_add_cols(lp, n_samples); // u_i for (UINT i = 0; i < n_samples; ++i) { lpx_set_col_bnds(lp, U(i), LPX_DB, 0.0, RegC * (*ptw)[i] * n_samples); // 0 <= u_i <= C_i lpx_set_obj_coef(lp, U(i), -1); // obj: -sum u_i } lpx_set_obj_dir(lp, LPX_MIN); // min obj // For adding columns int* ndx = new int[n_samples+1]; double* val = new double[n_samples+1]; REAL besterr = HUGE_VAL; pDataWgt pdw = ptw; for (UINT t = 0; t < max_n_model; ++t) { const pLearnModel p = train_with_smpwgt(pdw); REAL err = 0; for (UINT i = 0; i < n_samples; ++i) { if (p->c_error(p->get_output(i), ptd->y(i)) > 0.1) err += (*pdw)[i]; } if (err >= besterr - EPSILON) // Cannot find better hypotheses break; // Add one more constraint R(t) = -sum u_i y_i h_t(x_i) >= -1 lpx_add_rows(lp, 1); for (UINT i = 0; i < n_samples; ++i) { ndx[i+1] = U(i); val[i+1] = - p->get_output(i)[0] * ptd->y(i)[0]; } lpx_set_mat_row(lp, R(t), n_samples, ndx, val); lpx_set_row_bnds(lp, R(t), LPX_LO, -1.0, 0.0); // R(t) >= -1 // Solve inner problem lpx_simplex(lp); REAL sumu = -lpx_get_obj_val(lp); if (sumu < EPSILON) { // we do not expect this to happen std::cerr << "Warning: sum u is " << sumu << "; quit earlier.\n"; break; } besterr = (1.0 - 1.0 / sumu) / 2.0; lm.push_back(p); lm_wgt.push_back(0); ++n_in_agg; // Update sample weights DataWgt* sample_wgt = new DataWgt(n_samples); for (UINT i = 0; i < n_samples; ++i) { double wgt; wgt = lpx_get_col_prim(lp, U(i)); assert(wgt >= -EPSILON); if (wgt < 0) wgt = 0; (*sample_wgt)[i] = wgt / sumu; } pdw = sample_wgt; // Update hypothesis coefficients for (UINT k = 0; k <= t; ++k) { lm_wgt[k] = lpx_get_row_dual(lp, R(k)); assert(lm_wgt[k] > -EPSILON); if (lm_wgt[k] < 0) lm_wgt[k] = 0.0; } } delete[] ndx; delete[] val; lpx_delete_prob(lp); REAL err = 0; for (UINT i = 0; i < n_samples; ++i) { err += (get_output(i)[0]*ptd->y(i)[0] <= 0); } return err / n_samples;}} // namespace lemga
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -