⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 fann_trainer.py

📁 python 神经网络 数据挖掘 python实现的神经网络算法
💻 PY
字号:
#!/usr/bin/python
import os
import sys
from pyfann import libfann
import signal, random
from optparse import OptionParser
random.seed()

class Trainer:
    def __init__(self, network_save_file):
        # initialize network parameters
        self.network_save_file = network_save_file



    def load_data(self, data_file,val_file=None):
        # create training data, and ann object
        print "Loading data"
        self.train_data = libfann.training_data()
        self.train_data.read_train_from_file(data_file)
        self.dim_input=self.train_data.num_input_train_data()
        self.dim_output=self.train_data.num_output_train_data()

        input=self.train_data.get_input()
        target=self.train_data.get_output()
        
        data_lo_hi=[0,0]
        for i in range(len(input)):
            if target[i][0]<0.5:
               data_lo_hi[0]+=1 
            elif target[i][0]>0.5:
               data_lo_hi[1]+=1
        
        print "\t Train data is %d low and %d high" % tuple(data_lo_hi)

        
        if (val_file and os.path.exists(val_file)):
            print "Loading validation data"
            self.do_validation=True
            self.val_data=libfann.training_data()
            self.val_data.read_train_from_file(val_file)
            input=self.val_data.get_input()
            target=self.val_data.get_output()
            data_lo_hi=[0,0]
            for i in range(len(input)):
                if target[i][0]<0.5:
                   data_lo_hi[0]+=1 
                elif target[i][0]>0.5:
                   data_lo_hi[1]+=1
            print "\t Validation data is %d low and %d high" % tuple(data_lo_hi)
        else:
            self.val_data=self.train_data
            self.do_validation=False


    def create_ann(self, network_load_file=None, connection_rate=1, learning_rate=0.2, num_neurons_hidden=15, config_file=None):
        self.connection_rate = connection_rate 
        self.learning_rate = learning_rate 
        self.num_neurons_hidden = num_neurons_hidden 
        
        if config_file:
            execfile(config_file)
        
        self.ann = libfann.neural_net()
        
        if (network_load_file and os.path.exists(network_load_file)):
            print "Loading network"	
            self.ann.create_from_file(network_load_file)
        else:
            print "Creating network"	
            self.ann.create_sparse_array(self.connection_rate, (self.dim_input, self.num_neurons_hidden, self.dim_output))


            self.ann.set_activation_function_hidden(libfann.SIGMOID)
            #self.ann.set_activation_function_hidden(libfann.SIGMOID_SYMMETRIC)
            #self.ann.set_activation_function_hidden(libfann.SIGMOID_STEPWISE)
            #self.ann.set_activation_function_hidden(libfann.SIGMOID_SYMMETRIC_STEPWISE)
            
            self.ann.set_activation_function_output(libfann.SIGMOID)
            #self.ann.set_activation_function_target(libfann.SIGMOID_SYMMETRIC)
            #self.ann.set_activation_function_target(libfann.SIGMOID_STEPWISE)
            #self.ann.set_activation_function_target(libfann.SIGMOID_SYMMETRIC_STEPWISE)
            
            self.ann.set_train_error_function(libfann.ERRORFUNC_TANH)
            #self.ann.set_train_error_function(libfann.ERRORFUNC_LINEAR)

            #self.ann.set_training_algorithm(libfann.TRAIN_INCREMENTAL)
            self.ann.set_training_algorithm(libfann.TRAIN_RPROP)
            #self.ann.set_training_algorithm(libfann.TRAIN_QUICKPROP)

            print "initializing weights"
            #self.ann.init_weights(self.train_data)
            #self.ann.randomize_weights(-.10,.10)
            
        self.ann.set_learning_rate(self.learning_rate)
        self.ann.set_rprop_increase_factor(1.2)
        self.ann.set_rprop_decrease_factor(0.5)
        self.ann.set_rprop_delta_min(0)
        self.ann.set_rprop_delta_max(50)
        self.ann.print_parameters()


    def train(self, desired_error = 0.000001, 
            max_iterations = 1000, iterations_between_reports = 10, error_center=0.5,):
        # start training the network
        self.desired_error = desired_error 
        self.max_iterations = max_iterations 
        self.iterations_between_reports = iterations_between_reports 
        self.error_center = error_center
        
        if self.do_validation:
            val_input=self.val_data.get_input()
            val_target=self.val_data.get_output()
        else:
            val_input=self.train_data.get_input()
            val_target=self.train_data.get_output()

        
        self.ann.set_bit_fail_limit(self.error_center)
        
        print "Training network"
        epoch=0
        last_best=0
        last_best_err=0
        exit_cond=False
        last_best_err=-1
        best_bit_fail="N/A"
        err_tr="N/A"
        err=0
        while not exit_cond:
            try:
                if self.do_validation:
                    err=self.ann.test_data(self.val_data)
                    
                if err<last_best_err:
                    self.ann.save(network_save_file)
                    last_best=epoch
                    last_best_err=err
                    if self.do_validation:
                        err,bf=self.test(val_input, val_target, self.error_center)
                        best_bit_fail=bf

                
                if (not epoch % self.iterations_between_reports):
                    if self.do_validation:
                        print "Epochs: %d, CTErr: %s, CVErr: %f, BVErr: %f, BitF[#err,[#elo,#ehi]]: %s, LSEp: %d" % \
                            (epoch, str(err_tr), err, last_best_err, str(best_bit_fail),last_best)
                    else:
                        print "Epochs: %d, CTErr: %f, BitF: %d" % \
                            (epoch, err_tr, self.ann.get_bit_fail())
                            
                err_tr=self.ann.train_epoch(self.train_data)
                if not self.do_validation:
                    err=err_tr
                if epoch==0:
                    last_best_err=err
                    
                epoch+=1

            except KeyboardInterrupt:
                self.ann.save(self.network_save_file+".current")
                exit_cond=True

    
    def test(self, input=None, target=None, error_center=0.5):
        if not input or not target:
            input=self.val_data.get_input()
            target=self.val_data.get_output()

        errors_lo_hi=[0,0]
        self.ann.reset_MSE()
        for i in range(len(input)):
            out=self.ann.test(input[i], target[i])
            if target[i][0]<0.5 and out[0]>error_center:
               errors_lo_hi[0]+=1 
            elif target[i][0]>0.5 and out[0]<error_center:
               errors_lo_hi[1]+=1 
        
        return  self.ann.get_MSE(), [sum(errors_lo_hi), errors_lo_hi]
        #return  tot_err/len(input), [sum(errors_lo_hi), errors_lo_hi, data_lo_hi]



def search_file(filename):
    if not (filename and os.path.exists(filename)):
        if filename:
            print"""\nFile %s could not be found\n""" % filename
        return False
    return True

if __name__ == "__main__":

    parser = OptionParser("usage: %prog -s net_file -d data_file [-l net_load_file]")

    parser.add_option("-d", "--data-file", type="string", dest="data_file", 
                help="Read data from the given file", default=None)

    parser.add_option("-l", "--network-load", type="string", dest="network_load_file", 
                help="Read fann_network data from the given file", default=None)

    parser.add_option("-s", "--network-save", type="string", dest="network_save_file", 
                help="Save fann_network data from the given file", default=None)
    
    parser.add_option("-c", "--config-file", type="string", dest="config_file", 
                help="Load config from the given file", default=None)

    parser.add_option("-v", "--validation-data", type="string", dest="val_file", 
                help="Save fann_network data from the given file", default=None)

    parser.add_option("-t", "--just-test", action="store_true", dest="just_test", 
                help="Do not train", default=False)

    parser.add_option("-r", "--reset-file", action="store_true", dest="reset_file", 
                help="Do not load the network from save train file", default=False)

    (options, args) = parser.parse_args(sys.argv[1:])

    data_file=options.data_file
    network_save_file=options.network_save_file
    network_load_file=options.network_load_file
    val_file=options.val_file
    just_test=options.just_test
    reset_file=options.reset_file

    if (not network_load_file and not reset_file):
        network_load_file=network_save_file


    if (not (search_file(data_file) and ( network_save_file or (just_test and search_file(network_load_file))))):
        parser.print_help()
        print data_file, network_save_file, network_load_file 
        parser.error("Incorrect number of arguments" + str(options))
        sys.exit(1)

    trainer=Trainer(network_save_file)
    trainer.load_data(data_file, val_file)
    trainer.create_ann(network_load_file,config_file=options.config_file)
    if not just_test:
        trainer.train()
    print "Testing network"
    trainer.create_ann(network_save_file,config_file=options.config_file)
    err,bf=trainer.test()
    print "Errors:  ", bf, err
    sys.exit(0)

    #signal.signal(signal.SIGINT,signal.SIG_DFL)

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -