📄 base.py
字号:
return addup(other, self) def __sub__(self, other): return addup(self, -other) def __rsub__(self, other): return addup(other, -self) def __neg__(self): return -1*self def __pos__(self): return selfclass addfunction(addormultfunction): def __init__(self, lhs, rhs): # First, extract the sizes. lhsrows = rows(lhs) lhscols = cols(lhs) rhsrows = rows(rhs) rhscols = cols(rhs) if is1x1(lhs) or is1x1(rhs) or \ equivdims(lhsrows, rhsrows) and equivdims(lhscols, rhscols): # Valid addition, proceed. if lhsrows is 1: self.rows = rhsrows else: self.rows = lhsrows if lhscols is 1: self.cols = rhscols else: self.cols = lhscols self.lhs = lhs self.rhs = rhs self.brackets = True else: raise DimError def __str__(self): if str(self.rhs)[0] == '-': return "%s - %s" %(str(self.lhs), str(self.rhs)[1:]) else: return "%s + %s" %(str(self.lhs), str(self.rhs)) def _getpsd(self): return ispsd(self.lhs) and ispsd(self.rhs) psd = property(_getpsd) def _getnsd(self): return isnsd(self.lhs) and isnsd(self.rhs) nsd = property(_getnsd) def _getvalue(self): return value(self.lhs) + value(self.rhs) value = property(_getvalue) def getvarmult(self, var): if var is None and not getoptvars(self): return self else: lhs = getvarmult(self.lhs, var) rhs = getvarmult(self.rhs, var) s = lhs + rhs if is1x1(s): # Test to see if lhs and rhs cancel, without getting errors if # the equality test is forbidden. if iszero(s): return 0 else: return s*ones(size(self), symb=True) else: return s def getdimmult(self, d): if d is None and not getdims(self): return self else: lhs = getdimmult(self.lhs, d) rhs = getdimmult(self.rhs, d) s = lhs + rhs return getdimmult(self.lhs, d) + getdimmult(self.rhs, d) def getvecmult(self, var): if var is None and not getoptvars(self): return vec(self) else: lhs = getvecmult(self.lhs, var) rhs = getvecmult(self.rhs, var) s = lhs + rhs if var is None and is1x1(s): # Test to see if lhs and rhs cancel, without getting errors if # the equality test is forbidden. if iszero(s): return 0 else: return s*ones(rows(self)*cols(self), 1, symb=True) else: return s def epiorhypo(self): if isaffine(self): return (self, stdstruct()) else: if isaffine(self.lhs): lhs = self.lhs sl = stdstruct() else: (lhs, sl) = self.lhs.epiorhypo() if isaffine(self.rhs): rhs = self.rhs sr = stdstruct() else: (rhs, sr) = self.rhs.epiorhypo() return (lhs + rhs, sl + sr) def etranspose(self): return etranspose(self.lhs) + etranspose(self.rhs) def expandtr(self): return expandtr(self.lhs) + expandtr(self.rhs) def getquadmult(self, var): return getquadmult(self.lhs, var) + getquadmult(self.rhs, var) def vectorize(self): # jem: need to expand this to include scalars? n = value(cols(self)) return [x + y for (x, y) in zip(vectorize(self.lhs, n), vectorize(self.rhs, n))] def cvx(self): r = cvx(self.rhs) if r[0] == '-': return cvx(self.lhs) + ' - ' + cvx(-self.rhs) else: return cvx(self.lhs) + ' + ' + cvx(self.rhs) def _getsymm(self): return issymm(self.lhs) and issymm(self.rhs) symm = property(_getsymm) def nnz(self): if is1x1(self.lhs) or is1x1(self.rhs): # A scalar on either side will (we assume) cause complete fill-in. return rows(self)*cols(self) else: return nnz(self.lhs) + self.rhsclass multfunction(addormultfunction): def __init__(self, lhs, rhs): # First, if we have optvars on both sides we have a problem because # we would be violating the product rule. However, things like 3*(x+y) # are ok, so check for multiple optvars. #if getoptvars(lhs) and getoptvars(rhs): jem ignoring this for now. # raise ProductError # First, check that we have appropriate sizes. rhrows = rows(rhs) lhcols = cols(lhs) if equivdims(lhcols, rhrows) or is1x1(lhs) or is1x1(rhs): # Valid multiplication, proceed. # Set the correct resulting size. if is1x1(lhs): self.rows = rhrows self.cols = cols(rhs) elif is1x1(rhs): self.rows = rows(lhs) self.cols = lhcols else: self.cols = cols(rhs) self.rows = rows(lhs) # Goal of next parts: make lhs the only scalar, if it exists. # If the rhs is a scalar, reverse lhs and rhs. if is1x1(rhs) and not is1x1(lhs): lhs, rhs = rhs, lhs # Put a strict scalar (i.e. float / int) on the lhs, regardless. if isinstance(rhs, (int, float)) and not isinstance(lhs, (int, float)): lhs, rhs = rhs, lhs # Next line is to avoid problems with ((3*A)*x)*4. if hasattr(lhs, 'lhs') and isinstance(lhs, multfunction): if is1x1(lhs.lhs): lhs, rhs = lhs.lhs, lhs.rhs*rhs # Collect all scalars to the lhs. if hasattr(rhs, 'lhs') and isinstance(rhs, multfunction) and \ is1x1(rhs.lhs): if is1x1(lhs): lhs, rhs = lhs*rhs.lhs, rhs.rhs elif hasattr(lhs, 'lhs') and is1x1(lhs.lhs): lhs, rhs = lhs.lhs * rhs.lhs, lhs.rhs * rhs.rhs else: if is1x1(lhs): lhs, rhs = lhs*rhs.lhs, rhs.rhs self.lhs = lhs self.rhs = rhs self.brackets = False else: raise DimError def epiorhypo(self): if isaffine(self): return (self, stdstruct()) else: if isaffine(self.lhs): lhs = self.lhs sl = stdstruct() else: (lhs, sl) = self.lhs.epiorhypo() if isaffine(self.rhs): rhs = self.rhs sr = stdstruct() else: (rhs, sr) = self.rhs.epiorhypo() return (lhs*rhs, sl + sr) def __str__(self): if self.lhs is -1 or self.lhs is -1.0: return '-' + withbrackets(self.rhs) else: return withbrackets(self.lhs) + '*' + withbrackets(self.rhs) def _getconvex(self): if getoptvars(self.lhs) and getoptvars(self.rhs): return False # Otherwise, figure out which side has the variables. if getoptvars(self.rhs): constbit = self.lhs varbit = self.rhs elif getoptvars(self.lhs): constbit = self.rhs varbit = self.lhs else: return True # Look at the rules. if isaffine(varbit): return True elif ispos(constbit) and isconvex(varbit): return True elif isneg(constbit) and isconcave(varbit): return True return False convex = property(_getconvex) def _getconcave(self): if getoptvars(self.lhs) and getoptvars(self.rhs): return False # Otherwise, figure out which side has the variables. if getoptvars(self.rhs): constbit = self.lhs varbit = self.rhs elif getoptvars(self.lhs): constbit = self.rhs varbit = self.lhs else: return True # Look at the rules. if isaffine(varbit): return True elif ispos(constbit) and isconcave(varbit): return True elif isneg(constbit) and isconvex(varbit): return True return False concave = property(_getconcave) def _getincreasing(self): return (ispos(self.lhs) and isincreasing(self.rhs)) or \ (isneg(self.lhs) and isdecreasing(self.rhs)) increasing = property(_getincreasing) def _getdecreasing(self): return (ispos(self.lhs) and isdecreasing(self.rhs)) or \ (isneg(self.lhs) and isincreasing(self.rhs)) decreasing = property(_getdecreasing) def _getpos(self): return (ispos(self.lhs) and ispos(self.rhs)) or \ (isneg(self.lhs) and isneg(self.rhs)) pos = property(_getpos) def _getneg(self): return (ispos(self.lhs) and isneg(self.rhs)) or \ (isneg(self.lhs) and ispos(self.rhs)) neg = property(_getneg) def _getpsd(self): return (ispsd(self.lhs) and ispsd(self.rhs)) or \ (isnsd(self.lhs) and isnsd(self.rhs)) or \ (transpose(self.lhs) is self.rhs) or \ (self.lhs is transpose(self.rhs)) psd = property(_getpsd) def _getnsd(self): return (ispsd(self.lhs) and isnsd(self.rhs)) or \ (isnsd(self.lhs) and ispsd(self.rhs)) nsd = property(_getnsd) def _getvalue(self): return value(self.lhs) * value(self.rhs) value = property(_getvalue) def getvarmult(self, var): if var is None: # constant expression portion. if getoptvars(self.lhs): lhs = getvarmult(self.lhs, None) else: lhs = self.lhs if getoptvars(self.rhs): rhs = getvarmult(self.rhs, None) else: rhs = self.rhs if iszero(lhs) or iszero(rhs): return 0 # Adjust for situations like u[0]*b, where u[0] should actually be # on the rhs. if is1x1(self.lhs): return rhs * lhs else: return lhs * rhs else: if var in getoptvars(self.lhs): lhs = getvarmult(self.lhs, var) else: lhs = self.lhs if var in getoptvars(self.rhs): rhs = getvarmult(self.rhs, var) else: rhs = self.rhs if iszero(lhs) or iszero(rhs): return 0 else: # Adjust for situations like u[0]*b, where u[0] should actually be # on the rhs. if is1x1(self.lhs): return rhs * lhs else: return lhs * rhs def getdimmult(self, d): if d is None: # constant expression portion. if getdims(self.lhs): lhs = getdimmult(self.lhs, None) else: lhs = self.lhs if getdims(self.rhs): rhs = getdimmult(self.rhs, None) else: rhs = self.rhs if iszero(lhs) or iszero(rhs): return 0 return lhs * rhs else: if d in getdims(self.lhs): lhs = getdimmult(self.lhs, d) else: lhs = self.lhs if d in getdims(self.rhs): rhs = getdimmult(self.rhs, d) else: rhs = self.rhs if iszero(lhs) or iszero(rhs): return 0 else: return lhs * rhs def getvecmult(self, var): if var is None: # constant expression portion. lhs = getvecmult(self.lhs, None) rhs = getvecmult(self.rhs, None) if iszero(lhs) or iszero(rhs): return 0 return lhs * rhs else: if var in getoptvars(self.lhs): lhs = getvecmult(self.lhs, var) else: if is1x1(self.lhs): lhs = self.lhs else: lhs = lhexp(self.lhs, size(self.rhs)) if var in getoptvars(self.rhs): rhs = getvecmult(self.rhs, var) else: rhs = rhexp(self.rhs, size(self.lhs)) if iszero(lhs) or iszero(rhs): return 0 else: # Keep trobjs at the rhs by swapping the sides. # Also swap the sides if the variable appears on the left. if isinstance(lhs, trobj) or var in getoptvars(self.lhs): return rhs * lhs else: return lhs * rhs def etranspose(self): return etranspose(self.rhs)*etranspose(self.lhs) def expandtr(self):
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -