📄 gp-pred.c
字号:
/* gp-pred.c: Generic program for doing predictions with Gaussian Processes. * * Log record from a specified log file within a specified interval are used to * make predictions for squared error loss, absolute error loss and negative * log probability loss. For squared error loss, the mean of the predictive * distributon is used. For absolute error loss, the meadian is approximated by * MEDIANSSIZE samples from each record. For negative log predictive loss the * width of the predictive distribution is enlarged if the sum of the widths * of the predictive Gaussians is less than the spread of the central 2/3 of * of the means. * * (c) Copyright 1996 by Carl Edward Rasmussen. */ #include <stdio.h>#include <stdlib.h>#include <math.h>#include "util.h"#include "rand.h"#define MEDIANSSIZE 11 /* sample size for median predictions */#define two_pi 6.28318530717959real *w, /* hyperparameters */ **K, /* main matrices */ **K1, **K2, *q; /* unused */int no_wts, /* number of hyperparameters */ no_inp, /* input dimension */ no_tar, /* number of targets */ nfgeval; /* unused */struct exampleset train, test; extern void pred(real *y, real *s2, struct exampleset test);extern void init();extern real median(real *, int);extern real select(real *, int, int);main(argc, argv) int argc; char **argv;{ double r; real **s2, **means, **meds, *hlp, tmp; int i, j, k, l, h, mm, low, high, mod; long tm; /* time */ char trainfile[50], testfile[50], targetfile[50], logfile[50], outfile[50]; FILE *fp; if (argc != 4) { fprintf(stderr, "Usage: %s log-file instance-number [@][min]:[max][{%%|+}Interval]\n", argv[0]); exit(-1); } parse_range(argv[3], &low, &high, &mod); sprintf(trainfile, "train.%s", argv[2]); train.num = test.num = no_inp = -1; /* default for "unknown" */ loadExamples(&train, &no_inp, (no_tar=1, &no_tar), trainfile, NULL); sprintf(testfile, "test.%s", argv[2]); sprintf(targetfile, "targets.%s", argv[2]); loadExamples(&test, &no_inp, &no_tar, testfile, targetfile); sprintf(logfile, "%s.%s", argv[1], argv[2]); if ((fp=fopen(logfile, "r")) == NULL) { fprintf(stderr, "Could not open log file %s for reading... bye!\n", logfile); exit(-1); } init(); if (low<0 || high==-1) { /* range is given in time or no upper limit */ while (fscanf(fp, "%d %ld %lf %lf %lf", &j, &tm, &r, &r, &r) == 5) { for (i=0; i<no_wts; i++) { fscanf(fp, "%lf", &r); w[i] = r; } for (i=0; i<no_wts; i++) fscanf(fp, "%lf", &r); if (low<0 && tm/1000>=-low) low = j; if (tm/1000<high || high==-1) k = j; } if (high != -2) high = k; rewind(fp); } if (high==-2) high = low; if (mod==0) mod = high-low+1; fprintf(stderr, "Using up to %d samples with indexes between %d and %d for predicting...\n", mod, low, high); fflush(stderr); K = createMatrix(train.num, train.num); s2 = createMatrix(mod, test.num); means = createMatrix(mod, test.num); for (k=j=l=0; l<mod; l++) { if (mod==1) h=low; else h=low+(l*(high-low))/(mod-1); do { fscanf(fp, "%d %ld %lf %lf %lf", &j, &tm, &r, &r, &r); for (i=0; i<no_wts; i++) { fscanf(fp, "%lf", &r); w[i] = r; } for (i=0; i<no_wts; i++) fscanf(fp, "%lf", &r); } while (h>j); if (j>high) break; /* there are no more samples */ pred(means[l], s2[l], test); k++; if (j==high) break; /* there are no more samples */ } fclose(fp); fp = openPredFile("cguess.S.%s", argv[2]); /* write predictions for S loss */ for (i=0; i<test.num; i++) { for (tmp=0.0,j=0; j<k; j++) tmp += means[j][i]; fprintf(fp, "%10.6f\n", tmp/k); } fclose(fp); fp = openPredFile("cguess.A.%s", argv[2]); /* write predictions for A loss */ meds = createMatrix(k, MEDIANSSIZE); for (i=0; i<test.num; i++) { for (j=0; j<k; j++) for (l=0; l<MEDIANSSIZE; l++) meds[j][l] = means[j][i]+sqrt(s2[j][i])*rand_gaussian(); fprintf(fp, "%10.6f\n", median(meds[0], MEDIANSSIZE*k)); } fclose(fp); fp = openPredFile("clptarg.L.%s", argv[2]); /* write preds for L loss */ for (i=0; i<test.num; i++) { for (tmp=0.0,j=0; j<k; j++) { tmp += sqrt(s2[j][i]); meds[0][j] = means[j][i]; } tmp /= select(meds[0], k, k*5/6) - select(meds[0], k, k/6); if (tmp < 1.0) for (tmp=sq(tmp),j=0; j<k; j++) s2[j][i] /= tmp; for (tmp=0.0,j=0; j<k; j++) tmp += 1/sqrt(two_pi*s2[j][i])* exp(-0.5*sq(test.tar[i][0]-means[j][i])/s2[j][i]); fprintf(fp, "%10.6f\n", log(tmp/k)); } fclose(fp); free(meds[0]); free(meds); free(means[0]); free(means); free(s2[0]); free(s2);}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -