codegen.py
来自「CVXMOD is a Python-based tool for expres」· Python 代码 · 共 986 行 · 第 1/2 页
PY
986 行
"""Code generation for CVXMOD."""# Copyright (C) 2006-2008 Jacob Mattingley and Stephen Boyd.## This file is part of CVXMOD.## CVXMOD is free software; you can redistribute it and/or modify it under the# terms of the GNU General Public License as published by the Free Software# Foundation; either version 3 of the License, or (at your option) any later# version.## CVXMOD is distributed in the hope that it will be useful, but WITHOUT ANY# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR# A PARTICULAR PURPOSE. See the GNU General Public License for more details.## You should have received a copy of the GNU General Public License along with# this program. If not, see <http://www.gnu.org/licenses/>.from cvxmod.base import *from cvxmod.base import withbrackets, addfunction, multfunction, Jwarn,\ addormultfunction, symbslicefrom cvxmod.symbolic import *from cvxmod.symbolic import concatvertsymb, concathorizsymb, constsymb, eyesymbolfrom cvxmod.util import bylowerstr, joinlistfrom cvxopt.amd import orderfrom time import strftimeimport osimport reclass nze(symbol): # Nonzero elements, eg for use in a matrix. Invariant under multiplication # and division, except for by zero. def __init__(self, name=None): self.rows = 1 self.cols = 1 self.name = name def __repr__(self): return str(self) def __str__(self): if self.name is None: return '<nze>' else: return '<nze %s>' % self.name def __div__(self, other): if iszero(other): raise ZeroDivisionError else: return self def __rdiv__(self, other): if iszero(other): return 0 else: return self def __mul__(self, other): if iszero(other): return 0 else: return self def __rmul__(self, other): if iszero(other): return 0 else: return self def __add__(self, other): return self def __radd__(self, other): return self def __sub__(self, other): return self def __rsub__(self, other): return selfclass divfunction(addormultfunction): # first cut. no size checking or anything. def __init__(self, lhs, rhs): self.lhs = lhs self.rhs = rhs self.rows = 1 self.cols = 1 self.brackets = True def __str__(self): return withbrackets(self.lhs) + '/' + withbrackets(self.rhs)def divide(lhs, rhs): if iszero(lhs) or iszero(rhs): return 0 else: return divfunction(lhs, rhs)class sparsedict(dict): def __init__(self): dict.__init__(self) def __getitem__(self, key): try: return dict.__getitem__(self, key) except KeyError: return 0def addsplit(obj): # Argument order is (pre-operations, actual, nonindexed). if isoptvar(obj) or isparam(obj) or isinstance(obj, (int, float)): if is1x1(obj): return ([], [], obj) else: return ([], [obj], 0) elif isinstance(obj, addfunction): return [x + y for (x, y) in zip(addsplit(obj.lhs), addsplit(obj.rhs))] elif isinstance(obj, multfunction): if is1x1(obj.lhs): (pre, actual, nonind) = addsplit(obj.rhs) pre = [obj.lhs*x for x in pre] actual = [obj.lhs*x for x in actual] nonind = obj.lhs*nonind return (pre, actual, nonind) else: raise NotImplementedError('could not handle ' + str(obj)) elif isinstance(obj, constsymb): return ([], [], obj.constval) else: raise NotImplementedError('could not handle ' + str(obj))def stuffmatrix(obj, name): s = '' s += '\n// Begin: definition of %s.\n' % name for l in strsymbmatrix(obj, name).splitlines(): s += '// %s\n' % l s += '%s = malloc((%s)*sizeof(double));\n\n' % \ (name, str(rows(obj) * cols(obj))) s += stuffentry(obj, name, rows(obj), 0, 0, 0) s += '// End: definition of %s.\n' % name return sdef stuffentry(obj, name, rs, csleft, rsabove, rsbelow, smult=1): def stuff1x1(val): return '%s[%s] = %s;\n' % (name, str(topleft), str(val)) def nwsm(n): # Name with scalar multiplier. # Takes smult implicitly. return '%s' % str(smult*n) # here (or elsewhere) keep track of existing multiplied stuff. for example, # if we have tr(c)*x, don't want to have to work it out every time. # Everything should be able to accept a scalar multiple for one of its # arguments. s = '' topleft = rs*csleft + rsabove if isinstance(obj, (int, float, matrix, spmatrix)) and is1x1(obj): s += stuff1x1(str(obj)) elif isinstance(obj, concatvertsymb): s += stuffentry(obj.lhs, name, rs, csleft, rsabove, rsbelow, smult) s += stuffentry(obj.rhs, name, rs, csleft, rsabove + rows(obj.lhs), rsbelow - rows(obj.lhs), smult) elif isinstance(obj, concathorizsymb): s += stuffentry(obj.lhs, name, rs, csleft, rsabove, rsbelow, smult) s += stuffentry(obj.rhs, name, rs, csleft + cols(obj.lhs), rsabove, rsbelow, smult) elif isinstance(obj, constsymb): s += 'for (col = 0; col < %s; col++)\n' % str(cols(obj)) s += ' '*4 + 'for (row = 0; row < %s; row++)\n' % str(rows(obj)) s += ' '*8 + '%s[%s + col*(%s) + row] = %s;\n\n' % \ (str(name), str(topleft), str(rs), nwsm(obj.constval)) elif isinstance(obj, eyesymbol): s += 'for (col = 0; col < %s; col++)\n' % str(cols(obj)) s += ' '*4 + 'for (row = 0; row < %s; row++)\n' % str(rows(obj)) s += ' '*8 + '%s[%s + col*(%s) + row] =\n (row == col) ? %s : 0;\n\n' % \ (str(name), str(topleft), str(rs), str(smult)) elif isparam(obj): s += 'for (col = 0; col < %s; col++)\n' % str(cols(obj)) s += ' '*4 + 'for (row = 0; row < %s; row++)\n' % str(rows(obj)) s += ' '*8 + '%s[%s + col*(%s) + row] =\n %s[col*(%s) + row];\n\n' % \ (str(name), str(topleft), str(rs), nwsm(obj), str(rows(obj))) elif isinstance(obj, multfunction): # nastiness of non-recursion here! if is1x1(obj.lhs): s += stuffentry(obj.rhs, name, rs, csleft, rsabove, rsbelow, obj.lhs) else: raise NotImplementedError('could not handle ' + str(obj)) elif isinstance(obj, addfunction): # pre: temporary variables that need to be calculated first. # (need to come up with an indexing scheme so that A*x, for example, # can be reused. Probably assume # actual: pieces of expression that should be indexed. # noniter: pieces of expression that should not be indexed. (pre, actual, nonind) = addsplit(obj) if pre: s += pre + '\n\n' rhs = "\n" + " "*11 + " + " .join( [str(smult*x) + ("[col*(%s) + row]" % str(cols(x))) for x in actual]) if nonind: rhs += "\n" + " "*10 + " + " + str(smult*nonind) s += 'for (col = 0; col < %s; col++)\n' % str(cols(obj)) s += ' '*4 + 'for (row = 0; row < %s; row++)\n' % str(rows(obj)) s += ' '*8 + '%s[%s + col*(%s) + row] = %s;\n' % \ (str(name), str(topleft), str(rs), rhs) else: raise NotImplementedError('could not handle ' + str(obj)) return sdef exprtoC(obj): if isinstance(obj, symbslice): (a, b) = obj.splitsl() if b is None: return str(obj.arg) + '[%d]' % a else: index = a + b*value(rows(obj.arg)) return str(obj.arg) + '[%d]' % index elif isinstance(obj, (int, float)): return str(obj) elif isinstance(obj, multfunction): if obj.lhs is -1: return '-' + exprtoC(obj.rhs) else: return exprtoC(obj.lhs) + '*' + exprtoC(obj.rhs) elif isinstance(obj, addfunction): return exprtoC(obj.lhs) + ' + ' + exprtoC(obj.rhs) else: raise NotImplementedError('cannot handle %s' % str(obj))class codegen(object): def __init__(self, p, outdir, skel): self.p = p self.outdir = outdir self.skel = skel def printmatrix(self, x): m = rows(x) n = cols(x) s = 'printf("%s =\\n");\n' % str(x) s += "printmatrix(%s, %s, %s);\n" % (str(x), str(m), str(n)) self.write(s) def declarematrix(self, A): s = "double %s[] = {\n " % str(A) offset = 1 v = matrix(value(A)) k = 0 for j in range(value(cols(A))): for i in range(value(rows(A))): t = "% .7g, " % v[i,j] # jem: enhance this later to be nicer looking. s += t.rjust(12) if (k % 7) == 6: s += "\n" s += ' ' * offset k += 1 #s += "\n" #s += ' ' * offset s = s[:-(offset+3)] + "};\n\n" return s def readme(self): p = self.p d = {} d['header'] = header() d['probdef'] = str(p) d['stdform'] = p.strsymbsolve() writecode(self.outdir, self.skel, 'README', d) def tostdform(self): p = self.p ((ct, At, bt, Gt, ht), xt, d, optvars) = p.symbsolve('cAbGh') d = {} d['header'] = header() # Now is when we determine the signature. sig = 'lpstdform tostdform(\n' + ' '*8 if getdims(p): sig += ', '.join(['int %s' % x for x in bylowerstr(getdims(p))]) if getparams(p): sig += ',\n' + ' '*8 sig += ', '.join(['double *%s' % x for x in bylowerstr(getparams(p))]) sig += ')' d['tostdform_sig'] = sig d['lsf_c'] = stuffmatrix(ct, 'lsf->c') d['lsf_A'] = stuffmatrix(At, 'lsf->A') d['lsf_b'] = stuffmatrix(bt, 'lsf->b') d['lsf_G'] = stuffmatrix(Gt, 'lsf->G') d['lsf_h'] = stuffmatrix(ht, 'lsf->h') sig = 'int convertback(outvar *vars,\n' + ' '*8 if getdims(p): sig += ', '.join(['int %s' % x for x in bylowerstr(getdims(p))]) sig += ',\n double *xt)' d['convertback_sig'] = sig d['lsf_m'] = str(rows(At)) d['lsf_n'] = str(cols(At)) d['lsf_p'] = str(rows(Gt)) s = '' i = 0 rs = 0 for x in optvars: if x in getoptvars(p): # jem: need to also deal with substitute back situations. s += '// Recover %s.\n' % str(x) name = 'vars[%d]' % i; s += '%s->m = %s;\n' % (name, str(rows(x))) s += '%s->n = 1;\n' % name # for the moment (jem). s += '%s->val = malloc((%s)*sizeof(double));\n' % (name, str(cols(x))) s += 'for (row = 0; row < %s; row++)\n' % str(rows(x)) s += ' %s->val[row] = xt[%s + row];\n\n' % (name, str(rs)) i += 1 rs = compactdims(rs + rows(x)) d['varrecover'] = s writecode(self.outdir, self.skel, 'tostdform.h', d) writecode(self.outdir, self.skel, 'tostdform.c', d) def data(self): p = self.p d = {} d['header'] = header() s = '' for x in getparams(p): s += self.declarematrix(x) d['params'] = s writecode(self.outdir, self.skel, 'data.h', d) def feasdata(self, As, bs, n, p, LOWER): d = {} d['header'] = header() print 'Solving feasibility problem to get feasible starting point.' As = value(As) bs = value(bs) p = value(p) x = optvar('x', n) prob = problem(constr=[As*x == bs, x[:p] >= 0, x[p:] >= -LOWER + 1]) prob.solve() x = value(x) s = '' for i in range(rows(x)): s += 'CM_x[%d] = % .8g;\n' % (i, x[i]) d['CM_x'] = s writecode(self.outdir, self.skel, 'feas.h', d) return x def testfa(self, n): p = self.p d = {} d['header'] = header() d['n'] = n d['params'] = joinlist(bylowerstr(getparams(p)), True) d['optvars'] = joinlist(bylowerstr(getoptvars(p)), True) s = joinlist(bylowerstr(getparams(p)), True) d['fa_args'] = s writecode(self.outdir, self.skel, 'testfa.c', d) def testba(self, n): p = self.p dt = {} dt['header'] = header() dt['n'] = n dt['params'] = joinlist(bylowerstr(getparams(p)), True) dt['optvars'] = joinlist(bylowerstr(getoptvars(p)), True) print 'Creating method signature.' if getparams(self.p): dt['fb_sig_params'] = ', '.join(['double *%s' % x for x in \ bylowerstr(getparams(self.p))]) dt['fb_sig_params'] += ',' else: dt['fb_sig_params'] = '' dt['fb_args'] = joinlist(bylowerstr(getparams(p)), True) writecode(self.outdir, self.skel, 'testba.c', dt) def test2p(self, As, bs, cs, m, n, p): dt = {} dt['header'] = header() dt['m'] = value(m) dt['n'] = value(n) dt['p'] = value(p) dt['params'] = joinlist(bylowerstr(getparams(self.p)), True) dt['optvars'] = joinlist(bylowerstr(getoptvars(self.p)), True) print 'Generating code for two phase fast barrier method:' print 'Creating method signature.' if getparams(self.p): dt['fb_sig_params'] = ', '.join(['double *%s' % x for x in \ bylowerstr(getparams(self.p))]) dt['fb_sig_params'] += ',' else: dt['fb_sig_params'] = '' A = nzentries(As) b = nzentries(bs) print 'Generating code for initial feasible point.' s = '' # CVXMOD_WL_t1. t1 = optvar('CM_t1', m, 1) for i in range(m): t = -b[i] for j in range(n): if (i,j) in A: t = t + A[i,j] s += 'CM_t1[%d] = %s;\n' % (i, exprtoC(t)) dt['CM_t1'] = s dt['fb_args'] = joinlist(bylowerstr(getparams(self.p)), True) writecode(self.outdir, self.skel, 'test2p.c', dt) def fastacent(self, As, cs, m, n, p): dt = {} dt['header'] = header() dt['m'] = value(m) dt['n'] = value(n) dt['p'] = value(p) print 'Generating code for fast analytic centering method:' print 'Framing objective.' cs = value(cs) s = '' for i in range(n): s += 'CM_c[%d] = %s;\n' % (i, exprtoC(cs[i])) dt['CM_c'] = s print 'Creating method signature.' if getparams(self.p): dt['params'] = ', '.join(['double *%s' % x for x in \ bylowerstr(getparams(self.p))]) dt['params'] += ',' else: dt['params'] = '' print 'Inspecting problem structure.' A = nzentries(As) AT = nzentries(tp(As)) # Create a gh optvar simply to make it easy to convert the multiplication. gh = optvar('CM_gh', n, 1) # Now multiply and solve.
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?