codegen.py

来自「CVXMOD is a Python-based tool for expres」· Python 代码 · 共 986 行 · 第 1/2 页

PY
986
字号
"""Code generation for CVXMOD."""# Copyright (C) 2006-2008 Jacob Mattingley and Stephen Boyd.## This file is part of CVXMOD.## CVXMOD is free software; you can redistribute it and/or modify it under the# terms of the GNU General Public License as published by the Free Software# Foundation; either version 3 of the License, or (at your option) any later# version.## CVXMOD is distributed in the hope that it will be useful, but WITHOUT ANY# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR# A PARTICULAR PURPOSE. See the GNU General Public License for more details.## You should have received a copy of the GNU General Public License along with# this program. If not, see <http://www.gnu.org/licenses/>.from cvxmod.base import *from cvxmod.base import withbrackets, addfunction, multfunction, Jwarn,\        addormultfunction, symbslicefrom cvxmod.symbolic import *from cvxmod.symbolic import concatvertsymb, concathorizsymb, constsymb, eyesymbolfrom cvxmod.util import bylowerstr, joinlistfrom cvxopt.amd import orderfrom time import strftimeimport osimport reclass nze(symbol):    # Nonzero elements, eg for use in a matrix. Invariant under multiplication    # and division, except for by zero.    def __init__(self, name=None):        self.rows = 1        self.cols = 1        self.name = name    def __repr__(self):        return str(self)    def __str__(self):        if self.name is None:            return '<nze>'        else:            return '<nze %s>' % self.name    def __div__(self, other):        if iszero(other):            raise ZeroDivisionError        else:            return self    def __rdiv__(self, other):        if iszero(other):            return 0        else:            return self    def __mul__(self, other):        if iszero(other):            return 0        else:            return self    def __rmul__(self, other):        if iszero(other):            return 0        else:            return self    def __add__(self, other):        return self    def __radd__(self, other):        return self    def __sub__(self, other):        return self    def __rsub__(self, other):        return selfclass divfunction(addormultfunction):    # first cut. no size checking or anything.    def __init__(self, lhs, rhs):        self.lhs = lhs        self.rhs = rhs        self.rows = 1        self.cols = 1        self.brackets = True    def __str__(self):        return withbrackets(self.lhs) + '/' + withbrackets(self.rhs)def divide(lhs, rhs):    if iszero(lhs) or iszero(rhs):        return 0    else:        return divfunction(lhs, rhs)class sparsedict(dict):    def __init__(self):        dict.__init__(self)    def __getitem__(self, key):        try:            return dict.__getitem__(self, key)        except KeyError:            return 0def addsplit(obj):    # Argument order is (pre-operations, actual, nonindexed).    if isoptvar(obj) or isparam(obj) or isinstance(obj, (int, float)):        if is1x1(obj):            return ([], [], obj)        else:            return ([], [obj], 0)    elif isinstance(obj, addfunction):        return [x + y for (x, y) in zip(addsplit(obj.lhs), addsplit(obj.rhs))]    elif isinstance(obj, multfunction):        if is1x1(obj.lhs):            (pre, actual, nonind) = addsplit(obj.rhs)            pre = [obj.lhs*x for x in pre]            actual = [obj.lhs*x for x in actual]            nonind = obj.lhs*nonind                        return (pre, actual, nonind)        else:            raise NotImplementedError('could not handle ' + str(obj))    elif isinstance(obj, constsymb):        return ([], [], obj.constval)    else:        raise NotImplementedError('could not handle ' + str(obj))def stuffmatrix(obj, name):    s = ''    s += '\n// Begin: definition of %s.\n' % name    for l in strsymbmatrix(obj, name).splitlines():        s += '//   %s\n' % l    s += '%s = malloc((%s)*sizeof(double));\n\n' % \            (name, str(rows(obj) * cols(obj)))    s += stuffentry(obj, name, rows(obj), 0, 0, 0)    s += '// End: definition of %s.\n' % name    return sdef stuffentry(obj, name, rs, csleft, rsabove, rsbelow, smult=1):    def stuff1x1(val):        return '%s[%s] = %s;\n' % (name, str(topleft), str(val))    def nwsm(n):        # Name with scalar multiplier.        # Takes smult implicitly.        return '%s' % str(smult*n)    # here (or elsewhere) keep track of existing multiplied stuff. for example,    # if we have tr(c)*x, don't want to have to work it out every time.    # Everything should be able to accept a scalar multiple for one of its    # arguments.    s = ''    topleft = rs*csleft + rsabove    if isinstance(obj, (int, float, matrix, spmatrix)) and is1x1(obj):        s += stuff1x1(str(obj))    elif isinstance(obj, concatvertsymb):        s += stuffentry(obj.lhs, name, rs, csleft, rsabove, rsbelow, smult)        s += stuffentry(obj.rhs, name, rs, csleft, rsabove + rows(obj.lhs),                        rsbelow - rows(obj.lhs), smult)    elif isinstance(obj, concathorizsymb):        s += stuffentry(obj.lhs, name, rs, csleft, rsabove, rsbelow, smult)        s += stuffentry(obj.rhs, name, rs, csleft + cols(obj.lhs), rsabove,                        rsbelow, smult)    elif isinstance(obj, constsymb):        s += 'for (col = 0; col < %s; col++)\n' % str(cols(obj))        s += ' '*4 + 'for (row = 0; row < %s; row++)\n' % str(rows(obj))        s += ' '*8 + '%s[%s + col*(%s) + row] = %s;\n\n' % \                (str(name), str(topleft), str(rs), nwsm(obj.constval))    elif isinstance(obj, eyesymbol):        s += 'for (col = 0; col < %s; col++)\n' % str(cols(obj))        s += ' '*4 + 'for (row = 0; row < %s; row++)\n' % str(rows(obj))        s += ' '*8 + '%s[%s + col*(%s) + row] =\n          (row == col) ? %s : 0;\n\n' % \                (str(name), str(topleft), str(rs), str(smult))    elif isparam(obj):        s += 'for (col = 0; col < %s; col++)\n' % str(cols(obj))        s += ' '*4 + 'for (row = 0; row < %s; row++)\n' % str(rows(obj))        s += ' '*8 + '%s[%s + col*(%s) + row] =\n          %s[col*(%s) + row];\n\n' % \                (str(name), str(topleft), str(rs), nwsm(obj), str(rows(obj)))    elif isinstance(obj, multfunction):        # nastiness of non-recursion here!        if is1x1(obj.lhs):            s += stuffentry(obj.rhs, name, rs, csleft, rsabove, rsbelow, obj.lhs)        else:            raise NotImplementedError('could not handle ' + str(obj))    elif isinstance(obj, addfunction):        # pre: temporary variables that need to be calculated first.        # (need to come up with an indexing scheme so that A*x, for example,        # can be reused. Probably assume         # actual: pieces of expression that should be indexed.        # noniter: pieces of expression that should not be indexed.        (pre, actual, nonind) = addsplit(obj)        if pre:            s += pre + '\n\n'        rhs = "\n" + " "*11 + " + " .join(            [str(smult*x) + ("[col*(%s) + row]" % str(cols(x))) for x in actual])        if nonind:            rhs += "\n" + " "*10 + " + " + str(smult*nonind)        s += 'for (col = 0; col < %s; col++)\n' % str(cols(obj))        s += ' '*4 + 'for (row = 0; row < %s; row++)\n' % str(rows(obj))        s += ' '*8 + '%s[%s + col*(%s) + row] = %s;\n' % \                (str(name), str(topleft), str(rs), rhs)    else:        raise NotImplementedError('could not handle ' + str(obj))    return sdef exprtoC(obj):    if isinstance(obj, symbslice):        (a, b) = obj.splitsl()        if b is None:            return str(obj.arg) + '[%d]' % a        else:            index = a + b*value(rows(obj.arg))            return str(obj.arg) + '[%d]' % index    elif isinstance(obj, (int, float)):        return str(obj)    elif isinstance(obj, multfunction):        if obj.lhs is -1:            return '-' + exprtoC(obj.rhs)        else:            return exprtoC(obj.lhs) + '*' + exprtoC(obj.rhs)    elif isinstance(obj, addfunction):        return exprtoC(obj.lhs) + ' + ' + exprtoC(obj.rhs)    else:        raise NotImplementedError('cannot handle %s' % str(obj))class codegen(object):    def __init__(self, p, outdir, skel):        self.p = p        self.outdir = outdir        self.skel = skel    def printmatrix(self, x):        m = rows(x)        n = cols(x)        s = 'printf("%s =\\n");\n' % str(x)        s += "printmatrix(%s, %s, %s);\n" % (str(x), str(m), str(n))        self.write(s)    def declarematrix(self, A):        s = "double %s[] = {\n " % str(A)        offset = 1        v = matrix(value(A))        k = 0        for j in range(value(cols(A))):            for i in range(value(rows(A))):                t = "% .7g, " % v[i,j]                # jem: enhance this later to be nicer looking.                s += t.rjust(12)                if (k % 7) == 6:                    s += "\n"                    s += ' ' * offset                k += 1            #s += "\n"            #s += ' ' * offset        s = s[:-(offset+3)] + "};\n\n"        return s    def readme(self):        p = self.p        d = {}        d['header'] = header()        d['probdef'] = str(p)        d['stdform'] = p.strsymbsolve()        writecode(self.outdir, self.skel, 'README', d)    def tostdform(self):        p = self.p        ((ct, At, bt, Gt, ht), xt, d, optvars) = p.symbsolve('cAbGh')        d = {}        d['header'] = header()        # Now is when we determine the signature.        sig = 'lpstdform tostdform(\n' + ' '*8        if getdims(p):            sig += ', '.join(['int %s' % x for x in bylowerstr(getdims(p))])        if getparams(p):            sig += ',\n' + ' '*8            sig += ', '.join(['double *%s' % x for x in bylowerstr(getparams(p))])        sig += ')'        d['tostdform_sig'] = sig        d['lsf_c'] = stuffmatrix(ct, 'lsf->c')        d['lsf_A'] = stuffmatrix(At, 'lsf->A')        d['lsf_b'] = stuffmatrix(bt, 'lsf->b')        d['lsf_G'] = stuffmatrix(Gt, 'lsf->G')        d['lsf_h'] = stuffmatrix(ht, 'lsf->h')        sig = 'int convertback(outvar *vars,\n' + ' '*8        if getdims(p):            sig += ', '.join(['int %s' % x for x in bylowerstr(getdims(p))])        sig += ',\n        double *xt)'        d['convertback_sig'] = sig        d['lsf_m'] = str(rows(At))        d['lsf_n'] = str(cols(At))        d['lsf_p'] = str(rows(Gt))        s = ''        i = 0        rs = 0        for x in optvars:            if x in getoptvars(p):                # jem: need to also deal with substitute back situations.                s += '// Recover %s.\n' % str(x)                name = 'vars[%d]' % i;                s += '%s->m = %s;\n' % (name, str(rows(x)))                s += '%s->n = 1;\n' % name # for the moment (jem).                s += '%s->val = malloc((%s)*sizeof(double));\n' % (name, str(cols(x)))                s += 'for (row = 0; row < %s; row++)\n' % str(rows(x))                s += '    %s->val[row] = xt[%s + row];\n\n' % (name, str(rs))                i += 1            rs = compactdims(rs + rows(x))        d['varrecover'] = s        writecode(self.outdir, self.skel, 'tostdform.h', d)        writecode(self.outdir, self.skel, 'tostdform.c', d)    def data(self):        p = self.p        d = {}        d['header'] = header()        s = ''        for x in getparams(p):            s += self.declarematrix(x)        d['params'] = s        writecode(self.outdir, self.skel, 'data.h', d)    def feasdata(self, As, bs, n, p, LOWER):        d = {}        d['header'] = header()        print 'Solving feasibility problem to get feasible starting point.'        As = value(As)        bs = value(bs)        p = value(p)        x = optvar('x', n)        prob = problem(constr=[As*x == bs, x[:p] >= 0, x[p:] >= -LOWER + 1])        prob.solve()        x = value(x)        s = ''        for i in range(rows(x)):            s += 'CM_x[%d] = % .8g;\n' % (i, x[i])        d['CM_x'] = s        writecode(self.outdir, self.skel, 'feas.h', d)        return x    def testfa(self, n):        p = self.p        d = {}        d['header'] = header()        d['n'] = n        d['params'] = joinlist(bylowerstr(getparams(p)), True)        d['optvars'] = joinlist(bylowerstr(getoptvars(p)), True)        s = joinlist(bylowerstr(getparams(p)), True)        d['fa_args'] = s        writecode(self.outdir, self.skel, 'testfa.c', d)    def testba(self, n):        p = self.p        dt = {}        dt['header'] = header()        dt['n'] = n        dt['params'] = joinlist(bylowerstr(getparams(p)), True)        dt['optvars'] = joinlist(bylowerstr(getoptvars(p)), True)        print 'Creating method signature.'        if getparams(self.p):            dt['fb_sig_params'] = ', '.join(['double *%s' % x for x in \                                     bylowerstr(getparams(self.p))])            dt['fb_sig_params'] += ','        else:            dt['fb_sig_params'] = ''        dt['fb_args'] = joinlist(bylowerstr(getparams(p)), True)        writecode(self.outdir, self.skel, 'testba.c', dt)    def test2p(self, As, bs, cs, m, n, p):        dt = {}        dt['header'] = header()        dt['m'] = value(m)        dt['n'] = value(n)        dt['p'] = value(p)        dt['params'] = joinlist(bylowerstr(getparams(self.p)), True)        dt['optvars'] = joinlist(bylowerstr(getoptvars(self.p)), True)        print 'Generating code for two phase fast barrier method:'        print 'Creating method signature.'        if getparams(self.p):            dt['fb_sig_params'] = ', '.join(['double *%s' % x for x in \                                     bylowerstr(getparams(self.p))])            dt['fb_sig_params'] += ','        else:            dt['fb_sig_params'] = ''        A = nzentries(As)        b = nzentries(bs)        print 'Generating code for initial feasible point.'        s = ''        # CVXMOD_WL_t1.        t1 = optvar('CM_t1', m, 1)        for i in range(m):            t = -b[i]            for j in range(n):                if (i,j) in A:                    t = t + A[i,j]            s += 'CM_t1[%d] = %s;\n' % (i, exprtoC(t))        dt['CM_t1'] = s        dt['fb_args'] = joinlist(bylowerstr(getparams(self.p)), True)        writecode(self.outdir, self.skel, 'test2p.c', dt)    def fastacent(self, As, cs, m, n, p):        dt = {}        dt['header'] = header()        dt['m'] = value(m)        dt['n'] = value(n)        dt['p'] = value(p)        print 'Generating code for fast analytic centering method:'        print 'Framing objective.'        cs = value(cs)        s = ''        for i in range(n):            s += 'CM_c[%d] = %s;\n' % (i, exprtoC(cs[i]))        dt['CM_c'] = s        print 'Creating method signature.'        if getparams(self.p):            dt['params'] = ', '.join(['double *%s' % x for x in \                                      bylowerstr(getparams(self.p))])            dt['params'] += ','        else:            dt['params'] = ''        print 'Inspecting problem structure.'        A = nzentries(As)        AT = nzentries(tp(As))        # Create a gh optvar simply to make it easy to convert the multiplication.        gh = optvar('CM_gh', n, 1)                # Now multiply and solve.

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?