⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 lr.c

📁 ADaM is a data mining and image processing toolkit
💻 C
📖 第 1 页 / 共 3 页
字号:
/*  Logistic Regression using Truncated Iteratively Re-weighted Least Squares  (includes several programs)  Copyright (C) 2005  Paul Komarek  This program is free software; you can redistribute it and/or modify  it under the terms of the GNU General Public License as published by  the Free Software Foundation; either version 2 of the License, or  (at your option) any later version.  This program is distributed in the hope that it will be useful,  but WITHOUT ANY WARRANTY; without even the implied warranty of  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the  GNU General Public License for more details.  You should have received a copy of the GNU General Public License  along with this program; if not, write to the Free Software  Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA  Author: Paul Komarek, komarek@cmu.edu  Alternate contact: Andrew Moore, awm@cs.cmu.edu*//*   File:        lr.c   Author:      Paul Komarek   Created:     Thu Mar  6 12:14:39 EST 2003   Description: Logistic regression implmentation.   Copyright 2003, The Auton Lab, CMU*/#include <stdio.h>#include <float.h>#include <math.h>#include "amma.h"#include "amdyv.h"#include "amdym.h"#include "amdyv_array.h"#include "spardat.h"#include "lrutils.h"#include "lin_conjgrad.h"#include "lr.h"/***********************************************************************//* Option Parsing                                                      *//***********************************************************************/lr_options *mk_lr_options(void){  lr_options *opts;  opts = AM_MALLOC( lr_options);  /* Options which are always available. */  opts->rrlambda   = 10.0;  /*   Termination criteria for lr iterations. */  opts->lreps      = 0.05;  opts->lrmax      = 30;  /* cg options are available in conjuagate gradient runs. */  opts->cgbinit    = 1;  opts->cgdeveps   = 0.005;    /* suggestion: 0.005 */  opts->cgeps      = 0.000;  /* multiplied by initial CG rsqr. */  opts->cgmax      = 200;  opts->cgwindow   = 3;      /* Number of bad iterations allowed. */  opts->cgdecay    = 1000.0; /* Factor worse than best-seen that is allowed. */    return opts;}lr_options *mk_copy_lr_options( lr_options *opts){  lr_options *newopts;  newopts = AM_MALLOC( lr_options);  newopts->rrlambda        = opts->rrlambda;  newopts->lreps           = opts->lreps;  newopts->lrmax           = opts->lrmax;  newopts->cgbinit         = opts->cgbinit;  newopts->cgdeveps        = opts->cgdeveps;  newopts->cgeps           = opts->cgeps;  newopts->cgmax           = opts->cgmax;  newopts->cgwindow        = opts->cgwindow;  newopts->cgdecay         = opts->cgdecay;  return newopts;}void free_lr_options( lr_options *opts){  if (opts != NULL) AM_FREE( opts, lr_options);  return;}void parse_lr_options( lr_options *opts, int argc, char **argv){  /* Create an options struct with mk_options(), then call this function     to get the options from the command line. */  opts->rrlambda = double_from_args( "rrlambda", argc, argv, opts->rrlambda);  opts->lreps = double_from_args( "lreps", argc, argv, opts->lreps);  opts->lrmax = int_from_args( "lrmax", argc, argv, opts->lrmax);  /* check_lr_options() ensures not both cgdeveps and cgeps are specified     on the command line. */  opts->cgdeveps = double_from_args( "cgdeveps", argc, argv, opts->cgdeveps);  if (opts->cgdeveps != 0.0) {    opts->cgeps = 0.0;    opts->cgbinit = TRUE;  }  opts->cgeps = double_from_args( "cgeps", argc, argv, opts->cgeps);  if (opts->cgeps != 0.0) opts->cgdeveps = 0.0;  opts->cgmax = int_from_args( "cgmax", argc, argv, opts->cgmax);  opts->cgwindow = int_from_args( "cgwindow", argc, argv, opts->cgwindow);  opts->cgdecay  = double_from_args( "cgdecay", argc, argv, opts->cgdecay);  return;}void check_lr_options( lr_options *opts, int argc, char **argv){  /* Checks that options are sane.  Aborts if options are not sane. */  int use_cgdeveps, use_cgeps;  char *co = "check_lr_options";  if (opts == NULL) my_errorf( "%s: opts==NULL", co);  /* Check that the Ridge Regression parameter is non-negative. */  if (opts->rrlambda < 0.0) {    my_errorf( "%s: rrlambda(=%g) must be non-negative.", co, opts->rrlambda);  }  /* Check lreps and lrmax. */  if (opts->lreps < 1e-10) my_errorf( "%s: lreps(=%g) < 1e-10 is unreasonable",                                     co, opts->lreps);  if (opts->lrmax < 0) my_errorf( "%s: lrmax(=%d) < 0 is unreasonable",                                     co, opts->lrmax);  use_cgdeveps = (index_of_arg( "cgdeveps", argc, argv) > 0);  use_cgeps = (index_of_arg( "cgeps", argc, argv) > 0);  /* Check cgdeveps, cgeps and cgmax. */  if ( use_cgeps && use_cgdeveps) {    my_errorf( "%s: Cannot specify both cgdeveps and cgeps", co);  }  if (use_cgdeveps && opts->cgdeveps < 1e-10) {    my_errorf( "%s: cgdeveps(=%g) < 1e-10 is unreasonable",               co, opts->cgdeveps);  }  if (use_cgeps && opts->cgeps < 1e-10) {    my_errorf( "%s: cgeps(=%g) < 1e-10 is unreasonable", co, opts->cgeps);  }  if (opts->cgmax < 0) my_errorf( "%s: cgmax(=%d) < 0 is unreasonable",                                     co, opts->cgmax);  /* Check cgwindow and cgdecay. */  if (opts->cgwindow < 0) my_errorf( "%s: cgwindow(=%d) < 0 is unreasonable",                                     co, opts->cgwindow);  if (opts->cgdecay < 1.0) my_errorf( "%s: cgdecay(=%g) < 1.0 is unreasonable",                                     co, opts->cgdecay);  return;}/***********************************************************************//* LR_STATE STRUCT                                                     *//***********************************************************************/lr_state *mk_lr_state( lr_train *lrt, lr_options *opts){  lr_state *lrs;  lrs = AM_MALLOC( lr_state);  lrs->b0         = 0.0;  lrs->b          = mk_zero_dyv(lrt->numatts-1);  lrs->n          = mk_zero_dyv(lrt->numrows);  lrs->u          = mk_zero_dyv(lrt->numrows);  lrs->w          = mk_zero_dyv(lrt->numrows);  lrs->z          = mk_zero_dyv(lrt->numrows);  return lrs;}lr_state *mk_copy_lr_state( lr_state *lrs){  lr_state *lrscopy;  lrscopy = AM_MALLOC( lr_state);  lrscopy->b0          = lrs->b0;  lrscopy->b           = mk_copy_dyv( lrs->b);  lrscopy->n           = mk_copy_dyv( lrs->n);  lrscopy->u           = mk_copy_dyv( lrs->u);  lrscopy->w           = mk_copy_dyv( lrs->w);  lrscopy->z           = mk_copy_dyv( lrs->z);  return lrscopy;}void fprintf_lr_state( FILE *f, char *pre, lr_state *lrs){  fprintf( f, "%sb0: %g\n", pre, lrs->b0);  fprintf( f, pre);  fprintf_oneline_dyv( f, "b:", lrs->b, "\n");  fprintf( f, pre);  fprintf_oneline_dyv( f, "n:", lrs->n, "\n");  fprintf( f, pre);  fprintf_oneline_dyv( f, "u:", lrs->u, "\n");  fprintf( f, pre);  fprintf_oneline_dyv( f, "w:", lrs->w, "\n");  fprintf( f, pre);  fprintf_oneline_dyv( f, "z:", lrs->z, "\n");  return;}/* Copies initb to lrs->b. */void lr_state_overwrite_b( lr_state *lrs, dyv *initb){  copy_dyv( initb, lrs->b);  return;}void free_lr_state( lr_state *lrs){  if (lrs != NULL) {    if (lrs->b != NULL) free_dyv( lrs->b);    if (lrs->n != NULL) free_dyv( lrs->n);    if (lrs->u != NULL) free_dyv( lrs->u);    if (lrs->w != NULL) free_dyv( lrs->w);    if (lrs->z != NULL) free_dyv( lrs->z);    AM_FREE( lrs, lr_state);  }  return;}/***********************************************************************//* LR_STATEARR STRUCT                                                  *//***********************************************************************/lr_statearr *mk_array_of_null_lr_states( int size){  int i;  lr_statearr *lrsarr;  lrsarr = AM_MALLOC( lr_statearr);  lrsarr->size = size;  lrsarr->arr = AM_MALLOC_ARRAY( lr_state *, size);  for (i=0; i<size; ++i) lrsarr->arr[i] = NULL;  return lrsarr;}lr_state *lr_statearr_ref( lr_statearr *lrsarr, int index){#ifndef AMFAST  if (index < 0 || index > lrsarr->size) {    my_errorf( "lr_statearr_ref: illegal index %d not within [%d,%d]",               index, 0, lrsarr->size-1);  }#endif  return lrsarr->arr[index];}/* Copies lr_state. */void lr_statearr_set( lr_statearr *lrsarr, int index, lr_state *lrs){#ifndef AMFAST  if (index < 0 || index > lrsarr->size) {    my_errorf( "lr_statearr_set: illegal index %d not within [%d,%d]",               index, 0, lrsarr->size-1);  }#endif  lrsarr->arr[index] = mk_copy_lr_state( lrs);  return;}void free_lr_statearr( lr_statearr *lrsarr){  int i;  if (lrsarr != NULL) {    if (lrsarr->arr != NULL) {      for (i=0; i < lrsarr->size; ++i) {        if (lrsarr->arr[i] != NULL) free_lr_state( lrsarr->arr[i]);      }      AM_FREE_ARRAY( lrsarr->arr, lr_state *, lrsarr->size);    }    AM_FREE( lrsarr, lr_statearr);  }}/***********************************************************************//* LR_TRAIN STRUCT                                                     *//***********************************************************************/lr_train *mk_lr_train_from_dym( dym *factors, dyv *outputs, lr_options *opts){  /* Set rows to NULL if you want all rows from ds to be used. */  int numrows, numatts;  lr_train *lrt;  numrows = dym_rows( factors);  numatts = dym_cols( factors)+1; /* Number of factors including constant. */  /* Create lr lrt structure. */  lrt = AM_MALLOC(lr_train);  /* Copy in opts. */  lrt->opts = mk_copy_lr_options( opts);  /* Assign factors and outputs into lr structure. */  lrt->X = NULL;  lrt->M = factors;  /* Outputs. */  lrt->y = mk_copy_dyv( outputs);  if (!dyv_is_binary( outputs)) {    my_error( "mk_lr_train: Error: outputs are not binary.\n");  }  /* Set log likelihood of saturated model. */  lrt->likesat = 0.0;  /* Initialize remainder of lr struct */  lrt->numatts = numatts;  lrt->numrows = numrows;  /* Create lr_state member. */  lrt->lrs = mk_lr_state( lrt, opts);  /* Now that the structure is complete, update n and u to prepare for     iterations. */  lr_train_update_n(lrt);  lr_train_update_u(lrt);  return lrt;}lr_train *mk_lr_train_from_spardat( spardat *X, lr_options *opts){  /* Copies spardat X. */  int numrows, numatts;  lr_train *lrt;  numrows = spardat_num_rows(X);  numatts = spardat_num_atts(X)+1; /* Add 1 for constant att. */  lrt = AM_MALLOC(lr_train);  /* Copy in opts. */  lrt->opts = mk_copy_lr_options( opts);  /* Do not make a copy of the caller's spardat; too expensive. */  lrt->X = X;  lrt->M = NULL;  /* Initialize reaminder of lr struct */  lrt->numatts    = numatts;  lrt->numrows    = numrows;  /* Futz with 0-1 probabilities. */  lrt->y = mk_dyv_from_ivec( X->row_to_outval);  /* Set log likelihood of saturated model. */  lrt->likesat = 0.0;  /* Make lr_state member. */  lrt->lrs = mk_lr_state( lrt, opts);  /* Now that the structure is complete, update n and u to prepare for     iterations. */  lr_train_update_n(lrt);  lr_train_update_u(lrt);  return lrt;}lr_train *mk_copy_lr_train( const lr_train *source){  lr_train *dest;  dest = AM_MALLOC(lr_train);  dest->opts = mk_copy_lr_options( source->opts);  /* Don't copy spardat or factors dym.  Just keep pointer. */  dest->X = source->X;  dest->M = source->M;  dest->numatts    = source->numatts;  dest->numrows    = source->numrows;  dest->y          = mk_copy_dyv(source->y);  dest->likesat    = source->likesat;  dest->lrs        = mk_copy_lr_state( source->lrs);  return dest;}void free_lr_train( lr_train *lrt){  if (lrt != NULL) {    if (lrt->lrs != NULL)      free_lr_state( lrt->lrs);    if (lrt->y != NULL)        free_dyv( lrt->y);    if (lrt->opts != NULL)     free_lr_options( lrt->opts);    AM_FREE(lrt, lr_train);  }  return;}void fprintf_lr_train( FILE *f, char *pre, lr_train *lrt){  int numatts, numrows;  numatts = lrt->numatts;  numrows = lrt->numrows;  fprintf( f, "%snumatts: %d\n", pre, lrt->numatts);  fprintf( f, "%snumrows: %d\n", pre, lrt->numrows);  /* Print lr_state member. */  if (numatts < 15 && numrows < 500) fprintf_lr_state( f, pre, lrt->lrs);  return;}/* Copies initb to lr->b. */void lr_train_overwrite_b( lr_train *lrt, dyv *initb){  lr_state_overwrite_b( lrt->lrs, initb);  return;}/***********************************************************************//* LR ITERATIONS                                                       */

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -