📄 compiler.py
字号:
def visit_unary(self, unary, **kwargs): s = self.process(unary.element) if unary.operator: s = self.operator_string(unary.operator) + " " + s if unary.modifier: s = s + " " + self.operator_string(unary.modifier) return s def visit_binary(self, binary, **kwargs): op = self.operator_string(binary.operator) if callable(op): return op(self.process(binary.left), self.process(binary.right)) else: return self.process(binary.left) + " " + op + " " + self.process(binary.right) def operator_string(self, operator): return self.operators.get(operator, str(operator)) def visit_bindparam(self, bindparam, **kwargs): name = self._truncate_bindparam(bindparam) if name in self.binds: existing = self.binds[name] if existing is not bindparam and (existing.unique or bindparam.unique): raise exceptions.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key) self.binds[bindparam.key] = self.binds[name] = bindparam return self.bindparam_string(name) def _truncate_bindparam(self, bindparam): if bindparam in self.bind_names: return self.bind_names[bindparam] bind_name = bindparam.key bind_name = self._truncated_identifier("bindparam", bind_name) # add to bind_names for translation self.bind_names[bindparam] = bind_name return bind_name def _truncated_identifier(self, ident_class, name): if (ident_class, name) in self.generated_ids: return self.generated_ids[(ident_class, name)] anonname = ANONYMOUS_LABEL.sub(self._process_anon, name) if len(anonname) > self.dialect.max_identifier_length: counter = self.generated_ids.get(ident_class, 1) truncname = anonname[0:self.dialect.max_identifier_length - 6] + "_" + hex(counter)[2:] self.generated_ids[ident_class] = counter + 1 else: truncname = anonname self.generated_ids[(ident_class, name)] = truncname return truncname def _process_anon(self, match): (ident, derived) = match.group(1,2) key = ('anonymous', ident) if key in self.generated_ids: return self.generated_ids[key] else: anonymous_counter = self.generated_ids.get(('anon_counter', derived), 1) newname = derived + "_" + str(anonymous_counter) self.generated_ids[('anon_counter', derived)] = anonymous_counter + 1 self.generated_ids[key] = newname return newname def _anonymize(self, name): return ANONYMOUS_LABEL.sub(self._process_anon, name) def bindparam_string(self, name): if self.positional: self.positiontup.append(name) return self.bindtemplate % {'name':name, 'position':len(self.positiontup)} def visit_alias(self, alias, asfrom=False, **kwargs): if asfrom: return self.process(alias.original, asfrom=True, **kwargs) + " AS " + self.preparer.format_alias(alias, self._anonymize(alias.name)) else: return self.process(alias.original, **kwargs) def label_select_column(self, select, column, asfrom): """label columns present in a select().""" if isinstance(column, sql._Label): return column if select.use_labels and getattr(column, '_label', None): return column.label(column._label) if \ asfrom and \ isinstance(column, sql._ColumnClause) and \ not column.is_literal and \ column.table is not None and \ not isinstance(column.table, sql.Select): return column.label(column.name) elif not isinstance(column, (sql._UnaryExpression, sql._TextClause)) and (not hasattr(column, 'name') or isinstance(column, sql._Function)): return column.anon_label else: return column def visit_select(self, select, asfrom=False, parens=True, iswrapper=False, **kwargs): stack_entry = {'select':select} prev_entry = self.stack and self.stack[-1] or None if asfrom or (prev_entry and 'select' in prev_entry): stack_entry['is_subquery'] = True if prev_entry and 'iswrapper' in prev_entry: column_clause_args = {'result_map':self.result_map} else: column_clause_args = {} elif iswrapper: column_clause_args = {} stack_entry['iswrapper'] = True else: column_clause_args = {'result_map':self.result_map} if self.stack and 'from' in self.stack[-1]: existingfroms = self.stack[-1]['from'] else: existingfroms = None froms = select._get_display_froms(existingfroms) correlate_froms = util.Set() for f in froms: correlate_froms.add(f) correlate_froms.update(f._get_from_objects()) # TODO: might want to propigate existing froms for select(select(select)) # where innermost select should correlate to outermost# if existingfroms:# correlate_froms = correlate_froms.union(existingfroms) stack_entry['from'] = correlate_froms self.stack.append(stack_entry) # the actual list of columns to print in the SELECT column list. inner_columns = util.OrderedSet() for co in select.inner_columns: l = self.label_select_column(select, co, asfrom=asfrom) inner_columns.add(self.process(l, **column_clause_args)) collist = string.join(inner_columns.difference(util.Set([None])), ', ') text = " ".join(["SELECT"] + [self.process(x) for x in select._prefixes]) + " " text += self.get_select_precolumns(select) text += collist whereclause = select._whereclause from_strings = [] for f in froms: from_strings.append(self.process(f, asfrom=True)) w = self.get_whereclause(f) if w is not None: if whereclause is not None: whereclause = sql.and_(w, whereclause) else: whereclause = w if froms: text += " \nFROM " text += string.join(from_strings, ', ') else: text += self.default_from() if whereclause is not None: t = self.process(whereclause) if t: text += " \nWHERE " + t group_by = self.process(select._group_by_clause) if group_by: text += " GROUP BY " + group_by if select._having is not None: t = self.process(select._having) if t: text += " \nHAVING " + t text += self.order_by_clause(select) text += (select._limit or select._offset) and self.limit_clause(select) or "" text += self.for_update_clause(select) self.stack.pop(-1) if asfrom and parens: return "(" + text + ")" else: return text def get_select_precolumns(self, select): """Called when building a ``SELECT`` statement, position is just before column list.""" return select._distinct and "DISTINCT " or "" def order_by_clause(self, select): order_by = self.process(select._order_by_clause) if order_by: return " ORDER BY " + order_by else: return "" def for_update_clause(self, select): if select.for_update: return " FOR UPDATE" else: return "" def limit_clause(self, select): text = "" if select._limit is not None: text += " \n LIMIT " + str(select._limit) if select._offset is not None: if select._limit is None: text += " \n LIMIT -1" text += " OFFSET " + str(select._offset) return text def visit_table(self, table, asfrom=False, **kwargs): if asfrom: if getattr(table, "schema", None): return self.preparer.quote(table, table.schema) + "." + self.preparer.quote(table, table.name) else: return self.preparer.quote(table, table.name) else: return "" def visit_join(self, join, asfrom=False, **kwargs): return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \ self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause)) def visit_sequence(self, seq): return None def visit_insert(self, insert_stmt): self.isinsert = True colparams = self._get_colparams(insert_stmt) preparer = self.preparer return ("INSERT INTO %s (%s) VALUES (%s)" % (preparer.format_table(insert_stmt.table), ', '.join([preparer.quote(c[0], c[0].name) for c in colparams]), ', '.join([c[1] for c in colparams]))) def visit_update(self, update_stmt): self.stack.append({'from':util.Set([update_stmt.table])}) self.isupdate = True colparams = self._get_colparams(update_stmt) text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.quote(c[0], c[0].name), c[1]) for c in colparams], ', ') if update_stmt._whereclause: text += " WHERE " + self.process(update_stmt._whereclause) self.stack.pop(-1) return text def _get_colparams(self, stmt): """create a set of tuples representing column/string pairs for use in an INSERT or UPDATE statement. """ def create_bind_param(col, value): bindparam = sql.bindparam(col.key, value, type_=col.type) self.binds[col.key] = bindparam return self.bindparam_string(self._truncate_bindparam(bindparam)) self.postfetch = [] self.prefetch = [] # no parameters in the statement, no parameters in the # compiled params - return binds for all columns if self.column_keys is None and stmt.parameters is None: return [(c, create_bind_param(c, None)) for c in stmt.table.columns] # if we have statement parameters - set defaults in the # compiled params if self.column_keys is None: parameters = {} else: parameters = dict([(getattr(key, 'key', key), None) for key in self.column_keys]) if stmt.parameters is not None: for k, v in stmt.parameters.iteritems(): parameters.setdefault(getattr(k, 'key', k), v) # create a list of column assignment clauses as tuples values = [] for c in stmt.table.columns: if c.key in parameters: value = parameters[c.key] if sql._is_literal(value): value = create_bind_param(c, value) else: self.postfetch.append(c) value = self.process(value.self_group()) values.append((c, value)) elif isinstance(c, schema.Column): if self.isinsert: if (c.primary_key and self.dialect.preexecute_pk_sequences and not self.inline): if (((isinstance(c.default, schema.Sequence) and not c.default.optional) or not self.dialect.supports_pk_autoincrement) or (c.default is not None and not isinstance(c.default, schema.Sequence))): values.append((c, create_bind_param(c, None))) self.prefetch.append(c) elif isinstance(c.default, schema.ColumnDefault): if isinstance(c.default.arg, sql.ClauseElement): values.append((c, self.process(c.default.arg.self_group()))) if not c.primary_key: # dont add primary key column to postfetch self.postfetch.append(c) else: values.append((c, create_bind_param(c, None))) self.prefetch.append(c) elif isinstance(c.default, schema.PassiveDefault): if not c.primary_key: self.postfetch.append(c) elif isinstance(c.default, schema.Sequence): proc = self.process(c.default) if proc is not None: values.append((c, proc)) if not c.primary_key: self.postfetch.append(c) elif self.isupdate: if isinstance(c.onupdate, schema.ColumnDefault): if isinstance(c.onupdate.arg, sql.ClauseElement): values.append((c, self.process(c.onupdate.arg.self_group()))) self.postfetch.append(c) else: values.append((c, create_bind_param(c, None))) self.prefetch.append(c) elif isinstance(c.onupdate, schema.PassiveDefault): self.postfetch.append(c) return values def visit_delete(self, delete_stmt): self.stack.append({'from':util.Set([delete_stmt.table])}) self.isdelete = True text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table) if delete_stmt._whereclause: text += " WHERE " + self.process(delete_stmt._whereclause) self.stack.pop(-1) return text def visit_savepoint(self, savepoint_stmt): return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) def visit_rollback_to_savepoint(self, savepoint_stmt): return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) def visit_release_savepoint(self, savepoint_stmt): return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) def __str__(self): return self.string or ''class DDLBase(engine.SchemaIterator): def find_alterables(self, tables): alterables = [] class FindAlterables(schema.SchemaVisitor):
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -