Last active
December 28, 2016 15:18
-
-
Save dojiong/d3a0d5176a42bb6173f37ac1630308f5 to your computer and use it in GitHub Desktop.
generate model for sqlalchemy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import re | |
from collections import defaultdict | |
from sqlalchemy import create_engine | |
class Env(object): | |
def __init__(self, model='.db.Model', types='sqlalchemy.types', | |
column='sqlalchemy.Column', fix_onupdate=True): | |
self.model = model.rsplit('.', 1)[-1] | |
self.types = types.rsplit('.', 1)[-1] | |
self.column = column.rsplit('.', 1)[-1] | |
self.fix_onupdate = fix_onupdate | |
self.imports = set() | |
self.from_imports = defaultdict(lambda: set()) | |
self.add_import(model) | |
self.add_import(types) | |
self.add_import(column) | |
def add_import(self, path): | |
parts = path.rsplit('.', 1) | |
if len(parts) == 1: | |
self.imports.add(path) | |
else: | |
self.from_imports[parts[0]].add(parts[1]) | |
def get_imports(self): | |
lines = list(self.imports) | |
lines.extend('from %s import %s' % (k, ', '.join(sorted(v))) | |
for k, v in self.from_imports.items()) | |
return '\n'.join(sorted(lines)) | |
class ModelDesc(object): | |
def __init__(self, name, tablename, flask=False, fix_onupdate=True): | |
self.name = name | |
self.tablename = tablename | |
self.columns = [] | |
self.indexes = [] | |
if flask: | |
self.env = Env(model='.db.Model', types='.db.db', fix_onupdate=fix_onupdate) | |
else: | |
self.env = Env(fix_onupdate=fix_onupdate) | |
def add_column(self, column): | |
self.columns.append(column) | |
def add_index(self, name, columns, unique): | |
self.indexes.append((name, columns, unique)) | |
def get_table_args(self, env): | |
tableargs = '' | |
if self.indexes: | |
env.add_import('sqlalchemy.Index') | |
indexes = [] | |
for name, columns, unique in self.indexes: | |
if unique: | |
f = 'Index(%r, %s, unique=True)' | |
else: | |
f = 'Index(%r, %s)' | |
indexes.append(f % (name, ', '.join('%r' % x for x in columns))) | |
h = ' __table_args__ = (' | |
tableargs = (h + '%s,)\n') % (',\n' + ' ' * len(h)).join(indexes) | |
return tableargs | |
def dump(self): | |
tmpl = ('class {name}({model}):\n' | |
' __tablename__ = \'{tablename}\'\n{tableargs}\n') | |
head = tmpl.format(name=self.name, | |
model=self.env.model, | |
tablename=self.tablename, | |
tableargs=self.get_table_args(self.env)) | |
columns = '\n'.join([' ' + c.dump(self.env) for c in self.columns]) | |
imports = self.env.get_imports() | |
return imports + '\n\n\n' + head + columns | |
class ColumnDesc(object): | |
def __init__(self, name, type_desc, nullable, is_primary, auto_inc, | |
server_default, server_onupdate, comment): | |
self.name = name | |
self.type_desc = type_desc | |
self.nullable = nullable | |
self.is_primary = is_primary | |
self.auto_inc = auto_inc | |
self.server_default = server_default | |
self.server_onupdate = server_onupdate | |
self.comment = comment | |
self.type_desc = type_desc | |
def dump(self, env): | |
args = [ColumnType.parse(self.type_desc, env)] | |
if self.is_primary: | |
args.append('primary_key=True') | |
if self.auto_inc: | |
args.append('autoincrement=True') | |
if self.server_default is not None: | |
env.add_import('sqlalchemy.text') | |
dft = self.server_default | |
if env.fix_onupdate and self.server_onupdate: | |
dft += ' ON UPDATE ' + self.server_default | |
args.append('server_default=text(%r)' % dft) | |
if self.server_onupdate is not None: | |
env.add_import('sqlalchemy.text') | |
args.append('server_onupdate=text(%r)' % self.server_onupdate) | |
if not self.nullable and len(args) == 1: | |
args.append('nullable=False') | |
if self.comment is not None: | |
args.append('doc=%r' % self.comment) | |
return '%s = %s(%s)' % (self.name, env.column, ', '.join(args)) | |
class ColumnType(object): | |
type_r = re.compile(r'^([a-z]+)(\([^)]+\))?') | |
def __init__(self, name, model_type, arg_parse=None): | |
self.name = name | |
self.model_type = model_type | |
self.arg_parse = arg_parse | |
@classmethod | |
def parse(ColumnType, type_desc, env): | |
result = ColumnType.type_r.findall(type_desc) | |
if not result: | |
print('**Error: invalid type desc: `%s`' % type_desc) | |
raise SystemExit(-1) | |
typ, args = result[0] | |
col_type = type_table.get(typ) | |
if col_type is None: | |
print("**Error: unsupported column type: %s" % col_type) | |
raise SystemExit(-1) | |
ret = '%s.%s' % (env.types, col_type.model_type) | |
if args and col_type.arg_parse is not None: | |
ret += '(%s)' % ', '.join(col_type.arg_parse(args[1:-1].split(','))) | |
return ret | |
type_table = {t.name: t for t in [ | |
ColumnType('char', 'String', lambda args: ['length=%s' % args[0]]), | |
ColumnType('varchar', 'String', lambda args: ['length=%s' % args[0]]), | |
ColumnType('float', 'Float'), | |
ColumnType('double', 'Float'), | |
ColumnType('bigint', 'BigInteger'), | |
ColumnType('int', 'Integer'), | |
ColumnType('mediumint', 'Integer'), | |
ColumnType('smallint', 'SmallInteger'), | |
ColumnType('tinyint', 'SmallInteger'), | |
ColumnType('text', 'Text'), | |
ColumnType('blob', 'LargeBinary'), | |
ColumnType('timestamp', 'DateTime'), | |
ColumnType('datetime', 'DateTime'), | |
ColumnType('date', 'Date'), | |
ColumnType('time', 'Time'), | |
ColumnType('enum', 'String', | |
lambda args: ['length=%d' % max(len(x) for x in args) - 2]) | |
]} | |
if __name__ == '__main__': | |
import argparse | |
parser = argparse.ArgumentParser(description='make model from db') | |
parser.add_argument('-n', '--model-name', help='model name', default='') | |
parser.add_argument('-o', '--output', help='output file', default='') | |
parser.add_argument('-H', '--host', help='db host', default='localhost') | |
parser.add_argument('-u', '--user', help='db user', default='root') | |
parser.add_argument('-p', '--passwd', help='db passwd', default='') | |
parser.add_argument('-P', '--port', help='db port', default=3306, type=int) | |
parser.add_argument('-d', '--dialect', help='db dialect', default='mysql+pymysql') | |
parser.add_argument('--flask', help='use flask-sqlalchemy stype', | |
default=False, action='store_true') | |
parser.add_argument('--fix-onupdate', help='fix sqlalchemy server_onupdate', | |
default=False, action='store_true') | |
parser.add_argument('db', help='database name') | |
parser.add_argument('table', help='table name') | |
args = parser.parse_args() | |
model = ModelDesc(args.model_name or args.table.title().replace('_', ''), | |
args.table, args.flask, args.fix_onupdate) | |
engine = create_engine('{dialect}://{user}:{pwd}@{host}:{port}/{db}'.format( | |
dialect=args.dialect, user=args.user, pwd=args.passwd, | |
host=args.host, port=args.port, db=args.db)) | |
engine.execute('set names utf8') | |
query = ('select column_name,column_type,is_nullable,column_key,' | |
'column_default,extra,column_comment ' | |
'from information_schema.columns ' | |
'where table_schema="%s" and table_name="%s"') % (args.db, args.table) | |
for (name, type, nullable, key, default, extra, comment) in engine.execute(query): | |
server_default = default if default != 'NULL' else None | |
server_onupdate = None | |
if extra.startswith('on update'): | |
server_onupdate = extra[len('on update '):] | |
column = ColumnDesc(name, type, | |
nullable=nullable == 'YES', | |
is_primary=key == 'PRI', | |
auto_inc='auto_increment' in extra, | |
server_default=server_default, | |
server_onupdate=server_onupdate, | |
comment=comment or None) | |
model.add_column(column) | |
query = ('select index_name, group_concat(column_name order by seq_in_index), ' | |
'min(non_unique) from information_schema.statistics ' | |
'where table_schema="%s" and table_name="%s" and index_name != "PRIMARY" ' | |
'group by index_name') % (args.db, args.table) | |
for (name, columns, non_unique) in engine.execute(query): | |
model.add_index(name, columns.split(','), not non_unique) | |
if args.output: | |
with open(args.output, 'w') as f: | |
f.write(model.dump()) | |
else: | |
print(model.dump()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment