📄 svm.cpp
字号:
return data; } Qfloat get_non_cached(int i, int j) const { return (Qfloat) y[i]*y[j]*(this->*kernel_function)(i,j); } inline bool is_cached(const int i) const { return cache->is_cached(i); } void swap_index(int i, int j) const { cache->swap_index(i,j); Kernel::swap_index(i,j); swap(y[i],y[j]); swap(QD[i],QD[j]); } ~SVC_Q() { delete[] y; delete cache; delete[] QD; }protected: schar *y; Cache *cache; Qfloat *QD; int l;};class ONE_CLASS_Q: public Kernel{public: ONE_CLASS_Q(const svm_problem& prob, const svm_parameter& param) :Kernel(prob.l, prob.x, prob.nz_idx, prob.x_len, prob.max_idx, param) { this->l = prob.l; cache = new Cache(prob.l,(int)(param.cache_size*(1<<20))); QD = new Qfloat[prob.l]; for(int i=0;i<prob.l;i++) QD[i]= (Qfloat)(this->*kernel_function)(i,i); } Qfloat *get_Q(int i, int len) const { Qfloat *data; int start; if((start = cache->get_data(i,&data,len)) < len) { for(int j=start;j<len;j++) data[j] = (Qfloat)(this->*kernel_function)(i,j); } return data; } Qfloat *get_QD() const { return QD; } Qfloat *get_Q_subset(int i, int *idxs, int n) const { Qfloat *data; int start = cache->get_data(i,&data,l); if(start == 0) // Initialize cache row { for(int j=0; j<l; ++j) data[j] = NAN; } for(int j=0; j<n; ++j) { if(isnan(data[idxs[j]])) data[idxs[j]] = (Qfloat)(this->*kernel_function)(i,idxs[j]); } return data; } Qfloat get_non_cached(int i, int j) const { return (Qfloat) (this->*kernel_function)(i,j); } inline bool is_cached(const int i) const { return cache->is_cached(i); } void swap_index(int i, int j) const { cache->swap_index(i,j); Kernel::swap_index(i,j); swap(QD[i],QD[j]); } ~ONE_CLASS_Q() { delete cache; delete[] QD; }private: Cache *cache; Qfloat *QD; int l;};class SVR_Q: public Kernel{ public: SVR_Q(const svm_problem& prob, const svm_parameter& param) :Kernel(prob.l, prob.x, prob.nz_idx, prob.x_len, prob.max_idx, param) { l = prob.l; cache = new Cache(l,(int)(param.cache_size*(1<<20))); QD = new Qfloat[2*l]; sign = new schar[2*l]; index = new int[2*l]; for(int k=0;k<l;k++) { sign[k] = 1; sign[k+l] = -1; index[k] = k; index[k+l] = k; QD[k]= (Qfloat)(this->*kernel_function)(k,k); QD[k+l]=QD[k]; } buffer[0] = new Qfloat[2*l]; buffer[1] = new Qfloat[2*l]; next_buffer = 0; } void swap_index(int i, int j) const { swap(sign[i],sign[j]); swap(index[i],index[j]); swap(QD[i],QD[j]); } Qfloat *get_Q(int i, int len) const { Qfloat *data; int real_i = index[i]; if(cache->get_data(real_i,&data,l) < l) { for(int j=0;j<l;j++) data[j] = (Qfloat)(this->*kernel_function)(real_i,j); } // reorder and copy Qfloat *buf = buffer[next_buffer]; next_buffer = 1 - next_buffer; schar si = sign[i]; for(int j=0;j<len;j++) buf[j] = si * sign[j] * data[index[j]]; return buf; } Qfloat *get_QD() const { return QD; } Qfloat *get_Q_subset(int i, int *idxs, int n) const { Qfloat *data; int real_i = index[i]; int start = cache->get_data(real_i,&data,l); if(start == 0) // Initialize cache row { for(int j=0; j<l; ++j) { data[j] = NAN; } } for(int j=0; j<n; ++j) { int real_j = index[idxs[j]]; if(isnan(data[real_j])) data[real_j] = (Qfloat)(this->*kernel_function)(real_i,real_j); } // reorder and copy Qfloat *buf = buffer[next_buffer]; next_buffer = 1 - next_buffer; schar si = sign[i]; for(int j=0; j<n; ++j) buf[idxs[j]] = si * sign[idxs[j]] * data[index[idxs[j]]]; return buf; } Qfloat get_non_cached(int i, int j) const { int real_i = index[i]; int real_j = index[j]; return (Qfloat) sign[i]*sign[j]*(this->*kernel_function)(real_i,real_j); } inline bool is_cached(const int i) const { return cache->is_cached(i); } ~SVR_Q() { delete cache; delete[] sign; delete[] index; delete[] buffer[0]; delete[] buffer[1]; delete[] QD; }protected: int l; Cache *cache; schar *sign; int *index; mutable int next_buffer; Qfloat *buffer[2]; Qfloat *QD;};int Solver_NU::select_working_set(int &out_i, int &out_j){ // Always select the maximal violating pair. Old fashion. // Does the same as LibSVM v2.36. double Gmin1 = INF; double Gmin2 = INF; double Gmax1 = -INF; double Gmax2 = -INF; int min1 = -1; int min2 = -1; int max1 = -1; int max2 = -1;// printf("G = ");// for(int t=0; t<l; ++t)// printf(" %g",G[t]);// printf("\n"); for(int t=0; t<active_size; ++t) { if(y[t] == +1) { if(!is_upper_bound(t)) { if(G[t] < Gmin1) { Gmin1 = G[t]; min1 = t; } } if(!is_lower_bound(t)) { if(G[t] > Gmax1) { Gmax1 = G[t]; max1 = t; } } } else { if(!is_upper_bound(t)) { if(G[t] < Gmin2) { Gmin2 = G[t]; min2 = t; } } if(!is_lower_bound(t)) { if(G[t] > Gmax2) { Gmax2 = G[t]; max2 = t; } } } } if(max(Gmax1-Gmin1,Gmax2-Gmin2) < eps) return 1; if(Gmax1-Gmin1 > Gmax2-Gmin2) { out_i = max1; out_j = min1; } else { out_i = max2; out_j = min2; }// printf("Selected (%d,%d)\n",out_i,out_j); return 0;}// return 1 if already optimal, return 0 otherwise// int Solver_NU::select_working_set(int &out_i, int &out_j)// {// // return i,j such that y_i = y_j and// // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)// // j: minimizes the decrease of obj value// // (if quadratic coefficeint <= 0, replace it with tau)// // -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)// double Gmaxp = -INF;// int Gmaxp_idx = -1;// double Gmaxn = -INF;// int Gmaxn_idx = -1;// int Gmin_idx = -1;// double obj_diff_min = INF;// for(int t=0;t<active_size;t++)// if(y[t]==+1)// {// if(!is_upper_bound(t))// if(-G[t] >= Gmaxp)// {// Gmaxp = -G[t];// Gmaxp_idx = t;// }// }// else// {// if(!is_lower_bound(t))// if(G[t] >= Gmaxn)// {// Gmaxn = G[t];// Gmaxn_idx = t;// }// }// int ip = Gmaxp_idx;// int in = Gmaxn_idx;// const Qfloat *Q_ip = NULL;// const Qfloat *Q_in = NULL;// if(ip != -1) // NULL Q_ip not accessed: Gmaxp=-INF if ip=-1// Q_ip = Q->get_Q(ip,active_size);// if(in != -1)// Q_in = Q->get_Q(in,active_size);// for(int j=0;j<active_size;j++)// {// if(y[j]==+1)// {// if (!is_lower_bound(j)) // {// double grad_diff=Gmaxp+G[j];// if (grad_diff >= eps)// {// double obj_diff; // double quad_coef = Q_ip[ip]+QD[j]-2*Q_ip[j];// if (quad_coef > 0)// obj_diff = -(grad_diff*grad_diff)/quad_coef;// else// obj_diff = -(grad_diff*grad_diff)/TAU;// if (obj_diff <= obj_diff_min)// {// Gmin_idx=j;// obj_diff_min = obj_diff;// }// }// }// }// else// {// if (!is_upper_bound(j))// {// double grad_diff=Gmaxn-G[j];// if (grad_diff >= eps)// {// double obj_diff; // double quad_coef = Q_in[in]+QD[j]-2*Q_in[j];// if (quad_coef > 0)// obj_diff = -(grad_diff*grad_diff)/quad_coef;// else// obj_diff = -(grad_diff*grad_diff)/TAU;// if (obj_diff <= obj_diff_min)// {// Gmin_idx=j;// obj_diff_min = obj_diff;// }// }// }// }// }// if(Gmin_idx == -1)// return 1;// if (y[Gmin_idx] == +1)// out_i = Gmaxp_idx;// else// out_i = Gmaxn_idx;// out_j = Gmin_idx;// return 0;// }void Solver_NU::do_shrinking(){ double Gmax1 = -INF; // max { -y_i * grad(f)_i | y_i = +1, i in I_up(\alpha) } double Gmax2 = -INF; // max { y_i * grad(f)_i | y_i = +1, i in I_low(\alpha) } double Gmax3 = -INF; // max { -y_i * grad(f)_i | y_i = -1, i in I_up(\alpha) } double Gmax4 = -INF; // max { y_i * grad(f)_i | y_i = -1, i in I_low(\alpha) } // find maximal violating pair first int k; for(k=0;k<active_size;k++) { if(!is_upper_bound(k)) { if(y[k]==+1) { if(-G[k] > Gmax1) Gmax1 = -G[k]; } else if(-G[k] > Gmax3) Gmax3 = -G[k]; } if(!is_lower_bound(k)) { if(y[k]==+1) { if(G[k] > Gmax2) Gmax2 = G[k]; } else if(G[k] > Gmax4) Gmax4 = G[k]; } } // shrinking double Gm1 = -Gmax2; double Gm2 = -Gmax1; double Gm3 = -Gmax4; double Gm4 = -Gmax3; for(k=0;k<active_size;k++) { if(is_lower_bound(k)) { if(y[k]==+1) { if(-G[k] >= Gm1) continue; } else if(-G[k] >= Gm3) continue; } else if(is_upper_bound(k)) { if(y[k]==+1) { if(G[k] >= Gm2) continue; } else if(G[k] >= Gm4) continue; } else continue; --active_size; swap_index(k,active_size); --k; // look at the newcomer } // unshrink, check all variables again before final iterations if(unshrinked || max(-(Gm1+Gm2),-(Gm3+Gm4)) > eps*10) return; unshrinked = true; reconstruct_gradient(); for(k=l-1;k>=active_size;k--) { if(is_lower_bound(k)) { if(y[k]==+1) { if(-G[k] < Gm1) continue; } else if(-G[k] < Gm3) continue; } else if(is_upper_bound(k)) { if(y[k]==+1) { if(G[k] < Gm2) continue; } else if(G[k] < Gm4) continue; } else continue; swap_index(k,active_size); active_size++; ++k; // look at the newcomer }}double Solver_NU::calculate_rho(){ int nr_free1 = 0,nr_free2 = 0; double ub1 = INF, ub2 = INF; double lb1 = -INF, lb2 = -INF; double sum_free1 = 0, sum_free2 = 0;// printf("alpha = ");// for(int i=0; i<l; ++i)// printf(" %g",alpha[i]);// printf("\n"); for(int i=0;i<active_size;i++) { if(y[i]==+1) { if(is_lower_bound(i)) ub1 = min(ub1,G[i]); else if(is_upper_bound(i)) lb1 = max(lb1,G[i]); else { ++nr_free1; sum_free1 += G[i]; } } else { if(is_lower_bound(i)) ub2 = min(ub2,G[i]); else if(is_upper_bound(i)) lb2 = max(lb2,G[i]); else { ++nr_free2; sum_free2 += G[i]; } } } printf("nr_free1 = %d\n", nr_free1); printf("sum_free1 = %g\n",sum_free1); printf("nr_free2 = %d\n", nr_free2); printf("sum_freee = %g\n",sum_free2); double r1,r2; if(nr_free1 > 0) r1 = sum_free1/nr_free1; else r1 = (ub1+lb1)/2; if(nr_free2 > 0) r2 = sum_free2/nr_free2; else r2 = (ub2+lb2)/2; si->r = (r1+r2)/2; printf("(r1+r2)/2 = %g\n", (r1+r2)/2); printf("(r1+r2)/2 = %g\n", (r1-r2)/2); return (r1-r2)/2;}//// construct and solve various formulations//static void solve_c_svc(const svm_problem *prob, const svm_parameter* param, double *alpha, Solver::SolutionInfo* si, double Cp, double Cn){ int l = prob->l; double *minus_ones = new double[l]; schar *y = new schar[l]; int i; for(i=0;i<l;i++) { alpha[i] = 0; minus_ones[i] = -1; if(prob->y[i] > 0) y[i] = +1; else y[i]=-1; }// Solver_GPM sgpm(param->o, param->q);// sgpm.Solve(l, SVC_Q(*prob,*param,y), minus_ones, y,// alpha, Cp, Cn, param->eps, si, param->shrinking);// Solver_LOQO sl(param->o, param->q, 1);// sl.Solve(l, SVC_Q(*prob,*param,y), minus_ones, y,// alpha, Cp, Cn, param->eps, si, param->shrinking); int size; MPI_Comm_size(MPI_COMM_WORLD, &size);#ifdef SOLVER_PSMO Solver_Parallel_SMO sps(param->o, param->q, MPI_COMM_WORLD); sps.Solve(l, SVC_Q(*prob,*param,y), minus_ones, y, alpha, Cp, Cn, param->eps, si, param->shrinking);#endif#ifdef SOLVER_LOQO Solver_Parallel_LOQO spl(param->o, param->q, 1, MPI_COMM_WORLD, size, 1, param->o/size); spl.Solve(l, SVC_Q(*prob,*param,y), minus_ones, y, alpha, Cp, Cn, param->eps, si, param->shrinking);#endif
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -