⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 grid.py

📁 library for SVMclassification and regression. It solves C-SVM classification, nu-SVM classification
💻 PY
字号:
#!/usr/bin/env pythonimport os, sys, tracebackimport Queueimport getpassfrom threading import Threadfrom string import find, split, joinfrom subprocess import *# svmtrain and gnuplot executableis_win32 = (sys.platform == 'win32')if not is_win32:       svmtrain_exe = "../svm-train"       gnuplot_exe = "/usr/bin/gnuplot"else:       # example for windows       svmtrain_exe = r"..\windows\svmtrain.exe"       gnuplot_exe = r"c:\tmp\gnuplot\bin\pgnuplot.exe"# global parameters and their default valuesfold = 5c_begin, c_end, c_step = -5,  15, 2g_begin, g_end, g_step =  3, -15, -2global dataset_pathname, dataset_title, pass_through_stringglobal out_filename, png_filename# experimentaltelnet_workers = []ssh_workers = []nr_local_worker = 1# process command line options, set global parametersdef process_options(argv=sys.argv):    global fold    global c_begin, c_end, c_step    global g_begin, g_end, g_step    global dataset_pathname, dataset_title, pass_through_string    global svmtrain_exe, gnuplot_exe, gnuplot, out_filename, png_filename        usage = """\Usage: grid.py [-log2c begin,end,step] [-log2g begin,end,step] [-v fold] [-svmtrain pathname] [-gnuplot pathname] [-out pathname] [-png pathname][additional parameters for svm-train] dataset"""    if len(argv) < 2:        print usage        sys.exit(1)    dataset_pathname = argv[-1]    dataset_title = os.path.split(dataset_pathname)[1]    out_filename = '%s.out' % dataset_title    png_filename = '%s.png' % dataset_title    pass_through_options = []    i = 1    while i < len(argv) - 1:        if argv[i] == "-log2c":            i = i + 1            (c_begin,c_end,c_step) = map(float,split(argv[i],","))        elif argv[i] == "-log2g":            i = i + 1            (g_begin,g_end,g_step) = map(float,split(argv[i],","))        elif argv[i] == "-v":            i = i + 1            fold = argv[i]        elif argv[i] in ('-c','-g'):            print "Option -c and -g are renamed."            print usage            sys.exit(1)        elif argv[i] == '-svmtrain':            i = i + 1            svmtrain_exe = argv[i]        elif argv[i] == '-gnuplot':            i = i + 1            gnuplot_exe = argv[i]        elif argv[i] == '-out':            i = i + 1            out_filename = argv[i]        elif argv[i] == '-png':            i = i + 1            png_filename = argv[i]        else:            pass_through_options.append(argv[i])        i = i + 1    pass_through_string = join(pass_through_options," ")    assert os.path.exists(svmtrain_exe),"svm-train executable not found"        assert os.path.exists(gnuplot_exe),"gnuplot executable not found"    assert os.path.exists(dataset_pathname),"dataset not found"    gnuplot = Popen(gnuplot_exe,stdin = PIPE).stdindef range_f(begin,end,step):    # like range, but works on non-integer too    seq = []    while 1:        if step > 0 and begin > end: break        if step < 0 and begin < end: break        seq.append(begin)        begin = begin + step    return seqdef permute_sequence(seq):    n = len(seq)    if n <= 1: return seq    mid = int(n/2)    left = permute_sequence(seq[:mid])    right = permute_sequence(seq[mid+1:])    ret = [seq[mid]]    while left or right:        if left: ret.append(left.pop(0))        if right: ret.append(right.pop(0))    return retdef redraw (db,tofile=0):    if len(db) == 0: return    begin_level = round(max(map(lambda(x):x[2],db))) - 3    step_size = 0.5    if tofile:        gnuplot.write("set term png transparent small\n")        gnuplot.write("set output \"%s\"\n" % png_filename.replace('\\','\\\\'))        #gnuplot.write("set term postscript color solid\n")        #gnuplot.write("set output \"%s.ps\"\n" % dataset_title)    else:        if is_win32:            gnuplot.write("set term windows\n")        else:            gnuplot.write("set term x11\n")    gnuplot.write("set xlabel \"lg(C)\"\n")    gnuplot.write("set ylabel \"lg(gamma)\"\n")    gnuplot.write("set xrange [%s:%s]\n" % (c_begin,c_end))    gnuplot.write("set yrange [%s:%s]\n" % (g_begin,g_end))    gnuplot.write("set contour\n")    gnuplot.write("set cntrparam levels incremental %s,%s,100\n" % (begin_level,step_size))    gnuplot.write("set nosurface\n")    gnuplot.write("set view 0,0\n")    gnuplot.write("set label \"%s\" at screen 0.4,0.9\n" % dataset_title)    gnuplot.write("splot \"-\" with lines\n")    def cmp (x,y):        if x[0] < y[0]: return -1        if x[0] > y[0]: return 1        if x[1] > y[1]: return -1        if x[1] < y[1]: return 1        return 0    db.sort(cmp)    prevc = db[0][0]    for line in db:        if prevc != line[0]:            gnuplot.write("\n")            prevc = line[0]        gnuplot.write("%s %s %s\n" % line)    gnuplot.write("e\n")    gnuplot.flush()def calculate_jobs():    c_seq = permute_sequence(range_f(c_begin,c_end,c_step))    g_seq = permute_sequence(range_f(g_begin,g_end,g_step))    nr_c = float(len(c_seq))    nr_g = float(len(g_seq))    i = 0    j = 0    jobs = []    while i < nr_c or j < nr_g:        if i/nr_c < j/nr_g:            # increase C resolution            line = []            for k in range(0,j):                line.append((c_seq[i],g_seq[k]))            i = i + 1            jobs.append(line)        else:            # increase g resolution            line = []            for k in range(0,i):                line.append((c_seq[k],g_seq[j]))            j = j + 1            jobs.append(line)    return jobsclass WorkerStopToken:  # used to notify the worker to stop        passclass Worker(Thread):    def __init__(self,name,job_queue,result_queue):        Thread.__init__(self)        self.name = name        self.job_queue = job_queue        self.result_queue = result_queue    def run(self):        while 1:            (cexp,gexp) = self.job_queue.get()            if cexp is WorkerStopToken:                self.job_queue.put((cexp,gexp))                # print 'worker %s stop.' % self.name                break            try:                rate = self.run_one(2.0**cexp,2.0**gexp)                if rate is None: raise "get no rate"            except:                # we failed, let others do that and we just quit                traceback.print_tb(sys.exc_traceback)                self.job_queue.put((cexp,gexp))                print 'worker %s quit.' % self.name                break            else:                self.result_queue.put((self.name,cexp,gexp,rate))class LocalWorker(Worker):    def run_one(self,c,g):        cmdline = '%s -c %s -g %s -v %s %s %s' % \          (svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)        result = Popen(cmdline,shell=True,stdout=PIPE).stdout        for line in result.readlines():            if find(line,"Cross") != -1:                return float(split(line)[-1][0:-1])class SSHWorker(Worker):    def __init__(self,name,job_queue,result_queue,host):        Worker.__init__(self,name,job_queue,result_queue)        self.host = host        self.cwd = os.getcwd()    def run_one(self,c,g):        cmdline = 'ssh -x %s "cd %s; %s -c %s -g %s -v %s %s %s"' % \          (self.host,self.cwd,           svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)        result = Popen(cmdline,shell=True,stdout=PIPE).stdout        for line in result.readlines():            if find(line,"Cross") != -1:                return float(split(line)[-1][0:-1])class TelnetWorker(Worker):    def __init__(self,name,job_queue,result_queue,host,username,password):        Worker.__init__(self,name,job_queue,result_queue)        self.host = host        self.username = username        self.password = password            def run(self):        import telnetlib        self.tn = tn = telnetlib.Telnet(self.host)        tn.read_until("login: ")        tn.write(self.username + "\n")        tn.read_until("Password: ")        tn.write(self.password + "\n")        # XXX: how to know whether login is successful?        tn.read_until(self.username)        #         print 'login ok', self.host        tn.write("cd "+os.getcwd()+"\n")        Worker.run(self)        tn.write("exit\n")                   def run_one(self,c,g):        cmdline = '%s -c %s -g %s -v %s %s %s' % \          (svmtrain_exe,c,g,fold,pass_through_string,dataset_pathname)        result = self.tn.write(cmdline+'\n')        (idx,matchm,output) = self.tn.expect(['Cross.*\n'])        for line in split(output,'\n'):            if find(line,"Cross") != -1:                return float(split(line)[-1][0:-1])def main():    # set parameters    process_options()    # put jobs in queue    jobs = calculate_jobs()    job_queue = Queue.Queue(0)    result_queue = Queue.Queue(0)    for line in jobs:        for (c,g) in line:            job_queue.put((c,g))    # hack the queue to become a stack --    # this is important when some thread    # failed and re-put a job. If we still    # use FIFO, the job will be put    # into the end of the queue, and the graph    # will only be updated in the end    def _put(self,item):        if sys.hexversion >= 0x020400A1:            self.queue.appendleft(item)        else:            self.queue.insert(0,item)    import new    job_queue._put = new.instancemethod(_put,job_queue,job_queue.__class__)    # fire telnet workers    if telnet_workers:        nr_telnet_worker = len(telnet_workers)        username = getpass.getuser()        password = getpass.getpass()        for host in telnet_workers:            TelnetWorker(host,job_queue,result_queue,                     host,username,password).start()    # fire ssh workers    if ssh_workers:        for host in ssh_workers:            SSHWorker(host,job_queue,result_queue,host).start()    # fire local workers    for i in range(nr_local_worker):        LocalWorker('local',job_queue,result_queue).start()    # gather results    done_jobs = {}    result_file = open(out_filename,'w',0)    db = []    best_rate = -1    for line in jobs:        for (c,g) in line:            while not done_jobs.has_key((c,g)):                (worker,c1,g1,rate) = result_queue.get()                done_jobs[(c1,g1)] = rate                result_file.write('%s %s %s\n' %(c1,g1,rate))                result_file.flush()                print "[%s] %s %s %s" % (worker,c1,g1,rate),                if (rate > best_rate) or (rate==best_rate and g1==best_g1 and c1<best_c1):                    best_rate = rate                    best_c1,best_g1=c1,g1                    best_c = 2.0**c1                    best_g = 2.0**g1                print " (best c=%s, g=%s, rate=%s)" % \                    (best_c, best_g, best_rate)            db.append((c,g,done_jobs[(c,g)]))        redraw(db)        redraw(db,1)    job_queue.put((WorkerStopToken,None))    print "%s %s %s" % (best_c, best_g, best_rate)main()

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -