📄 lr.c
字号:
/* Logistic Regression using Truncated Iteratively Re-weighted Least Squares (includes several programs) Copyright (C) 2005 Paul Komarek This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA Author: Paul Komarek, komarek@cmu.edu Alternate contact: Andrew Moore, awm@cs.cmu.edu*//* File: lr.c Author: Paul Komarek Created: Thu Mar 6 12:14:39 EST 2003 Description: Logistic regression implmentation. Copyright 2003, The Auton Lab, CMU*/#include <stdio.h>#include <float.h>#include <math.h>#include "amma.h"#include "amdyv.h"#include "amdym.h"#include "amdyv_array.h"#include "spardat.h"#include "lrutils.h"#include "lin_conjgrad.h"#include "lr.h"/***********************************************************************//* Option Parsing *//***********************************************************************/lr_options *mk_lr_options(void){ lr_options *opts; opts = AM_MALLOC( lr_options); /* Options which are always available. */ opts->rrlambda = 10.0; /* Termination criteria for lr iterations. */ opts->lreps = 0.05; opts->lrmax = 30; /* cg options are available in conjuagate gradient runs. */ opts->cgbinit = 1; opts->cgdeveps = 0.005; /* suggestion: 0.005 */ opts->cgeps = 0.000; /* multiplied by initial CG rsqr. */ opts->cgmax = 200; opts->cgwindow = 3; /* Number of bad iterations allowed. */ opts->cgdecay = 1000.0; /* Factor worse than best-seen that is allowed. */ return opts;}lr_options *mk_copy_lr_options( lr_options *opts){ lr_options *newopts; newopts = AM_MALLOC( lr_options); newopts->rrlambda = opts->rrlambda; newopts->lreps = opts->lreps; newopts->lrmax = opts->lrmax; newopts->cgbinit = opts->cgbinit; newopts->cgdeveps = opts->cgdeveps; newopts->cgeps = opts->cgeps; newopts->cgmax = opts->cgmax; newopts->cgwindow = opts->cgwindow; newopts->cgdecay = opts->cgdecay; return newopts;}void free_lr_options( lr_options *opts){ if (opts != NULL) AM_FREE( opts, lr_options); return;}void parse_lr_options( lr_options *opts, int argc, char **argv){ /* Create an options struct with mk_options(), then call this function to get the options from the command line. */ opts->rrlambda = double_from_args( "rrlambda", argc, argv, opts->rrlambda); opts->lreps = double_from_args( "lreps", argc, argv, opts->lreps); opts->lrmax = int_from_args( "lrmax", argc, argv, opts->lrmax); /* check_lr_options() ensures not both cgdeveps and cgeps are specified on the command line. */ opts->cgdeveps = double_from_args( "cgdeveps", argc, argv, opts->cgdeveps); if (opts->cgdeveps != 0.0) { opts->cgeps = 0.0; opts->cgbinit = TRUE; } opts->cgeps = double_from_args( "cgeps", argc, argv, opts->cgeps); if (opts->cgeps != 0.0) opts->cgdeveps = 0.0; opts->cgmax = int_from_args( "cgmax", argc, argv, opts->cgmax); opts->cgwindow = int_from_args( "cgwindow", argc, argv, opts->cgwindow); opts->cgdecay = double_from_args( "cgdecay", argc, argv, opts->cgdecay); return;}void check_lr_options( lr_options *opts, int argc, char **argv){ /* Checks that options are sane. Aborts if options are not sane. */ int use_cgdeveps, use_cgeps; char *co = "check_lr_options"; if (opts == NULL) my_errorf( "%s: opts==NULL", co); /* Check that the Ridge Regression parameter is non-negative. */ if (opts->rrlambda < 0.0) { my_errorf( "%s: rrlambda(=%g) must be non-negative.", co, opts->rrlambda); } /* Check lreps and lrmax. */ if (opts->lreps < 1e-10) my_errorf( "%s: lreps(=%g) < 1e-10 is unreasonable", co, opts->lreps); if (opts->lrmax < 0) my_errorf( "%s: lrmax(=%d) < 0 is unreasonable", co, opts->lrmax); use_cgdeveps = (index_of_arg( "cgdeveps", argc, argv) > 0); use_cgeps = (index_of_arg( "cgeps", argc, argv) > 0); /* Check cgdeveps, cgeps and cgmax. */ if ( use_cgeps && use_cgdeveps) { my_errorf( "%s: Cannot specify both cgdeveps and cgeps", co); } if (use_cgdeveps && opts->cgdeveps < 1e-10) { my_errorf( "%s: cgdeveps(=%g) < 1e-10 is unreasonable", co, opts->cgdeveps); } if (use_cgeps && opts->cgeps < 1e-10) { my_errorf( "%s: cgeps(=%g) < 1e-10 is unreasonable", co, opts->cgeps); } if (opts->cgmax < 0) my_errorf( "%s: cgmax(=%d) < 0 is unreasonable", co, opts->cgmax); /* Check cgwindow and cgdecay. */ if (opts->cgwindow < 0) my_errorf( "%s: cgwindow(=%d) < 0 is unreasonable", co, opts->cgwindow); if (opts->cgdecay < 1.0) my_errorf( "%s: cgdecay(=%g) < 1.0 is unreasonable", co, opts->cgdecay); return;}/***********************************************************************//* LR_STATE STRUCT *//***********************************************************************/lr_state *mk_lr_state( lr_train *lrt, lr_options *opts){ lr_state *lrs; lrs = AM_MALLOC( lr_state); lrs->b0 = 0.0; lrs->b = mk_zero_dyv(lrt->numatts-1); lrs->n = mk_zero_dyv(lrt->numrows); lrs->u = mk_zero_dyv(lrt->numrows); lrs->w = mk_zero_dyv(lrt->numrows); lrs->z = mk_zero_dyv(lrt->numrows); return lrs;}lr_state *mk_copy_lr_state( lr_state *lrs){ lr_state *lrscopy; lrscopy = AM_MALLOC( lr_state); lrscopy->b0 = lrs->b0; lrscopy->b = mk_copy_dyv( lrs->b); lrscopy->n = mk_copy_dyv( lrs->n); lrscopy->u = mk_copy_dyv( lrs->u); lrscopy->w = mk_copy_dyv( lrs->w); lrscopy->z = mk_copy_dyv( lrs->z); return lrscopy;}void fprintf_lr_state( FILE *f, char *pre, lr_state *lrs){ fprintf( f, "%sb0: %g\n", pre, lrs->b0); fprintf( f, pre); fprintf_oneline_dyv( f, "b:", lrs->b, "\n"); fprintf( f, pre); fprintf_oneline_dyv( f, "n:", lrs->n, "\n"); fprintf( f, pre); fprintf_oneline_dyv( f, "u:", lrs->u, "\n"); fprintf( f, pre); fprintf_oneline_dyv( f, "w:", lrs->w, "\n"); fprintf( f, pre); fprintf_oneline_dyv( f, "z:", lrs->z, "\n"); return;}/* Copies initb to lrs->b. */void lr_state_overwrite_b( lr_state *lrs, dyv *initb){ copy_dyv( initb, lrs->b); return;}void free_lr_state( lr_state *lrs){ if (lrs != NULL) { if (lrs->b != NULL) free_dyv( lrs->b); if (lrs->n != NULL) free_dyv( lrs->n); if (lrs->u != NULL) free_dyv( lrs->u); if (lrs->w != NULL) free_dyv( lrs->w); if (lrs->z != NULL) free_dyv( lrs->z); AM_FREE( lrs, lr_state); } return;}/***********************************************************************//* LR_STATEARR STRUCT *//***********************************************************************/lr_statearr *mk_array_of_null_lr_states( int size){ int i; lr_statearr *lrsarr; lrsarr = AM_MALLOC( lr_statearr); lrsarr->size = size; lrsarr->arr = AM_MALLOC_ARRAY( lr_state *, size); for (i=0; i<size; ++i) lrsarr->arr[i] = NULL; return lrsarr;}lr_state *lr_statearr_ref( lr_statearr *lrsarr, int index){#ifndef AMFAST if (index < 0 || index > lrsarr->size) { my_errorf( "lr_statearr_ref: illegal index %d not within [%d,%d]", index, 0, lrsarr->size-1); }#endif return lrsarr->arr[index];}/* Copies lr_state. */void lr_statearr_set( lr_statearr *lrsarr, int index, lr_state *lrs){#ifndef AMFAST if (index < 0 || index > lrsarr->size) { my_errorf( "lr_statearr_set: illegal index %d not within [%d,%d]", index, 0, lrsarr->size-1); }#endif lrsarr->arr[index] = mk_copy_lr_state( lrs); return;}void free_lr_statearr( lr_statearr *lrsarr){ int i; if (lrsarr != NULL) { if (lrsarr->arr != NULL) { for (i=0; i < lrsarr->size; ++i) { if (lrsarr->arr[i] != NULL) free_lr_state( lrsarr->arr[i]); } AM_FREE_ARRAY( lrsarr->arr, lr_state *, lrsarr->size); } AM_FREE( lrsarr, lr_statearr); }}/***********************************************************************//* LR_TRAIN STRUCT *//***********************************************************************/lr_train *mk_lr_train_from_dym( dym *factors, dyv *outputs, lr_options *opts){ /* Set rows to NULL if you want all rows from ds to be used. */ int numrows, numatts; lr_train *lrt; numrows = dym_rows( factors); numatts = dym_cols( factors)+1; /* Number of factors including constant. */ /* Create lr lrt structure. */ lrt = AM_MALLOC(lr_train); /* Copy in opts. */ lrt->opts = mk_copy_lr_options( opts); /* Assign factors and outputs into lr structure. */ lrt->X = NULL; lrt->M = factors; /* Outputs. */ lrt->y = mk_copy_dyv( outputs); if (!dyv_is_binary( outputs)) { my_error( "mk_lr_train: Error: outputs are not binary.\n"); } /* Set log likelihood of saturated model. */ lrt->likesat = 0.0; /* Initialize remainder of lr struct */ lrt->numatts = numatts; lrt->numrows = numrows; /* Create lr_state member. */ lrt->lrs = mk_lr_state( lrt, opts); /* Now that the structure is complete, update n and u to prepare for iterations. */ lr_train_update_n(lrt); lr_train_update_u(lrt); return lrt;}lr_train *mk_lr_train_from_spardat( spardat *X, lr_options *opts){ /* Copies spardat X. */ int numrows, numatts; lr_train *lrt; numrows = spardat_num_rows(X); numatts = spardat_num_atts(X)+1; /* Add 1 for constant att. */ lrt = AM_MALLOC(lr_train); /* Copy in opts. */ lrt->opts = mk_copy_lr_options( opts); /* Do not make a copy of the caller's spardat; too expensive. */ lrt->X = X; lrt->M = NULL; /* Initialize reaminder of lr struct */ lrt->numatts = numatts; lrt->numrows = numrows; /* Futz with 0-1 probabilities. */ lrt->y = mk_dyv_from_ivec( X->row_to_outval); /* Set log likelihood of saturated model. */ lrt->likesat = 0.0; /* Make lr_state member. */ lrt->lrs = mk_lr_state( lrt, opts); /* Now that the structure is complete, update n and u to prepare for iterations. */ lr_train_update_n(lrt); lr_train_update_u(lrt); return lrt;}lr_train *mk_copy_lr_train( const lr_train *source){ lr_train *dest; dest = AM_MALLOC(lr_train); dest->opts = mk_copy_lr_options( source->opts); /* Don't copy spardat or factors dym. Just keep pointer. */ dest->X = source->X; dest->M = source->M; dest->numatts = source->numatts; dest->numrows = source->numrows; dest->y = mk_copy_dyv(source->y); dest->likesat = source->likesat; dest->lrs = mk_copy_lr_state( source->lrs); return dest;}void free_lr_train( lr_train *lrt){ if (lrt != NULL) { if (lrt->lrs != NULL) free_lr_state( lrt->lrs); if (lrt->y != NULL) free_dyv( lrt->y); if (lrt->opts != NULL) free_lr_options( lrt->opts); AM_FREE(lrt, lr_train); } return;}void fprintf_lr_train( FILE *f, char *pre, lr_train *lrt){ int numatts, numrows; numatts = lrt->numatts; numrows = lrt->numrows; fprintf( f, "%snumatts: %d\n", pre, lrt->numatts); fprintf( f, "%snumrows: %d\n", pre, lrt->numrows); /* Print lr_state member. */ if (numatts < 15 && numrows < 500) fprintf_lr_state( f, pre, lrt->lrs); return;}/* Copies initb to lr->b. */void lr_train_overwrite_b( lr_train *lrt, dyv *initb){ lr_state_overwrite_b( lrt->lrs, initb); return;}/***********************************************************************//* LR ITERATIONS */
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -