📄 icsiboost.c
字号:
int number_of_workers=1;#endif string_t* stem=NULL; array_t* args=string_argv_to_array(argc,argv); string_t* arg=NULL; while((arg=(string_t*)array_shift(args))!=NULL) { if(string_eq_cstr(arg,"-n")) { string_free(arg); arg=(string_t*)array_shift(args); if(arg==NULL)die("value needed for -n"); maximum_iterations=string_to_int32(arg); if(maximum_iterations<=0)die("invalid value for -n [%s]",arg->data); } else if(string_eq_cstr(arg,"--cutoff")) { //die("feature count cutoff not supported yet"); string_free(arg); arg=(string_t*)array_shift(args); if(arg==NULL)die("value needed for -f"); feature_count_cutoff=string_to_int32(arg); if(feature_count_cutoff<=0)die("invalid value for -f [%s]",arg->data); } else if(string_eq_cstr(arg,"-E")) { string_free(arg); arg=(string_t*)array_shift(args); if(arg==NULL)die("value needed for -E"); smoothing=string_to_double(arg); if(isnan(smoothing) || smoothing<=0)die("invalid value for -E [%s]",arg->data); } else if(string_eq_cstr(arg,"--jobs")) {#ifdef USE_THREADS string_free(arg); arg=(string_t*)array_shift(args); if(arg==NULL)die("value needed for --jobs"); number_of_workers=string_to_int32(arg); if(number_of_workers<=0)die("invalid value for -jobs [%s]",arg->data);#else die("thread support has not been activated at compile time");#endif } else if(string_eq_cstr(arg,"-V")) { verbose=1; } else if(string_eq_cstr(arg,"-S") && stem==NULL) { string_free(arg); arg=(string_t*)array_shift(args); if(arg==NULL)die("value needed for -S"); stem=string_copy(arg); } else if(string_eq_cstr(arg,"--names")) { string_free(arg); arg=(string_t*)array_shift(args); if(arg==NULL)die("file name expected after --names"); names_filename=string_copy(arg); } else if(string_eq_cstr(arg,"--ignore")) { string_free(arg); arg=(string_t*)array_shift(args); ignore_columns=string_split(arg,",",NULL); } else if(string_eq_cstr(arg,"--version")) { print_version(argv[0]); exit(0); } else if(string_eq_cstr(arg,"-C")) { classification_mode=1; } else if(string_eq_cstr(arg,"-o")) { classification_output=1; } else if(string_eq_cstr(arg,"--model")) { string_free(arg); arg=(string_t*)array_shift(args); model_name=string_copy(arg); } else if(string_eq_cstr(arg,"--train")) { string_free(arg); arg=(string_t*)array_shift(args); data_filename=string_copy(arg); } else if(string_eq_cstr(arg,"--do-not-pack-model")) { pack_model=0; } else if(string_eq_cstr(arg,"--output-weights")) { output_weights=1; } else if(string_eq_cstr(arg,"--dryrun")) { dryrun_mode=1; } else usage(argv[0]); string_free(arg); } array_free(args); if(stem==NULL)usage(argv[0]); int i; // data structures vector_t* templates = vector_new(16); vector_t* classes = NULL; // read names file if(names_filename==NULL) { names_filename = string_copy(stem); string_append_cstr(names_filename, ".names"); } mapped_t* input = mapped_load_readonly(names_filename->data); hashtable_t* templates_by_name=hashtable_new(); if(input == NULL) die("can't load \"%s\"", names_filename->data); string_t* line = NULL; int line_num = 0; while((line = mapped_readline(input)) != NULL) // should add some validity checking !!! { if(string_match(line,"^(\\|| *$)","n")) // skip comments and blank lines { string_free(line); continue; } if(classes != NULL) // this line contains a template definition { array_t* parts = string_split(line, "(^ +| *: *| *\\.$)", NULL); template_t* template = (template_t*)MALLOC(sizeof(template_t)); template->column = line_num-1; template->name = (string_t*)array_get(parts, 0); string_t* type = (string_t*)array_get(parts, 1); template->dictionary = hashtable_new(); template->tokens = vector_new(16); template->values = vector_new_float(16); template->classifiers = NULL; //template->dictionary_counts = vector_new(16); template->ordered=NULL; tokeninfo_t* unknown_token=(tokeninfo_t*)MALLOC(sizeof(tokeninfo_t)); unknown_token->id=0; unknown_token->key=strdup("?"); unknown_token->count=0; unknown_token->examples=NULL; vector_push(template->tokens,unknown_token); if(!strcmp(type->data, "continuous")) template->type = FEATURE_TYPE_CONTINUOUS; else if(!strcmp(type->data, "text")) template->type = FEATURE_TYPE_TEXT; else if(!strcmp(type->data, "scored text")) template->type = FEATURE_TYPE_IGNORE; else if(!strcmp(type->data, "ignore")) template->type = FEATURE_TYPE_IGNORE; else template->type = FEATURE_TYPE_SET; if(template->type == FEATURE_TYPE_SET) { array_t* values = string_split(type,"(^ +| *, *| *\\.$)", NULL); if(values->length <= 1)die("invalid column definition \"%s\", line %d in %s", line->data, line_num+1, names_filename->data); for(i=0; i<values->length; i++) { string_t* value=(string_t*)array_get(values,i); tokeninfo_t* tokeninfo=(tokeninfo_t*)MALLOC(sizeof(tokeninfo_t)); tokeninfo->id=i+1; // skip unknown value (?) tokeninfo->key=strdup(value->data); tokeninfo->count=0; tokeninfo->examples=vector_new_int32_t(16); hashtable_set(template->dictionary, value->data, value->length, tokeninfo); vector_push(template->tokens,tokeninfo); } string_array_free(values); } if(hashtable_exists(templates_by_name, template->name->data, template->name->length)!=NULL) die("duplicate feature name \"%s\", line %d in %s",template->name->data, line_num+1, names_filename->data); vector_push(templates, template); hashtable_set(templates_by_name, template->name->data, template->name->length, template); string_free(type); array_free(parts); //if(verbose)fprintf(stdout,"TEMPLATE: %d %s %d\n",template->column,template->name->data,template->type); } else // first line contains the class definitions { array_t* parts = string_split(line, "(^ +| *, *| *\\.$)", NULL); if(parts->length <= 1)die("invalid classes definition \"%s\", line %d in %s", line->data, line_num+1, names_filename->data); classes = vector_from_array(parts); array_free(parts); /*if(verbose) { fprintf(stdout,"CLASSES:"); for(i=0;i<classes->length;i++)fprintf(stdout," %s",((string_t*)vector_get(classes,i))->data); fprintf(stdout,"\n"); }*/ } string_free(line); line_num++; } if(ignore_columns!=NULL) { for(i=0; i<ignore_columns->length; i++) { string_t* column=(string_t*)array_get(ignore_columns, i); template_t* template=hashtable_exists(templates_by_name, column->data, column->length); if(template!=NULL) { template->type=FEATURE_TYPE_IGNORE; if(verbose>0) { warn("ignoring column \"%s\"", column->data); } } } string_array_free(ignore_columns); } vector_optimize(templates); mapped_free(input); hashtable_free(templates_by_name); string_free(names_filename); if(classification_mode) { vector_t* classifiers=NULL; if(model_name==NULL) { model_name = string_copy(stem); string_append_cstr(model_name, ".shyp"); } classifiers=load_model(templates,classes,model_name->data); double sum_of_alpha=0; int errors=0; int num_examples=0; for(i=0;i<classifiers->length;i++) { weakclassifier_t* classifier=vector_get(classifiers,i); sum_of_alpha+=classifier->alpha; } string_free(model_name); string_t* line=NULL; int line_num=0; while((line=string_readline(stdin))!=NULL) { line_num++; string_chomp(line); if(string_match(line,"^(\\|| *$)","n")) // skip comments and blank lines { string_free(line); continue; } int l; array_t* array_of_tokens=string_split(line, " *, *", NULL); if(array_of_tokens->length<templates->length || array_of_tokens->length>templates->length+1) die("wrong number of columns (%zd), \"%s\", line %d in %s", array_of_tokens->length, line->data, line_num, "stdin"); double score[classes->length]; for(l=0; l<classes->length; l++) score[l]=0.0; for(i=0; i<templates->length; i++) { template_t* template=vector_get(templates, i); string_t* token=array_get(array_of_tokens, i); if(template->type == FEATURE_TYPE_TEXT || template->type == FEATURE_TYPE_SET) { hashtable_t* subtokens=hashtable_new(); if(string_cmp_cstr(token,"?")!=0) { char* subtoken=NULL; for(subtoken=strtok(token->data, " "); subtoken != NULL; subtoken=strtok(NULL, " ")) { tokeninfo_t* tokeninfo=hashtable_get(template->dictionary, subtoken, strlen(subtoken)); if(tokeninfo!=NULL) hashtable_set(subtokens, &tokeninfo->id, sizeof(tokeninfo->id), tokeninfo); } } int j; for(j=0; j<template->classifiers->length; j++) { weakclassifier_t* classifier=vector_get(template->classifiers, j); if(hashtable_get(subtokens, &classifier->token, sizeof(classifier->token))==NULL) for(l=0; l<classes->length; l++) score[l]+=classifier->alpha*classifier->c1[l]; else for(l=0; l<classes->length; l++) score[l]+=classifier->alpha*classifier->c2[l]; } hashtable_free(subtokens); } else if(template->type == FEATURE_TYPE_CONTINUOUS) { float value = NAN; if(string_cmp_cstr(token,"?")!=0)value=string_to_float(token); int j; for(j=0; j<template->classifiers->length; j++) { weakclassifier_t* classifier=vector_get(template->classifiers, j); if(isnan(value)) for(l=0; l<classes->length; l++) score[l]+=classifier->alpha*classifier->c0[l]; else if(value < classifier->threshold) for(l=0; l<classes->length; l++) score[l]+=classifier->alpha*classifier->c1[l]; else for(l=0; l<classes->length; l++) score[l]+=classifier->alpha*classifier->c2[l]; } } // FEATURE_TYPE_IGNORE } string_t* true_class=NULL; vector_t* tokens=array_to_vector(array_of_tokens); if(tokens->length > templates->length) { true_class=vector_get(tokens, tokens->length-1); while(true_class->data[true_class->length-1]=='.') { true_class->data[true_class->length-1]='\0'; true_class->length--; } } for(l=0; l<classes->length; l++)score[l]/=sum_of_alpha; if(!dryrun_mode) { if(classification_output==0) { if(true_class!=NULL) { for(l=0; l<classes->length; l++) { string_t* class=vector_get(classes,l); if(string_cmp(class,true_class)==0) fprintf(stdout,"1 "); else fprintf(stdout,"0 "); } } else { for(l=0; l<classes->length; l++)fprintf(stdout,"? "); } for(l=0; l<classes->length; l++) { fprintf(stdout,"%.12f",score[l]); if(l<classes->length-1)fprintf(stdout," "); } fprintf(stdout,"\n"); } else { fprintf(stdout,"\n\n"); for(i=0; i<templates->length; i++) { template_t* template=vector_get(templates,i); string_t* token=vector_get(tokens,i); fprintf(stdout,"%s: %s\n", template->name->data, token->data); } if(true_class!=NULL) { fprintf(stdout,"correct label = %s \n",true_class->data); for(l=0; l<classes->length; l++)
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -