📄 kernel.cpp
字号:
#include "stdafx.h"
#include "kernel.h"/** * * kernel_container_c * **/kernel_container_c::~kernel_container_c(){ kernel = 0;};kernel_c* kernel_container_c::get_kernel(){ if(kernel == 0){ kernel = new kernel_dot_c(); is_linear=1; }; return kernel;};void kernel_container_c::clear(){ // do not delete kernel, for reading of aggregation kernels kernel = 0;};istream& operator >> (istream& data_stream, kernel_container_c& the_kernel){ char* s = new char[MAXCHAR]; if(data_stream.eof() || ('@' == data_stream.peek())){ // no kernel definition, take dot as default if(0 != the_kernel.kernel){ delete the_kernel.kernel; }; the_kernel.kernel = new kernel_dot_c(); // throw read_exception("No kernel definition found"); } else{ while((! data_stream.eof()) && (('#' == data_stream.peek()) || ('\n' == data_stream.peek()))){ // ignore comment & newline data_stream.getline(s,MAXCHAR); }; data_stream >> s; if(0 == strcmp("type",s)) { the_kernel.is_linear=0; data_stream >> s; if(0==strcmp("dot",s)){ if(0 != the_kernel.kernel){ delete the_kernel.kernel; }; the_kernel.kernel = new kernel_dot_c(); the_kernel.is_linear=1; data_stream >> *(the_kernel.kernel); } else if(0==strcmp("lin_dot",s)){ if(0 != the_kernel.kernel){ delete the_kernel.kernel; }; the_kernel.kernel = new kernel_lin_dot_c(); data_stream >> *(the_kernel.kernel); } else if(0 == strcmp("polynomial",s)){ if(0 != the_kernel.kernel){ delete the_kernel.kernel; }; the_kernel.kernel = new kernel_polynomial_c(); data_stream >> *(the_kernel.kernel); } else if(0 == strcmp("radial",s)){ if(0 != the_kernel.kernel){ delete the_kernel.kernel; }; the_kernel.kernel = new kernel_radial_c(); data_stream >> *(the_kernel.kernel); } else if(0 == strcmp("neural",s)){ if(0 != the_kernel.kernel){ delete the_kernel.kernel; }; the_kernel.kernel = new kernel_neural_c(); data_stream >> *(the_kernel.kernel); } else if(0 == strcmp("anova",s)){ if(0 != the_kernel.kernel){ delete the_kernel.kernel; }; the_kernel.kernel = new kernel_anova_c(); data_stream >> *(the_kernel.kernel); } else if(0 == strcmp("fourier",s)){ if(0 != the_kernel.kernel){ delete the_kernel.kernel; }; the_kernel.kernel = new kernel_fourier_c(); data_stream >> *(the_kernel.kernel); } else if(0 == strcmp("reg_fourier",s)){ if(0 != the_kernel.kernel){ delete the_kernel.kernel; }; the_kernel.kernel = new kernel_reg_fourier_c(); data_stream >> *(the_kernel.kernel); } else if(0 == strcmp("exponential",s)){ if(0 != the_kernel.kernel){ delete the_kernel.kernel; }; the_kernel.kernel = new kernel_exponential_c(); data_stream >> *(the_kernel.kernel); } else if((0==strcmp("complete_matrix",s))||(0==strcmp("comp",s))){ if(0 != the_kernel.kernel){ delete the_kernel.kernel; }; the_kernel.kernel = new kernel_regularized_c(); data_stream >> *(the_kernel.kernel); } else if((0==strcmp("regularized",s))||(0==strcmp("reg",s))){ if(0 != the_kernel.kernel){ delete the_kernel.kernel; }; the_kernel.kernel = new kernel_regularized_c(); data_stream >> *(the_kernel.kernel); } else if((0==strcmp("aggregation",s))||(0==strcmp("sum_aggregation",s))){ if(0 != the_kernel.kernel){ delete the_kernel.kernel; }; the_kernel.kernel = new kernel_aggregation_c(); data_stream >> *(the_kernel.kernel); } else if(0 == strcmp("prod_aggregation",s)){ if(0 != the_kernel.kernel){ delete the_kernel.kernel; }; the_kernel.kernel = new kernel_prod_aggregation_c(); data_stream >> *(the_kernel.kernel); } else if(0 == strcmp("zero",s)){ if(0 != the_kernel.kernel){ delete the_kernel.kernel; }; the_kernel.kernel = new kernel_zero_c(); data_stream >> *(the_kernel.kernel); } else if(0 == strcmp("lintransform",s)){ if(0 != the_kernel.kernel){ delete the_kernel.kernel; }; the_kernel.kernel = new kernel_lintransform_c(); data_stream >> *(the_kernel.kernel); } else if(0 == strcmp("user",s)){ if(0 != the_kernel.kernel){ delete the_kernel.kernel; }; the_kernel.kernel = new kernel_user_c(); data_stream >> *(the_kernel.kernel); } else if(0 == strcmp("user2",s)){ if(0 != the_kernel.kernel){ delete the_kernel.kernel; }; the_kernel.kernel = new kernel_user2_c(); data_stream >> *(the_kernel.kernel); } // insert code for other kernels here else{ char* t = new char[MAXCHAR]; strcpy(t,"Unknown kernel type: "); strcat(t,s); throw read_exception(t); }; } else{ cout<<"read: "<<s<<endl; throw read_exception("kernel type has to be defined first"); }; }; delete []s; return data_stream;};ostream& operator << (ostream& data_stream, kernel_container_c& the_kernel){ if(0 != the_kernel.kernel){ data_stream << *(the_kernel.kernel); } else{ data_stream << "Empty kernel"<<endl; }; return data_stream;};/** * * kernel_c * **/kernel_c::kernel_c(){ dim =0; cache_size=0; examples_size = 0; rows = 0; last_used = 0; index = 0; counter=0; // cache profiling: // cache_misses = 0; // cache_access = 0;};kernel_c::~kernel_c(){ // cout<<"destructor"<<endl; // cache profiling: // cout<<cache_access<<" access to the cache"<<endl; // cout<<cache_misses<<" cache misses ("<<100.0*(SVMFLOAT)cache_misses/((SVMFLOAT)cache_access)<<"%)"<<endl; clean_cache();};SVMFLOAT kernel_c::calculate_K(const SVMINT i, const SVMINT j){ // cout<<"K("<<i<<","<<j<<")"<<endl; // if(cached(i) && cached(j)){ // // both are cached -> not shrinked // return(rows[lookup(i)][lookup(j)]); // };// SVMINT pos_x = lookup(i);// SVMINT pos_y = lookup(j);// if((index[pos_x] == i) && (index[pos_y] == j)// && (last_used[pos_x] != 0) && (last_used[pos_y] != 0)){// return rows[pos_x][j];// }; svm_example x = the_examples->get_example(i); svm_example y = the_examples->get_example(j); return(calculate_K(x,y));};inlineSVMFLOAT kernel_c::calculate_K(const svm_example x, const svm_example y){ // default is inner product return innerproduct(x,y); };inlineSVMFLOAT kernel_c::innerproduct(const svm_example x, const svm_example y){ // returns x*y SVMFLOAT result=0; svm_attrib* att_x = x.example; svm_attrib* att_y = y.example; svm_attrib* length_x = &(att_x[x.length]); svm_attrib* length_y = &(att_y[y.length]); while((att_x < length_x) && (att_y < length_y)){ if(att_x->index == att_y->index){ result += (att_x->att)*(att_y->att); att_x++; att_y++; } else if(att_x->index < att_y->index){ att_x++; } else{ att_y++; }; }; return result;};int kernel_c::cached(const SVMINT i){ int ok; SVMINT pos = lookup(i); if(index[pos] == i){ if(last_used[pos] > 0){ ok = 1; } else{ ok = 0; }; } else{ ok = 0; }; return(ok);};SVMFLOAT kernel_c::norm2(const svm_example x, const svm_example y){ // returns ||x-y||^2 SVMFLOAT result=0; SVMINT length_x = x.length; SVMINT length_y = y.length; svm_attrib* att_x = x.example; svm_attrib* att_y = y.example; SVMINT pos_x=0; SVMINT pos_y=0; SVMFLOAT dummy; while((pos_x < length_x) && (pos_y < length_y)){ if(att_x[pos_x].index == att_y[pos_y].index){ dummy = att_x[pos_x++].att-att_y[pos_y++].att; result += dummy*dummy; } else if(att_x[pos_x].index < att_y[pos_y].index){ dummy = att_x[pos_x++].att; result += dummy*dummy; } else{ dummy = att_y[pos_y++].att; result += dummy*dummy; }; }; while(pos_x < length_x){ dummy = att_x[pos_x++].att; result += dummy*dummy; }; while(pos_y < length_y){ dummy = att_y[pos_y++].att; result += dummy*dummy; }; return result;};int kernel_c::check(){ // check cache integrity, for debugging int result = 1; cout<<"Checking cache"<<endl; SVMINT i; // rows != 0 for(i=0;i<cache_size;i++){ if(rows[i] == 0){ cout<<"ERROR: row["<<i<<"] = 0"<<endl; result = 0; }; }; cout<<"rows[i] checked"<<endl; // 0 <= index <= examples_size if(index != 0){ SVMINT last_i=index[0]; for(i=0;i<=cache_size;i++){ if(index[i]<0){ cout<<"ERROR: index["<<i<<"] = "<<index[i]<<endl; result = 0; }; if(index[i]>examples_size){ cout<<"ERROR: index["<<i<<"] = "<<index[i]<<endl; result = 0; }; if(index[i]<last_i){ cout<<"ERROR: index["<<i<<"] descending"<<endl; result = 0; }; last_i = index[i]; }; }; cout<<"index[i] checked"<<endl; // 0 <= last_used <= counter for(i=0;i<cache_size;i++){ if(last_used[i]<0){ cout<<"ERROR: last_used["<<i<<"] = "<<last_used[i]<<endl; result = 0; }; if(last_used[i]>counter){ cout<<"ERROR: last_used["<<i<<"] = "<<last_used[i]<<endl; result = 0; }; }; cout<<"last_used[i] checked"<<endl; cout<<"complete cache test"<<endl; SVMFLOAT* adummy; SVMINT i2; for(i2=0;i2<cache_size;i2++){ cout<<i2<<" "; cout.flush(); adummy = new SVMFLOAT[examples_size]; for(SVMINT ai=0;ai<examples_size;ai++) adummy[ai] = (rows[i2])[ai]; delete [](rows[i2]); rows[i] = adummy; } cout<<"cache test succeeded"<<endl; return result;};void kernel_c::init(SVMINT cache_MB, example_set_c* new_examples){// cout<<"init"<<endl;// cout<<"cache_size = "<<cache_size<<endl;// cout<<"examples_size = "<<examples_size<<endl;// cout<<"rows = "<<rows<<endl;// if(rows != 0)cout<<"rows[0] = "<<rows[0]<<endl; clean_cache(); the_examples = new_examples; dim = the_examples->get_dim(); cache_mem = cache_MB*1048576; // check if reserved memory big enough if(cache_mem<(SVMINT)(sizeof(SVMFLOAT)*the_examples->size()+sizeof(SVMFLOAT*)+2*sizeof(SVMINT))){ // not enough space for one example, increaee cache_mem = sizeof(SVMFLOAT)*the_examples->size()+sizeof(SVMFLOAT*)+2*sizeof(SVMINT); }; set_examples_size(the_examples->size());};void kernel_c::clean_cache(){ counter=0; if(rows != 0){ SVMINT i; for(i=0;i<cache_size;i++){ if(0 != rows[i]){ delete [](rows[i]); rows[i]=0; }; }; delete []rows; }; if(last_used != 0) delete []last_used; if(index != 0) delete []index; rows=0; last_used=0; index=0; cache_size=0; examples_size=0;};inlineSVMINT kernel_c::lookup(const SVMINT i){ // find row i in cache // returns pos of element i if i in cache, // returns pos of smallest element larger than i otherwise SVMINT low; SVMINT high; SVMINT med; low=0; high=cache_size; // binary search while(low<high){ med = (low+high)/2; if(index[med]>=i){ high=med; } else{ low=med+1; }; }; return high;};void kernel_c::overwrite(const SVMINT i, const SVMINT j){ // overwrite entry i with entry j // WARNING: only to be used for shrinking! // i in cache? SVMINT pos_i=lookup(i); SVMINT pos_j=lookup(j); if((index[pos_i] == i) && (index[pos_j] == j)){ // swap pos_i and pos_j SVMFLOAT* dummy = rows[pos_i]; rows[pos_i] = rows[pos_j]; rows[pos_j] = dummy; last_used[pos_i] = last_used[pos_j]; } else{ // mark rows as invalid if(index[pos_i] == i){ last_used[pos_i] = 0; } else if(index[pos_j] == j){ last_used[pos_j] = 0; }; }; // swap i and j in all rows SVMFLOAT* my_row; for(pos_i=0;pos_i<cache_size;pos_i++){ my_row = rows[pos_i]; if(my_row != 0){ my_row[i] = my_row[j]; }; };};void kernel_c::set_examples_size(const SVMINT new_examples_size){ // cache row with new_examples_size entries only // cout<<"shrinking from "<<examples_size<<" to "<<new_examples_size<<endl; if(new_examples_size>examples_size){ clean_cache(); examples_size = new_examples_size; cache_size = cache_mem/(sizeof(SVMFLOAT)*examples_size+sizeof(SVMFLOAT*)+2*sizeof(SVMINT)); if(cache_size>examples_size){ cache_size = examples_size; }; // init rows = new SVMFLOAT*[cache_size]; last_used = new SVMINT[cache_size]; index = new SVMINT[cache_size+1]; SVMINT i; for(i=0;i<cache_size;i++){ rows[i] = 0; // new SVMFLOAT[new_examples_size]; last_used[i] = 0; index[i] = new_examples_size; }; index[cache_size] = new_examples_size; } else if(new_examples_size<examples_size){ // copy as much rows into new cache as possible SVMINT old_cache_size=cache_size; cache_size = cache_mem/(sizeof(SVMFLOAT)*new_examples_size+sizeof(SVMFLOAT*)+2*sizeof(SVMINT)); if(cache_size > new_examples_size){ cache_size = new_examples_size; }; if(cache_size>=old_cache_size){ // skip it, enough space available cache_size=old_cache_size; return; }; SVMFLOAT** new_rows = new SVMFLOAT*[cache_size]; SVMINT* new_last_used = new SVMINT[cache_size]; SVMINT* new_index = new SVMINT[cache_size+1]; SVMINT old_pos=0; SVMINT new_pos=0; new_index[cache_size] = new_examples_size; while((old_pos<old_cache_size) && (new_pos < cache_size)){ if(last_used[old_pos] > 0){ // copy example into new cache at new_pos
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -