📄 dirichlet.c
字号:
#include <stdio.h>#include <math.h>#include <malloc.h>#include <assert.h>#include <stdlib.h>double *dirichlet_alphas;double *dirichlet_counts;int dirichlet_num_bags; /* e.g. the number of document classes */int dirichlet_num_dims; /* e.g. the vocabulary size */#define COUNTS(BI,DI) (counts[(BI * num_dims) + DI])/* Read NUM_BAGS, NUM_DIMS, allocate COUNTS and read them from FP */voiddirichlet_read_counts (FILE *fp, double **counts_ptr, int *num_bags_ptr, int *num_dims_ptr){ int i, j, num_bags, num_bags_capacity, num_dims; double *counts; fscanf (fp, "%d", &num_dims); assert (num_dims > 1); num_bags_capacity = 32; counts = malloc (num_bags_capacity * num_dims * sizeof (double)); for (i = 0; fscanf (fp, "%lf", counts+i) == 1; i++) { if (i+1 >= num_bags_capacity * num_dims) { num_bags_capacity *= 2; counts = realloc (counts, num_bags_capacity * num_dims * sizeof (double)); } //printf ("count[%d][%d] = %g\n", i/num_dims, i%num_dims, counts[i]); } if (i % num_dims) { fprintf (stderr, "Counts must be input in groups of %d\n", num_dims); exit (1); } num_bags = i / num_dims; *num_bags_ptr = num_bags; *num_dims_ptr = num_dims; *counts_ptr = counts;}/* Calculate the alpha parameters of a Dirichlet by moment matching and place them in ALPHAS. Return the sum of the alphas. */doubledirichlet_moment_match (int num_dims, int num_bags, double *counts, double *alphas){ /* The "sample" we will calculate mean and variance for is the "proportion" of each token type */ double *sample_mean = alloca (num_dims * sizeof (double)); double *sample_variance = alloca (num_dims * sizeof (double)); double *bag_total = alloca (num_bags * sizeof (double)); double x, y, alphas_sum; int i, j; /* Count the total number of words in each bag */ for (i = 0; i < num_bags; i++) { bag_total[i] = 0; for (j = 0; j < num_dims; j++) bag_total[i] += COUNTS(i,j); assert (bag_total[i]); } /* Calculate the sample mean for each dimension, j. This is = (1/#bags) * Sum_bags (count[bag][j]/total[bag]) */ for (j = 0; j < num_dims; j++) { sample_mean[j] = 0; for (i = 0; i < num_bags; i++) sample_mean[j] += COUNTS(i,j) / (bag_total[i] * num_bags); assert (sample_mean[j] == sample_mean[j]); }#if 0 for (j = 0; j < num_dims; j++) printf ("sample mean alpha[%d] = %g\n", j, sample_mean[j]);#endif /* Calculate the sample variance for each dimension, j. This is = E[x^2] - E[x]^2 = [(1/(#bags-1)) * Sum_bags (COUNTS(i,j) / bag_total[i])^2] - sample_mean[j]^2 */ for (j = 0; j < num_dims; j++) { sample_variance[j] = 0; for (i = 0; i < num_bags; i++) { x = COUNTS(i,j) / bag_total[i]; sample_variance[j] += x * x; } sample_variance[j] /= num_bags; /* We now have E[x^2] */ sample_variance[j] -= (sample_mean[j] * sample_mean[j]); /* We now have E[x^2] - E[x]^2 */ assert (sample_variance[j] == sample_variance[j]); } /* Calculate the sum of the alphas */ x = 0; for (j = 0; j < num_dims - 1; j++) { assert (sample_variance[j] > 0); y = ((sample_mean[j] * (1 - sample_mean[j])) / sample_variance[j]) - 1; assert (y > 0); x += log (y); assert (x == x); } x *= 1.0 / (num_dims - 1.0); alphas_sum = exp (x); assert (alphas_sum == alphas_sum); for (j = 0; j < num_dims; j++) alphas[j] = sample_mean[j] * alphas_sum; return alphas_sum;}double log_gamma(double x){ double result, y, xnum, xden; int i; static double d1 = -5.772156649015328605195174e-1; static double p1[] = { 4.945235359296727046734888e0, 2.018112620856775083915565e2, 2.290838373831346393026739e3, 1.131967205903380828685045e4, 2.855724635671635335736389e4, 3.848496228443793359990269e4, 2.637748787624195437963534e4, 7.225813979700288197698961e3 }; static double q1[] = { 6.748212550303777196073036e1, 1.113332393857199323513008e3, 7.738757056935398733233834e3, 2.763987074403340708898585e4, 5.499310206226157329794414e4, 6.161122180066002127833352e4, 3.635127591501940507276287e4, 8.785536302431013170870835e3 }; static double d2 = 4.227843350984671393993777e-1; static double p2[] = { 4.974607845568932035012064e0, 5.424138599891070494101986e2, 1.550693864978364947665077e4, 1.847932904445632425417223e5, 1.088204769468828767498470e6, 3.338152967987029735917223e6, 5.106661678927352456275255e6, 3.074109054850539556250927e6 }; static double q2[] = { 1.830328399370592604055942e2, 7.765049321445005871323047e3, 1.331903827966074194402448e5, 1.136705821321969608938755e6, 5.267964117437946917577538e6, 1.346701454311101692290052e7, 1.782736530353274213975932e7, 9.533095591844353613395747e6 }; static double d4 = 1.791759469228055000094023e0; static double p4[] = { 1.474502166059939948905062e4, 2.426813369486704502836312e6, 1.214755574045093227939592e8, 2.663432449630976949898078e9, 2.940378956634553899906876e10, 1.702665737765398868392998e11, 4.926125793377430887588120e11, 5.606251856223951465078242e11 }; static double q4[] = { 2.690530175870899333379843e3, 6.393885654300092398984238e5, 4.135599930241388052042842e7, 1.120872109616147941376570e9, 1.488613728678813811542398e10, 1.016803586272438228077304e11, 3.417476345507377132798597e11, 4.463158187419713286462081e11 }; static double c[] = { -1.910444077728e-03, 8.4171387781295e-04, -5.952379913043012e-04, 7.93650793500350248e-04, -2.777777777777681622553e-03, 8.333333333333333331554247e-02, 5.7083835261e-03 }; static double a = 0.6796875; if((x <= 0.5) || ((x > a) && (x <= 1.5))) { if(x <= 0.5) { result = -log(x); /* Test whether X < machine epsilon. */ if(x+1 == 1) { return result; } } else { result = 0; x = (x - 0.5) - 0.5; } xnum = 0; xden = 1; for(i=0;i<8;i++) { xnum = xnum * x + p1[i]; xden = xden * x + q1[i]; } result += x*(d1 + x*(xnum/xden)); } else if((x <= a) || ((x > 1.5) && (x <= 4))) { if(x <= a) { result = -log(x); x = (x - 0.5) - 0.5; } else { result = 0; x -= 2; } xnum = 0; xden = 1; for(i=0;i<8;i++) { xnum = xnum * x + p2[i]; xden = xden * x + q2[i]; } result += x*(d2 + x*(xnum/xden)); } else if(x <= 12) { x -= 4; xnum = 0; xden = -1; for(i=0;i<8;i++) { xnum = xnum * x + p4[i]; xden = xden * x + q4[i]; } result = d4 + x*(xnum/xden); } /* X > 12 */ else { y = log(x); result = x*(y - 1) - y*0.5 + .9189385332046727417803297; x = 1/x; y = x*x; xnum = c[6]; for(i=0;i<6;i++) { xnum = xnum * y + c[i]; } xnum *= x; result += xnum; } return result;}doubledirichlet_multinomial_log_evidence (int num_dims, int num_bags, double *counts, double *alphas){ double evidence, alphas_sum; double *bag_total = alloca (num_bags * sizeof (double)); int i, j; /* Calculate the sum of the alphas */ alphas_sum = 0; for (j = 0; j < num_dims; j++) alphas_sum += alphas[j]; /* Calculate the bag totals */ for (i = 0; i < num_bags; i++) { bag_total[i] = 0; for (j = 0; j < num_dims; j++) bag_total[i] += COUNTS(i,j); } evidence = 0; for (i = 0; i < num_bags; i++) { evidence += (log_gamma (alphas_sum) - log_gamma (bag_total[i] + alphas_sum)); for (j = 0; j < num_dims; j++) evidence += log_gamma(COUNTS(i,j) + alphas[j]) - log_gamma(alphas[j]); } return evidence;}voidprint_usage (const char *argv[]){ fprintf (stderr, "usage: \n");}intmain (int argc, const char *argv[]){ double sum; int i, j, argi, num_classes; /* Can be difference from num_dims when there are multiple classes */ int num_alphas; int index_of_correct_class = -1; num_classes = 0; for (argi = 1; argi < argc; argi++) { if (argv[argi][0] != '-') break; switch (argv[argi][1]) { case 'c': /* Do classification of bags according to evidence from several different dirichlets */ num_classes = atoi (argv[++argi]); break; case 'I': index_of_correct_class = atoi (argv[++argi]); break; default: fprintf (stderr, "%s: unrecognized option `%s'\n", argv[0], argv[argi]); print_usage (argv); exit (-1); } } if (argi < argc) { /* Get the alphas from the command line and then calculate the evidence of the counts read in on stdin */ int dirichlet_num_dims_capacity = 32; double evidence; dirichlet_alphas = malloc (dirichlet_num_dims_capacity * sizeof(double)); for (num_alphas = 0; argi < argc; argi++, num_alphas++) { if (num_alphas >= dirichlet_num_dims_capacity) { dirichlet_num_dims_capacity *= 2; dirichlet_alphas = realloc (dirichlet_alphas, dirichlet_num_dims_capacity * sizeof(double)); } dirichlet_alphas[num_alphas] = atof (argv[argi]); //printf("alpha[%d] = %g\n",num_alphas,dirichlet_alphas[num_alphas]); } dirichlet_read_counts (stdin, &dirichlet_counts, &dirichlet_num_bags, &dirichlet_num_dims); assert ((num_classes && num_alphas % dirichlet_num_dims == 0) || num_alphas == dirichlet_num_dims); if (num_classes) { double *ev = alloca (num_classes * sizeof (double)); double max_ev; int c, max_c, num_bags_correct; assert (num_alphas == dirichlet_num_dims * num_classes); num_bags_correct = 0; for (i = 0; i < dirichlet_num_bags; i++) { max_c = -1; max_ev = -DBL_MAX; for (c = 0; c < num_classes; c++) { ev[c] = dirichlet_multinomial_log_evidence (dirichlet_num_dims, 1, dirichlet_counts + (i * dirichlet_num_dims), dirichlet_alphas + (c * dirichlet_num_dims)); if (ev[c] > max_ev) { max_ev = ev[c]; max_c = c; } } if (index_of_correct_class != -1 && max_c == index_of_correct_class) num_bags_correct++; printf ("bag[%d] class=%d ", i, max_c); for (c = 0; c < num_classes; c++) printf ("class[%d]=%g ", c, ev[c]); printf ("\n"); } if (index_of_correct_class != -1) printf ("Correct %d/%d = %g\n", num_bags_correct, dirichlet_num_bags, ((float)num_bags_correct)/dirichlet_num_bags); } else { evidence = dirichlet_multinomial_log_evidence (dirichlet_num_dims, dirichlet_num_bags, dirichlet_counts, dirichlet_alphas); printf ("log(evidence) = %g\n", evidence); } } else { /* Read the counts on stdin, then calculate the alphas by moment matching */ dirichlet_read_counts (stdin, &dirichlet_counts, &dirichlet_num_bags, &dirichlet_num_dims); dirichlet_alphas = malloc (dirichlet_num_dims * sizeof (double)); sum = dirichlet_moment_match (dirichlet_num_dims, dirichlet_num_bags, dirichlet_counts, dirichlet_alphas); fprintf (stderr, "n = %d\n", dirichlet_num_bags); fprintf (stderr, "sum = %g\np = ", sum); for (j = 0; j < dirichlet_num_dims; j++) fprintf (stderr, "%9g ", dirichlet_alphas[j] / sum); fprintf (stderr, "\nalphas =\n "); for (j = 0; j < dirichlet_num_dims; j++) printf ("%15f\n", dirichlet_alphas[j]); } exit (0);}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -