Skip to content

Instantly share code, notes, and snippets.

@TheWaWaR
Last active August 29, 2015 14:13
Show Gist options
  • Save TheWaWaR/3139ba1321d3f2645217 to your computer and use it in GitHub Desktop.
Save TheWaWaR/3139ba1321d3f2645217 to your computer and use it in GitHub Desktop.
SQLAlchemy definition to Alembic definition.
import re
VAR_REGEXP = "[_A-Za-z][_a-zA-Z0-9]*$"
COLUMN_MARK = 'db.Column('
DB_PREFIX = 'db.'
SA_PREFIX = 'sa.'
ALEMBIC_OP_NAME = 'op'
SQLALCHEMY_NAME = 'sa'
INDENT_1 = ' ' # 4 spaces
INDENT_2 = INDENT_1 * 2
MODULE_TMPL = u'''def create_%(module_name)s():\n%(tables_str)s'''
TABLE_TMPL = u'''%(indent_1)s%(op)s.create_table(
%(indent_2)s"%(table_name)s",
%(columns_str)s
%(indent_1)s)'''
COLUMN_TMPL = '''%(indent_2)s%(sa)s.Column("%(col_name)s", %(col_define)s)'''
def print_tbls(s):
tables = []
table_name, columns = None, []
for line in s.splitlines():
line = line.strip()
if not line: continue
print line
try:
print '-----'
if line.startswith('__tablename__'):
name = line.strip().split('=')[-1].strip()[1:-1]
assert re.match(VAR_REGEXP, name) is not None
if table_name is not None:
tables.append([table_name, columns])
columns = []
table_name = name
elif line.find('=') > -1 and line.find(COLUMN_MARK) > -1:
col_name, col_define = line.split('=', 1)
col_name = col_name.strip()
col_define = col_define.strip()
if col_define[-2:] != '),' and col_define[-1] != ')':
raise ValueError('Invalid column:' + col_define)
col_define = col_define[:-1].split(COLUMN_MARK)[-1]
if col_define[0] in ("'", '"'):
col_name, col_define = col_define.split(',', 1)
col_name = col_name[1:-1]
col_define = col_define.strip()
print 'col_name = [%s], col_define = [%s]' % (col_name, col_define)
assert re.match(VAR_REGEXP, col_name) is not None
assert line.find(col_name) > -1 and line.find(col_define) > -1
assert len(line) >= 2 and len(col_define) >= 2
if DB_PREFIX:
col_define = col_define.replace(DB_PREFIX, SA_PREFIX)
columns.append([col_name, col_define])
else:
print '::::: Nothing'
except Exception as e:
print line
print e
raise
print '=============='
tables.append([table_name, columns])
return tables
def parse(filename):
with open(filename, 'r') as f:
return print_tbls(f.read())
def build_module(module_name, tables):
op = ALEMBIC_OP_NAME
sa = SQLALCHEMY_NAME
indent_1 = INDENT_1
indent_2 = INDENT_2
table_str_lst = []
for table_name, columns in tables:
columns_str = ',\n'.join([COLUMN_TMPL % locals() for col_name, col_define in columns])
table_str_lst.append(TABLE_TMPL % locals())
tables_str = '\n\n'.join(table_str_lst)
return MODULE_TMPL % locals()
if __name__ == '__main__':
modules = []
for fn in ['module1.py', 'module1.py', 'module1.py', 'module1.py', 'module1.py', 'module1.py']:
modules.append([fn.split('.')[0], parse(fn)])
print '###########################################\n\n'
print '\n\n\n'.join([build_module(module_name, tables)
for module_name, tables in modules])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment