⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 kernel.cpp

📁 支持向量机(SVM)的VC源代码
💻 CPP
📖 第 1 页 / 共 3 页
字号:
#include "kernel.h"


/**
 *
 * kernel_container_c
 *
 **/

kernel_container_c::~kernel_container_c(){
  kernel = 0;
};


kernel_c* kernel_container_c::get_kernel(){
  if(kernel == 0){
    kernel = new kernel_dot_c();
    is_linear=1;    
  };
  return kernel;
};


void kernel_container_c::clear(){
  // do not delete kernel, for reading of aggregation kernels
  kernel = 0;
};


istream& operator >> (istream& data_stream, kernel_container_c& the_kernel){
  char* s = new char[MAXCHAR];

  if(data_stream.eof() || ('@' == data_stream.peek())){
    // no kernel definition, take dot as default
    if(0 != the_kernel.kernel){
      delete the_kernel.kernel;
    };
    the_kernel.kernel = new kernel_dot_c();
    //    throw read_exception("No kernel definition found");
  }
  else{
    while((! data_stream.eof()) &&
	  (('#' == data_stream.peek()) ||
	   ('\n' == data_stream.peek()))){
	// ignore comment & newline
	data_stream.getline(s,MAXCHAR);
    };
    data_stream >> s;
    if(0 == strcmp("type",s)) {
      the_kernel.is_linear=0;
      data_stream >> s;
      if(0==strcmp("dot",s)){
	if(0 != the_kernel.kernel){
	  delete the_kernel.kernel;
	};
	the_kernel.kernel = new kernel_dot_c();
	the_kernel.is_linear=1;
	data_stream >> *(the_kernel.kernel);
      }
      else if(0==strcmp("lin_dot",s)){
	if(0 != the_kernel.kernel){
	  delete the_kernel.kernel;
	};
	the_kernel.kernel = new kernel_lin_dot_c();
	data_stream >> *(the_kernel.kernel);
      }
      else if(0 == strcmp("polynomial",s)){
	if(0 != the_kernel.kernel){
	  delete the_kernel.kernel;
	};
	the_kernel.kernel = new kernel_polynomial_c();
	data_stream >> *(the_kernel.kernel);
      }
      else if(0 == strcmp("radial",s)){
	if(0 != the_kernel.kernel){
	  delete the_kernel.kernel;
	};
	the_kernel.kernel = new kernel_radial_c();
	data_stream >> *(the_kernel.kernel);
      }
      else if(0 == strcmp("neural",s)){
	if(0 != the_kernel.kernel){
	  delete the_kernel.kernel;
	};
	the_kernel.kernel = new kernel_neural_c();
	data_stream >> *(the_kernel.kernel);
      }
      else if(0 == strcmp("anova",s)){
	if(0 != the_kernel.kernel){
	  delete the_kernel.kernel;
	};
	the_kernel.kernel = new kernel_anova_c();
	data_stream >> *(the_kernel.kernel);
      }
      else if((0==strcmp("aggregation",s))||(0==strcmp("sum_aggregation",s))){
	if(0 != the_kernel.kernel){
	  delete the_kernel.kernel;
	};
	the_kernel.kernel = new kernel_aggregation_c();
	data_stream >> *(the_kernel.kernel);
      }
      else if(0 == strcmp("prod_aggregation",s)){
	if(0 != the_kernel.kernel){
	  delete the_kernel.kernel;
	};
	the_kernel.kernel = new kernel_prod_aggregation_c();
	data_stream >> *(the_kernel.kernel);
      }
      else if(0 == strcmp("zero",s)){
	if(0 != the_kernel.kernel){
	  delete the_kernel.kernel;
	};
	the_kernel.kernel = new kernel_zero_c();
	data_stream >> *(the_kernel.kernel);
      }
      else if(0 == strcmp("user",s)){
	if(0 != the_kernel.kernel){
	  delete the_kernel.kernel;
	};
	the_kernel.kernel = new kernel_user_c();
	data_stream >> *(the_kernel.kernel);
      }
      else if(0 == strcmp("user2",s)){
	if(0 != the_kernel.kernel){
	  delete the_kernel.kernel;
	};
	the_kernel.kernel = new kernel_user2_c();
	data_stream >> *(the_kernel.kernel);

      }
      // insert code for other kernels here
      else{
	char* t = new char[MAXCHAR];
	strcpy(t,"Unknown kernel type: ");
	strcat(t,s);
	throw read_exception(t);
      };
    }
    else{
      cout<<"read: "<<s<<endl;
      throw read_exception("kernel type has to be defined first");
    };
  };
  delete []s;
  return data_stream;
};


ostream& operator << (ostream& data_stream, kernel_container_c& the_kernel){
  if(0 != the_kernel.kernel){
    data_stream << *(the_kernel.kernel);
  }
  else{
    data_stream << "Empty kernel"<<endl;
  };
  return data_stream;
};


/**
 *
 * kernel_c
 *
 **/
kernel_c::kernel_c(){
  dim =0;
  cache_size=0;
  examples_size = 0;
  rows = 0;
  last_used = 0;
  index = 0;
  counter=0;
};

kernel_c::~kernel_c(){
  //  cout<<"destructor"<<endl;
  clean_cache();
};

SVMFLOAT kernel_c::calculate_K(const SVMINT i, const SVMINT j){
  //  cout<<"K("<<i<<","<<j<<")"<<endl;
  svm_example x = the_examples->get_example(i);
  svm_example y = the_examples->get_example(j);
  return(calculate_K(x,y));
};

inline
SVMFLOAT kernel_c::calculate_K(const svm_example x, const svm_example y){
  // default is inner product
  return innerproduct(x,y); 
};

inline
SVMFLOAT kernel_c::innerproduct(const svm_example x, const svm_example y){
  // returns x*y
  SVMFLOAT result=0;

  svm_attrib* att_x = x.example;
  svm_attrib* att_y = y.example;
  svm_attrib* length_x = &(att_x[x.length]);
  svm_attrib* length_y = &(att_y[y.length]);

  while((att_x < length_x) && (att_y < length_y)){
    if(att_x->index == att_y->index){
      result += (att_x->att)*(att_y->att);
      att_x++;
      att_y++;
    }
    else if(att_x->index < att_y->index){
      att_x++;
    }
    else{
      att_y++;
    };
  };

  return result;
};


int kernel_c::cached(const SVMINT i){
  return(index[lookup(i)] == i);
};


SVMFLOAT kernel_c::norm2(const svm_example x, const svm_example y){
  // returns ||x-y||^2
  SVMFLOAT result=0;
  SVMINT length_x = x.length;
  SVMINT length_y = y.length;
  svm_attrib* att_x = x.example;
  svm_attrib* att_y = y.example;
  SVMINT pos_x=0;
  SVMINT pos_y=0;
  SVMFLOAT dummy;
  while((pos_x < length_x) && (pos_y < length_y)){
    if(att_x[pos_x].index == att_y[pos_y].index){
      dummy = att_x[pos_x++].att-att_y[pos_y++].att;
      result += dummy*dummy;
    }
    else if(att_x[pos_x].index < att_y[pos_y].index){
      dummy = att_x[pos_x++].att;
      result += dummy*dummy;
    }
    else{
      dummy = att_y[pos_y++].att;
      result += dummy*dummy;
    };
  };
  while(pos_x < length_x){
    dummy = att_x[pos_x++].att;
    result += dummy*dummy;
  };
  while(pos_y < length_y){
    dummy = att_y[pos_y++].att;
    result += dummy*dummy;
  };
  return result;
};


int kernel_c::check(){
  // check cache integrity, for debugging
  int result = 1;

  cout<<"Checking cache"<<endl;
  SVMINT i;
  // rows != 0
  for(i=0;i<cache_size;i++){
    if(rows[i] == 0){
      cout<<"ERROR: row["<<i<<"] = 0"<<endl;
      result = 0;
    };
  };
  cout<<"rows[i] checked"<<endl;

  // 0 <= index <= examples_size
  if(index != 0){
    SVMINT last_i=index[0];
    for(i=0;i<=cache_size;i++){
      if(index[i]<0){
	cout<<"ERROR: index["<<i<<"] = "<<index[i]<<endl;
	result = 0;
      };
      if(index[i]>examples_size){
	cout<<"ERROR: index["<<i<<"] = "<<index[i]<<endl;
	result = 0;
      };
      if(index[i]<last_i){
	cout<<"ERROR: index["<<i<<"] descending"<<endl;
	result = 0;
      };
      last_i = index[i];
    };
  };
  cout<<"index[i] checked"<<endl;

  // 0 <= last_used <= counter
  for(i=0;i<cache_size;i++){
    if(last_used[i]<0){
      cout<<"ERROR: last_used["<<i<<"] = "<<last_used[i]<<endl;
      result = 0;
    };
    if(last_used[i]>counter){
      cout<<"ERROR: last_used["<<i<<"] = "<<last_used[i]<<endl;
      result = 0;
    };
  };
  cout<<"last_used[i] checked"<<endl;

  cout<<"complete cache test"<<endl;
  SVMFLOAT* adummy;
  for(SVMINT i2=0;i2<cache_size;i2++){
    cout<<i2<<" "; cout.flush();
    adummy = new SVMFLOAT[examples_size];
    for(SVMINT ai=0;ai<examples_size;ai++) adummy[ai] = (rows[i2])[ai];
    delete [](rows[i2]);
    rows[i] = adummy;
  }
  cout<<"cache test succeeded"<<endl;



  return result;
};


void kernel_c::init(SVMINT cache_MB, example_set_c* new_examples){
//   cout<<"init"<<endl;
//   cout<<"cache_size = "<<cache_size<<endl;
//   cout<<"examples_size = "<<examples_size<<endl;
//   cout<<"rows = "<<rows<<endl;
//   if(rows != 0)cout<<"rows[0] = "<<rows[0]<<endl;
  clean_cache();
  the_examples = new_examples; 
  dim = the_examples->get_dim();
  cache_mem = cache_MB*1048576;
  // check if reserved memory big enough
  if(cache_mem<(SVMINT)(sizeof(SVMFLOAT)*the_examples->size()+sizeof(SVMFLOAT*)+2*sizeof(SVMINT))){
    // not enough space for one example, increaee
    cache_mem = sizeof(SVMFLOAT)*the_examples->size()+sizeof(SVMFLOAT*)+2*sizeof(SVMINT);
  };
  set_examples_size(the_examples->size());
};


void kernel_c::clean_cache(){
  counter=0;
  if(rows != 0){
    for(SVMINT i=0;i<cache_size;i++){
      if(0 != rows[i]){
	delete [](rows[i]);
	rows[i]=0;
      };
    };
    delete []rows;
  };
  if(last_used != 0) delete []last_used;
  if(index != 0) delete []index;
  rows=0;
  last_used=0;
  index=0;
  cache_size=0;
  examples_size=0;
};


inline
SVMINT kernel_c::lookup(const SVMINT i){
  // find row i in cache
  // returns pos of element i if i in cache,
  // returns pos of smallest element large than i otherwise
  SVMINT low;
  SVMINT high;
  SVMINT med;

  low=0;
  high=cache_size;
  // binary search
  while(low<high){
    med = (low+high)/2;
    if(index[med]>=i){
      high=med;
    }
    else{
      low=med+1;
    };
  };
  return high;
};


void kernel_c::overwrite(const SVMINT i, const SVMINT j){
  // overwirte entry i with entry j
  // WARNING: only to be used for shrinking!

  //  cout<<"overwrite("<<i<<","<<j<<")"<<endl;
  // i in cache?
  SVMINT pos_i=lookup(i);
  SVMINT pos_j=lookup(j);

  if((index[pos_i] == i) && (index[pos_j] == j)){
    // swap pos_i and pos_j
    SVMFLOAT* dummy = rows[pos_i];
    rows[pos_i] = rows[pos_j];
    rows[pos_j] = dummy;
    last_used[pos_i] = last_used[pos_j];

  }
  else{
    // mark rows as invalid
    if(index[pos_i] == i){
      last_used[pos_i] = 0;
    }
    else if(index[pos_j] == j){
      last_used[pos_j] = 0;
    };
  };

  // swap i and j in all rows
  SVMFLOAT* my_row;
  for(pos_i=0;pos_i<cache_size;pos_i++){
    my_row = rows[pos_i];
    if(my_row != 0){
      my_row[i] = my_row[j];
    };
  };
};


void kernel_c::set_examples_size(const SVMINT new_examples_size){
  // cache row with new_examples_size entries only
  
  // cout<<"shrinking from "<<examples_size<<" to "<<new_examples_size<<endl;
  if(new_examples_size>examples_size){
    clean_cache();
    examples_size = new_examples_size;
    cache_size = cache_mem/(sizeof(SVMFLOAT)*examples_size+sizeof(SVMFLOAT*)+2*sizeof(SVMINT));
    if(cache_size>examples_size){
      cache_size = examples_size;
    };
    // init 
    rows = new SVMFLOAT*[cache_size];
    last_used = new SVMINT[cache_size];
    index = new SVMINT[cache_size+1];
    for(SVMINT i=0;i<cache_size;i++){
      rows[i] = 0; // new SVMFLOAT[new_examples_size];
      last_used[i] = 0;
      index[i] = new_examples_size;
    };
    index[cache_size] = new_examples_size;
  }
  else if(new_examples_size<examples_size){
    // copy as much rows into new cache as possible
    SVMINT old_cache_size=cache_size;
    cache_size = cache_mem/(sizeof(SVMFLOAT)*new_examples_size+sizeof(SVMFLOAT*)+2*sizeof(SVMINT));
    if(cache_size > new_examples_size){
      cache_size = new_examples_size;
    };
    if(cache_size>=old_cache_size){
      // skip it, enough space available
      cache_size=old_cache_size;
      return;
    };

    SVMFLOAT** new_rows = new SVMFLOAT*[cache_size];
    SVMINT* new_last_used = new SVMINT[cache_size];
    SVMINT* new_index = new SVMINT[cache_size+1];
    SVMINT old_pos=0;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -