⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 kernel.cpp

📁 vc环境下的支持向量机源码
💻 CPP
📖 第 1 页 / 共 4 页
字号:
#include "stdafx.h"
#include "kernel.h"/** * * kernel_container_c * **/kernel_container_c::~kernel_container_c(){  kernel = 0;};kernel_c* kernel_container_c::get_kernel(){  if(kernel == 0){    kernel = new kernel_dot_c();    is_linear=1;      };  return kernel;};void kernel_container_c::clear(){  // do not delete kernel, for reading of aggregation kernels  kernel = 0;};istream& operator >> (istream& data_stream, kernel_container_c& the_kernel){  char* s = new char[MAXCHAR];  if(data_stream.eof() || ('@' == data_stream.peek())){    // no kernel definition, take dot as default    if(0 != the_kernel.kernel){      delete the_kernel.kernel;    };    the_kernel.kernel = new kernel_dot_c();    //    throw read_exception("No kernel definition found");  }  else{    while((! data_stream.eof()) &&	  (('#' == data_stream.peek()) ||	   ('\n' == data_stream.peek()))){	// ignore comment & newline	data_stream.getline(s,MAXCHAR);    };    data_stream >> s;    if(0 == strcmp("type",s)) {      the_kernel.is_linear=0;      data_stream >> s;      if(0==strcmp("dot",s)){	if(0 != the_kernel.kernel){	  delete the_kernel.kernel;	};	the_kernel.kernel = new kernel_dot_c();	the_kernel.is_linear=1;	data_stream >> *(the_kernel.kernel);      }      else if(0==strcmp("lin_dot",s)){	if(0 != the_kernel.kernel){	  delete the_kernel.kernel;	};	the_kernel.kernel = new kernel_lin_dot_c();	data_stream >> *(the_kernel.kernel);      }      else if(0 == strcmp("polynomial",s)){	if(0 != the_kernel.kernel){	  delete the_kernel.kernel;	};	the_kernel.kernel = new kernel_polynomial_c();	data_stream >> *(the_kernel.kernel);      }      else if(0 == strcmp("radial",s)){	if(0 != the_kernel.kernel){	  delete the_kernel.kernel;	};	the_kernel.kernel = new kernel_radial_c();	data_stream >> *(the_kernel.kernel);      }      else if(0 == strcmp("neural",s)){	if(0 != the_kernel.kernel){	  delete the_kernel.kernel;	};	the_kernel.kernel = new kernel_neural_c();	data_stream >> *(the_kernel.kernel);      }      else if(0 == strcmp("anova",s)){	if(0 != the_kernel.kernel){	  delete the_kernel.kernel;	};	the_kernel.kernel = new kernel_anova_c();	data_stream >> *(the_kernel.kernel);      }      else if(0 == strcmp("fourier",s)){	if(0 != the_kernel.kernel){	  delete the_kernel.kernel;	};	the_kernel.kernel = new kernel_fourier_c();	data_stream >> *(the_kernel.kernel);      }      else if(0 == strcmp("reg_fourier",s)){	if(0 != the_kernel.kernel){	  delete the_kernel.kernel;	};	the_kernel.kernel = new kernel_reg_fourier_c();	data_stream >> *(the_kernel.kernel);      }      else if(0 == strcmp("exponential",s)){	if(0 != the_kernel.kernel){	  delete the_kernel.kernel;	};	the_kernel.kernel = new kernel_exponential_c();	data_stream >> *(the_kernel.kernel);      }      else if((0==strcmp("complete_matrix",s))||(0==strcmp("comp",s))){	if(0 != the_kernel.kernel){	  delete the_kernel.kernel;	};	the_kernel.kernel = new kernel_regularized_c();	data_stream >> *(the_kernel.kernel);      }      else if((0==strcmp("regularized",s))||(0==strcmp("reg",s))){	if(0 != the_kernel.kernel){	  delete the_kernel.kernel;	};	the_kernel.kernel = new kernel_regularized_c();	data_stream >> *(the_kernel.kernel);      }      else if((0==strcmp("aggregation",s))||(0==strcmp("sum_aggregation",s))){	if(0 != the_kernel.kernel){	  delete the_kernel.kernel;	};	the_kernel.kernel = new kernel_aggregation_c();	data_stream >> *(the_kernel.kernel);      }      else if(0 == strcmp("prod_aggregation",s)){	if(0 != the_kernel.kernel){	  delete the_kernel.kernel;	};	the_kernel.kernel = new kernel_prod_aggregation_c();	data_stream >> *(the_kernel.kernel);      }      else if(0 == strcmp("zero",s)){	if(0 != the_kernel.kernel){	  delete the_kernel.kernel;	};	the_kernel.kernel = new kernel_zero_c();	data_stream >> *(the_kernel.kernel);      }      else if(0 == strcmp("lintransform",s)){	if(0 != the_kernel.kernel){	  delete the_kernel.kernel;	};	the_kernel.kernel = new kernel_lintransform_c();	data_stream >> *(the_kernel.kernel);      }      else if(0 == strcmp("user",s)){	if(0 != the_kernel.kernel){	  delete the_kernel.kernel;	};	the_kernel.kernel = new kernel_user_c();	data_stream >> *(the_kernel.kernel);      }      else if(0 == strcmp("user2",s)){	if(0 != the_kernel.kernel){	  delete the_kernel.kernel;	};	the_kernel.kernel = new kernel_user2_c();	data_stream >> *(the_kernel.kernel);      }      // insert code for other kernels here      else{	char* t = new char[MAXCHAR];	strcpy(t,"Unknown kernel type: ");	strcat(t,s);	throw read_exception(t);      };    }    else{      cout<<"read: "<<s<<endl;      throw read_exception("kernel type has to be defined first");    };  };  delete []s;  return data_stream;};ostream& operator << (ostream& data_stream, kernel_container_c& the_kernel){  if(0 != the_kernel.kernel){    data_stream << *(the_kernel.kernel);  }  else{    data_stream << "Empty kernel"<<endl;  };  return data_stream;};/** * * kernel_c * **/kernel_c::kernel_c(){  dim =0;  cache_size=0;  examples_size = 0;  rows = 0;  last_used = 0;  index = 0;  counter=0;  // cache profiling:  //  cache_misses = 0;  //  cache_access = 0;};kernel_c::~kernel_c(){  //  cout<<"destructor"<<endl;  // cache profiling:  //  cout<<cache_access<<" access to the cache"<<endl;  //  cout<<cache_misses<<" cache misses ("<<100.0*(SVMFLOAT)cache_misses/((SVMFLOAT)cache_access)<<"%)"<<endl;  clean_cache();};SVMFLOAT kernel_c::calculate_K(const SVMINT i, const SVMINT j){  //  cout<<"K("<<i<<","<<j<<")"<<endl;  //    if(cached(i) && cached(j)){  //      // both are cached -> not shrinked  //      return(rows[lookup(i)][lookup(j)]);  //    };//   SVMINT pos_x = lookup(i);//   SVMINT pos_y = lookup(j);//   if((index[pos_x] == i) && (index[pos_y] == j)//      && (last_used[pos_x] != 0) && (last_used[pos_y] != 0)){//     return rows[pos_x][j];//   };  svm_example x = the_examples->get_example(i);  svm_example y = the_examples->get_example(j);  return(calculate_K(x,y));};inlineSVMFLOAT kernel_c::calculate_K(const svm_example x, const svm_example y){  // default is inner product  return innerproduct(x,y); };inlineSVMFLOAT kernel_c::innerproduct(const svm_example x, const svm_example y){  // returns x*y  SVMFLOAT result=0;  svm_attrib* att_x = x.example;  svm_attrib* att_y = y.example;  svm_attrib* length_x = &(att_x[x.length]);  svm_attrib* length_y = &(att_y[y.length]);  while((att_x < length_x) && (att_y < length_y)){    if(att_x->index == att_y->index){      result += (att_x->att)*(att_y->att);      att_x++;      att_y++;    }    else if(att_x->index < att_y->index){      att_x++;    }    else{      att_y++;    };  };  return result;};int kernel_c::cached(const SVMINT i){  int ok;  SVMINT pos = lookup(i);  if(index[pos] == i){    if(last_used[pos] > 0){      ok = 1;    }    else{      ok = 0;    };  }  else{    ok = 0;  };  return(ok);};SVMFLOAT kernel_c::norm2(const svm_example x, const svm_example y){  // returns ||x-y||^2  SVMFLOAT result=0;  SVMINT length_x = x.length;  SVMINT length_y = y.length;  svm_attrib* att_x = x.example;  svm_attrib* att_y = y.example;  SVMINT pos_x=0;  SVMINT pos_y=0;  SVMFLOAT dummy;  while((pos_x < length_x) && (pos_y < length_y)){    if(att_x[pos_x].index == att_y[pos_y].index){      dummy = att_x[pos_x++].att-att_y[pos_y++].att;      result += dummy*dummy;    }    else if(att_x[pos_x].index < att_y[pos_y].index){      dummy = att_x[pos_x++].att;      result += dummy*dummy;    }    else{      dummy = att_y[pos_y++].att;      result += dummy*dummy;    };  };  while(pos_x < length_x){    dummy = att_x[pos_x++].att;    result += dummy*dummy;  };  while(pos_y < length_y){    dummy = att_y[pos_y++].att;    result += dummy*dummy;  };  return result;};int kernel_c::check(){  // check cache integrity, for debugging  int result = 1;  cout<<"Checking cache"<<endl;  SVMINT i;  // rows != 0  for(i=0;i<cache_size;i++){    if(rows[i] == 0){      cout<<"ERROR: row["<<i<<"] = 0"<<endl;      result = 0;    };  };  cout<<"rows[i] checked"<<endl;  // 0 <= index <= examples_size  if(index != 0){    SVMINT last_i=index[0];    for(i=0;i<=cache_size;i++){      if(index[i]<0){	cout<<"ERROR: index["<<i<<"] = "<<index[i]<<endl;	result = 0;      };      if(index[i]>examples_size){	cout<<"ERROR: index["<<i<<"] = "<<index[i]<<endl;	result = 0;      };      if(index[i]<last_i){	cout<<"ERROR: index["<<i<<"] descending"<<endl;	result = 0;      };      last_i = index[i];    };  };  cout<<"index[i] checked"<<endl;  // 0 <= last_used <= counter  for(i=0;i<cache_size;i++){    if(last_used[i]<0){      cout<<"ERROR: last_used["<<i<<"] = "<<last_used[i]<<endl;      result = 0;    };    if(last_used[i]>counter){      cout<<"ERROR: last_used["<<i<<"] = "<<last_used[i]<<endl;      result = 0;    };  };  cout<<"last_used[i] checked"<<endl;  cout<<"complete cache test"<<endl;  SVMFLOAT* adummy;  SVMINT i2;  for(i2=0;i2<cache_size;i2++){    cout<<i2<<" "; cout.flush();    adummy = new SVMFLOAT[examples_size];    for(SVMINT ai=0;ai<examples_size;ai++) adummy[ai] = (rows[i2])[ai];    delete [](rows[i2]);    rows[i] = adummy;  }  cout<<"cache test succeeded"<<endl;  return result;};void kernel_c::init(SVMINT cache_MB, example_set_c* new_examples){//   cout<<"init"<<endl;//   cout<<"cache_size = "<<cache_size<<endl;//   cout<<"examples_size = "<<examples_size<<endl;//   cout<<"rows = "<<rows<<endl;//   if(rows != 0)cout<<"rows[0] = "<<rows[0]<<endl;  clean_cache();  the_examples = new_examples;   dim = the_examples->get_dim();  cache_mem = cache_MB*1048576;  // check if reserved memory big enough  if(cache_mem<(SVMINT)(sizeof(SVMFLOAT)*the_examples->size()+sizeof(SVMFLOAT*)+2*sizeof(SVMINT))){    // not enough space for one example, increaee    cache_mem = sizeof(SVMFLOAT)*the_examples->size()+sizeof(SVMFLOAT*)+2*sizeof(SVMINT);  };  set_examples_size(the_examples->size());};void kernel_c::clean_cache(){  counter=0;  if(rows != 0){    SVMINT i;    for(i=0;i<cache_size;i++){      if(0 != rows[i]){	delete [](rows[i]);	rows[i]=0;      };    };    delete []rows;  };  if(last_used != 0) delete []last_used;  if(index != 0) delete []index;  rows=0;  last_used=0;  index=0;  cache_size=0;  examples_size=0;};inlineSVMINT kernel_c::lookup(const SVMINT i){  // find row i in cache  // returns pos of element i if i in cache,  // returns pos of smallest element larger than i otherwise  SVMINT low;  SVMINT high;  SVMINT med;  low=0;  high=cache_size;  // binary search  while(low<high){    med = (low+high)/2;    if(index[med]>=i){      high=med;    }    else{      low=med+1;    };  };  return high;};void kernel_c::overwrite(const SVMINT i, const SVMINT j){  // overwrite entry i with entry j  // WARNING: only to be used for shrinking!  // i in cache?  SVMINT pos_i=lookup(i);  SVMINT pos_j=lookup(j);  if((index[pos_i] == i) && (index[pos_j] == j)){    // swap pos_i and pos_j    SVMFLOAT* dummy = rows[pos_i];    rows[pos_i] = rows[pos_j];    rows[pos_j] = dummy;    last_used[pos_i] = last_used[pos_j];  }  else{    // mark rows as invalid    if(index[pos_i] == i){      last_used[pos_i] = 0;    }    else if(index[pos_j] == j){      last_used[pos_j] = 0;    };  };  // swap i and j in all rows  SVMFLOAT* my_row;  for(pos_i=0;pos_i<cache_size;pos_i++){    my_row = rows[pos_i];    if(my_row != 0){      my_row[i] = my_row[j];    };  };};void kernel_c::set_examples_size(const SVMINT new_examples_size){  // cache row with new_examples_size entries only    // cout<<"shrinking from "<<examples_size<<" to "<<new_examples_size<<endl;  if(new_examples_size>examples_size){    clean_cache();    examples_size = new_examples_size;    cache_size = cache_mem/(sizeof(SVMFLOAT)*examples_size+sizeof(SVMFLOAT*)+2*sizeof(SVMINT));    if(cache_size>examples_size){      cache_size = examples_size;    };    // init     rows = new SVMFLOAT*[cache_size];    last_used = new SVMINT[cache_size];    index = new SVMINT[cache_size+1];    SVMINT i;    for(i=0;i<cache_size;i++){      rows[i] = 0; // new SVMFLOAT[new_examples_size];      last_used[i] = 0;      index[i] = new_examples_size;    };    index[cache_size] = new_examples_size;  }  else if(new_examples_size<examples_size){    // copy as much rows into new cache as possible    SVMINT old_cache_size=cache_size;    cache_size = cache_mem/(sizeof(SVMFLOAT)*new_examples_size+sizeof(SVMFLOAT*)+2*sizeof(SVMINT));    if(cache_size > new_examples_size){      cache_size = new_examples_size;    };    if(cache_size>=old_cache_size){      // skip it, enough space available      cache_size=old_cache_size;      return;    };    SVMFLOAT** new_rows = new SVMFLOAT*[cache_size];    SVMINT* new_last_used = new SVMINT[cache_size];    SVMINT* new_index = new SVMINT[cache_size+1];    SVMINT old_pos=0;    SVMINT new_pos=0;    new_index[cache_size] = new_examples_size;    while((old_pos<old_cache_size) && (new_pos < cache_size)){      if(last_used[old_pos] > 0){	// copy example into new cache at new_pos

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -