📄 svm-train.cpp
字号:
} out2: if(j>=1 && x_space[j-1].index > max_index) max_index = x_space[j-1].index; x_space[j++].index = -1; }
if ( param.svm_type == CVR ) { param.C = (scale_param <= 0.0) ? 10000.0 : scale_param; if ( param.mu < 0.0 ) param.mu = 0.02; double maxY = -INF, minY = INF; for (i=0; i<prob.l; i++) { maxY = max(maxY, prob.y[i]); minY = min(minY, prob.y[i]); } maxY = max(maxY, -minY); param.C = param.C *maxY; param.mu = param.mu*maxY; printf("MU %.16g, ", param.mu); } else if ( param.svm_type == CVM_LS ) { param.C = (scale_param <= 0.0) ? 10000.0 : scale_param; param.mu = param.C/((reg_param < 0.0) ? 100.0 : reg_param)/prob.l; printf("MU %.16g, ", param.mu); } else // other SVM type { param.C = (reg_param <= 0.0) ? 100.0 : param.C = reg_param; }
if(param.gamma == 0.0) param.gamma = 1.0/max_index;
else if (param.gamma < -0.5)
param.gamma = 2.0/CalRBFWidth();
if(param.kernel_type == PRECOMPUTED) for(i=0;i<prob.l;i++) { if (prob.x[i][0].index != 0) { fprintf(stderr,"Wrong input format: first column must be 0:sample_serial_number\n"); exit(1); } if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index) { fprintf(stderr,"Wrong input format: sample_serial_number out of range\n"); exit(1); } } fclose(fp); switch(param.kernel_type) { case NORMAL_POLY: case POLY: printf("Degree %.16g, coef0 %.16g, ", param.degree, param.coef0); break; case RBF: case EXP: case INV_DIST: case INV_SQDIST: printf("Gamma %.16g, ", param.gamma); break; case SIGMOID: printf("Gamma %.16g, coef0 %.16g, ", param.gamma, param.coef0); break; } printf("C = %.16g\n", param.C);}
#define D_ITER 5000
/*
corrected solution with dense format extension
*/
void read_problem(const char *filename, const char *filename2)
{
int elements, max_index, i, j;
int type, dim;
FILE *fp = fopen(filename,"r+t");
if(fp == NULL)
{
fprintf(stderr,"can't open input file %s\n",filename);
exit(1);
}
prob.l = 0;
elements = 0;
type = 0; // sparse format
dim = 0;
int c;
do
{
c = fgetc(fp);
switch(c)
{
case '\n':
++prob.l;
// fall through,
// count the '-1' element
if ((type == 1) && (dim == 0)) // dense format
{
dim = elements;
}
break;
case ':':
++elements;
break;
case ',':
++elements;
type = 1;
break;
default:
;
}
} while (c != EOF);
rewind(fp);
int num1 = prob.l;
FILE *fp2 = fopen(filename2,"r+t");
if(fp2 == NULL)
{
fprintf(stderr,"can't open input file %s\n",filename);
}
else
{
do
{
c = fgetc(fp2);
switch(c)
{
case '\n':
++prob.l;
// fall through,
// count the '-1' element
if ((type == 1) && (dim == 0)) // dense format
{
dim = elements;
}
break;
case ':':
++elements;
break;
case ',':
++elements;
type = 1;
break;
default:
;
}
} while (c != EOF);
rewind(fp2);
}
prob.y = Malloc(double,prob.l);
prob.x = Malloc(struct svm_node *,prob.l);
x_space = Malloc(struct svm_node,elements + prob.l);
if (!prob.y || !prob.x || !x_space)
{
fprintf(stdout, "ERROR: not enough memory!\n");
prob.l = 0;
return;
}
max_index = 0;
j=0;
elements = 0;
for(i=0;i<num1;i++)
{
double label;
prob.x[i] = &x_space[j];
if (type == 0) // sparse format
{
fscanf(fp,"%lf",&label);
prob.y[i] = label;
}
/* if (i % D_ITER == 0)
printf ("\n\n%d ",i);
*/
int elementsInRow = 0;
while(1)
{
int c;
do {
c = getc(fp);
if(c=='\n') break;
} while(isspace(c));
if(c=='\n') break;
ungetc(c,fp);
if (type == 0) // sparse format
{
#ifdef INT_FEAT
int tmpindex;
int tmpvalue;
fscanf(fp,"%d:%d",&tmpindex,&tmpvalue);
x_space[j].index = tmpindex;
x_space[j].value = tmpvalue;
#else
fscanf(fp,"%d:%lf",&(x_space[j].index),&(x_space[j].value));
#endif
++j;
}
else if ((type == 1) && (elementsInRow < dim)) // dense format, read a feature
{
x_space[j].index = elementsInRow;
elementsInRow++;
#ifdef INT_FEAT
int tmpvalue;
fscanf(fp, "%d,", &tmpvalue);
x_space[j].value = tmpvalue;
#else
fscanf(fp, "%lf,", &(x_space[j].value));
#endif
++j;
}
else if ((type == 1) && (elementsInRow >= dim)) // dense format, read the label
{
fscanf(fp,"%lf",&(prob.y[i]));
}
/* if (i % D_ITER == 0)
printf ("%d:%d ",x_space[j-1].index,x_space[j-1].value);
*/
}
if(j>=1 && x_space[j-1].index > max_index)
max_index = x_space[j-1].index;
x_space[j++].index = -1;
}
// printf("Finish reading the first file!\n");
for(i=num1;i<prob.l;i++)
{
double label;
prob.x[i] = &x_space[j];
if (type == 0) // sparse format
{
fscanf(fp2,"%lf",&label);
prob.y[i] = label;
}
/* if (i % D_ITER == 0)
printf ("\n\n%d ",i);
*/
int elementsInRow = 0;
while(1)
{
int c;
do {
c = getc(fp2);
if(c=='\n') break;
} while(isspace(c));
if(c=='\n') break;
ungetc(c,fp2);
if (type == 0) // sparse format
{
#ifdef INT_FEAT
int tmpindex;
int tmpvalue;
fscanf(fp2,"%d:%d",&tmpindex,&tmpvalue);
x_space[j].index = tmpindex;
x_space[j].value = tmpvalue;
#else
fscanf(fp2,"%d:%lf",&(x_space[j].index),&(x_space[j].value));
#endif
++j;
}
else if ((type == 1) && (elementsInRow < dim)) // dense format, read a feature
{
x_space[j].index = elementsInRow;
elementsInRow++;
#ifdef INT_FEAT
int tmpvalue;
fscanf(fp2, "%d,", &tmpvalue);
x_space[j].value = tmpvalue;
#else
fscanf(fp2, "%lf,", &(x_space[j].value));
#endif
++j;
}
else if ((type == 1) && (elementsInRow >= dim)) // dense format, read the label
{
fscanf(fp2,"%lf",&(prob.y[i]));
}
/* if (i % D_ITER == 0)
printf ("%d:%d ",x_space[j-1].index,x_space[j-1].value);
*/
}
if(j>=1 && x_space[j-1].index > max_index)
max_index = x_space[j-1].index;
x_space[j++].index = -1;
}
if ( param.svm_type == CVR ) { param.C = (scale_param <= 0.0) ? 10000.0 : scale_param; if ( param.mu < 0.0 ) param.mu = 0.02; double maxY = -INF, minY = INF; for (i=0; i<prob.l; i++) { maxY = max(maxY, prob.y[i]); minY = min(minY, prob.y[i]); } maxY = max(maxY, -minY); param.C = param.C *maxY; param.mu = param.mu*maxY; printf("MU %.16g, ", param.mu); } else if ( param.svm_type == CVM_LS ) { param.C = (scale_param <= 0.0) ? 10000.0 : scale_param; param.mu = param.C/((reg_param < 0.0) ? 100.0 : reg_param)/prob.l; printf("MU %.16g, ", param.mu); } else // other SVM type { param.C = (reg_param <= 0.0) ? 100.0 : param.C = reg_param; }
if(param.gamma == 0.0)
param.gamma = 1.0/max_index;
else if (param.gamma < -0.5)
param.gamma = 2.0/CalRBFWidth();
if(param.kernel_type == PRECOMPUTED)
for(i=0;i<prob.l;i++)
{
if (prob.x[i][0].index != 0)
{
fprintf(stderr,"Wrong input format: first column must be 0:sample_serial_number\n");
exit(1);
}
if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index)
{
fprintf(stderr,"Wrong input format: sample_serial_number out of range\n");
exit(1);
}
}
fclose(fp);
if (fp2 != NULL)
{
fclose(fp2);
}
switch(param.kernel_type) { case NORMAL_POLY: case POLY: printf("Degree %.16g, coef0 %.16g, ", param.degree, param.coef0); break; case RBF: case EXP: case INV_DIST: case INV_SQDIST: printf("Gamma %.16g, ", param.gamma); break; case SIGMOID: printf("Gamma %.16g, coef0 %.16g, ", param.gamma, param.coef0); break; } printf("C = %.16g\n", param.C);
}
// read in a problem (in svmlight format)
/*
original solution with memory bug
*/
/*
void read_problem(const char *filename)
{
int elements, max_index, i, j;
FILE *fp = fopen(filename,"r");
if(fp == NULL)
{
fprintf(stderr,"can't open input file %s\n",filename);
exit(1);
}
prob.l = 0;
elements = 0;
while(1)
{
int c = fgetc(fp);
switch(c)
{
case '\n':
++prob.l;
// fall through,
// count the '-1' element
case ':':
++elements;
break;
case EOF:
goto out;
default:
;
}
}
out:
rewind(fp);
prob.y = Malloc(double,prob.l);
prob.x = Malloc(struct svm_node *,prob.l);
x_space = Malloc(struct svm_node,elements);
max_index = 0;
j=0;
for(i=0;i<prob.l;i++)
{
double label;
prob.x[i] = &x_space[j];
fscanf(fp,"%lf",&label);
prob.y[i] = label;
while(1)
{
int c;
do {
c = getc(fp);
if(c=='\n') goto out2;
} while(isspace(c));
ungetc(c,fp);
fscanf(fp,"%d:%lf",&(x_space[j].index),&(x_space[j].value));
++j;
}
out2:
if(j>=1 && x_space[j-1].index > max_index)
max_index = x_space[j-1].index;
x_space[j++].index = -1;
}
if(param.gamma == 0)
param.gamma = 1.0/max_index;
else if (param.gamma < -0.5)
param.gamma = 2.0/CalRBFWidth();
if(param.kernel_type == PRECOMPUTED)
for(i=0;i<prob.l;i++)
{
if (prob.x[i][0].index != 0)
{
fprintf(stderr,"Wrong input format: first column must be 0:sample_serial_number\n");
exit(1);
}
if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index)
{
fprintf(stderr,"Wrong input format: sample_serial_number out of range\n");
exit(1);
}
}
fclose(fp);
}
*/
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -