📄 fann_cpp.h
字号:
{
public:
/** Constructor. Use one of the create functions to create a neural network. */
neural_net() : ann(NULL)
{
}
/** Destructor. Automatic cleanup. */
#ifdef USE_VIRTUAL_DESTRUCTOR
virtual
#endif
~neural_net()
{
destroy();
}
/** Destructs the entire network. Called automatically by the destructor. */
void destroy()
{
if (ann != NULL)
{
fann_destroy(ann);
ann = NULL;
}
}
/** Constructs a backpropagation neural network, from an connection rate,
a learning rate, the number of layers and the number of neurons in each
of the layers.
The connection rate controls how many connections there will be in the
network. If the connection rate is set to 1, the network will be fully
connected, but if it is set to 0.5 only half of the connections will be set.
There will be a bias neuron in each layer (except the output layer),
and this bias neuron will be connected to all neurons in the next layer.
When running the network, the bias nodes always emits 1 */
bool create(float connection_rate, float learning_rate,
/* the number of layers, including the input and output layer */
unsigned int num_layers,
/* the number of neurons in each of the layers, starting with
the input layer and ending with the output layer */
...)
{ va_list layers; va_start(layers, num_layers); bool status = create_array(connection_rate, learning_rate, num_layers, reinterpret_cast<unsigned int *>(layers)); va_end(layers); return status; }
/** Just like create, but with an array of layer sizes instead of individual parameters. */
bool create_array(float connection_rate, float learning_rate,
unsigned int num_layers, unsigned int * layers)
{
destroy();
ann = fann_create_array(connection_rate, learning_rate, num_layers, layers);
return (ann != NULL);
}
/** Create a fully connected neural network with shortcut connections. */
bool create_shortcut(float learning_rate,
/* the number of layers, including the input and output layer */
unsigned int num_layers,
/* the number of neurons in each of the layers, starting with
the input layer and ending with the output layer */
...)
{
va_list layers; va_start(layers, num_layers); bool status = create_shortcut_array(learning_rate, num_layers, reinterpret_cast<unsigned int *>(layers)); va_end(layers); return status; }
/** Create a neural network with shortcut connections. */
bool create_shortcut_array(float learning_rate, unsigned int num_layers,
unsigned int * layers)
{
destroy();
ann = fann_create_shortcut_array(learning_rate, num_layers, layers);
return (ann != NULL);
}
/** Runs an input through the network, and returns the output. */
fann_type* run(fann_type *input)
{
if (ann == NULL)
{
return NULL;
}
return fann_run(ann, input);
}
/** Randomize weights (from the beginning the weights are random between -0.1 and 0.1) */
void randomize_weights(fann_type min_weight, fann_type max_weight)
{
if (ann != NULL)
{
fann_randomize_weights(ann, min_weight, max_weight);
}
}
/** Initialize the weights using Widrow + Nguyen's algorithm.*/
void init_weights(const training_data &data)
{
if ((ann != NULL) && (data.train_data != NULL))
{
fann_init_weights(ann, data.train_data);
}
}
/** Print out which connections there are in the ann */
void print_connections()
{
if (ann != NULL)
{
fann_print_connections(ann);
}
}
/** Constructs a backpropagation neural network from a configuration file. */
bool create_from_file(const char *configuration_file)
{
destroy();
ann = fann_create_from_file(configuration_file);
return (ann != NULL);
}
/** Save the entire network to a configuration file. */
void save(const char *configuration_file)
{
if (ann != NULL)
{
fann_save(ann, configuration_file);
}
}
/** Saves the entire network to a configuration file.
But it is saved in fixed point format no matter which
format it is currently in.
This is usefull for training a network in floating points,
and then later executing it in fixed point.
The function returns the bit position of the fix point, which
can be used to find out how accurate the fixed point network will be.
A high value indicates high precision, and a low value indicates low
precision.
A negative value indicates very low precision, and a very
strong possibility for overflow.
(the actual fix point will be set to 0, since a negative
fix point does not make sence).
Generally, a fix point lower than 6 is bad, and should be avoided.
The best way to avoid this, is to have less connections to each neuron,
or just less neurons in each layer.
The fixed point use of this network is only intended for use on machines that
have no floating point processor, like an iPAQ. On normal computers the floating
point version is actually faster. */
int save_to_fixed(const char *configuration_file)
{
int fixpoint = 0;
if (ann != NULL)
{
fixpoint = fann_save_to_fixed(ann, configuration_file);
}
return fixpoint;
}
#ifndef FIXEDFANN
/** Train one iteration with a set of inputs, and a set of desired outputs. */
void train(fann_type *input, fann_type *desired_output)
{
if (ann != NULL)
{
fann_train(ann, input, desired_output);
}
}
#endif /* NOT FIXEDFANN */
/** Test with a set of inputs, and a set of desired outputs.
This operation updates the mean square error, but does not
change the network in any way. */
fann_type * test(fann_type *input, fann_type *desired_output)
{
fann_type * output = NULL;
if (ann != NULL)
{
output = fann_test(ann, input, desired_output);
}
return output;
}
/** Reads the mean square error from the network. */
float get_MSE()
{
float mse = 0.0f;
if (ann != NULL)
{
mse = fann_get_MSE(ann);
}
return mse;
}
/** Resets the mean square error from the network. */
void reset_MSE()
{
if (ann != NULL)
{
fann_reset_MSE(ann);
}
}
#ifndef FIXEDFANN
/** Train one epoch with a set of training data. */
float train_epoch(const training_data &data)
{
float mse = 0.0f;
if ((ann != NULL) && (data.train_data != NULL))
{
mse = fann_train_epoch(ann, data.train_data);
}
return mse;
}
/** Test a set of training data and calculate the MSE */
float test_data(const training_data &data)
{
float mse = 0.0f;
if ((ann != NULL) && (data.train_data != NULL))
{
mse = fann_test_data(ann, data.train_data);
}
return mse;
}
/** Trains on an entire dataset, for a maximum of max_epochs
epochs or until mean square error is lower than desired_error.
Reports about the progress is given every
epochs_between_reports epochs.
If epochs_between_reports is zero, no reports are given. */
void train_on_data(const training_data &data, unsigned int max_epochs,
unsigned int epochs_between_reports, float desired_error)
{
if ((ann != NULL) && (data.train_data != NULL))
{
fann_train_on_data(ann, data.train_data, max_epochs,
epochs_between_reports, desired_error);
}
}
/** Same as fann_train_on_data, but a callback function is given,
which can be used to print out reports. (effective for gui programming).
If the callback returns -1, then the training is terminated, otherwise
it continues until the normal stop criteria. */
void train_on_data_callback(const training_data &data, unsigned int max_epochs,
unsigned int epochs_between_reports, float desired_error,
int (FANN_API *callback)(unsigned int epochs, float error))
{
if ((ann != NULL) && (data.train_data != NULL))
{
fann_train_on_data_callback(ann, data.train_data, max_epochs,
epochs_between_reports, desired_error, callback);
}
}
/** Does the same as train_on_data, but reads the data directly from a file. */
void train_on_file(char *filename, unsigned int max_epochs,
unsigned int epochs_between_reports, float desired_error)
{
if (ann != NULL)
{
fann_train_on_file(ann, filename, max_epochs,
epochs_between_reports, desired_error);
}
}
/** Does the same as train_on_data_callback, but reads the data directly from a file. */
void train_on_file_callback(char *filename, unsigned int max_epochs,
unsigned int epochs_between_reports, float desired_error,
int (FANN_API *callback)(unsigned int epochs, float error))
{
if (ann != NULL)
{
fann_train_on_file_callback(ann, filename, max_epochs,
epochs_between_reports, desired_error, callback);
}
}
#endif /* NOT FIXEDFANN */
/** Prints all of the parameters and options of the ANN */
void print_parameters()
{
if (ann != NULL)
{
fann_print_parameters(ann);
}
}
/** Get the training algorithm. */
training_algorithm_enum get_training_algorithm()
{
unsigned int training_algorithm = 0;
if (ann != NULL)
{
training_algorithm = fann_get_training_algorithm(ann);
}
return (training_algorithm_enum)training_algorithm;
}
/** Set the training algorithm. */
void set_training_algorithm(training_algorithm_enum training_algorithm)
{
if (ann != NULL)
{
fann_set_training_algorithm(ann, training_algorithm);
}
}
/** Get the learning rate. */
float get_learning_rate()
{
float learning_rate = 0.0f;
if (ann != NULL)
{
learning_rate = fann_get_learning_rate(ann);
}
return learning_rate;
}
/** Set the learning rate. */
void set_learning_rate(float learning_rate)
{
if (ann != NULL)
{
fann_set_learning_rate(ann, learning_rate);
}
}
/** Get the activation function used in the hidden layers. */
activation_function_enum get_activation_function_hidden()
{
unsigned int activation_function = 0;
if (ann != NULL)
{
activation_function = fann_get_activation_function_hidden(ann);
}
return (activation_function_enum)activation_function;
}
/** Set the activation function for the hidden layers. */
void set_activation_function_hidden(activation_function_enum activation_function)
{
if (ann != NULL)
{
fann_set_activation_function_hidden(ann, activation_function);
}
}
/** Get the activation function used in the output layer. */
activation_function_enum get_activation_function_output()
{
unsigned int activation_function = 0;
if (ann != NULL)
{
activation_function = fann_get_activation_function_output(ann);
}
return (activation_function_enum)activation_function;
}
/** Set the activation function for the output layer. */
void set_activation_function_output(activation_function_enum activation_function)
{
if (ann != NULL)
{
fann_set_activation_function_output(ann, activation_function);
}
}
/** Get the steepness parameter for the sigmoid function used in the hidden layers. */
fann_type get_activation_steepness_hidden()
{
fann_type activation_steepness = 0;
if (ann != NULL)
{
activation_steepness = fann_get_activation_steepness_hidden(ann);
}
return activation_steepness;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -