mushroom.py

来自「一个功能强大的神经网络分析程序」· Python 代码 · 共 43 行

PY
43
字号
#!/usr/bin/pythonimport fanndef print_callback(epochs, error):	print "Epochs     %8d. Current MSE-Error: %.10f\n" % (epochs, error)	return 0# initialize network parametersconnection_rate = 1learning_rate = 0.7num_neurons_hidden = 32desired_error = 0.000001max_iterations = 300iterations_between_reports = 1# create training data, and ann objectprint "Creating network."	train_data = fann.read_train_from_file("datasets/mushroom.train")ann = fann.create(connection_rate, learning_rate, (train_data.get_num_input(), num_neurons_hidden, train_data.get_num_output()))# start training the networkprint "Training network"ann.set_activation_function_hidden(fann.SIGMOID_SYMMETRIC_STEPWISE)ann.set_activation_function_output(fann.SIGMOID_STEPWISE)ann.set_training_algorithm(fann.TRAIN_INCREMENTAL)	ann.train_on_data(train_data, max_iterations, iterations_between_reports, desired_error)	# test outcomeprint "Testing network"test_data = fann.read_train_from_file("datasets/mushroom.test")ann.reset_MSE()for i in range(test_data.get_num_data()):    ann.test(test_data.get_input(i), test_data.get_output(i))print "MSE error on test data: %f" % ann.get_MSE()# save network to diskprint "Saving network"ann.save("mushroom_float.net")

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?