⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 predict.cpp

📁 libsvm可视化程序
💻 CPP
字号:
#include "stdafx.h"
#include <stdio.h>
#include <ctype.h>
#include <stdlib.h>
#include <string.h>
#include "svm.h"

char* line;
int max_line_len = 1024;
struct svm_node *x;
int max_nr_attr = 64;

struct svm_model* model_test;
int predict_probability=0;

void predict(FILE *input, FILE *output)
{
	int correct = 0;
	int total = 0;
	double error = 0;
	double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;

	int svm_type=svm_get_svm_type(model_test);
	int nr_class=svm_get_nr_class(model_test);
	int *labels=(int *) malloc(nr_class*sizeof(int));
	double *prob_estimates=NULL;
	int j,count=0;

	if(predict_probability)
	{
		if (svm_type==NU_SVR || svm_type==EPSILON_SVR)
			printf("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=%g\n",svm_get_svr_probability(model_test));
		else
		{
			svm_get_labels(model_test,labels);
			prob_estimates = (double *) malloc(nr_class*sizeof(double));
			fprintf(output,"labels");
		
			for(j=0;j<nr_class;j++)
				fprintf(output," %d",labels[j]);
			fprintf(output,"\n");
		}
	}
	while(1)
	{
		int i = 0;
		int c;
		double target,v;

		if (fscanf(input,"%lf",&target)==EOF)
			break;

		while(1)
		{
			if(i>=max_nr_attr-1)	// need one more for index = -1
			{
				max_nr_attr *= 2;
				x = (struct svm_node *) realloc(x,max_nr_attr*sizeof(struct svm_node));
			}

			do {
				c = getc(input);
				if(c=='\n' || c==EOF) goto out2;
			} while(isspace(c));
			ungetc(c,input);
			fscanf(input,"%d:%lf",&x[i].index,&x[i].value);
			++i;
		}	

out2:
		x[i++].index = -1;

		if (predict_probability && (svm_type==C_SVC || svm_type==NU_SVC))
		{
			v = svm_predict_probability(model_test,x,prob_estimates);
			fprintf(output,"%g ",v);
			for(j=0;j<nr_class;j++)
				fprintf(output,"%g ",prob_estimates[j]);
				
			fprintf(output,"\n");
		}
		else
		{
			v = svm_predict(model_test,x);
				fprintf(output,"%g     ",v);
			count++;
			if(count==10){
			fprintf(output,"\n");
			count=0;
			}
		}

		if(v == target)
			++correct;
		error += (v-target)*(v-target);
		sumv += v;
		sumy += target;
		sumvv += v*v;
		sumyy += target*target;
		sumvy += v*target;
		++total;
	}
		fprintf(output,"\n");
	fprintf(output,"精确度 = %g%% (%d/%d) (classification)\n",
		(double)correct/total*100,correct,total);
	fprintf(output,"平均平方差 = %g (regression)\n",error/total);
	fprintf(output,"平方相关系数 = %g (regression)\n",
		((total*sumvy-sumv*sumy)*(total*sumvy-sumv*sumy))/
		((total*sumvv-sumv*sumv)*(total*sumyy-sumy*sumy))
		);
	if(predict_probability)
	{
		free(prob_estimates);
		free(labels);
	}
}



int test(char *test, char *strmodel)
{
	FILE *input, *output;
	
	input = fopen(test,"r");
	if(input == NULL)
	{
		//fprintf(stderr,"can't open input file %s\n",argv[i]);
		exit(1);
	}
	CString outfilename(test);
	outfilename+=".result";
	char strout[1024];
	memset(strout,0,1024);
	memcpy(strout,LPCTSTR(outfilename),outfilename.GetLength());
	output = fopen(strout,"w");
	if(output == NULL)
	{
		exit(1);
	}

	if((model_test=svm_load_model(strmodel))==0)
	{
		//fprintf(stderr,"can't open model file %s\n",argv[i+1]);
		exit(1);
	}
	
	line = (char *) malloc(max_line_len*sizeof(char));
	x = (struct svm_node *) malloc(max_nr_attr*sizeof(struct svm_node));
	if(predict_probability)
		if(svm_check_probability_model(model_test)==0)
		{
			//fprintf(stderr,"Model does not support probabiliy estimates\n");
			exit(1);
		}
	predict(input,output);
	svm_destroy_model(model_test);
	free(line);
	free(x);
	fclose(input);
	fclose(output);
	return 0;
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -