Skip to content

Instantly share code, notes, and snippets.

@dojiong
Last active December 28, 2016 15:18
Show Gist options
  • Save dojiong/d3a0d5176a42bb6173f37ac1630308f5 to your computer and use it in GitHub Desktop.
Save dojiong/d3a0d5176a42bb6173f37ac1630308f5 to your computer and use it in GitHub Desktop.
generate model for sqlalchemy
#!/usr/bin/env python
import re
from collections import defaultdict
from sqlalchemy import create_engine
class Env(object):
def __init__(self, model='.db.Model', types='sqlalchemy.types',
column='sqlalchemy.Column', fix_onupdate=True):
self.model = model.rsplit('.', 1)[-1]
self.types = types.rsplit('.', 1)[-1]
self.column = column.rsplit('.', 1)[-1]
self.fix_onupdate = fix_onupdate
self.imports = set()
self.from_imports = defaultdict(lambda: set())
self.add_import(model)
self.add_import(types)
self.add_import(column)
def add_import(self, path):
parts = path.rsplit('.', 1)
if len(parts) == 1:
self.imports.add(path)
else:
self.from_imports[parts[0]].add(parts[1])
def get_imports(self):
lines = list(self.imports)
lines.extend('from %s import %s' % (k, ', '.join(sorted(v)))
for k, v in self.from_imports.items())
return '\n'.join(sorted(lines))
class ModelDesc(object):
def __init__(self, name, tablename, flask=False, fix_onupdate=True):
self.name = name
self.tablename = tablename
self.columns = []
self.indexes = []
if flask:
self.env = Env(model='.db.Model', types='.db.db', fix_onupdate=fix_onupdate)
else:
self.env = Env(fix_onupdate=fix_onupdate)
def add_column(self, column):
self.columns.append(column)
def add_index(self, name, columns, unique):
self.indexes.append((name, columns, unique))
def get_table_args(self, env):
tableargs = ''
if self.indexes:
env.add_import('sqlalchemy.Index')
indexes = []
for name, columns, unique in self.indexes:
if unique:
f = 'Index(%r, %s, unique=True)'
else:
f = 'Index(%r, %s)'
indexes.append(f % (name, ', '.join('%r' % x for x in columns)))
h = ' __table_args__ = ('
tableargs = (h + '%s,)\n') % (',\n' + ' ' * len(h)).join(indexes)
return tableargs
def dump(self):
tmpl = ('class {name}({model}):\n'
' __tablename__ = \'{tablename}\'\n{tableargs}\n')
head = tmpl.format(name=self.name,
model=self.env.model,
tablename=self.tablename,
tableargs=self.get_table_args(self.env))
columns = '\n'.join([' ' + c.dump(self.env) for c in self.columns])
imports = self.env.get_imports()
return imports + '\n\n\n' + head + columns
class ColumnDesc(object):
def __init__(self, name, type_desc, nullable, is_primary, auto_inc,
server_default, server_onupdate, comment):
self.name = name
self.type_desc = type_desc
self.nullable = nullable
self.is_primary = is_primary
self.auto_inc = auto_inc
self.server_default = server_default
self.server_onupdate = server_onupdate
self.comment = comment
self.type_desc = type_desc
def dump(self, env):
args = [ColumnType.parse(self.type_desc, env)]
if self.is_primary:
args.append('primary_key=True')
if self.auto_inc:
args.append('autoincrement=True')
if self.server_default is not None:
env.add_import('sqlalchemy.text')
dft = self.server_default
if env.fix_onupdate and self.server_onupdate:
dft += ' ON UPDATE ' + self.server_default
args.append('server_default=text(%r)' % dft)
if self.server_onupdate is not None:
env.add_import('sqlalchemy.text')
args.append('server_onupdate=text(%r)' % self.server_onupdate)
if not self.nullable and len(args) == 1:
args.append('nullable=False')
if self.comment is not None:
args.append('doc=%r' % self.comment)
return '%s = %s(%s)' % (self.name, env.column, ', '.join(args))
class ColumnType(object):
type_r = re.compile(r'^([a-z]+)(\([^)]+\))?')
def __init__(self, name, model_type, arg_parse=None):
self.name = name
self.model_type = model_type
self.arg_parse = arg_parse
@classmethod
def parse(ColumnType, type_desc, env):
result = ColumnType.type_r.findall(type_desc)
if not result:
print('**Error: invalid type desc: `%s`' % type_desc)
raise SystemExit(-1)
typ, args = result[0]
col_type = type_table.get(typ)
if col_type is None:
print("**Error: unsupported column type: %s" % col_type)
raise SystemExit(-1)
ret = '%s.%s' % (env.types, col_type.model_type)
if args and col_type.arg_parse is not None:
ret += '(%s)' % ', '.join(col_type.arg_parse(args[1:-1].split(',')))
return ret
type_table = {t.name: t for t in [
ColumnType('char', 'String', lambda args: ['length=%s' % args[0]]),
ColumnType('varchar', 'String', lambda args: ['length=%s' % args[0]]),
ColumnType('float', 'Float'),
ColumnType('double', 'Float'),
ColumnType('bigint', 'BigInteger'),
ColumnType('int', 'Integer'),
ColumnType('mediumint', 'Integer'),
ColumnType('smallint', 'SmallInteger'),
ColumnType('tinyint', 'SmallInteger'),
ColumnType('text', 'Text'),
ColumnType('blob', 'LargeBinary'),
ColumnType('timestamp', 'DateTime'),
ColumnType('datetime', 'DateTime'),
ColumnType('date', 'Date'),
ColumnType('time', 'Time'),
ColumnType('enum', 'String',
lambda args: ['length=%d' % max(len(x) for x in args) - 2])
]}
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='make model from db')
parser.add_argument('-n', '--model-name', help='model name', default='')
parser.add_argument('-o', '--output', help='output file', default='')
parser.add_argument('-H', '--host', help='db host', default='localhost')
parser.add_argument('-u', '--user', help='db user', default='root')
parser.add_argument('-p', '--passwd', help='db passwd', default='')
parser.add_argument('-P', '--port', help='db port', default=3306, type=int)
parser.add_argument('-d', '--dialect', help='db dialect', default='mysql+pymysql')
parser.add_argument('--flask', help='use flask-sqlalchemy stype',
default=False, action='store_true')
parser.add_argument('--fix-onupdate', help='fix sqlalchemy server_onupdate',
default=False, action='store_true')
parser.add_argument('db', help='database name')
parser.add_argument('table', help='table name')
args = parser.parse_args()
model = ModelDesc(args.model_name or args.table.title().replace('_', ''),
args.table, args.flask, args.fix_onupdate)
engine = create_engine('{dialect}://{user}:{pwd}@{host}:{port}/{db}'.format(
dialect=args.dialect, user=args.user, pwd=args.passwd,
host=args.host, port=args.port, db=args.db))
engine.execute('set names utf8')
query = ('select column_name,column_type,is_nullable,column_key,'
'column_default,extra,column_comment '
'from information_schema.columns '
'where table_schema="%s" and table_name="%s"') % (args.db, args.table)
for (name, type, nullable, key, default, extra, comment) in engine.execute(query):
server_default = default if default != 'NULL' else None
server_onupdate = None
if extra.startswith('on update'):
server_onupdate = extra[len('on update '):]
column = ColumnDesc(name, type,
nullable=nullable == 'YES',
is_primary=key == 'PRI',
auto_inc='auto_increment' in extra,
server_default=server_default,
server_onupdate=server_onupdate,
comment=comment or None)
model.add_column(column)
query = ('select index_name, group_concat(column_name order by seq_in_index), '
'min(non_unique) from information_schema.statistics '
'where table_schema="%s" and table_name="%s" and index_name != "PRIMARY" '
'group by index_name') % (args.db, args.table)
for (name, columns, non_unique) in engine.execute(query):
model.add_index(name, columns.split(','), not non_unique)
if args.output:
with open(args.output, 'w') as f:
f.write(model.dump())
else:
print(model.dump())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment