Created
January 24, 2018 04:32
-
-
Save skwerlman/5610695e49ca605c1bdd6957fabdfa25 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
Run migrations on the servatrice database. | |
Reads migrations from `servatrice/migrations/` and runs them | |
on the database in order. | |
Only runs migrations if they are nedded. This is determined using | |
the cockatrice_schema_version table. | |
Stops running migrations if any fail for any reason. | |
""" | |
import os | |
from argparse import ArgumentParser | |
import pymysql | |
import pymysql.cursors | |
SQL_CONTROLLER = None | |
def run_sql_command(sql: str) -> str: | |
"""Run a SQL command.""" | |
SQL_CONTROLLER.execute(sql) | |
result = SQL_CONTROLLER.fetchall() | |
return result | |
def get_all_migrations(args) -> list: | |
"""Get a list of all available migrations.""" | |
files = os.listdir(args.migration_directory) | |
migrations = [f'{args.migration_directory}/{x}' for x in files if x.endswith('.sql')] | |
migrations.sort() | |
return migrations | |
def get_schema_version() -> int: | |
"""Get the schema version of the servatrice database.""" | |
command = 'SELECT version FROM cockatrice_schema_version;' | |
result = run_sql_command(command) | |
result = result[0]['version'] | |
return int(result) | |
def valid_migrations(migrations: list, schema_version: int) -> list: | |
"""Retrun a list of valid migrations.""" | |
valid = [] | |
for migration in migrations: | |
parts = migration.split('_') | |
if schema_version <= int(parts[1]): | |
valid.append(migration) | |
valid.sort() | |
return valid | |
def run_migration(migration: str) -> dict: | |
"""Load a migration from disk and run it.""" | |
with open(migration, 'r') as f: | |
sql = f.read() | |
# print(SQL_CONTROLLER.mogrify(sql)) | |
# exit() | |
try: | |
result = run_sql_command(sql) | |
status = { | |
'success': True, | |
'result': result | |
} | |
except pymysql.err.MySQLError as exception: | |
status = { | |
'success': False, | |
'error': exception | |
} | |
return status | |
def main() -> None: | |
"""Run the migrations.""" | |
global SQL_CONTROLLER | |
parser = ArgumentParser( | |
description='Run migrations on a servatrice database.', | |
epilog='Be sure to manually verify migrations _before_ running them!' | |
) | |
mysql_group = parser.add_argument_group('MySql Server Args') | |
mysql_group.add_argument('-u', '--user', required=True) | |
mysql_group.add_argument('-p', '--password', '--pass', required=True) | |
mysql_group.add_argument('-H', '--host', default='127.0.0.1') | |
mysql_group.add_argument('-d', '--database', default='servatrice') | |
mysql_group.add_argument('-P', '--port', type=int, default=3306) | |
script_group = parser.add_argument_group('Script Args') | |
script_group.add_argument('-D', '--migration-directory', default='./migrations') | |
script_group.add_argument('--safe-mode', type=bool, default=True) | |
args = parser.parse_args() | |
connection = pymysql.connect( | |
host=args.host, user=args.user, password=args.password, | |
db=args.database, charset='utf8mb4', | |
cursorclass=pymysql.cursors.DictCursor) | |
SQL_CONTROLLER = connection.cursor() | |
migrations = get_all_migrations(args) | |
schema_version = get_schema_version() | |
migrations = valid_migrations(migrations, schema_version) | |
for migration in migrations: | |
status = run_migration(migration) | |
if not status['success']: | |
exc = status['error'] | |
print(exc) | |
break | |
# TODO handle ctlaltca's concerns from #2969 | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment