⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 svm.h

📁 良好的代码实现
💻 H
字号:
// SVM.h: interface for the CSVM class.
//
//////////////////////////////////////////////////////////////////////

#if !defined(AFX_SVM_H__169833C2_5291_4CDB_BE81_22DBD80908F0__INCLUDED_)
#define AFX_SVM_H__169833C2_5291_4CDB_BE81_22DBD80908F0__INCLUDED_

#if _MSC_VER > 1000
#pragma once
#endif // _MSC_VER > 1000

#include "Compute_Param.h"
#include "Compute_Prompt.h"
#include "Compute_Result.h"

# define CFLOAT  float       /* the type of float to use for caching */
/* kernel evaluations. Using float saves */
/* us some memory, but you can use double, too */
# define FNUM    long        /* the type used for storing feature ids */
# define FVAL    float       /* the type used for storing feature values */

# define LINEAR  0           /* linear kernel type */
# define POLY    1           /* polynoial kernel type */
# define RBF     2           /* rbf kernel type */
# define SIGMOID 3           /* sigmoid kernel type */
# define CUSTOM  4

typedef struct word
{
	FNUM    wnum;	
	FVAL    weight;
} SVM_WORD;

typedef struct doc
{
	long    docnum;
	double  twonorm_sq;
	SVM_WORD    *words;
} DOC;

typedef struct learn_parm 
{
	double svm_c;                /* upper bound C on alphas */
	double svm_costratio;        /* factor to multiply C for positive examples */
	double transduction_posratio;/* fraction of unlabeled examples to be */	/* classified as positives */
	long   biased_hyperplane;    /* if nonzero, use hyperplane w*x+b=0 	otherwise w*x=0 */
	long   svm_maxqpsize;        /* size q of working set */
	long   svm_newvarsinqp;      /* new variables to enter the working set in each iteration */
	double epsilon_crit;         /* tolerable error for distances used in stopping criterion */
	double epsilon_shrink;       /* how much a multiplier should be above  zero for shrinking */
	long   svm_iter_to_shrink;   /* iterations h after which an example can be removed by shrinking */
	long   remove_inconsistent;  /* exclude examples with alpha at C and  retrain */
	
								 
	long   skip_final_opt_check; 
								 /* do not check KT-Conditions at the end of optimization for examples removed by  
								 shrinking. WARNING: This might lead to sub-optimal solutions! */
	
	long   compute_loo;          /* if nonzero, computes leave-one-out	 estimates */
	double rho;                  /* parameter in xi/alpha-estimates and for pruning leave-one-out range [1..2] */
	long   xa_depth;             /* parameter in xi/alpha-estimates upper  bounding the number of SV the current alpha_t is distributed over */
	char predfile[200];          /* file for predicitions on unlabeled examples					 in transduction */
	char alphafile[200];         
								/* file to store optimal alphas in. use
									empty string if alphas should not be 	 output */
	
	
	/* you probably do not want to touch the following */
	double epsilon_const;        /* tolerable error on eq-constraint */
	double epsilon_a;            /* tolerable error on alphas at bounds */
	double opt_precision;        /* precision of solver, set to e.g. 1e-21 	 if you get convergence problems */
	
	
	/* the following are only for internal use */
	long   svm_c_steps;          /* do so many steps for finding optimal C */
	double svm_c_factor;         /* increase C by this factor every step */
	double svm_costratio_unlab;
	double svm_unlabbound;
	double *svm_cost;            /* individual upper bounds for each var */
	long   totwords;             /* number of features */
} LEARN_PARM;

typedef struct kernel_parm 
{
	long    kernel_type;   
	long    poly_degree;
	double  rbf_gamma;
	double  coef_lin;
	double  coef_const;
	char    custom[50];    /* for user supplied kernel */
} KERNEL_PARM;

typedef struct model 
{
	long    sv_num;	
	long    at_upper_bound;
	double  b;
	DOC     **supvec;
	double  *alpha;
	long    *index;       /* index from docnum to position in model */
	long    totwords;     /* number of features */
	long    totdoc;       /* number of training documents */
	KERNEL_PARM kernel_parm; /* kernel */
	
	/* the following values are not written to file */
	double  loo_error,loo_recall,loo_precision; /* leave-one-out estimates */
	double  xa_error,xa_recall,xa_precision;    /* xi/alpha estimates */
												double  *lin_weights;  	/* weights for linear case using folding */
} MODEL;

typedef struct quadratic_program 
{
	long   opt_n;            /* number of variables */
	long   opt_m;            /* number of linear equality constraints */
	double *opt_ce,*opt_ce0; /* linear equality constraints */
	double *opt_g;           /* hessian of objective */
	double *opt_g0;          /* linear part of objective */
	double *opt_xinit;       /* initial value for variables */
	double *opt_low,*opt_up; /* box constraints */
} QP;

typedef struct kernel_cache 
{
	long   *index;  /* cache some kernel evalutations */
	CFLOAT *buffer; /* to improve speed */
	long   *invindex;
	long   *active2totdoc;
	long   *totdoc2active;
	long   *lru;
	long   *occu;
	long   elems;
	long   max_elems;
	long   time;
	long   activenum;
	long   buffsize;
} KERNEL_CACHE;

 
typedef struct timing_profile 
{
	long   time_kernel;
	long   time_opti;
	long   time_shrink;
	long   time_update;
	long   time_model;
	long   time_check;
	long   time_select;
} TIMING;
 
typedef struct shrink_state 
{
	long   *active;
	long   *inactive_since;
	long   deactnum;
	double **a_history;
} SHRINK_STATE;

typedef struct cache_parm_s {
  KERNEL_CACHE *kernel_cache;
  CFLOAT *cache;
  DOC *docs; 
  long m;
  KERNEL_PARM *kernel_parm;
  long offset,stepsize;
} cache_parm_t;

class CSVM  
{
public: 
	CSVM();
	virtual ~CSVM();
	int svm_learn_main (int);
	int svm_classify (int, double*);
	double svm_classify(DOC &doc);
private:
	void set_learn_parameters(LEARN_PARM* learn_parm,KERNEL_PARM* kernel_parm);
private:   //svm_common.h
	double classify_example(MODEL *, DOC *);
	double classify_example_linear(MODEL *, DOC *);
	CFLOAT kernel(KERNEL_PARM *, DOC *, DOC *); 
	double custom_kernel(KERNEL_PARM *, DOC *, DOC *); 
	double sprod_ss(SVM_WORD *, SVM_WORD *);
	double model_length_s(MODEL *, KERNEL_PARM *);
	void   clear_vector_n(double *, long);
	void   add_vector_ns(double *, SVM_WORD *, double);
	double sprod_ns(double *, SVM_WORD *);
	double sprod_ss1(SVM_WORD *a,SVM_WORD*b,int offset);
	double sprod_ss2(SVM_WORD *a,SVM_WORD*b,int offset);
	void   add_weight_vector_to_linear_model(MODEL *);
	int    read_model(char *, MODEL *, long, long);
	int    read_documents(char *, DOC *, long *, long, long, long *, long *, int);
	int    parse_document(char *, DOC *, long *, long *, long);
	int    nol_ll(char *, long *, long *, long *);
	double ktt(int ta,int tb,double pa[],double pb[]);
	double kt(int t,double ta[],double tb[]);
	double fi(double* tt);
	double fs(double ta[]);
	double sumofword(DOC* a);
	long   minl(long, long);
	long   maxl(long, long);
	long   get_runtime();
	void   *my_malloc(long); 
	void   copyright_notice();
	//void  SetInitParam();         ?????????
	int isnan(double);
private:    //svm_common.h
	void	printm(char*);
	void    printe(char*);
private:    //svm_learn.h
	void   svm_learn(DOC *, long *, long, long, LEARN_PARM *, KERNEL_PARM *, 
		 KERNEL_CACHE *, MODEL *);
	long   optimize_to_convergence(DOC *, long *, long, long, LEARN_PARM *,
			       KERNEL_PARM *, KERNEL_CACHE *, SHRINK_STATE *,
			       MODEL *, long *, long *, double *, double *, 
			       TIMING *, double *, long, long);
	double compute_objective_function(double *, double *, long *, long *);
	void   clear_index(long *);
	void   add_to_index(long *, long);
	long   compute_index(long *,long, long *);
	void   optimize_svm(DOC *, long *, long *, long *, long *, MODEL *, long, 
		    long *, long, double *, double *, LEARN_PARM *, CFLOAT *, 
		    KERNEL_PARM *, QP *, double *);
	void   compute_matrices_for_optimization(DOC *, long *, long *, long *, 
					 long *, long *, MODEL *, double *, 
					 double *, long, long, LEARN_PARM *, 
					 CFLOAT *, KERNEL_PARM *, QP *);
	long   calculate_svm_model(DOC *, long *, long *, double *, double *, 
			   double *, LEARN_PARM *, long *, MODEL *);
	long   check_optimality(MODEL *, long *, long *, double *, double *,long, 
			LEARN_PARM *,double *, double, long *, long *, long *,
			long *, long, KERNEL_PARM *);
	long   identify_inconsistent(double *, long *, long *, long, LEARN_PARM *, 
			     long *, long *);
	long   identify_misclassified(double *, long *, long *, long,
			      MODEL *, long *, long *);
	long   identify_one_misclassified(double *, long *, long *, long,
				  MODEL *, long *, long *);
	long   incorporate_unlabeled_examples(MODEL *, long *,long *, long *,
				      double *, double *, long, double *,
				      long *, long *, long, KERNEL_PARM *,
				      LEARN_PARM *);
	void   update_linear_component(DOC *, long *, long *, double *, double *, 
			       long *, long, long, KERNEL_PARM *, 
			       KERNEL_CACHE *, double *,
			       CFLOAT *, double *);
	long   select_next_qp_subproblem_grad(long *, long *, double *, double *, long,
				      long, LEARN_PARM *, long *, long *, 
				      long *, double *, long *, KERNEL_CACHE *,
				      long *, long *);
	long   select_next_qp_subproblem_grad_cache(long *, long *, double *, double *,
					    long, long, LEARN_PARM *, long *, 
					    long *, long *, double *, long *,
					    KERNEL_CACHE *, long *, long *);
	void   select_top_n(double *, long, long *, long);
	void   init_shrink_state(SHRINK_STATE *, long, long);
	void   shrink_state_cleanup(SHRINK_STATE *);
	long shrink_problem(
			/* shrink some variables away */
			/* do the shrinking only if at least minshrink variables can be removed */
			LEARN_PARM *learn_parm,
			SHRINK_STATE *shrink_state,
			long *active2dnum,long iteration,long* last_suboptimal_at,
			long totdoc,long minshrink,
			double *a,
			long *inconsistent);
	void   reactivate_inactive_examples(long *, long *, double *, SHRINK_STATE *,
				    double *, long, long, long, LEARN_PARM *, 
				    long *, DOC *, KERNEL_PARM *,
				    KERNEL_CACHE *, MODEL *, CFLOAT *, 
				    double *, double *);
	/* cache kernel evalutations to improve speed */
	void   get_kernel_row(KERNEL_CACHE *,DOC *, long, long, long *, CFLOAT *, 
		      KERNEL_PARM *);
	void   cache_kernel_row(KERNEL_CACHE *,DOC *, long, KERNEL_PARM *);
	void   cache_multiple_kernel_rows(KERNEL_CACHE *,DOC *, long *, long, 
				  KERNEL_PARM *);
	void   kernel_cache_shrink(KERNEL_CACHE *,long, long, long *);
	void   kernel_cache_init(KERNEL_CACHE *,long, long);
	void   kernel_cache_reset_lru(KERNEL_CACHE *);
	void   kernel_cache_cleanup(KERNEL_CACHE *);
	long   kernel_cache_malloc(KERNEL_CACHE *);
	void   kernel_cache_free(KERNEL_CACHE *,long);
	long   kernel_cache_free_lru(KERNEL_CACHE *);
	CFLOAT *kernel_cache_clean_and_malloc(KERNEL_CACHE *,long);
	long   kernel_cache_touch(KERNEL_CACHE *,long);
	long   kernel_cache_check(KERNEL_CACHE *,long);
	void compute_xa_estimates(
		MODEL *model,                           /* xa-estimate of error rate, */
		long *label,long *unlabeled,long totdoc,          /* recall, and precision      */
		DOC	*docs,       
		double *lin,double *a,                      
		KERNEL_PARM *kernel_parm,
		LEARN_PARM *learn_parm,
		double *error,double *recall,double *precision);
	double xa_estimate_error(MODEL *, long *, long *, long, DOC *, 
			 double *, double *, KERNEL_PARM *, 
			 LEARN_PARM *);
	double xa_estimate_recall(MODEL *, long *, long *, long, DOC *, 
			  double *, double *, KERNEL_PARM *, 
			  LEARN_PARM *);
	double xa_estimate_precision(MODEL *, long *, long *, long, DOC *, 
			     double *, double *, KERNEL_PARM *, 
			     LEARN_PARM *);
	void avg_similarity_of_sv_of_one_class(MODEL *, DOC *, double *, long *, KERNEL_PARM *, double *, double *);
	double most_similar_sv_of_same_class(MODEL *, DOC *, double *, long, long *, KERNEL_PARM *, LEARN_PARM *);
	double distribute_alpha_t_greedily(long *, long, DOC *, double *, long, long *, KERNEL_PARM *, LEARN_PARM *, double);
	double distribute_alpha_t_greedily_noindex(MODEL *, DOC *, double *, long, long *, KERNEL_PARM *, LEARN_PARM *, double); 
	void estimate_transduction_quality(MODEL *, long *, long *, long, DOC *, double *);
	double estimate_margin_vcdim(MODEL *, double, double, KERNEL_PARM *);
	double estimate_sphere(MODEL *, KERNEL_PARM *);
	double estimate_r_delta_average(DOC *, long, KERNEL_PARM *); 
	double estimate_r_delta(DOC *, long, KERNEL_PARM *); 
	double length_of_longest_document_vector(DOC *, long, KERNEL_PARM *); 
	void   write_model(char *, MODEL *);
	void   write_prediction(char *, MODEL *, double *, double *, long *, long *,
			long, LEARN_PARM *);
	void   write_alphas(char *, double *, long *, long);
private:    //svm_hideo.h
	double *optimize_qp(QP *, double *, long, double *, LEARN_PARM *);
	int optimize_hildreth_despo(long,long,double,double,double,long,long,long,double,double *,
			    double *,double *,double *,double *,double *,
			    double *,double *,double *,long *,double *);
	int solve_dual(long,long,double,double,long,double *,double *,double *,
	       double *,double *,double *,double *,double *,double *,
	       double *,double *,double *,double *,long);
	void linvert_matrix(double *, long, double *, double, long *);
	void lprint_matrix(double *, long);
	void ladd_matrix(double *, long, double);
	void lcopy_matrix(double *, long, double *);
	void lswitch_rows_matrix(double *, long, long, long);
	void lswitchrk_matrix(double *, long, long, long);
	double calculate_qp_objective(long, double *, double *, double *);
public:
	CCompute_Prompt com_pro;
	CCompute_Param com_param;
	CCompute_Result com_result;
private:    //svm_hideo.h
	double *primal,*dual;
	long   precision_violations;
	double opt_precision;
	long   maxiter;
	double lindep_sensitivity;
	double *buffer;
	long   *nonoptimal;
	long  smallroundcount;
private:
	char temstr[MAX_PATH*10];
};

#endif // !defined(AFX_SVM_H__169833C2_5291_4CDB_BE81_22DBD80908F0__INCLUDED_)

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -