📄 multiclass.c
字号:
for (cisi = 0; cisi < cis_size+2; cisi++) { //doc->cis_mixture[cisi] = exp (doc->cis_mixture[cisi] - max); cis_mixture_sum += doc->cis_mixture[cisi]; } bow_verbosify (bow_verbose, "%s ", doc->filename); for (cisi = 0; cisi < cis_size+2; cisi++) { doc->cis_mixture[cisi] /= cis_mixture_sum; bow_verbosify (bow_verbose, "%s=%g,", (cisi < cis_size ? bow_int2str (crossbow_classnames, doc->cis[cisi]) : (cisi == cis_size ? "root" : "uniform")), doc->cis_mixture[cisi]); } bow_verbosify (bow_verbose, "\n"); } } /* Normalize all per-word M-step results */ bow_treenode_set_words_from_new_words_all (crossbow_root, 0.0); bow_treenode_set_prior_and_extra_from_new_prior_all (crossbow_root, &multiclass_uniform_new_prior, &multiclass_uniform_prior, 0.0); /* xxx Increase this M? */ cmixture_set_from_new (0, 0.01, 1); bow_verbosify (bow_progress, "PP2=%g\n", -log_prob_of_data2 / num_data_words); return exp (-log_prob_of_data / num_data_words);}voidmulticlass_train (){ int iteration, max_num_iterations = 999999; double pp, old_pp; void print_diagnostics () { treenode *iterator, *tn; printf ("uniform prior=%g\n", multiclass_uniform_prior); for (iterator = crossbow_root; (tn = bow_treenode_iterate_all (&iterator));) { printf ("%s prior=%g\n", tn->name, tn->prior); bow_treenode_word_probs_print (tn, 10); printf ("\n"); //bow_treenode_word_likelihood_ratios_print (tn, 5); } cmixture_print_diagnostics (stdout); } bow_treenode_set_lambdas_leaf_only_all (crossbow_root); multiclass_place_labeled_data (); print_diagnostics (); for (iteration = 0, old_pp = 3000, pp = 2000; iteration < max_num_iterations && (old_pp - pp) > 0.0001; iteration++) { old_pp = pp; pp = multiclass_em_one_iteration (); printf ("PP=%g\n", pp); if (old_pp < pp) bow_verbosify (bow_progress, "Perplexity rose!\n"); print_diagnostics (); }}voidbow_sort_scores (bow_score *scores, int count){ static int score_compare (const void *x, const void *y) { if (((bow_score *)x)->weight > ((bow_score *)y)->weight) return -1; else if (((bow_score *)x)->weight == ((bow_score *)y)->weight) return 0; else return 1; } qsort (scores, count, sizeof (bow_score), score_compare);}doublemulticlass_log_prob_of_classes_given_doc (int *cis, int cis_size, crossbow_doc *doc){ int cisi, wvi, actual_cis_size; double *mixture; double log_prob_of_classes; bow_wv *wv; int wv_word_count; //int num_mixtures = cis_size + 2; double pr_w; cmixture *m; static int verbose = 0; static int factored_prior = 0; /* Get the mcombo entry for this set of classes */ m = cmixture_for_cis (cis, cis_size, 0, &actual_cis_size); assert (actual_cis_size == cis_size); /* Allocate space for word-specific mixture weights */ mixture = alloca ((cis_size + 1 + 1) * sizeof (double)); multiclass_mixture_given_cis (cis, cis_size, mixture); /* Incoporate the prior of this class combination */ if (factored_prior) { log_prob_of_classes = 0; for (cisi = 0; cisi < cis_size; cisi++) if (cis[cisi] >= 0) log_prob_of_classes += log (crossbow_root->children[cis[cisi]]->prior); log_prob_of_classes += log (crossbow_root->prior); log_prob_of_classes += log (multiclass_uniform_prior); } else { /* If the CIS mixture includes a class that has no training data, then reject by returning an impossibly low score. */ for (cisi = 0; cisi < cis_size; cisi++) if (crossbow_root->children[cis[cisi]]->prior == 0) return -FLT_MAX; if (m) log_prob_of_classes = log (m->prior); else log_prob_of_classes = log (multiclass_mixture_prior_alpha / multiclass_mixture_prior_normalizer); } wv = crossbow_wv_at_di (doc->di); wv_word_count = bow_wv_word_count (wv); for (wvi = 0; wvi < wv->num_entries; wvi++) {#if 0 /* Get "complete"-knowledge mixture weights specific to this word */ static int complete = 0; double *mixture_weights; double mixture_sum; if (complete) { mixture_sum = 0; for (cisi = 0; cisi < cis_size; cisi++) if (cis[cisi] >= 0) { mixture[cisi] = mixture_weights[cisi] * crossbow_root->children[cis[cisi]]->words[wv->entry[wvi].wi]; mixture_sum += mixture[cisi]; } mixture[cis_size] = mixture_weights[cis_size] * crossbow_root->words[wv->entry[wvi].wi]; mixture_sum += mixture[cis_size]; mixture[cis_size+1] = mixture_weights[cis_size+1] * 1.0 / bow_num_words (); mixture_sum += mixture[cis_size+1]; /* Normalize them */ for (cisi = 0; cisi < cis_size+2; cisi++) if (cis[cisi] >= 0) mixture[cisi] /= mixture_sum; }#endif pr_w = 0; for (cisi = 0; cisi < cis_size; cisi++) { if (cis[cisi] >= 0) pr_w += mixture[cisi] * crossbow_root->children[cis[cisi]]->words[wv->entry[wvi].wi]; } pr_w += mixture[cis_size] * crossbow_root->words[wv->entry[wvi].wi]; pr_w += mixture[cis_size+1] * 1.0 / bow_num_words (); assert (pr_w > 0); log_prob_of_classes += wv->entry[wvi].count * log (pr_w); if (verbose) { fprintf (stdout, "%04d %06d %-16s %12.3f %12g ", doc->di, wv->entry[wvi].wi, bow_int2word (wv->entry[wvi].wi), -log_prob_of_classes, pr_w); for (cisi = 0; cisi < cis_size; cisi++) if (cis[cisi] >= 0) fprintf (stdout, "%s,", bow_int2str (crossbow_classnames, cis[cisi])); fprintf (stdout, "\n"); } }#define USE_BIC 0#if USE_BIC return log_prob_of_classes - (num_mixtures-1) / 2 * log(wv_word_count);#else return log_prob_of_classes;#endif}#if 0static voidmulticlass_classify_doc_into_single_class (crossbow_doc *doc, bow_score *scores, int scores_count){ int ci; int cis[3]; assert (scores_count >= crossbow_root->children_count); cis[1] = cis[2] = -1; for (ci = 0; ci < crossbow_root->children_count; ci++) { cis[0] = ci; scores[ci].di = ci; scores[ci].weight = multiclass_log_prob_of_classes_given_doc (cis, 1, doc); } bow_sort_scores (scores, crossbow_root->children_count);}#endifstatic intmulticlass_cis_is_in_top (int *cis, multiclass_score *scores, int top_count){ int si; int ci; /* Temporarily always say yes. */ return 1; for (si = 0; si < top_count; si++) { for (ci = 0; ci < 3; ci++) if (scores[si].c[ci] != cis[ci]) break; if (ci == 3) /* All CIS were matched, return `yes'. */ return 1; } /* CIS not found in the TOP_COUNT entries of SCORES. Return `no'. */ return 0;}/* Good below *//* Returns 0 when there is no next */intmulticlass_next_cis (int *cis, int cis_size){ int cisi = cis_size - 1; bow_error ("Not implemented"); while (cis[cisi] < crossbow_root->children_count) { cis[cis_size]++; return 1; } return 0;}intmulticlass_cis_scores_index (int *cis, int cis_size, multiclass_score *scores, int scores_count){ int si, cisi; assert (cis_size <= MAX_NUM_MIXTURE_CLASSES); for (si = 0; si < scores_count; si++) { for (cisi = 0; cisi < cis_size; cisi++) { if (scores[si].c[cisi] != cis[cisi]) goto next_si; else if (cis[cisi] == -1) break; } return si; next_si: } return -1;}/* Allow CIS's that have size 3 or less, or are already in the training data. */intmulticlass_artificially_prune_cis (int *cis, int cis_size){ return (cis_size > 3 && !cmixture_for_cis (cis, cis_size, 0, NULL));}/* Greedily add classes to CIS by P(c|d,\vec{c}). */intmulticlass_explore_cis_greedy0 (crossbow_doc *doc, multiclass_score *scores, int *scores_count, int scores_capacity, const int *cis, int cis_size, int cis_capacity, const int *exclude_cis, int exclude_cis_size, int exclude_cis_capacity){ int nc = crossbow_root->children_count; int *local_cis, local_cis_size, *local_exclude_cis, local_exclude_cis_size; int cisi, ci, ci2, si, max_si = -1; int max_ci, max_ci2; double max_score; local_cis_size = cis_size + 1; if (local_cis_size > MAX_NUM_MIXTURE_CLASSES) return 0; local_exclude_cis = alloca (exclude_cis_capacity * sizeof(int)); local_cis = alloca (cis_capacity * sizeof(int)); local_exclude_cis_size = exclude_cis_size; for (cisi = 0; cisi < exclude_cis_capacity; cisi++) local_exclude_cis[cisi] = exclude_cis[cisi]; for (cisi = 0; cisi < cis_capacity; cisi++) local_cis[cisi] = cis[cisi]; max_score = -FLT_MAX; max_ci = -1; for (ci = 0; ci < nc; ci++) { if (crossbow_root->children[ci]->prior == 0) goto next_class1; for (cisi = 0; cisi < exclude_cis_size; cisi++) if (exclude_cis[cisi] == ci) goto next_class1; for (cisi = 0; cisi < cis_size; cisi++) if (cis[cisi] == ci) goto next_class1; /* Copy the old CIS into LOCAL_CIS, plus the new class */ for (cisi = 0; cisi < cis_size; cisi++) local_cis[cisi] = cis[cisi]; local_cis[cis_size] = ci; qsort (local_cis, local_cis_size, sizeof (int), compare_ints); if ((si = multiclass_cis_scores_index (local_cis, local_cis_size, scores, *scores_count)) == -1 && !multiclass_artificially_prune_cis (local_cis, local_cis_size)) { for (cisi = 0; cisi < MAX_NUM_MIXTURE_CLASSES; cisi++) scores[*scores_count].c[cisi] = local_cis[cisi]; scores[*scores_count].score = multiclass_log_prob_of_classes_given_doc (local_cis, local_cis_size, doc); if (scores[*scores_count].score > max_score) { max_score = scores[*scores_count].score; max_si = *scores_count; max_ci = ci; } (*scores_count)++; assert (*scores_count < scores_capacity); } else if (si != -1 && scores[si].score > max_score) { max_score = scores[si].score; max_si = si; max_ci = ci; } next_class1: } if (local_exclude_cis_size + 1 < exclude_cis_capacity/2 && local_exclude_cis_size < 5 && max_ci >= 0) { /* Do some exploration by making a recursive call that excludes the winner */ local_exclude_cis[local_exclude_cis_size++] = max_ci; assert (local_exclude_cis_size < exclude_cis_capacity); multiclass_explore_cis_greedy0 (doc, scores, scores_count, scores_capacity, cis, cis_size, cis_capacity, local_exclude_cis, local_exclude_cis_size, exclude_cis_capacity); local_exclude_cis_size--; local_exclude_cis[local_exclude_cis_size] = -1; } local_cis_size = cis_size + 2; if (local_cis_size > MAX_NUM_MIXTURE_CLASSES) return 0; max_ci = max_ci2 = -1; for (ci = 0; ci < nc; ci++) { if (crossbow_root->children[ci]->prior == 0) goto next_class2; for (cisi = 0; cisi < exclude_cis_size; cisi++) if (exclude_cis[cisi] == ci) goto next_class2; for (cisi = 0; cisi < cis_size; cisi++) if (cis[cisi] == ci) goto next_class2; for (ci2 = ci+1; ci2 < nc; ci2++) { if (crossbow_root->children[ci2]->prior == 0) goto next_class22; for (cisi = 0; cisi < exclude_cis_size; cisi++) if (exclude_cis[cisi] == ci2) goto next_class22; for (cisi = 0; cisi < cis_size; cisi++) if (cis[cisi] == ci2) goto next_class22; /* Copy the old CIS into LOCAL_CIS, plus the new classes */ for (cisi = 0; cisi < cis_size; cisi++) local_cis[cisi] = cis[cisi]; local_cis[cis_size] = ci; local_cis[cis_size+1] = ci2; qsort (local_cis, local_cis_size, sizeof (int), compare_ints); if ((si = multiclass_cis_scores_index (local_cis, local_cis_size, scores, *scores_count)) == -1 && !multiclass_artificially_prune_cis (local_cis,local_cis_size)) { for (cisi = 0; cisi < MAX_NUM_MIXTURE_CLASSES; cisi++) scores[*scores_count].c[cisi] = local_cis[cisi]; scores[*scores_count].score = multiclass_log_prob_of_classes_given_doc (local_cis, local_cis_size, doc); if (scores[*scores_count].score > max_score) { max_score = scores[*scores_count].score; max_si = *scores_count; max_ci = ci; max_ci2 = ci2; } (*scores_count)++; assert (*scores_count < scores_capacity); } else if (si != -1 && scores[si].score > max_score) { max_score = scores[si].score; max_si = si; max_ci = ci; max_ci2 = ci2; } next_class22: } next_class2: } assert (max_si >= 0); if (local_exclude_cis_size + 2 < exclude_cis_capacity/2 && local_exclude_cis_size < 5 && max_ci >= 0 && max_ci2 >= 0) { /* Do some exploration by making a recursive call that excludes the winner */ local_exclude_cis[local_exclude_cis_size++] = max_ci; local_exclude_cis[local_exclude_cis_size++] = max_ci2; assert (local_exclude_cis_size < exclude_cis_capacity); multiclass_explore_cis_greedy0 (doc, scores, scores_count, scores_capacity, cis, cis_size, cis_capacity, local_exclude_cis, local_exclude_cis_size, exclude_cis_capacity); local_exclude_cis_size--; local_exclude_cis[local_exclude_cis_size] = -1; local_exclude_cis_size--; local_exclude_cis[local_exclude_cis_size] = -1; } /* Make a recursive call */ if (cis_size + 1 < MAX_NUM_MIXTURE_CLASSES) { /* Copy the current highest scorer into LOCAL_CIS */ local_cis_size = MAX_NUM_MIXTURE_CLASSES; for (cisi = 0; cisi < MAX_NUM_MIXTURE_CLASSES; cisi++) { local_cis[cisi] = scores[max_si].c[cisi]; if (local_cis_size == 0 && scores[max_si].c[cisi] == -1) local_cis_size = cisi; } assert (local_cis_size > cis_size); if (local_cis_size < MAX_NUM_MIXTURE_CLASSES) multiclass_explore_cis_greedy0 (doc, scores, scores_count, scores_capacity, local_cis, local_cis_size, cis_capacity,
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -