📄 smo.cpp
字号:
max_i = (max_i+1)%n; }; // problem solved? if((max_error<=max_allowed_error) && (iteration>2)){ error=0; break; }; //////////////////////////////////////////////////////////// // new!!! find element with maximal diff to max_i // loop would be better SVMFLOAT max_diff = -1; SVMFLOAT this_diff; int n_up; // not at upper bound int n_lo; if(x[max_i] <= qp->l[max_i]){ // at lower bound n_lo = 0; } else{ n_lo = 1; }; if(x[max_i] >= qp->u[max_i]){ // at lower bound n_up=0; } else{ n_up=1; }; min_i = (max_i+1)%n; for(i=0;i<n;i++){ if((i != max_i) && (n_up || (x[i] < qp->u[i])) && (n_lo || (x[i] > qp->l[i]))){ if(x[i] <= qp->l[i]){ // at lower bound this_error = -sum[i]; if(qp->A[i]<0){ this_error = -this_error; }; } else{ // between bounds this_error = sum[i]; if(qp->A[i]>0){ this_error = -this_error; }; }; this_diff = abs(this_error - max_lambda_eq); if(this_diff>max_diff){ max_diff = this_diff; min_i = i; }; }; }; //////////////////////////////////////////////////////////// // optimize SVMINT it=1; while((0 == minimize_ij(min_i,max_i)) && (it<n)){ it++; min_i = (min_i+1)%n; if(min_i == max_i){ min_i = (min_i+1)%n; }; }; if(it==n){ error=1; } else{ error=0; }; // time up? iteration++; if(iteration>max_iteration){ calc_lambda_eq(); error+=1; break; }; }; return error;};SVMFLOAT smo_c::get_lambda_eq(){ return lambda_eq;};/** * * unbiased SVM * **/int smo_c::minimize_i(const SVMINT i){ // minimize xi with simple_solve SVMFLOAT sum_i; // sum_{k\ne i} Hik x_k +c[i] // init sum_i,j sum_i=sum[i]; sum_i -= qp->H[i*(n+1)]*x[i]; SVMFLOAT old_xi = x[i]; x[i] = -sum_i/(qp->H[i*(n+1)]); if(x[i] < qp->l[i]){ x[i] = qp->l[i]; } else if(x[i] > qp->u[i]){ x[i] = qp->u[i]; }; int ok; SVMFLOAT target; target = (old_xi-x[i])*(qp->H[i*(n+1)]/2*(old_xi+x[i])+sum_i); if(target < 0){ // cout<<"increase on SMO: "<<target<<endl; x[i] = old_xi; old_xi=0; ok=0; } else{ old_xi-=x[i]; SVMINT k; for(k=0;k<n;k++){ sum[k]-=qp->H[i*n+k]*old_xi; }; ok=1; }; if(abs(old_xi) > is_zero){ ok =1; } else{ ok=0; }; return ok;};int smo_c::smo_solve_single(quadratic_program* the_qp,SVMFLOAT* the_x){ int error=0; x = the_x; set_qp(the_qp); SVMINT i; SVMINT j; for(i=0;i<n;i++){ sum[i] = qp->c[i]; for(j=0;j<n;j++){ sum[i] += qp->H[i*n+j]*x[j]; }; }; SVMINT iteration=0; SVMFLOAT this_error; SVMFLOAT max_error = -infinity; SVMINT max_i = 0; SVMINT old_max_i=-1; lambda_eq = 0.0; while(1){ // get i with largest KKT error if(! error){ // cout<<"l"; max_error = -infinity; max_i = 0; // heuristic for i for(i=0;i<n;i++){ if(x[i] <= qp->l[i]){ // at lower bound this_error = -sum[i]; } else if(x[i] >= qp->u[i]){ // at upper bound this_error = sum[i]; } else{ // between bounds this_error = sum[i]; if(this_error<0) this_error = -this_error; } if((this_error>max_error) && (old_max_i != i)){ max_i = i; max_error = this_error; }; }; old_max_i = max_i; } else{ // heuristic didn't work max_i = (max_i+1)%n; }; // problem solved? if((max_error<=max_allowed_error) && (iteration>2)){ error=0; break; }; //////////////////////////////////////////////////////////// // optimize SVMINT it=minimize_i(max_i); if(it != 0){ error=1; } else{ error=0; }; // time up? iteration++; if(iteration>max_iteration){ error+=1; break; }; }; return error;};/** * * nuSVM * **/SVMFLOAT smo_c::get_lambda_nu(){ return lambda_nu;};void smo_c::calc_lambda_nu(){ SVMFLOAT lambda_pos_sum = 0; SVMFLOAT lambda_neg_sum = 0; SVMINT countpos = 0; SVMINT countneg = 0; SVMINT i; for(i=0;i<qp->n;i++){ if((x[i] > qp->l[i]) && (x[i]<qp->u[i])){ if(qp->A[i]>0){ lambda_pos_sum += sum[i]; countpos++; } else{ lambda_neg_sum += sum[i]; countneg++; }; }; }; if((countpos>0) && (countneg>0)){ lambda_pos_sum /= (SVMFLOAT)countpos; lambda_neg_sum /= (SVMFLOAT)countneg; lambda_eq = -(lambda_pos_sum-lambda_neg_sum)/2; lambda_nu = -(lambda_pos_sum+lambda_neg_sum)/2; } else{ if(countpos>0){ lambda_eq = -lambda_pos_sum / (SVMFLOAT)countpos; lambda_eq /= 2; lambda_nu = lambda_eq; } else if(countneg>0){ lambda_eq = -lambda_neg_sum / (SVMFLOAT)countneg; lambda_eq /= 2; lambda_nu = lambda_eq; } else{ calc_lambda_eq(); lambda_nu=0; }; };};int smo_c::smo_solve_const_sum(quadratic_program* the_qp,SVMFLOAT* the_x){ // solve optimization problem keeping sum x_i fixed int error=0; x = the_x; set_qp(the_qp); SVMFLOAT target=0; SVMINT i; SVMINT j; for(i=0;i<n;i++){ sum[i] = 0; for(j=0;j<n;j++){ sum[i] += qp->H[i*n+j]*x[j]; }; target += x[i]*sum[i]/2; target += qp->c[i]*x[i]; sum[i] += qp->c[i]; }; SVMINT iteration=0; SVMFLOAT this_error; SVMFLOAT max_error = -infinity; SVMFLOAT min_error_pos = infinity; SVMFLOAT min_error_neg = infinity; SVMINT max_i = 0; SVMINT min_i = 1; SVMINT min_i_pos = 1; SVMINT min_i_neg = 1; SVMINT old_min_i=-1; SVMINT old_max_i=-1; int use_sign=1; while(1){ // get i with largest KKT error if(! error){ use_sign = -use_sign; calc_lambda_nu(); max_error = -infinity; min_error_pos = infinity; min_error_neg = infinity; max_i = (old_max_i+1)%n; // heuristic for i for(i=0;i<n;i++){ if(x[i] <= qp->l[i]){ // at lower bound this_error = -sum[i]-lambda_nu; if(qp->A[i]>0){ this_error -= lambda_eq; } else{ this_error += lambda_eq; }; } else if(x[i] >= qp->u[i]){ // at upper bound this_error = sum[i]+lambda_nu; if(qp->A[i]>0){ this_error += lambda_eq; } else{ this_error -= lambda_eq; }; } else{ // between bounds this_error = sum[i]+lambda_nu; if(qp->A[i]>0){ this_error += lambda_eq; } else{ this_error -= lambda_eq; }; if(this_error<0) this_error = -this_error; } if(this_error>max_error){ if((old_max_i != i) && (qp->A[i] == use_sign)){ // look for specific sign max_i = i; }; max_error = this_error; }; if((qp->A[i]>0) && (this_error<=min_error_pos) && (i != old_min_i)){ min_i_pos = i; min_error_pos = this_error; }; if((qp->A[i]<0) && (this_error<=min_error_neg) && (i != old_min_i)){ min_i_neg = i; min_error_neg = this_error; }; }; old_max_i = max_i; // look for minimal error with same sign as max_i if(qp->A[max_i]>0){ min_i = min_i_pos; } else{ min_i = min_i_neg; }; old_min_i = min_i; } else{ // heuristic didn't work max_i = (max_i+1)%n; min_i = (max_i+1)%n; }; // problem solved? if((max_error<=max_allowed_error) && (iteration > 2)){ error=0; break; }; // optimize SVMINT it=1; // n-1 iterations error=1; while((error) && (it<n)){ if(qp->A[min_i] == qp->A[max_i]){ error = ! minimize_ij(min_i,max_i); }; it++; min_i = (min_i+1)%n; if(min_i == max_i){ min_i = (min_i+1)%n; }; }; // time up? iteration++; if(iteration>max_iteration){ calc_lambda_nu(); error+=1; break; }; }; SVMFLOAT ntarget=0; for(i=0;i<qp->n;i++){ for(j=0;j<qp->n;j++){ ntarget += x[i]*qp->H[i*qp->n+j]*x[j]/2; }; ntarget += qp->c[i]*x[i]; }; if(target<ntarget){ error++; }; return error;};
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -