⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 rocchio-main.c

📁 Ripper 分类算法
💻 C
📖 第 1 页 / 共 2 页
字号:
/* bug ???? cross-validation also somehow seems to throw off the similarity computation*/#include <stdio.h>#include "ripper.h"#include "protos.h"#include "mdb.h"static double binomial_std_err(double,int);static double rocchio_error_rate(vec_t *,char *,BOOL);static void model_rocchio(vec_t *);static int eval_final_rule(vec_t *);static double eval_rule(vec_t *);static double inner_product(vec_t *,double *);static void alloc_wts();/* adjustable parameters */static double Beta;static double Gamma;static BOOL Transform_data;static BOOL Two_prototypes;/* what to optimize */static enum {ACCURACY,F1} Metric;BOOL Predict;BOOL Predict_probabilities;/* for probabilistic prediction, where one * outputs a probability 0-1 */static double Max_sim;#define words(exi) (vref(aval_t,exi->inst,0)->u.set)/* for transforming data */static double *Sum_wt,*Sum_wt_p,*Sum_wt_n;/* interface to cross-validation*/static void train_rocchio(vec_t *data){    model_rocchio(data);}static double test_rocchio(vec_t *data){    return rocchio_error_rate(data,"foobar",FALSE);}/******************************************************************************/char *Program="rocchio";char *Help_str[] = {    "syntax: rocchio [options] [stem]",    "   use Rocchio's method to learn a classifier",    "",    "options are:",    "  -s:       read from std input",    "  -kN       estimate error rate by N-fold cross-validation",       "  -l        estimate error rate via leave-one-out method",    "  -Ln       set loss ratio",    "  -2        use two prototype vectors",    "  -Pn       set factor for positive prototype (default 16)",    "  -Nn       set factor for positive prototype (default 4)",    "  -v#:      set verbosity",    "  -m s:     set threshold to optimize metric s",    "              s must be 'accuracy' or 'f-measure' (default 'accuracy')",    "  -p:       print predictions for test cases to stem.pred",    "  -X:       transform data by adding distance to prototype",    "             (creates stem_cos.data, stem_cos.test)",    "  -R:       rank test cases with 'probabilistic' predictions",    NULL};main(argc,argv)int argc;char *argv[];{    vec_t *train_data;    vec_t *test_data;    char *stem=NULL;    BOOL use_stdin;    int o;    int folds;    double err;     BOOL crossval;    double loss_ratio;    example_t *exi;    int i;    FILE *fptest,*fptrain;        crossval = FALSE;    Class_ordering = INCFREQ;    /* defaults */    use_stdin = FALSE;    set_trace_level(SUMM);    Metric = ACCURACY;    Predict = FALSE;    FP_cost = FN_cost = 1.0;    Predict_probabilities = FALSE;    Two_prototypes = FALSE;    Beta = 16.0;    Gamma = 4.0;    while ((o=getopt(argc,argv,"stv:m:pP:N:k:lL:hRX2"))!=EOF) {	switch (o) {	  case 'k':	    crossval = TRUE;	    folds = atoi(optarg); 	    break;	  case 'l':	    crossval = TRUE;	    folds = 0;	    break;	  case 's':	    use_stdin = TRUE;	    break;	  case 'R':	    Predict = TRUE;	    Predict_probabilities = TRUE;	    break;	  case 'X':	    Transform_data = TRUE;	    break;	  case 'p':	    Predict = TRUE;	    break;	  case 'm':	    if (optarg[0]=='a' || optarg[0]=='A') {		Metric = ACCURACY;	    } else {		Metric = F1;	    }	  case 'L':	    loss_ratio = atof(optarg);	    FP_cost = 2.0*loss_ratio/(loss_ratio+1.0);	    FN_cost = 2.0/(loss_ratio+1.0);	    trace(SUMM) {		printf("option: ratio of cost of FP to cost of FN is %g:%g\n",		       FP_cost,FN_cost);	    }	    break;	  case 'P':	    Beta = atof(optarg);	    break;	  case 'N':	    Gamma = atof(optarg);	    break;	  case '2':	    Two_prototypes = TRUE;	    break;	  case 'v':	    set_trace_level(atoi(optarg)); 	    break;	  case '?':	  default: 	    give_help();	    if (o=='h') exit(0);	    else fatal("option not implemented");	}    }    test_data = NULL;    if (optind<argc) {	stem = argv[optind++];	ld_names(add_ext(stem,".names"));	if (use_stdin) train_data = ld_data(NULL);	else {	    train_data = ld_data(add_ext(stem,".data"));	    test_data = ld_data(add_ext(stem,".test"));	}    } else {	train_data = ld_data(NULL);    }    if (optind<argc) {	warning("not all arguments were used: %s ...",argv[optind]);    }    if (!train_data || vmax(train_data)==0) fatal("no examples");    if (!set_field(0) || n_fields()>1) {	fatal("rocchio requires one set-valued field");    }    /* allocate internally used weight vectors */     alloc_wts();    if (crossval) {	cross_validate(folds,train_data,&train_rocchio,&test_rocchio);    } else {	model_rocchio(train_data);	err = rocchio_error_rate(train_data,stem,FALSE);	trace(SUMM) {	    printf("Train error rate:   %.2f%% +/- %.2f%% (%d datapoints)   <<\n",		   100*err,		   100*binomial_std_err(err,vmax(train_data)),		   vmax(train_data));	}	if (test_data) {	    err = rocchio_error_rate(test_data,stem,Predict);	    trace(SUMM) {		printf("Test error rate:   %.2f%% +/- %.2f%% (%d datapoints)   <<\n",		       100*err,		       100*binomial_std_err(err,vmax(test_data)),		       vmax(test_data));	    }	}		if (Transform_data) {	    if (stem==NULL) stem="foo";	    if ((fptrain = fopen(add_ext(stem,"_cos.data"),"w"))==NULL) {		fatal("can't open file for transformed training data");	    }	    if (test_data && 		(fptest = fopen(add_ext(stem,"_cos.test"),"w"))==NULL) {		fatal("can't open file for transformed test data");	    }	    for (i=0; i<vmax(train_data); i++) {		exi = vref(example_t,train_data,i);		if (Two_prototypes) {		    fprintf(fptrain,"%g,%g,",			   inner_product(words(exi),Sum_wt_p),			   inner_product(words(exi),Sum_wt_n));		} else {		    fprintf(fptrain,"%g,",eval_rule(words(exi)));		}		fprint_example(fptrain,exi);	    }	    if (test_data) {		for (i=0; i<vmax(test_data); i++) {		    exi = vref(example_t,test_data,i);		    if (Two_prototypes) {			fprintf(fptest,"%g,%g,",				inner_product(words(exi),Sum_wt_p),				inner_product(words(exi),Sum_wt_n));		    } else {			fprintf(fptest,"%g,",eval_rule(words(exi)));		    }		    fprint_example(fptest,exi);		}	    } /* if test data */	} /* if transform data */    } /* else not cross validating */}static double binomial_std_err(double err_rate,int n){    return sqrt( (err_rate)*(1-err_rate)/((double)n-1) );}/* compute error rate and maybe print predictions */static double rocchio_error_rate(vec_t *data,char *stem,BOOL print_predictions){    int i;    example_t *exi;    symbol_t *pos_cl,*neg_cl;    symbol_t *actual,*predicted;    ex_count_t err,tot;    double p;    FILE *fp=NULL;    if (print_predictions) {	fp = fopen(add_ext(stem,".pred"),"w");	if (fp==NULL) 	  error("can't open prediction file %s",add_ext(stem,".pred"));    }    pos_cl = vref(atom_t,Classes,0)->nom;    neg_cl = vref(atom_t,Classes,1)->nom;    trace (LONG)       printf("// evaluating weights on %d examples\n",vmax(data));         err = tot = 0;    for (i=0; i<vmax(data); i++) {	exi = vref(example_t,data,i);	tot += exi->wt;	actual = exi->lab.nom;	predicted = eval_final_rule(words(exi)) ? pos_cl : neg_cl;	if (actual!=predicted) err += exi->wt;	if (fp!=NULL) {	    if (Predict_probabilities) {		p = eval_rule(words(exi))/Max_sim;		fprint_symbol(fp,actual);		fprintf(fp," %.20f ",p);		fprint_symbol(fp,predicted);		fprintf(fp,"\n");	    } else {		fprint_symbol(fp,predicted);		fprintf(fp," 1 1 ");		fprint_symbol(fp,actual);		fprintf(fp,"\n");	    }	}    }    if (fp!=NULL) fclose(fp);    return err/tot;}/****************************************************************************//* implementation of this algorithm stolen wholesale from Rob */#define NUM_THRESH_INC  100#define MAX_LOG_PRECOMP    (256)#ifndef MAXDOUBLE#define MAXDOUBLE MAXREAL#endifstatic int *Seen;static double *Wt;static double *Feature_wt;static double Log_precomp[MAX_LOG_PRECOMP];static double ltc_wt(int, int);static void init_wts(symbol_t *,vec_t *);static void find_threshold(symbol_t *,vec_t *);static void finalize_wts(symbol_t *,vec_t *);static double Final_thresh; /************** classify an instance************/static int eval_final_rule(vec_t *inst){    if (Two_prototypes) {	return (inner_product(inst,Sum_wt_p) >= inner_product(inst,Sum_wt_n));    } else {	return (eval_rule(inst) >= Final_thresh);    }}/************** computes similarity of an instance to the prototype for pos************/static double eval_rule(vec_t *inst) {    return inner_product(inst,Sum_wt);}/************** computes similarity of an instance a given prototype vector************/static double inner_product(vec_t *inst,double *sum_wt){  int k,aix;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -