⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 ann.cc

📁 KNN algorithm, and related
💻 CC
字号:
/*--------------------------------------------------------------------------@COPYRIGHT  :              Copyright 1997, Vasken Kollokian,               McConnell Brain Imaging Centre,              Montreal Neurological Institute, McGill University.              Permission to use, copy, modify, and distribute this              software and its documentation for any purpose and without              fee is hereby granted, provided that the above copyright              notice appear in all copies.  The author and McGill University              make no representations about the suitability of this              software for any purpose.  It is provided "as is" without              express or implied warranty.---------------------------------------------------------------------------- $RCSfile: ann.cc,v $$Revision: 1.1.1.1 $$Author: jason $$Date: 2002/03/20 22:16:34 $$State: Exp $--------------------------------------------------------------------------*/#include <assert.h>#include "EBTKS/Minc.h"#include "EBTKS/FileIO.h"#include "EBTKS/SimpleArray.h" //#include "TrainingSet.h"    #include "EBTKS/ValueMap.h"#include "ann.h"#include "EBTKS/backProp.h"extern "C" {  #include <volume_io.h>  #include "../class_globals.h"}//Global variables:double           minInput  = -1.0;double           maxInput  = 1.0;double           minTarget = 0.1;double           maxTarget = 0.9;unsigned         nNodesInHiddenLayer = 10;Array<LinearMap> inputMaps;IntArray         nodeToClassMap;BP_ANN          *BP = 0;/* ----------------------------- MNI Header ----------------------------------@NAME          : ann_init_training@INPUT         : parameter file name@OUTPUT        :@RETURNS       :@DESCRIPTION   : initializes training for an artificial neural network@METHOD        :@GLOBALS       :@CALLS         : @CREATED       : Jul 24, 1995 (Samson ANTEL)@MODIFIED      : Aug 06, 1995 (Alex Zijdenbos)------------------------------------------------------------------------------ */void ann_init_training(char *param_filename){  if (BP) {    delete BP;    BP = 0;  }  InputFile annFile;  if (!param_filename || !annFile.attach(param_filename)) {    if (param_filename) {      cerr << "Couldn't open parameter file " << param_filename << endl;      exit(EXIT_FAILURE);    }    if (verbose)      cout << "No parameters specified - using defaults" << endl;    UnsignedArray topology(3);    topology[(unsigned int)0] = num_features;    topology[(unsigned int)1] = nNodesInHiddenLayer;    topology[(unsigned int)2] = num_classes;    BP = new BP_ANN(topology, verbose);    if (!BP) {      cerr << "Couldn't create BP_ANN" << endl;      exit(EXIT_FAILURE);    }  }  else {    if (verbose)      cout << "Reading network parameters..." << endl;    BP = new BP_ANN(annFile.stream(), verbose);    annFile.attach(param_filename); // Re-attach to rewind input stream    if (!BP) {      cerr << "Couldn't create BP_ANN" << endl;      exit(EXIT_FAILURE);    }    ann_load_train_param(annFile.stream());  }  if (num_features != BP->nInputNodes()) {    cerr << "Warning: num_features (" << num_features 	 << ") doesn't match # ann input nodes ("	 << BP->nInputNodes() << "). Forcing " << num_features << endl;    BP->nInputNodes(num_features);  }  if (num_classes != BP->nOutputNodes()) {    cerr << "Warning: num_classes (" << num_classes 	 << ") doesn't match # ann output nodes ("	 << BP->nOutputNodes() << "). Forcing " << num_classes << endl;    BP->nOutputNodes(num_classes);  }  /*  if (num_features != BP->nInputNodes()) {    cerr << "num_features (" << num_features << ") doesn't match # ann input nodes ("	 << BP->nInputNodes() << ")" << endl;    exit(EXIT_FAILURE);  }  if (num_classes != BP->nOutputNodes()) {    cerr << "num_classes (" << num_classes << ") doesn't match # ann output nodes ("	 << BP->nOutputNodes() << ")" << endl;    exit(EXIT_FAILURE);  }  */}/* ----------------------------- MNI Header ----------------------------------@NAME          : ann_train_samples@INPUT         :@OUTPUT        :@RETURNS       :@DESCRIPTION   : takes a feature matrix and trains an ANN classifier on it.@METHOD        :@GLOBALS       :@CALLS         : @CREATED       : Jul 24, 1995 (Samson ANTEL)@MODIFIED      : Aug 06, 1995 (Alex Zijdenbos)------------------------------------------------------------------------------ */void ann_train_samples(void){  if (!BP) {    fprintf(stderr, "BP_ANN not initialized!\n");    exit(EXIT_FAILURE);  }  unsigned i, j;  inputMaps.newSize(num_features);  for (j=0; j<num_features; j++){    double mmax = -MAXDOUBLE;    double mmin = MAXDOUBLE;    for (unsigned i=0; i<num_samples; i++){      double value = feature_matrix[i][j];      if (value > mmax)	mmax=value;      if (value < mmin)	mmin=value;    }    inputMaps[j](mmin, mmax, minInput, maxInput);  }         IntArray temp(class_column, num_samples);  int minClass = min(temp);  int maxClass = max(temp);  UnsignedArray classToNodeMap(maxClass - minClass + 1);  classToNodeMap.clear(0);  nodeToClassMap.newSize(num_classes);  nodeToClassMap.clear(minClass);  unsigned nodeCtr = 1;  for (i = 0; i < num_samples; i++) {    unsigned classID = class_column[i] - minClass;    if ((classID != 0) && !classToNodeMap[classID]) {      classToNodeMap[classID] = nodeCtr;      nodeToClassMap[nodeCtr] = classID + minClass;      nodeCtr++;    }  }  if (verbose)    ann_save_train_param(cout);  // Remap feature_matrix  if (verbose)    cout << "Remapping feature matrix..." << flush;  for (i = 0; i < num_samples; i++)    for (j = 0; j < num_features; j++)      feature_matrix[i][j] = inputMaps[j](feature_matrix[i][j]);  if (verbose)    cout << "Done" << endl;  BP->initTraining(num_samples);  BP->shuffleInterval(25);    DblArray  targetValues(minTarget, num_classes);  double   *targetValuePtr = targetValues.contents(); // For efficiency only  int sample = 0;  while (sample >= 0) {    // Create target values array    double *node = targetValuePtr + class_column[sample] - minClass;    *node = maxTarget;    sample = BP->train(feature_matrix[sample], targetValuePtr, ann_error_monitor);    *node = minTarget;  }  /*  if (verbose)    cout << "Creating training set..." << flush;    TrainingSet TS(num_samples, num_features, num_classes, minTarget, maxTarget);  for (i = 0; i < num_samples; i++){    DblArray inputValues(num_features);    for (unsigned j = 0; j < num_features; j++)      inputValues[j] = inputMaps[j](feature_matrix[i][j]);    TS.add(classToNodeMap[class_column[i] - minClass], inputValues);  }  if (verbose)    cout << "Done" << endl << "Training network..." << endl;  BP->train(TS, ann_error_monitor);  */}/* ----------------------------- MNI Header ----------------------------------@NAME          : ann_save_training;@INPUT         :@OUTPUT        :@RETURNS       :@DESCRIPTION   : saves the results of an ANN training session.@METHOD        :@GLOBALS       :@CALLS         : ann_save_train_param;@CREATED       : Jul 24, 1995 (Samson ANTEL)@MODIFIED      : Aug 05, 1995 (Alex ZIJDENBOS)------------------------------------------------------------------------------ */void ann_save_training(char *save_train_filename){  if (!BP) {    cerr << "BP_ANN not trained!" << endl;    exit(EXIT_FAILURE);  }  if (verbose)    cout << "Saving trained network..." << flush;  OutputFile save_train_file(save_train_filename, ios::out, OutputFile::NO_COMPRESS);   if (!save_train_file) {    cerr << "Cannot open " << save_train_filename << endl;    exit(EXIT_FAILURE);  }   BP->save(save_train_file);  ann_save_train_param(save_train_file);   if (verbose)    cout << "Done" << endl;}/* ----------------------------- MNI Header ----------------------------------@NAME          : ann_load_training@INPUT         :@OUTPUT        :@RETURNS       : @DESCRIPTION   : loads the results of an ANN training session@METHOD        :@GLOBALS       :@CALLS         : ann_load_train_param(FILE *)@CREATED       : Jul 24, 1995 (Samson ANTEL)@MODIFIED      : Aug 06, 1995 (Alex ZIJDENBOS)------------------------------------------------------------------------------ */void ann_load_training(char *load_train_filename){  InputFile load_train_file(load_train_filename);  if (!load_train_file) {    cerr << "Cannot open " << load_train_filename << endl;    exit(EXIT_FAILURE);  }  if (BP)    delete BP;    if (verbose)    cout << "Reading network..." << endl;  BP = new BP_ANN(load_train_file.stream(), verbose);  if (!BP) {    cerr << "Cannot create BP_ANN" << endl;    exit(EXIT_FAILURE);  }  num_features = BP->nInputNodes();  num_classes = BP->nOutputNodes();  load_train_file.attach(load_train_filename); // Re-attach to rewind input stream  ann_load_train_param(load_train_file.stream());  if (verbose)    ann_save_train_param(cout);}/* ----------------------------- MNI Header ----------------------------------@NAME          : ann_classify_sample@INPUT         :@OUTPUT        : class_number, class_prob, class_labels@RETURNS       : @DESCRIPTION   : given a feature vector, return a class, class probability values,                 and the corresponding class numbers.@METHOD        :@GLOBALS       :@CALLS         :@CREATED       : Jul 24, 1995 (Samson ANTEL)@MODIFIED      : Aug 07, 1995 (Alex ZIJDENBOS)------------------------------------------------------------------------------ */void ann_classify_sample(int *class_num, double *class_prob, int *class_labels){    if (!BP) {    fprintf(stderr, "BP_ANN doesn't exist!\n");    exit(EXIT_FAILURE);  }  DblArray inputValues(feature_vector, num_features);  DblArray outputValues(num_classes);  BP->evaluate(map(inputValues, inputMaps), outputValues);  unsigned node;  outputValues.max(&node);  *class_num = nodeToClassMap[node];  if (class_prob) {    outputValues.asCarray(class_prob);    if (class_labels)      nodeToClassMap.asCarray(class_labels);  }//    cout << "Classify: " << inputValues << " => " << *class_num << endl;} /* ----------------------------- MNI Header ----------------------------------@NAME          : ann_save_train_param@INPUT         :@OUTPUT        :@RETURNS       : @DESCRIPTION   : appends nCycles, mmax, mmin to saved network file@METHOD        :@GLOBALS       :@CALLS         :@CREATED       : Aug 01, 1995 (Samson ANTEL)@MODIFIED      : Aug 06, 1995 (Alex ZIJDENBOS)------------------------------------------------------------------------------ */ostream& ann_save_train_param(ostream& OS){  OS << "min_input:         " << minInput << endl     << "max_input:         " << maxInput << endl     << "min_target:        " << minTarget << endl     << "max_target:        " << maxTarget << endl     << "input_maps:        " << num_features << endl;  for (unsigned i = 0; i < num_features; i++)    OS << inputMaps[i].factor() << " " << inputMaps[i].offset() << endl;  //  OS << "node_to_class_map: " << nodeToClassMap << endl;    return OS;}/* ----------------------------- MNI Header ----------------------------------@NAME          : ann_load_train_param@INPUT         :@OUTPUT        :@RETURNS       : @DESCRIPTION   : loads nCycles, mmax, mmin from saved network file@METHOD        :@GLOBALS       :@CALLS         :@CREATED       : Aug 01, 1995 (Samson ANTEL)@MODIFIED      : Aug 06, 1995 (Alex ZIJDENBOS)------------------------------------------------------------------------------ */istream& ann_load_train_param(istream& IS){// Set defaults  minInput  = -1.0;  maxInput  = 1.0;  minTarget = 0.1;  maxTarget = 0.9;  nodeToClassMap = IntArray(0, 1, num_classes - 1);// Load values from stream  MString key;  while (IS >> key) {    if (key.contains("min_input"))      IS >> minInput;    else if (key.contains("max_input"))      IS >> maxInput;    else if (key.contains("min_target"))      IS >> minTarget;    else if (key.contains("max_target"))      IS >> maxTarget;    else if (key.contains("input_maps")) {      unsigned nMaps;      IS >> nMaps;      inputMaps.newSize(nMaps);      for (unsigned i = 0; i < nMaps; i++)	IS >> inputMaps[i].factor() >> inputMaps[i].offset();    }    else if (key.contains("node_to_class_map")) {      for (unsigned i = 0; i < num_classes; i++)	IS >> nodeToClassMap[i];      if (class_name == NULL) {	typedef char *charPtr;	class_name = new charPtr[num_classes];	assert(class_name);	for (unsigned i = 0; i < num_classes; i++) {	  MString intString;	  intString += nodeToClassMap[i];	  class_name[i] = new char[intString.length() + 1];	  strcpy(class_name[i], intString);	}      }    }  }  return IS;}/* ----------------------------- MNI Header ----------------------------------@NAME          : ann_error_monitor@INPUT         : cycle #, cycle MSE@OUTPUT        :@RETURNS       : @DESCRIPTION   : ANN output error monitor@METHOD        :@GLOBALS       :@CALLS         :@CREATED       : Aug 06, 1995 (Alex ZIJDENBOS)@MODIFIED      : ------------------------------------------------------------------------------ */voidann_error_monitor(unsigned cycle, double error){  if (verbose)    cout << "Cycle: " << cycle << " error: " << error << endl;}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -