📄 rocchio-main.c
字号:
double sum, sq_sum; double w; /* if inst is word wi, set Seen[k] to f_ik */ for (k=0; k<vmax(inst); k++) { aix = (*vref(symbol_t *,inst,k))->index; Seen[aix]++; } /* compute sum_{k=1}^t (w_{ik} * w_{+k}) w_{+k} * tf_{ik} * log(N_D/n_k) = sum_{k=1}^t ---------------------- sqrt(sum_{j=1}^t [tf_{ij} * log(N_D/n_k)]^2 sq_sum is the thing inside sqrt, Sum_wt[k] is w_{+k} on iteration j, w is tf_{ik} * log(N_D/n_k) */ sum = sq_sum = 0.0; for (k=0; k<vmax(inst); k++) { aix = (*vref(symbol_t *,inst,k))->index; if (Seen[aix]) { w = ltc_wt(Seen[aix], aix); sq_sum += w*w; sum += w*sum_wt[aix]; Seen[aix] = 0; } } if (sum==0.0) return 0; else return sum/sqrt(sq_sum);}/****************************************************************************//********************** top-level learning algorithm**** see description in Ittner, Lewis, Ahn 94**********************/static void model_rocchio(vec_t *data){ int i,j; double m; double d; double sq_sum; example_t *exi; int ajx; symbol_t *pos_cl; ex_count_t p,n; atom_t *c0,*c1; vec_t *seti; /* count and order the classes--pos_cl is minority class */ if (vmax(Classes)!=2) { fatal("this implementation of rocchio only handles two-class problems"); } count_classes(data); reorder_classes(data); pos_cl = vref(atom_t,Classes,0)->nom; p = Class_counts[0]; n = Class_counts[1]; trace(SUMM) { printf("// %g examples (%g/%g)\n",p+n,p,n); printf("// class '%s' is treated as positive\n",pos_cl->name); fflush(stdout); } /* initialize the weights--- for each word index i, initialize Seen[i] to zero, Sum_wt[i] to zero, and set Feature_wt[i] to $N_D/n_k$ */ trace(SUMM) { printf("// initializing weights...\n"); fflush(stdout); } init_wts(pos_cl,data); trace(SUMM) { printf("// setting weights...\n"); fflush(stdout); } /* for each word i set Sum_wt[j] to w'_{+j} where w_{+j} = beta/|R_+| * sum_{i in R_+} w_ij - gamma/|R_-| * sum_{i in R_-} w_ij */ for (i=0; i<vmax(data); i++) { exi = vref(example_t,data,i); seti = words(exi); /* count occurances of each word j in document i ie Seen[i] is set to fik */ for (j=0; j<vmax(seti); j++) { ajx = (*vref(symbol_t *,seti,j))->index; Seen[ajx]++; } /* set Wt[j] = tf_{ij} log(N_D/N) and sq_sum to sum_{j=1}^t [ tf_{ij} log(N_D/N) ]^2 */ sq_sum = 0.0; for (j=0; j<vmax(seti); j++) { ajx = (*vref(symbol_t *,seti,j))->index; if (Seen[ajx] > 0) { Wt[ajx] = ltc_wt(Seen[ajx], ajx); sq_sum += Wt[ajx]*Wt[ajx]; Seen[ajx] = -1; /* only update for each word once */ } } /* increment Sum_wt appropriately */ if (Two_prototypes) { m = (exi->lab.nom==pos_cl ? 1.0/ (p * sqrt(sq_sum)) : 1.0/ (n * sqrt(sq_sum)) ); } else { m = (exi->lab.nom==pos_cl ? Beta / (p * sqrt(sq_sum)) : -Gamma / (n * sqrt(sq_sum)) ); } for (j=0; j<vmax(seti); j++) { ajx = (*vref(symbol_t *,seti,j))->index; if (Seen[ajx]) { Sum_wt[ajx] += m * Wt[ajx]; if (Two_prototypes) { if (exi->lab.nom==pos_cl) { Sum_wt_p[ajx] += m * Wt[ajx]; } else { Sum_wt_n[ajx] += m * Wt[ajx]; } } Seen[ajx] = 0; } } } trace(SUMM) { printf("// finalizing weights...\n"); fflush(stdout); } /* sets Sum_wt[i] to w_{+k} and finds threshold */ finalize_wts(pos_cl,data);}/************* computes (log(f)+1)*log(number docs/number docs with feature i)***********/static double ltc_wt(int f, int i){ return (f < MAX_LOG_PRECOMP ? Log_precomp[f] : 1.0 + log(f)) * Feature_wt[i];}/************* Allocates the weight vectors ***********/static void alloc_wts(){ Seen = newmem(n_symbolic_values(),int); Wt = newmem(n_symbolic_values(),double); Sum_wt = newmem(n_symbolic_values(),double); if (Two_prototypes) { Sum_wt_p = newmem(n_symbolic_values(),double); Sum_wt_n = newmem(n_symbolic_values(),double); } Feature_wt = newmem(n_symbolic_values(),double);}/************* Initializes the weights of the weight vector.***********/static void init_wts(symbol_t *pos_cl,vec_t *data){ int i; ex_count_t p,n; for (i=0; i<n_symbolic_values(); i++) { Seen[i] = 0; Wt[i] = Feature_wt[i] = 0.0; Sum_wt[i] = 0.0; if (Two_prototypes) { Sum_wt_p[i] = Sum_wt_n[i] = 0.0; } } compute_field_stats(pos_cl,0,data); for (i = 0; i < n_visited_symbols(); i++) { pos_field_stat(visited_symbol(i),&p,&n); Seen[i] = 0; Sum_wt[i] = 0.0; if (Two_prototypes) { Sum_wt_p[i] = Sum_wt_n[i] = 0.0; } if (p+n > 0) { Feature_wt[i] = log(((double) vmax(data))/(p+n)); } } Log_precomp[0] = 0.0; for (i = 1; i < MAX_LOG_PRECOMP; i++) Log_precomp[i] = 1.0 + log(i);}/************* Finalizes weights for use by final evaluation rule.***********/static void finalize_wts(symbol_t *pos_cl,vec_t *data){ int i; for (i = 0; i < n_symbolic_values(); i++) { if (Sum_wt[i] < 0.0) { Sum_wt[i] = 0.0; } if (Two_prototypes) { if (Sum_wt_p[i] < 0.0) { Sum_wt_p[i] = 0.0; } if (Sum_wt_n[i] < 0.0) { Sum_wt_n[i] = 0.0; } } } find_threshold(pos_cl,data);}/************** finds best threshold for weight function ************/typedef struct pair_s { double val; BOOL lab;} pair_t;static int compare_pair(pi,pj)pair_t *pi,*pj;{ if (pi->val > pj->val) return 1; else if (pi->val == pj->val) return 0; else return -1;}static void find_threshold(symbol_t *pos_cl,vec_t *data){ int t; example_t *ext; pair_t *pair; double bestf,besty,f; double y; int c[2][2],n_pos,n_neg; /* allocate and sort pairs, and count # pos/neg examples */ n_pos = n_neg = 0; pair = newmem(vmax(data),pair_t); for (t=0; t<vmax(data); t++) { ext = vref(example_t,data,t); pair[t].val = eval_rule(words(ext)); pair[t].lab = (ext->lab.nom == pos_cl); if (pair[t].lab) n_pos++; else n_neg++; } qsort((char *)pair,vmax(data),sizeof(pair_t),&compare_pair); /* find greatest similarity value for training data */ Max_sim = pair[vmax(data)-1].val; trace(DBG1) { printf("// sorted pairs:\n"); for (t=0; t<vmax(data); t++) { printf("// %3d: (%g,%c)\n", t,pair[t].val,"-+"[pair[t].lab]); } } bestf = -MAXREAL; /* c[act][pred] will initially be confusion matrix for lowest possible threshold---ie all predicted pos */ c[0][0] = 0; /* # neg < threshold y */ c[1][0] = 0; /* # pos < threshold y */ c[0][1] = n_neg; /* # neg >= threshold y */ c[1][1] = n_pos; /* # pos >= threshold y */ for (t=0; t<vmax(data); ) { y = pair[t].val; /* see if current confusion matrix is best */ if (Metric==F1) { f = (2.0*c[1][1])/(2*c[1][1]+c[1][0]+c[0][1]); } else if (Metric==ACCURACY) { f = -(FP_cost * c[0][1]) - (FN_cost * c[1][0]); } else fatal("unknown optimization metric!"); if (f>bestf) { if (t==0) besty = y; else besty = (y + pair[t-1].val)/2.0; bestf = f; trace(DBUG) { printf("// use=T "); } } else { trace(DBUG) { printf("// use=F "); } } /* update confusion matrix for next threshold, and advance t to point to that new threshold (with val!=y) */ while (t<vmax(data) && pair[t].val==y) { c[pair[t].lab][0]++; c[pair[t].lab][1]--; trace(DBUG) { printf("%3d: (%g,%c)\tv=%g\t00=%d 01=%d 10=%d 11=%d\n", t,pair[t].val,"-+"[pair[t].lab],f, c[0][0],c[0][1],c[1][0],c[1][1]); } t++; } } /* also consider the default concept, which rejects everything */ if (vmax(data)>0) { /* make a value bigger than the max value */ y = (pair[vmax(data)-1].val+1.0) * 1.1; if (Metric==F1) { f = (2.0*c[1][1])/(2*c[1][1]+c[1][0]+c[0][1]); } else if (Metric==ACCURACY) { f = -(FP_cost * c[0][1]) - (FN_cost * c[1][0]); } else fatal("unknown optimization metric!"); if (f>=bestf) { besty = y; bestf = f; trace(DBUG) { printf("// using [all neg] threshold=%g (value %f)\n", besty,bestf); } } } trace(LONG) { printf("// best threshold=%g (value %f)\n",besty,bestf); } Final_thresh = besty; freemem(pair);}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -