Created
March 18, 2015 09:42
-
-
Save adoc/d06c30a6d915b248ada4 to your computer and use it in GitHub Desktop.
SQLAlchemy Migration Tool
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" | |
# Original Author: Tyler Lesmann | |
# Original Source: http://www.tylerlesmann.com/2009/apr/27/copying-databases-across-platforms-sqlalchemy/ | |
# TODO: Look at https://gist.github.com/thoas/1589496 if any | |
# "sequencing" issues come in to play. | |
* Python 3 Only (Deal with it.) | |
""" | |
from pprint import pprint | |
import sys | |
import getopt | |
import urllib.parse | |
import sqlalchemy | |
import sqlalchemy.orm | |
import sqlalchemy.exc | |
from sqlalchemy.ext.declarative import declarative_base | |
class Context: | |
""" | |
""" | |
def __init__(self, connection_string, Base): | |
""" | |
""" | |
self.connection_string = connection_string | |
self.Base = Base | |
self.MetaData = Base.MetaData | |
self.dialect = None | |
self.schema = None | |
self.database = None | |
self.parse_connection_string() | |
self.Engine = sqlalchemy.create_engine(self.connection_string, | |
echo=False, convert_unicode=True) | |
self.Session = sqlalchemy.orm.sessionmaker(bind=self.Engine)() | |
self.MetaData.bind = self.Engine | |
def parse_connection_string(self): | |
""" | |
""" | |
result = urllib.parse.urlparse(self.connection_string) | |
for keys, desc in self.dialect_map.items(): | |
if keys: | |
for key in keys: | |
if key in result.scheme: | |
self.dialect = keys | |
break | |
break | |
query = urllib.parse.parse_qs(result.query) | |
if 'schema' in query: | |
self.schema = query['schema'][0] | |
self.database = result.path.strip('/') | |
def get_table_list(self): | |
"""Return a list of table names from the current metadata. | |
""" | |
return self.MetaData.tables.keys() | |
def iter_table(self): | |
for table_name in self.get_table_list(): | |
yield table_name, sqlalchemy.Table(table_name, self.MetaData, | |
autoload=True) | |
def commit(self): | |
self.Session.commit() | |
class SourceContext(Context): | |
""" | |
""" | |
def iter_records(self, table): | |
for record in self.Session.query(table).all(): | |
yield record | |
class DestinationContext(Context): | |
""" | |
""" | |
def create_table(self, source_table): | |
source_table.metadata.create_all(self.Engine) | |
def init_table(self, table_name): | |
return sqlalchemy.Table(table_name, self.MetaData, autoload=True) | |
def quick_mapper(table): | |
"""Returns a SQLAlchemy declarative model given the ``table`` | |
object. | |
""" | |
# TODO: Check the validity of creating a declarative_base for each | |
# table. | |
Base = declarative_base() | |
class GenericMapper(Base): | |
__table__ = table | |
return GenericMapper | |
def migrate(src_connection_string, dest_connection_string, Base): | |
""" | |
""" | |
print("Creating source database session...") | |
src_ctx = SourceContext(src_connection_string, Base) | |
print("Creating destination database session...") | |
dest_ctx = DestinationContext(dest_connection_string, Base) | |
print("Iterating Tables...") | |
for table_name, src_table in src_ctx.iter_table(): | |
dest_ctx.create_table(src_table) | |
Model = quick_mapper(src_table) | |
columns = src_table.columns.keys() | |
for record in src_ctx.iter_records(): | |
data = dict( | |
[(str(column), getattr(record, column)) for column in columns] | |
) | |
dest_ctx.Session.merge(Model(**data)) | |
dest_ctx.commit() | |
def usage(argv): | |
return """ | |
Give a source and target connection string and | |
Usage: %s -f source_server -t destination_server model_module [model_module2 model_module3] | |
-f, -t = driver://user[:password]@host[:port]/database[?schema=schema_name] | |
Example: %s | |
""" % (argv[0], argv[0])) | |
def main(argv): | |
"""Main script entry point. Passed ``sys.argv``. | |
""" | |
optlist, model_modules = getopt.getopt(argv[1:], 'f:t:') | |
options = dict(optlist) | |
if '-f' not in options or '-t' not in options or not model_modules: | |
print(usage(argv)) | |
sys.exit(1) | |
else: | |
migrate(options['-f'], options['-t'], model_modules) | |
if __name__ == '__main__': | |
main(sys.argv) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment