📄 mysvm.cpp
字号:
#include <stdio.h>#include <stdlib.h>#include <string.h>#include <ctype.h>#include "svm.h"#include <iostream>using namespace std;#define Malloc(type,n) (type *)malloc((n)*sizeof(type))void exit_with_help(){ printf( "Usage: svm-train [options] training_set_file [model_file]\n" "options:\n" "-s svm_type : set type of SVM (default 0)\n" " 0 -- C-SVC\n" " 1 -- nu-SVC\n" " 2 -- one-class SVM\n" " 3 -- epsilon-SVR\n" " 4 -- nu-SVR\n" "-t kernel_type : set type of kernel function (default 2)\n" " 0 -- linear: u'*v\n" " 1 -- polynomial: (gamma*u'*v + coef0)^degree\n" " 2 -- radial basis function: exp(-gamma*|u-v|^2)\n" " 3 -- sigmoid: tanh(gamma*u'*v + coef0)\n" " 4 -- precomputed kernel (kernel values in training_set_file)\n" "-d degree : set degree in kernel function (default 3)\n" "-g gamma : set gamma in kernel function (default 1/k)\n" "-r coef0 : set coef0 in kernel function (default 0)\n" "-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)\n" "-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)\n" "-p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)\n" "-m cachesize : set cache memory size in MB (default 100)\n" "-e epsilon : set tolerance of termination criterion (default 0.001)\n" "-h shrinking: whether to use the shrinking heuristics, 0 or 1 (default 1)\n" "-b probability_estimates: whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)\n" "-wi weight: set the parameter C of class i to weight*C, for C-SVC (default 1)\n" "-v n: n-fold cross validation mode\n" ); exit(1);}void read_problem(const char* filename);void do_cross_validation();void predict(FILE *input, FILE *output);struct svm_parameter param; struct svm_problem prob; struct svm_model *model;struct svm_node *x_space;int cross_validation;int nr_fold;//下面四个变量是用在预测中的char* line;int max_line_len = 1024;struct svm_node *x; //用来存储从测试文件中读取的每个测试样本 int max_nr_attr = 64; //定义最大的属性个数void main(){ //训练模型 const char *error_msg; const char* model_file_name="..\\mysvm.model"; const char* train_file_name="..\\train.txt"; param.svm_type = 1; //设置训练模型需要的参数 param.kernel_type = 2; param.degree = 3; param.gamma = 0; // 1/k param.coef0 = 0; param.nu = 0.5; param.cache_size = 100; param.C = 1; param.eps = 1e-3; param.p = 0.1; param.shrinking = 1; param.probability = 0; param.nr_weight = 0; param.weight_label = NULL; param.weight = NULL; cross_validation = 0; read_problem(train_file_name); //读取要训练的数据集 error_msg = svm_check_parameter(&prob,¶m); //检查参数设置和训练数据集是否符合规定,返回错误信息。 if(error_msg) //如果发生错误,则输出错误信息 { fprintf(stderr,"Error: %s\n",error_msg); exit(1); } if(cross_validation) //判断是否需要进行交互验证,需要的话就调用下面这个函数 { do_cross_validation(); } else { model = svm_train(&prob,¶m); //调用svm.cpp文件中的svm_train函数进行模型的训练,并返回一个svm_model*的对象 svm_save_model(model_file_name,model); //将训练好的模型存储在model_file_name命名的文件中 } //注意不能在这个地方销毁参数,因为会导致程序执行的错误 //svm_destroy_param(¶m); //销毁参数, //free(prob.y); //释放由malloc分配的资源空间 //free(prob.x); //free(x_space); //下面的程序是进行预测 FILE* input,*output; input=fopen("..\\test.txt","r"); //读入要测试的样本文件 if(input==NULL) { cout<<"can't open the test file!"<<endl; exit(1); } output=fopen("..\\result.txt","w"); //打开要存入结果的文件 if (output==NULL) { cout<<"can't open the result file!"<<endl; exit(1); } line = (char *) malloc(max_line_len*sizeof(char)); x = (struct svm_node *) malloc(max_nr_attr*sizeof(struct svm_node)); predict(input,output); //调用predict函数对输入的测试文件进行预测所属的类别标号 svm_destroy_model(model); //销毁模型 svm_destroy_param(¶m); free(prob.y); //释放资源 free(prob.x); free(x_space); free(line); free(x); fclose(input);//关闭文件流 fclose(output); while(1);//维持DOS框}void read_problem(const char* filename){ int elements, max_index, i, j; FILE *fp = fopen(filename,"r");//读取一个文件 if(fp == NULL)//判断文件打开是否成功 { cout<<"can't open the train file"<<endl; exit(1); } else { cout<<"open the train file successfully!"<<endl; //输出这个信息用于调试,看文件打开是否成功 } prob.l = 0;//初始化样本个数 elements = 0;//初始化总的元素个数 while(1) { int c = fgetc(fp); //读取文件流fp每个字符,返回的是一个int型的值 switch(c) { case '\n': //如果是回车符号,表示一个样本读取完成,则prob.l表示样本个数要加1(参见svm_problem的结构体) ++prob.l; // fall through, // count the '-1' element case ':': //如果是冒号,根据数据的格式,:表示的是属性,则元素个数加1 ++elements; break; case EOF: //文件读取结束,则跳出循环 goto out; default: ; } }out: cout<<"the number of the train samples is:"<<prob.l<<endl; rewind(fp); //将文件流的指针重新定位到文件流的开始 prob.y = Malloc(double,prob.l); //给prob的结构体成员y分配空间 prob.x = Malloc(struct svm_node *,prob.l); //给x分配空间,它是一个双重指针,*x表示每个样本地址的指针,**x表示的是样本 x_space = Malloc(struct svm_node,elements); //给样本空间分配空间,总共有elements个元素 max_index = 0; //每个样本的属性值的最大个数 j=0; for(i=0;i<prob.l;i++) { double label; prob.x[i] = &x_space[j]; //使得x[i]指向每个样本的首地址 fscanf(fp,"%lf",&label); prob.y[i] = label; //存储每个样本的类别标号 while(1) { int c; do { c = getc(fp); //读取fp的每个字符 if(c=='\n') goto out2; //读取一个样本结束的话,就转到out2的地方 } while(isspace(c)); //如果当前读取的字符是空格,表示一个属性值读取完毕,继续读取一个字符,既不是空格又不是回车符,则跳出do,while循环 ungetc(c,fp); //将刚读取的那个字符,即每个属性值的标号放回去,fp前移 if (fscanf(fp,"%d:%lf",&(x_space[j].index),&(x_space[j].value)) < 2) //判断属性的书写格式是否正确 { fprintf(stderr,"Wrong input format at line %d\n", i+1); exit(1); } ++j;//j自增,j记录的其实是样本属性值的个数,当c读到的是回车符的时候,当前j的值就是前面所有i样本的属性个数的总和,这样上面的prob.x[i]的赋值才会是正确的 } out2: if(j>=1 && x_space[j-1].index > max_index) max_index = x_space[j-1].index; x_space[j++].index = -1;//因为每个属性的存储格式,四个属性值的话,实际存储的是5个属性值,只是最后一个的index为-1,值为空。参加svm_node结构体,及其存储格式 } if(param.gamma == 0) //如果上面对para.gamma的设定为0,则修改它的值为1/max_index param.gamma = 1.0/max_index; if(param.kernel_type == PRECOMPUTED) //一般情况下kernel_type的值为RBF,所以一般下面的程序不会执行 for(i=0;i<prob.l;i++) { if (prob.x[i][0].index != 0) { fprintf(stderr,"Wrong input format: first column must be 0:sample_serial_number\n"); exit(1); } if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index) { fprintf(stderr,"Wrong input format: sample_serial_number out of range\n"); exit(1); } } fclose(fp); //关闭文件流}void do_cross_validation() //调整参数{ int i; int total_correct = 0; double total_error = 0; double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0; double *target = Malloc(double,prob.l); svm_cross_validation(&prob,¶m,nr_fold,target); //调用svm.cpp文件中的函数进行交互验证 if(param.svm_type == EPSILON_SVR || //判断是否是回归任务 param.svm_type == NU_SVR) { for(i=0;i<prob.l;i++) { double y = prob.y[i]; double v = target[i]; total_error += (v-y)*(v-y); sumv += v; sumy += y; sumvv += v*v; sumyy += y*y; sumvy += v*y; } printf("Cross Validation Mean squared error = %g\n",total_error/prob.l); printf("Cross Validation Squared correlation coefficient = %g\n", ((prob.l*sumvy-sumv*sumy)*(prob.l*sumvy-sumv*sumy))/ ((prob.l*sumvv-sumv*sumv)*(prob.l*sumyy-sumy*sumy)) ); } else { for(i=0;i<prob.l;i++) if(target[i] == prob.y[i]) ++total_correct; printf("Cross Validation Accuracy = %g%%\n",100.0*total_correct/prob.l); } free(target);}void predict(FILE *input, FILE *output){ int correct = 0; //记录预测正确的个数 int total = 0; //记录总的测试样本的个数 double error = 0; double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0; int svm_type=svm_get_svm_type(model); //获取模型的svm类型 int nr_class=svm_get_nr_class(model); //获取训练好的模型的类别数 cout<<"the number of class is :"<<nr_class<<endl; while(1) { int i = 0; int c; double target,v; if (fscanf(input,"%lf",&target)==EOF) //读入测试文件的每个样本的类别标号label,并且存储在target变量中方便后面的正确与否的判断。如果是文件的结束的话就跳出循环,所有的样本都已经测试完毕 break; while(1) { //以下的if 判断是为了判断测试文件的每个样本的属性的最大个数是否会超过定义的宏,如果超过的话,我们重新定义该宏,并分配空间 if(i>=max_nr_attr-1) // need one more for index = -1 { max_nr_attr *= 2; x = (struct svm_node *) realloc(x,max_nr_attr*sizeof(struct svm_node)); } do { c = getc(input); //读取测试文件的每个字符 if(c=='\n' || c==EOF) goto out2; //判断该字符是否是回车符,如果是回车,表示一个样本读取完毕,转到out2的地方 } while(isspace(c)); //如果当前读到的字符是空格的话,就再往前读一个字符,如果该字符既不是回车又不是空格,就跳出do while循环,表示的是现在读到的是该样本的第i+1个属性的属性标号 ungetc(c,input); //把刚读到的字符会退到input中,然后判断测试文件样本的输入格式是否正确 if (fscanf(input,"%d:%lf",&x[i].index,&x[i].value) < 2) { fprintf(stderr,"Wrong input format at line %d\n", total+1); exit(1); } ++i; //如果输入的测试文件的样本格式是正确的,就自增i,继续读取下一个属性 } out2: x[i].index = -1; //根据数据存储的格式,最后要留一个index为-1,值为空的svm_node的节点 v = svm_predict(model,x); //利用模型model进行预测 fprintf(output,"%g\n",v); //将v格式化到output文件中,并输出output文件 if(v == target) //将预测的类别v和真实的类别target进行比较 ++correct; //如果相同,则将预测正确的个数自增1 error += (v-target)*(v-target); sumv += v; sumy += target; sumvv += v*v; sumyy += target*target; sumvy += v*target; ++total; //统计总的样本个数 } printf("Accuracy = %g%% (%d/%d) (classification)\n", (double)correct/total*100,correct,total);}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -