📄 kernel.cpp
字号:
#include "kernel.h"
/**
*
* kernel_container_c
*
**/
kernel_container_c::~kernel_container_c(){
kernel = 0;
};
kernel_c* kernel_container_c::get_kernel(){
if(kernel == 0){
kernel = new kernel_dot_c();
is_linear=1;
};
return kernel;
};
void kernel_container_c::clear(){
// do not delete kernel, for reading of aggregation kernels
kernel = 0;
};
istream& operator >> (istream& data_stream, kernel_container_c& the_kernel){
char* s = new char[MAXCHAR];
if(data_stream.eof() || ('@' == data_stream.peek())){
// no kernel definition, take dot as default
if(0 != the_kernel.kernel){
delete the_kernel.kernel;
};
the_kernel.kernel = new kernel_dot_c();
// throw read_exception("No kernel definition found");
}
else{
while((! data_stream.eof()) &&
(('#' == data_stream.peek()) ||
('\n' == data_stream.peek()))){
// ignore comment & newline
data_stream.getline(s,MAXCHAR);
};
data_stream >> s;
if(0 == strcmp("type",s)) {
the_kernel.is_linear=0;
data_stream >> s;
if(0==strcmp("dot",s)){
if(0 != the_kernel.kernel){
delete the_kernel.kernel;
};
the_kernel.kernel = new kernel_dot_c();
the_kernel.is_linear=1;
data_stream >> *(the_kernel.kernel);
}
else if(0==strcmp("lin_dot",s)){
if(0 != the_kernel.kernel){
delete the_kernel.kernel;
};
the_kernel.kernel = new kernel_lin_dot_c();
data_stream >> *(the_kernel.kernel);
}
else if(0 == strcmp("polynomial",s)){
if(0 != the_kernel.kernel){
delete the_kernel.kernel;
};
the_kernel.kernel = new kernel_polynomial_c();
data_stream >> *(the_kernel.kernel);
}
else if(0 == strcmp("radial",s)){
if(0 != the_kernel.kernel){
delete the_kernel.kernel;
};
the_kernel.kernel = new kernel_radial_c();
data_stream >> *(the_kernel.kernel);
}
else if(0 == strcmp("neural",s)){
if(0 != the_kernel.kernel){
delete the_kernel.kernel;
};
the_kernel.kernel = new kernel_neural_c();
data_stream >> *(the_kernel.kernel);
}
else if(0 == strcmp("anova",s)){
if(0 != the_kernel.kernel){
delete the_kernel.kernel;
};
the_kernel.kernel = new kernel_anova_c();
data_stream >> *(the_kernel.kernel);
}
else if((0==strcmp("aggregation",s))||(0==strcmp("sum_aggregation",s))){
if(0 != the_kernel.kernel){
delete the_kernel.kernel;
};
the_kernel.kernel = new kernel_aggregation_c();
data_stream >> *(the_kernel.kernel);
}
else if(0 == strcmp("prod_aggregation",s)){
if(0 != the_kernel.kernel){
delete the_kernel.kernel;
};
the_kernel.kernel = new kernel_prod_aggregation_c();
data_stream >> *(the_kernel.kernel);
}
else if(0 == strcmp("zero",s)){
if(0 != the_kernel.kernel){
delete the_kernel.kernel;
};
the_kernel.kernel = new kernel_zero_c();
data_stream >> *(the_kernel.kernel);
}
else if(0 == strcmp("user",s)){
if(0 != the_kernel.kernel){
delete the_kernel.kernel;
};
the_kernel.kernel = new kernel_user_c();
data_stream >> *(the_kernel.kernel);
}
else if(0 == strcmp("user2",s)){
if(0 != the_kernel.kernel){
delete the_kernel.kernel;
};
the_kernel.kernel = new kernel_user2_c();
data_stream >> *(the_kernel.kernel);
}
// insert code for other kernels here
else{
char* t = new char[MAXCHAR];
strcpy(t,"Unknown kernel type: ");
strcat(t,s);
throw read_exception(t);
};
}
else{
cout<<"read: "<<s<<endl;
throw read_exception("kernel type has to be defined first");
};
};
delete []s;
return data_stream;
};
ostream& operator << (ostream& data_stream, kernel_container_c& the_kernel){
if(0 != the_kernel.kernel){
data_stream << *(the_kernel.kernel);
}
else{
data_stream << "Empty kernel"<<endl;
};
return data_stream;
};
/**
*
* kernel_c
*
**/
kernel_c::kernel_c(){
dim =0;
cache_size=0;
examples_size = 0;
rows = 0;
last_used = 0;
index = 0;
counter=0;
};
kernel_c::~kernel_c(){
// cout<<"destructor"<<endl;
clean_cache();
};
SVMFLOAT kernel_c::calculate_K(const SVMINT i, const SVMINT j){
// cout<<"K("<<i<<","<<j<<")"<<endl;
svm_example x = the_examples->get_example(i);
svm_example y = the_examples->get_example(j);
return(calculate_K(x,y));
};
inline
SVMFLOAT kernel_c::calculate_K(const svm_example x, const svm_example y){
// default is inner product
return innerproduct(x,y);
};
inline
SVMFLOAT kernel_c::innerproduct(const svm_example x, const svm_example y){
// returns x*y
SVMFLOAT result=0;
svm_attrib* att_x = x.example;
svm_attrib* att_y = y.example;
svm_attrib* length_x = &(att_x[x.length]);
svm_attrib* length_y = &(att_y[y.length]);
while((att_x < length_x) && (att_y < length_y)){
if(att_x->index == att_y->index){
result += (att_x->att)*(att_y->att);
att_x++;
att_y++;
}
else if(att_x->index < att_y->index){
att_x++;
}
else{
att_y++;
};
};
return result;
};
int kernel_c::cached(const SVMINT i){
return(index[lookup(i)] == i);
};
SVMFLOAT kernel_c::norm2(const svm_example x, const svm_example y){
// returns ||x-y||^2
SVMFLOAT result=0;
SVMINT length_x = x.length;
SVMINT length_y = y.length;
svm_attrib* att_x = x.example;
svm_attrib* att_y = y.example;
SVMINT pos_x=0;
SVMINT pos_y=0;
SVMFLOAT dummy;
while((pos_x < length_x) && (pos_y < length_y)){
if(att_x[pos_x].index == att_y[pos_y].index){
dummy = att_x[pos_x++].att-att_y[pos_y++].att;
result += dummy*dummy;
}
else if(att_x[pos_x].index < att_y[pos_y].index){
dummy = att_x[pos_x++].att;
result += dummy*dummy;
}
else{
dummy = att_y[pos_y++].att;
result += dummy*dummy;
};
};
while(pos_x < length_x){
dummy = att_x[pos_x++].att;
result += dummy*dummy;
};
while(pos_y < length_y){
dummy = att_y[pos_y++].att;
result += dummy*dummy;
};
return result;
};
int kernel_c::check(){
// check cache integrity, for debugging
int result = 1;
cout<<"Checking cache"<<endl;
SVMINT i;
// rows != 0
for(i=0;i<cache_size;i++){
if(rows[i] == 0){
cout<<"ERROR: row["<<i<<"] = 0"<<endl;
result = 0;
};
};
cout<<"rows[i] checked"<<endl;
// 0 <= index <= examples_size
if(index != 0){
SVMINT last_i=index[0];
for(i=0;i<=cache_size;i++){
if(index[i]<0){
cout<<"ERROR: index["<<i<<"] = "<<index[i]<<endl;
result = 0;
};
if(index[i]>examples_size){
cout<<"ERROR: index["<<i<<"] = "<<index[i]<<endl;
result = 0;
};
if(index[i]<last_i){
cout<<"ERROR: index["<<i<<"] descending"<<endl;
result = 0;
};
last_i = index[i];
};
};
cout<<"index[i] checked"<<endl;
// 0 <= last_used <= counter
for(i=0;i<cache_size;i++){
if(last_used[i]<0){
cout<<"ERROR: last_used["<<i<<"] = "<<last_used[i]<<endl;
result = 0;
};
if(last_used[i]>counter){
cout<<"ERROR: last_used["<<i<<"] = "<<last_used[i]<<endl;
result = 0;
};
};
cout<<"last_used[i] checked"<<endl;
cout<<"complete cache test"<<endl;
SVMFLOAT* adummy;
for(SVMINT i2=0;i2<cache_size;i2++){
cout<<i2<<" "; cout.flush();
adummy = new SVMFLOAT[examples_size];
for(SVMINT ai=0;ai<examples_size;ai++) adummy[ai] = (rows[i2])[ai];
delete [](rows[i2]);
rows[i] = adummy;
}
cout<<"cache test succeeded"<<endl;
return result;
};
void kernel_c::init(SVMINT cache_MB, example_set_c* new_examples){
// cout<<"init"<<endl;
// cout<<"cache_size = "<<cache_size<<endl;
// cout<<"examples_size = "<<examples_size<<endl;
// cout<<"rows = "<<rows<<endl;
// if(rows != 0)cout<<"rows[0] = "<<rows[0]<<endl;
clean_cache();
the_examples = new_examples;
dim = the_examples->get_dim();
cache_mem = cache_MB*1048576;
// check if reserved memory big enough
if(cache_mem<(SVMINT)(sizeof(SVMFLOAT)*the_examples->size()+sizeof(SVMFLOAT*)+2*sizeof(SVMINT))){
// not enough space for one example, increaee
cache_mem = sizeof(SVMFLOAT)*the_examples->size()+sizeof(SVMFLOAT*)+2*sizeof(SVMINT);
};
set_examples_size(the_examples->size());
};
void kernel_c::clean_cache(){
counter=0;
if(rows != 0){
for(SVMINT i=0;i<cache_size;i++){
if(0 != rows[i]){
delete [](rows[i]);
rows[i]=0;
};
};
delete []rows;
};
if(last_used != 0) delete []last_used;
if(index != 0) delete []index;
rows=0;
last_used=0;
index=0;
cache_size=0;
examples_size=0;
};
inline
SVMINT kernel_c::lookup(const SVMINT i){
// find row i in cache
// returns pos of element i if i in cache,
// returns pos of smallest element large than i otherwise
SVMINT low;
SVMINT high;
SVMINT med;
low=0;
high=cache_size;
// binary search
while(low<high){
med = (low+high)/2;
if(index[med]>=i){
high=med;
}
else{
low=med+1;
};
};
return high;
};
void kernel_c::overwrite(const SVMINT i, const SVMINT j){
// overwirte entry i with entry j
// WARNING: only to be used for shrinking!
// cout<<"overwrite("<<i<<","<<j<<")"<<endl;
// i in cache?
SVMINT pos_i=lookup(i);
SVMINT pos_j=lookup(j);
if((index[pos_i] == i) && (index[pos_j] == j)){
// swap pos_i and pos_j
SVMFLOAT* dummy = rows[pos_i];
rows[pos_i] = rows[pos_j];
rows[pos_j] = dummy;
last_used[pos_i] = last_used[pos_j];
}
else{
// mark rows as invalid
if(index[pos_i] == i){
last_used[pos_i] = 0;
}
else if(index[pos_j] == j){
last_used[pos_j] = 0;
};
};
// swap i and j in all rows
SVMFLOAT* my_row;
for(pos_i=0;pos_i<cache_size;pos_i++){
my_row = rows[pos_i];
if(my_row != 0){
my_row[i] = my_row[j];
};
};
};
void kernel_c::set_examples_size(const SVMINT new_examples_size){
// cache row with new_examples_size entries only
// cout<<"shrinking from "<<examples_size<<" to "<<new_examples_size<<endl;
if(new_examples_size>examples_size){
clean_cache();
examples_size = new_examples_size;
cache_size = cache_mem/(sizeof(SVMFLOAT)*examples_size+sizeof(SVMFLOAT*)+2*sizeof(SVMINT));
if(cache_size>examples_size){
cache_size = examples_size;
};
// init
rows = new SVMFLOAT*[cache_size];
last_used = new SVMINT[cache_size];
index = new SVMINT[cache_size+1];
for(SVMINT i=0;i<cache_size;i++){
rows[i] = 0; // new SVMFLOAT[new_examples_size];
last_used[i] = 0;
index[i] = new_examples_size;
};
index[cache_size] = new_examples_size;
}
else if(new_examples_size<examples_size){
// copy as much rows into new cache as possible
SVMINT old_cache_size=cache_size;
cache_size = cache_mem/(sizeof(SVMFLOAT)*new_examples_size+sizeof(SVMFLOAT*)+2*sizeof(SVMINT));
if(cache_size > new_examples_size){
cache_size = new_examples_size;
};
if(cache_size>=old_cache_size){
// skip it, enough space available
cache_size=old_cache_size;
return;
};
SVMFLOAT** new_rows = new SVMFLOAT*[cache_size];
SVMINT* new_last_used = new SVMINT[cache_size];
SVMINT* new_index = new SVMINT[cache_size+1];
SVMINT old_pos=0;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -