Skip to content

Instantly share code, notes, and snippets.

@guyskk
Created November 16, 2017 06:20
Show Gist options
  • Save guyskk/a92a0ebfef9e2f8ba19380f73b5ee675 to your computer and use it in GitHub Desktop.
Save guyskk/a92a0ebfef9e2f8ba19380f73b5ee675 to your computer and use it in GitHub Desktop.
Format SQL in python source code
"""
Requires:
pip install click sqlparse
"""
import re
from textwrap import indent, dedent
import click
import sqlparse
_leading_whitespace_re = re.compile('(^[ \t]*)(?:[^ \t\n])')
_divider_re = re.compile('({}|{})'.format('"""', "'''"))
def _is_create_table_sql(code):
cols = re.findall('({})'.format('|'.join([
'varchar',
'string',
'int',
'bigint',
'double',
'row format',
])), code, re.I)
return len(cols) > 3
def format_code(text, **options):
blocks = []
prev_end = 0
is_sql = False
for match in _divider_re.finditer(text):
start, end = match.span()
code = text[prev_end: start]
if is_sql:
is_docstring, indent_size = _get_last_line_info(blocks[-1])
if (is_docstring or _is_create_table_sql(code)):
code = '"""' + code + '"""'
else:
code = _format_sql(code, **options)
code = indent(code, ' ' * (indent_size + 4))
code = '"""' + code + ' ' * indent_size + '"""'
blocks.append(code)
is_sql = not is_sql
prev_end = end
blocks.append(text[prev_end:])
return ''.join(blocks)
def _get_indent_size(line):
match = _leading_whitespace_re.search(line)
if not match:
return 0
return _compute_indent_size(match.group()[:-1])
def _compute_indent_size(chars):
return sum(4 if x == '\t' else 1 for x in chars)
def _get_last_line_info(code):
code = code.strip('\n')
line = code.rsplit('\n', maxsplit=1)[-1]
is_docstring = line.strip() == ''
if is_docstring:
indent_size = _compute_indent_size(line)
else:
indent_size = _get_indent_size(line)
return is_docstring, indent_size
def _format_sql(sql, **options):
sql = dedent(sql.strip('\n'))
options.setdefault('reindent', True)
options.setdefault('keyword_case', 'upper')
options.setdefault('indent_width', 4)
sql = sqlparse.format(sql, **options)
return '\n' + sql + '\n'
DEMO = '''
print('haha')
mysql = """
select * from mytable
where id=1
and name='xxx'
"""
print('hhhh')
def my_func():
"""
This
IS
Doc
String
"""
print('haha')
mysql = """
select * from mytable
where id=1
and name='xxx'
"""
print('hhhh')
mysql = """
select * from mytable where id=1 and name='xxx' """
'''
DEMO += """
print('haha')
mysql = '''
select * from mytable
where id=1
and name='xxx'
'''
print('hhhh')
"""
@click.command()
@click.argument('filepath', type=click.Path(exists=True))
@click.option('--dryrun', '-d', default=False, is_flag=True, help="Don't write file, just print it.")
@click.option('--demo', default=False, is_flag=True, help='Run demo test')
def cli(filepath, dryrun=False, demo=False):
"""Format SQL in python source code"""
if demo:
print(format_code(DEMO))
return
with open(filepath) as f:
text = f.read()
text = format_code(text)
if dryrun:
print(text)
else:
with open(filepath, 'w') as f:
f.write(text)
if __name__ == '__main__':
cli()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment