📄 rocchio-main.c
字号:
/* bug ???? cross-validation also somehow seems to throw off the similarity computation*/#include <stdio.h>#include "ripper.h"#include "protos.h"#include "mdb.h"static double binomial_std_err(double,int);static double rocchio_error_rate(vec_t *,char *,BOOL);static void model_rocchio(vec_t *);static int eval_final_rule(vec_t *);static double eval_rule(vec_t *);static double inner_product(vec_t *,double *);static void alloc_wts();/* adjustable parameters */static double Beta;static double Gamma;static BOOL Transform_data;static BOOL Two_prototypes;/* what to optimize */static enum {ACCURACY,F1} Metric;BOOL Predict;BOOL Predict_probabilities;/* for probabilistic prediction, where one * outputs a probability 0-1 */static double Max_sim;#define words(exi) (vref(aval_t,exi->inst,0)->u.set)/* for transforming data */static double *Sum_wt,*Sum_wt_p,*Sum_wt_n;/* interface to cross-validation*/static void train_rocchio(vec_t *data){ model_rocchio(data);}static double test_rocchio(vec_t *data){ return rocchio_error_rate(data,"foobar",FALSE);}/******************************************************************************/char *Program="rocchio";char *Help_str[] = { "syntax: rocchio [options] [stem]", " use Rocchio's method to learn a classifier", "", "options are:", " -s: read from std input", " -kN estimate error rate by N-fold cross-validation", " -l estimate error rate via leave-one-out method", " -Ln set loss ratio", " -2 use two prototype vectors", " -Pn set factor for positive prototype (default 16)", " -Nn set factor for positive prototype (default 4)", " -v#: set verbosity", " -m s: set threshold to optimize metric s", " s must be 'accuracy' or 'f-measure' (default 'accuracy')", " -p: print predictions for test cases to stem.pred", " -X: transform data by adding distance to prototype", " (creates stem_cos.data, stem_cos.test)", " -R: rank test cases with 'probabilistic' predictions", NULL};main(argc,argv)int argc;char *argv[];{ vec_t *train_data; vec_t *test_data; char *stem=NULL; BOOL use_stdin; int o; int folds; double err; BOOL crossval; double loss_ratio; example_t *exi; int i; FILE *fptest,*fptrain; crossval = FALSE; Class_ordering = INCFREQ; /* defaults */ use_stdin = FALSE; set_trace_level(SUMM); Metric = ACCURACY; Predict = FALSE; FP_cost = FN_cost = 1.0; Predict_probabilities = FALSE; Two_prototypes = FALSE; Beta = 16.0; Gamma = 4.0; while ((o=getopt(argc,argv,"stv:m:pP:N:k:lL:hRX2"))!=EOF) { switch (o) { case 'k': crossval = TRUE; folds = atoi(optarg); break; case 'l': crossval = TRUE; folds = 0; break; case 's': use_stdin = TRUE; break; case 'R': Predict = TRUE; Predict_probabilities = TRUE; break; case 'X': Transform_data = TRUE; break; case 'p': Predict = TRUE; break; case 'm': if (optarg[0]=='a' || optarg[0]=='A') { Metric = ACCURACY; } else { Metric = F1; } case 'L': loss_ratio = atof(optarg); FP_cost = 2.0*loss_ratio/(loss_ratio+1.0); FN_cost = 2.0/(loss_ratio+1.0); trace(SUMM) { printf("option: ratio of cost of FP to cost of FN is %g:%g\n", FP_cost,FN_cost); } break; case 'P': Beta = atof(optarg); break; case 'N': Gamma = atof(optarg); break; case '2': Two_prototypes = TRUE; break; case 'v': set_trace_level(atoi(optarg)); break; case '?': default: give_help(); if (o=='h') exit(0); else fatal("option not implemented"); } } test_data = NULL; if (optind<argc) { stem = argv[optind++]; ld_names(add_ext(stem,".names")); if (use_stdin) train_data = ld_data(NULL); else { train_data = ld_data(add_ext(stem,".data")); test_data = ld_data(add_ext(stem,".test")); } } else { train_data = ld_data(NULL); } if (optind<argc) { warning("not all arguments were used: %s ...",argv[optind]); } if (!train_data || vmax(train_data)==0) fatal("no examples"); if (!set_field(0) || n_fields()>1) { fatal("rocchio requires one set-valued field"); } /* allocate internally used weight vectors */ alloc_wts(); if (crossval) { cross_validate(folds,train_data,&train_rocchio,&test_rocchio); } else { model_rocchio(train_data); err = rocchio_error_rate(train_data,stem,FALSE); trace(SUMM) { printf("Train error rate: %.2f%% +/- %.2f%% (%d datapoints) <<\n", 100*err, 100*binomial_std_err(err,vmax(train_data)), vmax(train_data)); } if (test_data) { err = rocchio_error_rate(test_data,stem,Predict); trace(SUMM) { printf("Test error rate: %.2f%% +/- %.2f%% (%d datapoints) <<\n", 100*err, 100*binomial_std_err(err,vmax(test_data)), vmax(test_data)); } } if (Transform_data) { if (stem==NULL) stem="foo"; if ((fptrain = fopen(add_ext(stem,"_cos.data"),"w"))==NULL) { fatal("can't open file for transformed training data"); } if (test_data && (fptest = fopen(add_ext(stem,"_cos.test"),"w"))==NULL) { fatal("can't open file for transformed test data"); } for (i=0; i<vmax(train_data); i++) { exi = vref(example_t,train_data,i); if (Two_prototypes) { fprintf(fptrain,"%g,%g,", inner_product(words(exi),Sum_wt_p), inner_product(words(exi),Sum_wt_n)); } else { fprintf(fptrain,"%g,",eval_rule(words(exi))); } fprint_example(fptrain,exi); } if (test_data) { for (i=0; i<vmax(test_data); i++) { exi = vref(example_t,test_data,i); if (Two_prototypes) { fprintf(fptest,"%g,%g,", inner_product(words(exi),Sum_wt_p), inner_product(words(exi),Sum_wt_n)); } else { fprintf(fptest,"%g,",eval_rule(words(exi))); } fprint_example(fptest,exi); } } /* if test data */ } /* if transform data */ } /* else not cross validating */}static double binomial_std_err(double err_rate,int n){ return sqrt( (err_rate)*(1-err_rate)/((double)n-1) );}/* compute error rate and maybe print predictions */static double rocchio_error_rate(vec_t *data,char *stem,BOOL print_predictions){ int i; example_t *exi; symbol_t *pos_cl,*neg_cl; symbol_t *actual,*predicted; ex_count_t err,tot; double p; FILE *fp=NULL; if (print_predictions) { fp = fopen(add_ext(stem,".pred"),"w"); if (fp==NULL) error("can't open prediction file %s",add_ext(stem,".pred")); } pos_cl = vref(atom_t,Classes,0)->nom; neg_cl = vref(atom_t,Classes,1)->nom; trace (LONG) printf("// evaluating weights on %d examples\n",vmax(data)); err = tot = 0; for (i=0; i<vmax(data); i++) { exi = vref(example_t,data,i); tot += exi->wt; actual = exi->lab.nom; predicted = eval_final_rule(words(exi)) ? pos_cl : neg_cl; if (actual!=predicted) err += exi->wt; if (fp!=NULL) { if (Predict_probabilities) { p = eval_rule(words(exi))/Max_sim; fprint_symbol(fp,actual); fprintf(fp," %.20f ",p); fprint_symbol(fp,predicted); fprintf(fp,"\n"); } else { fprint_symbol(fp,predicted); fprintf(fp," 1 1 "); fprint_symbol(fp,actual); fprintf(fp,"\n"); } } } if (fp!=NULL) fclose(fp); return err/tot;}/****************************************************************************//* implementation of this algorithm stolen wholesale from Rob */#define NUM_THRESH_INC 100#define MAX_LOG_PRECOMP (256)#ifndef MAXDOUBLE#define MAXDOUBLE MAXREAL#endifstatic int *Seen;static double *Wt;static double *Feature_wt;static double Log_precomp[MAX_LOG_PRECOMP];static double ltc_wt(int, int);static void init_wts(symbol_t *,vec_t *);static void find_threshold(symbol_t *,vec_t *);static void finalize_wts(symbol_t *,vec_t *);static double Final_thresh; /************** classify an instance************/static int eval_final_rule(vec_t *inst){ if (Two_prototypes) { return (inner_product(inst,Sum_wt_p) >= inner_product(inst,Sum_wt_n)); } else { return (eval_rule(inst) >= Final_thresh); }}/************** computes similarity of an instance to the prototype for pos************/static double eval_rule(vec_t *inst) { return inner_product(inst,Sum_wt);}/************** computes similarity of an instance a given prototype vector************/static double inner_product(vec_t *inst,double *sum_wt){ int k,aix;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -