⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 owtestlearners.py

📁 orange源码 数据挖掘技术
💻 PY
字号:
"""
<name>Test Learners</name>
<description>Estimates the predictive performance of learners on a data set.</description>
<icon>icons/TestLearners.png</icon>
<contact>Blaz Zupan (blaz.zupan(@at@)fri.uni-lj.si)</contact> 
<priority>200</priority>
"""
#
# OWTestLearners.py
#

from qttable import *
from OWWidget import *
import orngTest, orngStat, OWGUI
import warnings
warnings.filterwarnings("ignore", "'id' is not a builtin attribute", orange.AttributeWarning)

##############################################################################

class OWTestLearners(OWWidget):
    settingsList = ["sampleMethod", "nFolds", "pLearning", "useStat", "pRepeat", "precision"]
    callbackDeposit = []

    stat = ( ('Classification Accuracy', 'CA', 'CA(res)'),
             ('Sensitivity', 'Sens', 'sens(cm)'),
             ('Specificity', 'Spec', 'spec(cm)'),
             ('Area Under ROC Curve', 'AUC', 'AUC(res)'),
             ('Information Score', 'IS', 'IS(res)'),
             ('Brier Score', 'Brier', 'BrierScore(res)')
           )
    
    def __init__(self,parent=None, signalManager = None):
        OWWidget.__init__(self, parent, signalManager, "TestLearners")
        
        self.inputs = [("Data", ExampleTableWithClass, self.cdata, Default), ("Separate Test Data", ExampleTableWithClass, self.testdata), ("Learner", orange.Learner, self.learner, Multiple)]
        self.outputs = [("Evaluation Results", orngTest.ExperimentResults)]

        # Settings
        self.sampleMethod = 0           # cross validation
        self.nFolds = 5                 # cross validation folds
        self.pLearning = 70   # size of learning set when sampling [%]
        self.useStat = [1] * len(self.stat)
        self.pRepeat = 10
        self.precision = 4
        self.loadSettings()
        
        self.data = None                # input data set
        self.testdata = None            # separate test data set
        self.learners = None            # set of learners (input)
        self.results = None             # from orngTest
        self.scores = None              # to be displayed in the table

        # GUI
        self.s = [None] * 5
        self.sBox = QVButtonGroup("Sampling", self.controlArea)        
        self.s[0] = QRadioButton('Cross Validation', self.sBox)

        box = QHBox(self.sBox)
        QWidget(box).setFixedSize(19, 8)
        OWGUI.spin(box, self, 'nFolds', 2, 100, step=1, label='Number of Folds:  ')

        self.s[1] = QRadioButton('Leave-One-Out', self.sBox)
        self.s[2] = QRadioButton('Random Sampling', self.sBox)

        box = QHBox(self.sBox)
        QWidget(box).setFixedSize(19, 8)
        OWGUI.spin(box, self, 'pRepeat', 1, 100, step=1, label='Repeat Train/Test:  ')

        self.h2Box = QHBox(self.sBox)
        QWidget(self.h2Box).setFixedSize(19, 8)
        QLabel("Relative Training Set Size:", self.h2Box)
        box = QHBox(self.sBox)
        QWidget(box).setFixedSize(19, 8)
        OWGUI.hSlider(box, self, 'pLearning', minValue=10, maxValue=100, step=1, ticks=10, labelFormat="   %d%%")        

        self.s[3] = QRadioButton('Test on Train Data', self.sBox)        
        self.s[4] = self.testDataBtn = QRadioButton('Test on Test Data', self.sBox)

        QWidget(self.sBox).setFixedSize(0, 8)
        self.applyBtn = QPushButton("&Apply", self.sBox)
        self.applyBtn.setDisabled(TRUE)

        if self.sampleMethod == 4:
            self.sampleMethod = 0
        self.s[self.sampleMethod].setChecked(TRUE)        
        OWGUI.separator(self.controlArea)
        
        # statistics
        self.statBox = QVGroupBox(self.controlArea)
        self.statBox.setTitle('Statistics')
        self.statBtn = []
        for i in range(len(self.stat)):
            self.statBtn.append(QCheckBox(self.stat[i][0], self.statBox))
            self.statBtn[i].setChecked(self.useStat[i])

        # table with results
        self.layout=QVBoxLayout(self.mainArea)
        self.g = QVGroupBox(self.mainArea)
        self.g.setTitle('Evaluation Results')

        self.tab=QTable(self.g)
        self.tab.setSelectionMode(QTable.NoSelection)
        self.layout.add(self.g)

        self.lab = QLabel(self.g)
            
        # signals
        self.connect(self.applyBtn, SIGNAL("clicked()"), self.test)
        self.dummy1 = [None]*len(self.s)
        for i in range(len(self.s)):
            self.dummy1[i] = lambda x, v=i: self.sChanged(x, v)
            self.connect(self.s[i], SIGNAL("toggled(bool)"), self.dummy1[i])
        self.dummy2 = [None]*len(self.stat)
        for i in range(len(self.stat)):
            self.dummy2[i] = lambda x, v=i: self.statChanged(x, v)
            self.connect(self.statBtn[i], SIGNAL('toggled(bool)'), self.dummy2[i])

        self.resize(600,400)

    # test() evaluates the learners on a sigle data set
    # if learner is specified, this is either a new or an oldlearner to
    # be tested. the list in results should either be recomputed or added
    # else, if learner=None, all results are recomputed (user pressed apply button)
    def test(self, learner=None):
        if not self.data:
            return
        if learner:
            learners = [learner]
        else:
            learners = self.learners

        if self.sampleMethod==4 and not self.testdata:
            self.results = None
            self.setStatTable() # makes table with results empty
            return

        pb = None
        if self.sampleMethod==0:
            pb = ProgressBar(self, iterations=self.nFolds)
            res = orngTest.crossValidation(learners, self.data, folds=self.nFolds, strat=orange.MakeRandomIndices.StratifiedIfPossible, callback=pb.advance, storeExamples = True)
        elif self.sampleMethod==1:
            pb = ProgressBar(self, iterations=len(self.data))
            res = orngTest.leaveOneOut(learners, self.data, callback=pb.advance, storeExamples = True)
        elif self.sampleMethod==2:
            pb = ProgressBar(self, iterations=self.pRepeat)
            res = orngTest.proportionTest(learners, self.data, self.pLearning/100., times=self.pRepeat, callback=pb.advance)
        elif self.sampleMethod==3:
            res = orngTest.learnAndTestOnLearnData(learners, self.data)
        elif self.sampleMethod==4:                
            res = orngTest.learnAndTestOnTestData(learners, self.data, self.testdata)

        cm = orngStat.computeConfusionMatrices(res, classIndex = self.classindex)
        cdt = orngStat.computeCDT(res, classIndex = self.classindex)

        # merging of results and scores (if necessary)
        if self.results and learner:
            if learner.id not in [l.id for l in self.learners]:
                # this is a new learner, add new results
                self.results.classifierNames.append(learner.name)
                self.results.numberOfLearners += 1
                for i,r in enumerate(self.results.results):
                    r.classes.append(res.results[i].classes[0])
                    r.probabilities.append(res.results[i].probabilities[0])
                for (i, stat) in enumerate(self.stat):
                    try:
                        self.scores[i].append(eval('orngStat.' + stat[2])[0])
                    except:
                        self.scores[i].append(-1) # handle the exception
                        type, val, traceback = sys.exc_info()
                        sys.excepthook(type, val, traceback)  # print the exception
                        self.error("Caught an exception while evaluating classifiers")
            else:
                # this is an old but updated learner
                indx = [l.id for l in self.learners].index(learner.id)
                self.results.classifierNames[indx] = learner.name
                for i,r in enumerate(self.results.results):
                    r.classes[indx] = res.results[i].classes[0]
                    r.probabilities[indx] = res.results[i].probabilities[0]
                for (i, stat) in enumerate(self.stat):
                    try:
                        self.scores[i][indx] = eval('orngStat.' + stat[2])[0]
                    except:
                        self.scores[i][indx] = -1
                        type, val, traceback = sys.exc_info()
                        sys.excepthook(type, val, traceback)  # print the exception
                        self.error("Caught an exception while evaluating classifiers")
                    
        else: # test on all learners, or on the new learner with no other learners in the memory
            self.results = res
            self.scores = []
            for i in range(len(self.stat)):
                try:
                    self.scores.append(eval('orngStat.' + self.stat[i][2]))
                except:
                    self.scores.append([-1 for c in range(len(self.results.learners))]) # handle the exception
                    type, val, traceback = sys.exc_info()
                    sys.excepthook(type, val, traceback)  # print the exception
                    self.error("Caught an exception while evaluating classifiers")

        # update the tables that show the results
        self.setStatTable()
        self.send("Evaluation Results", self.results)
        if pb: pb.finish()

#        except Exception, msg:
#            QMessageBox.critical(self, self.title + ": Execution error", "Error while testing: '%s'" % msg)

    # slots: handle input signals
    def cdata(self, data):
        if not data:
            self.data = None
            self.results = None
            self.setStatTable()
            return

##        if self.testdata and data.domain <> self.testdata.domain:
##            self.testdata = None
##            self.results = None
##            # self.setStatTable()
        self.data = orange.Filter_hasClassValue(data)
        self.classindex = 0 # data.targetValIndx
        if self.learners:
            self.applyBtn.setDisabled(FALSE)
            self.results = None; self.scores = None
            self.test()

    def testdata(self, data):
        self.testdata = data
        if self.sampleMethod == 4:
            self.test()

    def learner(self, learner, id=None):
        if learner: # a new or updated learner
            # print 'ADD/UPD', learner.name, ", id:", id
            learner.id = id # remember id's of learners
            self.test(learner)
            if self.learners:
                if id not in [l.id for l in self.learners]:
                    self.learners.append(learner)
            else:
                self.learners = [learner]
            self.applyBtn.setDisabled(FALSE)
        else: # remove a learner and corresponding results
            # print 'REMOVE', id, 'FROM', self.learners
            ids = [l.id for l in self.learners]
            if id not in ids:
                return                  # happens if a widget with learner empties the signal first
            indx = ids.index(id)

            if self.results:
                del self.results.classifierNames[indx]
                self.results.numberOfLearners -= 1
                for i, r in enumerate(self.results.results):
                    del r.classes[indx]
                    del r.probabilities[indx]
                for (i, stat) in enumerate(self.stat):
                    del self.scores[i][indx]
                self.setStatTable()
                self.send("Evaluation Results", self.results)

            del self.learners[indx]
            
    # signal processing
    def statChanged(self, value, id):
        self.useStat[id] = value
        if value:
            self.tab.showColumn(id+1)
            self.tab.adjustColumn(id+1)
        else:
            self.tab.hideColumn(id+1)

    def sChanged(self, value, id):
        if self.sampleMethod <> id:
            self.sampleMethod = id
            if self.data:
                self.results = None
                self.test()

    # reporting on evaluation results
    def setStatTable(self):
        if not self.results:
            self.tab.setNumRows(0)
            return
        self.tab.setNumCols(len(self.stat)+1)
        self.tabHH=self.tab.horizontalHeader()
        self.tabHH.setLabel(0, 'Classifier')
        for i in range(len(self.stat)):
            self.tabHH.setLabel(i+1, self.stat[i][1])

        self.tab.setNumRows(self.results.numberOfLearners)
        for i in range(len(self.results.classifierNames)):
            self.tab.setText(i, 0, self.results.classifierNames[i])

        prec="%%.%df" % self.precision

        for i in range(self.results.numberOfLearners):
            for j in range(len(self.stat)):
                if self.scores[j][i] < 1e-8:
                    self.tab.setText(i, j+1, "N/A")
                else:
                    self.tab.setText(i, j+1, prec % self.scores[j][i])

        for i in range(len(self.stat)+1):
            self.tab.adjustColumn(i)

        for i in range(len(self.stat)):
            if not self.useStat[i]:
                self.tab.hideColumn(i+1)

#
class ProgressBar:
    def __init__(self, widget, iterations):
        self.iter = iterations
        self.widget = widget
        self.count = 0
        self.widget.progressBarInit()
    def advance(self):
        self.count += 1
        self.widget.progressBarSet(int(self.count*100/self.iter))
    def finish(self):
        self.widget.progressBarFinished()

##############################################################################
# Test the widget, run from DOS prompt

if __name__=="__main__":
    a=QApplication(sys.argv)
    ow=OWTestLearners()
    a.setMainWidget(ow)

    data = orange.ExampleTable('voting')

    l1 = orange.MajorityLearner(); l1.name = '1 - Majority'

    l2 = orange.BayesLearner()
    l2.estimatorConstructor = orange.ProbabilityEstimatorConstructor_m(m=10)
    l2.conditionalEstimatorConstructor = orange.ConditionalProbabilityEstimatorConstructor_ByRows(estimatorConstructor = orange.ProbabilityEstimatorConstructor_m(m=10))
    l2.name = '2 - NBC (m=10)'

    l3 = orange.BayesLearner(); l3.name = '3 - NBC (default)'

    l4 = orange.MajorityLearner(); l4.name = "4 - Majority"

    testcase = 2

    if testcase == 0: # 1(UPD), 3, 4
        ow.cdata(data)
        ow.learner(l1, 1)
        ow.learner(l2, 2)
        ow.learner(l3, 3)
        l1.name = l1.name + " UPD"
        ow.learner(l1, 1)
        ow.learner(None, 2)
        ow.learner(l4, 4)
    if testcase == 1: # no data, all learners removed
        ow.learner(l1, 1)
        ow.learner(l2, 2)
        ow.learner(None, 2)
        ow.learner(None, 1)
        ow.cdata(data)
    if testcase == 2: # sends data, then learner, then removes the learner
        ow.cdata(data)
        ow.learner(l1, 1)
        ow.learner(None, 1)
        
    ow.show()
    a.exec_loop()
    ow.saveSettings()

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -