📄 recurrent2.cpp
字号:
/*
nn-utility (Provides neural networking utilities for c++ programmers)
Copyright (C) 2003 Panayiotis Thomakos
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
*/
//To contact the author send an email to panthomakos@users.sourceforge.net
#include <nn-utility.h>
#define OUT_SIZE 3
using namespace nn_utility;
class MINE : public nn_utility_functions<float>{
public:
typedef float MATRIX[NN_UTIL_SIZE][NN_UTIL_SIZE];
void GetInput( int interation, MATRIX &send, MATRIX &target, int &inputs );
}derived;
int id = 0;
void MINE::GetInput( int interation, MATRIX &send, MATRIX &target, int &inputs ){
switch ( id ){
case 0: LoadVectorf( send[0], 1, 1.0 );
LoadVectorf( send[1], 1, 1.0 ); break;
case 1: LoadVectorf( send[0], 1, 0.0 );
LoadVectorf( send[1], 1, 1.0 ); break;
case 2: LoadVectorf( send[0], 1, 1.0 );
LoadVectorf( send[1], 1, 0.0 ); break;
case 3: LoadVectorf( send[0], 1, 0.0 );
LoadVectorf( send[1], 1, 0.0 ); break;
}
inputs = 2;
id = ( id+1 > 3 ? 0 : id+1 );
}
void SetConstant( KOHEN_SOFM ** TOSET, float value ){
for ( int i = 0; i < (*TOSET)->row; i++ ){
for ( int e = 0; e < (*TOSET)->col; e++ ){
(*TOSET)->matrix[i][e] = value;
}
}
}
int main(){
layer<float> *multi = new layer<float>();
multi->define_recurrent();
KOHEN_SOFM *classify = new KOHEN_SOFM();
classify->define( OUT_SIZE+1,OUT_SIZE );
layer<float> *ppClassify = classify;
multi->add( &ppClassify );
classify->radius = 1;
SetConstant( &classify, 0.1 );
nn_utility_functions<float>::MATRIX FINAL, FINAL2, FINAL3, FINAL4, INPUT, INPUT2, INPUT3, INPUT4;
derived.LoadVectorf( INPUT[0], 1, 0.0 );
derived.LoadVectorf( INPUT[1], 1, 0.0 );
derived.LoadVectorf( INPUT2[0], 1, 1.0 );
derived.LoadVectorf( INPUT2[1], 1, 1.0 );
derived.LoadVectorf( INPUT3[0], 1, 0.0 );
derived.LoadVectorf( INPUT3[1], 1, 1.0 );
derived.LoadVectorf( INPUT4[0], 1, 1.0 );
derived.LoadVectorf( INPUT4[1], 1, 0.0 );
multi->FeedForward( INPUT, FINAL, 2 );
multi->FeedForward( INPUT2, FINAL2, 2 );
multi->FeedForward( INPUT3, FINAL3, 2 );
multi->FeedForward( INPUT4, FINAL4, 2 );
derived.PrintMatrix( "before training (false)", FINAL, 2, OUT_SIZE );
derived.PrintMatrix( "before training (false)", FINAL3, 2, OUT_SIZE );
derived.PrintMatrix( "before training (false)", FINAL4, 2, OUT_SIZE );
derived.PrintMatrix( "before training (true)", FINAL2, 2, OUT_SIZE );
derived.train( &multi, 5, 1.0 );
multi->FeedForward( INPUT, FINAL, 2 );
multi->FeedForward( INPUT2, FINAL2, 2 );
multi->FeedForward( INPUT3, FINAL3, 2 );
multi->FeedForward( INPUT4, FINAL4, 2 );
derived.PrintMatrix( "after training (false)", FINAL, 2, OUT_SIZE );
derived.PrintMatrix( "after training (false)", FINAL3, 2, OUT_SIZE );
derived.PrintMatrix( "after training (false)", FINAL4, 2, OUT_SIZE );
derived.PrintMatrix( "after training (true)", FINAL2, 2, OUT_SIZE );
cout << '\n';
return 0;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -