📄 rocchio.c
字号:
/* Copyright (C) 2001-2002 Mikael Ylikoski * See the accompanying file "README" for the full copyright notice *//** * @file * Rocchio learning algorithm. * * @author Mikael Ylikoski * @date 2001-2002 */#include <stdlib.h>#include <string.h>#include "multi.h"#include "rocchio.h"#include "utility.h"#include "vector.h"/** * Rocchio database. */struct rocchio_db_ { vector *df; vector *tf; float inc; /**< Positive learning rate */ float dec; /**< Negative learning rate */ double (*similarity_measure) (vector *u, vector *v); /**< Similarity measure function */};/** * Rocchio class. */typedef struct { vector *class; /**< Average vector */ vector *pos_class; /**< Average vector of positive examples */ vector *neg_class; /**< Average vector of negative examples */ int pos_nod; /**< Number of positive examples */ int neg_nod; /**< Number of negative examples */} rocchio_class;/** * Create a new global state. * Default learning rate is 1 for positive examples and 0 for negative. * * @return The classifier. */void *rocchio_new_db (const char *opts) { rocchio_db *db; db = my_malloc (sizeof(rocchio_db)); db->df = vector_new (100); if (!db->df) { free (db); return NULL; } db->tf = vector_new (100); if (!db->tf) { free (db->df); free (db); return NULL; } /* FIXME read opts */ db->inc = 1; db->dec = 0; db->similarity_measure = vector_cosine_similarity; return db;}/** * Create a new class. * * @return The class. */void *rocchio_new (void) { rocchio_class *rc; rc = my_malloc (sizeof(rocchio_class)); rc->class = NULL; rc->pos_class = NULL; rc->neg_class = NULL; rc->pos_nod = 0; rc->neg_nod = 0; return rc;}/** * Copy a rocchio class. * * @param data class to copy * @return The copy */void *rocchio_copy (void *data) { rocchio_class *rc; rocchio_class *rd; rc = my_malloc (sizeof(rocchio_class)); rd = (rocchio_class *)data; if (rd->pos_class) { rc->pos_class = vector_copy (rd->pos_class); if (!rc->pos_class) { free (rc); return NULL; } } else rc->pos_class = NULL; if (rd->neg_class) { rc->neg_class = vector_copy (rd->neg_class); if (!rc->neg_class) { vector_free (rc->pos_class); free (rc); return NULL; } } else rc->neg_class = NULL; rc->class = NULL; rc->pos_nod = rd->pos_nod; rc->neg_nod = rd->neg_nod; return rc;}/** * Free memory used by rocchio class. * * @param data class to free */voidrocchio_free (void *data) { rocchio_class *rc; rc = (rocchio_class *)data; if (rc->class) vector_free (rc->class); if (rc->pos_class) vector_free (rc->pos_class); if (rc->neg_class) vector_free (rc->neg_class); free (rc);}/** * Change the learning rate of a classifier. * * @param db classifier database * @param inc positive learning rate * @param dec negative learning rate */voidrocchio_set_learning_rate (rocchio_db *db, float inc, float dec) { db->inc = inc; db->dec = dec;}/** * Set similarity measure function for the rocchio classifier. * * @param db classifier database * @param sim_measure measure function */voidrocchio_set_similarity_measure (rocchio_db *db, double (*sim_measure) (vector *, vector *)) { db->similarity_measure = sim_measure;}/** * Learn from a positive example. * * @param db classifier database * @param data class * @param v example vector * @param class example class */introcchio_learn (void *db, void *data, vector *v, int class) { rocchio_class *rc; rocchio_db *rdb; rdb = (rocchio_db *)db; rc = (rocchio_class *)data; if (class == 1) if (rc->pos_class == NULL) { rc->pos_class = vector_copy (v); rc->pos_nod = 1; } else vector_n_mean (rc->pos_class, v, ++rc->pos_nod); else if (rc->neg_class == NULL) { rc->neg_class = vector_copy (v); rc->neg_nod = 1; } else vector_n_mean (rc->neg_class, v, ++rc->neg_nod); if (rc->class) { vector_free (rc->class); rc->class = NULL; } vector_add (rdb->tf, v); vector_add_v (rdb->df, v, 1); return 0;}/** * Update the classification vectors. * Must be called before classification after learning (this is done * automatically). * * @param db classifier database * @param rc class */static voidrocchio_sync_db (rocchio_db *db, rocchio_class *rc) { if (rc->class == NULL) { rc->class = vector_copy (rc->pos_class); vector_scale (rc->class, db->inc); } if (db->dec != 0) if (rc->class == NULL) vector_add_w (rc->class, rc->neg_class, -db->dec);}/** * Classify a vector. * * @param db classifier database * @param data class * @param v vector to classify * @return The most probable class. */doublerocchio_classify (void *db, void *data, vector *v) { rocchio_class *rc; rocchio_db *rdb; rdb = (rocchio_db *)db; rc = (rocchio_class *)data; rocchio_sync_db(rdb, rc); if (rc->class) return rdb->similarity_measure (v, rc->class); return 0;}doublenew_rocchio_classify (void *db, void *data, vector *v) { int i; rocchio_class *rc; rocchio_db *rdb; vector *w; rdb = (rocchio_db *)db; rc = (rocchio_class *)data; rocchio_sync_db (rdb, rc); if (rc->class) { w = vector_copy (rc->class); for (i = 0; i < v->nel; i++) { } return rdb->similarity_measure (v, w); } return 0;}void *rocchio_load_db (FILE *file) { int i; rocchio_db *rdb; rdb = my_malloc (sizeof(rocchio_db)); i = fscanf (file, "inc %f\n", &rdb->inc); i = fscanf (file, "dec %f\n", &rdb->dec); fscanf (file, "df "); rdb->df = vector_load (file); fscanf (file, "\n"); fscanf (file, "tf "); rdb->tf = vector_load (file); fscanf (file, "\n"); rdb->similarity_measure = vector_cosine_similarity; return rdb;}void *rocchio_load_class (FILE *file) { int i; rocchio_class *ncl; ncl = my_malloc (sizeof(rocchio_class)); i = fscanf (file, "p_nod %d\n", &ncl->pos_nod); if (i != 1) { free (ncl); return NULL; } i = fscanf (file, "n_nod %d\n", &ncl->neg_nod); if (i != 1) { free (ncl); return NULL; } fscanf (file, "p_vec "); ncl->pos_class = vector_load (file); fscanf (file, "\n"); fscanf (file, "n_vec "); ncl->neg_class = vector_load (file); fscanf (file, "\n"); ncl->class = NULL; return ncl;}introcchio_save_db (FILE *file, void *db) { rocchio_db *rdb; rdb = (rocchio_db *)db; fprintf (file, "inc %f\n", rdb->inc); fprintf (file, "dec %f\n", rdb->dec); fprintf (file, "df "); vector_save (rdb->df, file); fprintf (file, "\n"); fprintf (file, "tf "); vector_save (rdb->tf, file); fprintf (file, "\n"); return 0;}introcchio_save_class (FILE *file, void *data) { rocchio_class *ncl; ncl = (rocchio_class *)data; fprintf (file, "p_nod %d\n", ncl->pos_nod); fprintf (file, "n_nod %d\n", ncl->neg_nod); fprintf (file, "p_vec "); vector_save (ncl->pos_class, file); fprintf (file, "\n"); fprintf (file, "n_vec "); vector_save (ncl->neg_class, file); fprintf (file, "\n"); return 0;}/** * Keep cygwin happy. */intmain (void) { return 0;}/** * Rocchio classifier name. */const char *my_classifier_name = "Rocchio";/** * Rocchio classifier functions. */const multi_functions my_functions = { .new_db = rocchio_new_db, .new = rocchio_new, .copy = rocchio_copy, .free = rocchio_free, .learn = rocchio_learn, .classify = rocchio_classify, .load_db = rocchio_load_db, .load_class = rocchio_load_class, .save_db = rocchio_save_db, .save_class = rocchio_save_class, .option = 0};
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -