📄 mushroom.py
字号:
#!/usr/bin/pythonimport fanndef print_callback(epochs, error): print "Epochs %8d. Current MSE-Error: %.10f\n" % (epochs, error) return 0# initialize network parametersconnection_rate = 1learning_rate = 0.7num_neurons_hidden = 32desired_error = 0.000001max_iterations = 300iterations_between_reports = 1# create training data, and ann objectprint "Creating network." train_data = fann.read_train_from_file("datasets/mushroom.train")ann = fann.create(connection_rate, learning_rate, (train_data.get_num_input(), num_neurons_hidden, train_data.get_num_output()))# start training the networkprint "Training network"ann.set_activation_function_hidden(fann.SIGMOID_SYMMETRIC_STEPWISE)ann.set_activation_function_output(fann.SIGMOID_STEPWISE)ann.set_training_algorithm(fann.TRAIN_INCREMENTAL) ann.train_on_data(train_data, max_iterations, iterations_between_reports, desired_error) # test outcomeprint "Testing network"test_data = fann.read_train_from_file("datasets/mushroom.test")ann.reset_MSE()for i in range(test_data.get_num_data()): ann.test(test_data.get_input(i), test_data.get_output(i))print "MSE error on test data: %f" % ann.get_MSE()# save network to diskprint "Saving network"ann.save("mushroom_float.net")
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -