📄 smo.cc
字号:
int a1 = sparse_points[i1].id[p1]; int a2 = sparse_points[i2].id[p2]; if (a1 == a2) { dot += sparse_points[i1].val[p1] * sparse_points[i2].val[p2]; p1++; p2++; } else if (a1 > a2) p2++; else p1++; } return (float)dot;}#line 1415 "smo.w"float dot_product_dense(int i1, int i2){ float dot = 0.; for (int i=0; i<d; i++) dot += dense_points[i1][i] * dense_points[i2][i]; return dot;}#line 1465 "smo.w"float rbf_kernel(int i1, int i2){ float s = dot_product_func(i1,i2); s *= -2; s += precomputed_self_dot_product[i1] + precomputed_self_dot_product[i2]; return exp(-s/two_sigma_squared);}#line 1656 "smo.w"int read_data(istream& is){ string s; int n_lines; for (n_lines = 0; getline(is, s, '\n'); n_lines++) { istrstream line(s.c_str()); vector<float> v; float t; while (line >> t) v.push_back(t); target.push_back(v.back()); v.pop_back(); int n = v.size(); if (is_sparse_data && is_binary) { sparse_binary_vector x; for (int i=0; i<n; i++) { if (v[i] < 1 || v[i] > d) { cerr << "error: line " << n_lines+1 << ": attribute index " << int(v[i]) << " out of range."<<endl; exit(1); } x.id.push_back(int(v[i])-1); } sparse_binary_points.push_back(x); } else if (is_sparse_data && !is_binary) { sparse_vector x; for (int i=0; i<n; i+=2) { if (v[i] < 1 || v[i] > d) { cerr << "data file error: line " << n_lines+1 << ": attribute index " << int(v[i]) << " out of range." << endl; exit(1); } x.id.push_back(int(v[i])-1); x.val.push_back(v[i+1]); } sparse_points.push_back(x); } else { if (v.size() != d) { cerr << "data file error: line " << n_lines+1 << " has " << v.size() << " attributes; should be d=" << d <<endl; exit(1); } dense_points.push_back(v); } } return n_lines;}#line 1741 "smo.w"void write_svm(ostream& os) { os << d << endl; os << is_sparse_data << endl; os << is_binary << endl; os << is_linear_kernel << endl; os << b << endl; if (is_linear_kernel) { for (int i=0; i<d; i++) os << w[i] << endl; } else { os << two_sigma_squared << endl; int n_support_vectors=0; for (int i=0; i<end_support_i; i++) if (alph[i] > 0) n_support_vectors++; os << n_support_vectors << endl; for (int i=0; i<end_support_i; i++) if (alph[i] > 0) os << alph[i] << endl; for (int i=0; i<end_support_i; i++) if (alph[i] > 0) { if (is_sparse_data && is_binary) { for (int j=0; j<sparse_binary_points[i].id.size(); j++) os << (sparse_binary_points[i].id[j]+1) << ' '; } else if (is_sparse_data && !is_binary) { for (int j=0; j<sparse_points[i].id.size(); j++) os << (sparse_points[i].id[j]+1) << ' ' << sparse_points[i].val[j] << ' '; } else { for (int j=0; j<d; j++) os << dense_points[i][j] << ' '; } os << target[i]; os << endl; } }}#line 1784 "smo.w"int read_svm(istream& is) { is >> d; is >> is_sparse_data; is >> is_binary; is >> is_linear_kernel; is >> b; if (is_linear_kernel) { w.resize(d); for (int i=0; i<d; i++) is >> w[i]; } else { is >> two_sigma_squared; int n_support_vectors; is >> n_support_vectors; alph.resize(n_support_vectors, 0.); for (int i=0; i<n_support_vectors; i++) is >> alph[i]; string dummy_line_to_skip_newline; getline(is, dummy_line_to_skip_newline, '\n'); return read_data(is); } return 0;}#line 1821 "smo.w"floaterror_rate(){ int n_total = 0; int n_error = 0; for (int i=first_test_i; i<N; i++) { if (learned_func(i) > 0 != target[i] > 0) n_error++; n_total++; } return float(n_error)/float(n_total);}#line 828 "smo.w"#line 846 "smo.w"int main(int argc, char *argv[]) { #line 1579 "smo.w" char *data_file_name = "svm.data"; char *svm_file_name = "svm.model"; char *output_file_name = "svm.output"; #line 847 "smo.w" int numChanged; int examineAll; #line 1501 "smo.w" { extern char *optarg; extern int optind; int c; int errflg = 0; while ((c = getopt (argc, argv, "n:d:c:t:e:p:f:m:o:r:lsba")) != EOF) switch (c) { case 'n': N = atoi(optarg); break; case 'd': d = atoi(optarg); break; case 'c': C = atof (optarg); break; case 't': tolerance = atof(optarg); break; case 'e': eps = atof (optarg); break; case 'p': two_sigma_squared = atof (optarg); break; case 'f': data_file_name = optarg; break; case 'm': svm_file_name = optarg; break; case 'o': output_file_name = optarg; break; case 'r': srand48 (atoi (optarg)); break; case 'l': is_linear_kernel = true; break; case 's': is_sparse_data = true; break; case 'b': is_binary = true; break; case 'a': is_test_only = true; break; case '?': errflg++; } if (errflg || optind < argc) { cerr << "usage: " << argv[0] << " " << "-f data_file_name\n" "-m svm_file_name\n" "-o output_file_name\n" "-n N\n" "-d d\n" "-c C\n" "-t tolerance\n" "-e epsilon\n" "-p two_sigma_squared\n" "-r random_seed\n" "-l (is_linear_kernel)\n" "-s (is_sparse_data)\n" "-b (is_binary)\n" "-a (is_test_only)\n" ; exit (2); } }#line 851 "smo.w" #line 1615 "smo.w" { int n; if (is_test_only) { ifstream svm_file(svm_file_name); end_support_i = first_test_i = n = read_svm(svm_file); N += n; } if (N > 0) { target.reserve(N); if (is_sparse_data && is_binary) sparse_binary_points.reserve(N); else if (is_sparse_data && !is_binary) sparse_points.reserve(N); else dense_points.reserve(N); } ifstream data_file(data_file_name); n = read_data(data_file); if (is_test_only) { N = first_test_i + n; } else { N = n; first_test_i = 0; end_support_i = N; } }#line 852 "smo.w" if (!is_test_only) { alph.resize(end_support_i, 0.); /* initialize threshold to zero */ b = 0.; /* E_i = u_i - y_i = 0 - y_i = -y_i */ error_cache.resize(N); if (is_linear_kernel) w.resize(d,0.); } #line 1334 "smo.w" if (is_linear_kernel && is_sparse_data && is_binary) learned_func = learned_func_linear_sparse_binary; if (is_linear_kernel && is_sparse_data && !is_binary) learned_func = learned_func_linear_sparse_nonbinary; if (is_linear_kernel && !is_sparse_data) learned_func = learned_func_linear_dense; if (!is_linear_kernel) learned_func = learned_func_nonlinear; #line 1356 "smo.w" if (is_sparse_data && is_binary) dot_product_func = dot_product_sparse_binary; if (is_sparse_data && !is_binary) dot_product_func = dot_product_sparse_nonbinary; if (!is_sparse_data) dot_product_func = dot_product_dense; #line 1433 "smo.w" if (is_linear_kernel) kernel_func = dot_product_func; if (!is_linear_kernel) kernel_func = rbf_kernel; #line 1479 "smo.w" if (!is_linear_kernel) { precomputed_self_dot_product.resize(N); for (int i=0; i<N; i++) precomputed_self_dot_product[i] = dot_product_func(i,i); } #line 867 "smo.w" if (!is_test_only) { numChanged = 0; examineAll = 1; while (numChanged > 0 || examineAll) { numChanged = 0; if (examineAll) { for (int k = 0; k < N; k++) numChanged += examineExample (k); } else { for (int k = 0; k < N; k++) if (alph[k] != 0 && alph[k] != C) numChanged += examineExample (k); } if (examineAll == 1) examineAll = 0; else if (numChanged == 0) examineAll = 1; //cerr << error_rate() << endl; #line 1846 "smo.w" /* L_D */ { #if 0 float s = 0.; for (int i=0; i<N; i++) s += alph[i]; float t = 0.; for (int i=0; i<N; i++) for (int j=0; j<N; j++) t += alph[i]*alph[j]*target[i]*target[j]*kernel_func(i,j); cerr << "Objective function=" << (s - t/2.) << endl; for (int i=0; i<N; i++) if (alph[i] < 0) cerr << "alph[" << i << "]=" << alph[i] << " < 0" << endl; s = 0.; for (int i=0; i<N; i++) s += alph[i] * target[i]; cerr << "s=" << s << endl; cerr << "error_rate=" << error_rate() << '\t'; #endif int non_bound_support =0; int bound_support =0; for (int i=0; i<N; i++) if (alph[i] > 0) { if (alph[i] < C) non_bound_support++; else bound_support++; } cerr << "non_bound=" << non_bound_support << '\t'; cerr << "bound_support=" << bound_support << endl; } #line 889 "smo.w" } #line 1811 "smo.w" { if (!is_test_only && svm_file_name != NULL) { ofstream svm_file(svm_file_name); write_svm(svm_file); } }#line 891 "smo.w" cerr << "threshold=" << b << endl; } cout << error_rate() << endl; #line 1839 "smo.w" { ofstream output_file(output_file_name); for (int i=first_test_i; i<N; i++) output_file << learned_func(i) << endl; }#line 895 "smo.w"}#line 829 "smo.w"
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -