Created
November 6, 2020 19:01
-
-
Save rbellamy/10b6f44f657f4ea0bd9e624d9912e6ae to your computer and use it in GitHub Desktop.
postway and postsplit - a FlyWay-like tool for raw psql and a DAG-based splitter for pg_dump
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import argparse | |
import logging | |
import os | |
import re | |
from enum import Enum | |
from itertools import groupby, chain | |
# import matplotlib.pyplot as plt | |
import networkx as nx | |
class ParentName(Enum): | |
none = 1 | |
before = 2 | |
after = 3 | |
def __str__(self): | |
return self.name | |
@staticmethod | |
def from_string(s): | |
try: | |
return ParentName[s] | |
except KeyError: | |
raise ValueError() | |
class PostSplit: | |
"""A class object for the PG dump split utility script. | |
Split pg_dump files to schema/schema_type file hierarchy | |
Use with files produced by pg_dump -s | |
Original author: kmatt - from https://gist.github.com/kmatt/2572360 | |
https://gist.github.com/rbellamy/e61bbe89abe97afb0fd6150190856f4f | |
""" | |
def __init__(self): | |
self.parser = argparse.ArgumentParser( | |
description='A script for parsing a text pg_dump file and creating either a) per-object files, or b) ' | |
'per-object files in a form consumable by postway/postway') | |
self.logger = logging.getLogger('pg_dump_split') | |
self.logger.setLevel(logging.DEBUG) | |
self.console_handler = logging.StreamHandler() # sys.stderr | |
self.console_handler.setLevel( | |
logging.CRITICAL) # set later by set_log_level_from_verbose() in interactive sessions | |
self.console_handler.setFormatter( | |
logging.Formatter('[%(levelname)s] %(message)s')) | |
self.logger.addHandler(self.console_handler) | |
self.create_public_schema_sql = ''' | |
CREATE SCHEMA public; | |
GRANT ALL ON SCHEMA public TO postgres; | |
GRANT ALL ON SCHEMA public TO PUBLIC; | |
COMMENT ON SCHEMA public IS 'standard public schema'; | |
''' | |
self.args = None | |
def _get_sections(self): | |
infile = self.args.infile | |
with open(infile) as f: | |
groups = groupby(f, lambda x: x.startswith('-- TOC')) | |
for key, value in groups: | |
if key: | |
yield chain([next(value)], (next(groups)[1])) # all lines up to -- TOC | |
def _add_edges_by_parent_name(self, graph): | |
"""assign edges based on parent_name inspection""" | |
for parent_n, parent_d in graph.nodes(data=True): | |
if parent_d['schema_type'] == 'table': | |
parent_schema = parent_d['schema'] | |
parent_object_name = parent_d['object_name'] | |
for child_n, child_d in graph.nodes(data=True): | |
child_schema_type = child_d['schema_type'] | |
child_schema = child_d['schema'] | |
child_object_name = child_d['parent_name'] | |
if 'parent_name' in child_d and \ | |
parent_schema == child_schema and \ | |
parent_object_name == child_d['parent_name']: | |
self.logger.debug( | |
'[toc table <- {}] edge: {}.{}({}) <- {}.{}({})'.format(child_schema_type, | |
parent_schema, | |
parent_object_name, | |
parent_n, | |
child_schema, | |
child_object_name, | |
child_n)) | |
graph.add_edge(child_n, parent_n) | |
return graph | |
def _add_edges_by_sql(self, parent_schema_types, child_schema_types, graph): | |
""" | |
- if parent 'object_name' is in 'sql' of child, then parent is dependency of child. | |
- if 'myent_pkg.r' is in 'check_missing_goals_setup' function, then 'myent_pkg'(TOC == 1765) is a dependency of | |
'check_missing_goals_setup'(TOC == 4054) | |
>>> graph.add_edge(1765, 4054) | |
:param parent_schema_types: | |
:type parent_schema_types: | |
:param child_schema_types: | |
:type child_schema_types: | |
:param graph: | |
:type graph: | |
:return: | |
:rtype: | |
""" | |
for outer_n, outer_d in graph.nodes(data=True): | |
if outer_d['schema_type'] in parent_schema_types: | |
parent_schema = outer_d['schema'] | |
parent_object_name = outer_d['object_name'] | |
for inner_n, inner_d in graph.nodes(data=True): | |
if inner_d['schema_type'] in child_schema_types: | |
child_schema = inner_d['schema'] | |
child_object_name = inner_d['object_name'] | |
child_sql = inner_d['sql'] | |
if self._is_dependent_sql(parent_schema, | |
parent_object_name, | |
child_schema, | |
child_object_name, | |
child_sql): | |
self.logger.debug('[{} <- {}] edge: {}.{}({}) <- {}.{}({})'.format( | |
'/'.join(parent_schema_types), | |
'/'.join(child_schema_types), | |
parent_schema, | |
parent_object_name, | |
outer_n, | |
child_schema, | |
child_object_name, | |
inner_n)) | |
graph.add_edge(inner_n, outer_n) | |
return graph | |
def _merge_acl(self, graph): | |
""" | |
apply schema_type == 'acl' to all types - this effectively removes the acl type | |
*NOTE: There is a bug - currently does not correctly match with foreign tables, sequences and views* | |
""" | |
acl_delete_list = [] | |
for outer_n, outer_d in graph.nodes(data=True): | |
parent_schema = outer_d['schema'] | |
parent_schema_type = outer_d['schema_type'] | |
parent_object_name = outer_d['object_name'] | |
if parent_schema_type != 'acl': | |
for inner_n, inner_d in graph.nodes(data=True): | |
if inner_d['schema_type'] == 'acl': | |
acl_schema = inner_d['schema'] | |
acl_parent_schema_type = self._get_acl_parent_type(inner_d['sql']) | |
acl_object_name = inner_d['object_name'] | |
if parent_schema_type == acl_parent_schema_type and \ | |
parent_object_name == acl_object_name and \ | |
parent_schema == acl_schema: | |
self.logger.debug( | |
'[{} <- acl] merge: {}.{}({}) <- {}.{}({})'.format(parent_schema_type, | |
parent_schema, | |
parent_object_name, | |
outer_n, | |
acl_schema, | |
acl_object_name, | |
inner_n)) | |
graph.nodes[outer_n]['sql'] = '{}\n{}'.format(outer_d['sql'], inner_d['sql']) | |
acl_delete_list.append(inner_n) | |
graph.remove_nodes_from(acl_delete_list) | |
return graph | |
@staticmethod | |
def _is_dependent_sql(parent_schema, parent_object_name, child_schema, child_object_name, sql): | |
""" | |
1. If the fully-qualified parent and child names match, then no match. | |
2. If the schema names are different, then the parent must be matched with a fully-qualified name in the child. | |
3. If the schema names are the same, then the parent must be matched with a relative name, if a fully-qualified | |
parent with the same relative name, but different schema, hasn't already matched the child. | |
:param parent_schema: | |
:type parent_schema: | |
:param parent_object_name: | |
:type parent_object_name: | |
:param child_schema: | |
:type child_schema: | |
:param child_object_name: | |
:type child_object_name: | |
:param sql: | |
:type sql: | |
:return: | |
:rtype: | |
""" | |
parent = '{}\.{}'.format(parent_schema, parent_object_name) | |
child = '{}\.{}'.format(child_schema, child_object_name) | |
relative_regex = re.compile(r'.*[\s,;\(\)]{}[\s,;\(\)].*'.format(parent_object_name)) | |
fully_qualified_regex = re.compile(r'.*[\s,;\(\)]{}[\s,;\(\)].*'.format(parent)) | |
is_dependent_sql = False | |
for line in sql.splitlines(keepends=True): | |
if is_dependent_sql or parent == child: | |
break | |
if not line.startswith('-- Name'): | |
if parent_schema != child_schema: | |
is_dependent_sql = fully_qualified_regex.match(line) | |
else: | |
is_dependent_sql = relative_regex.match(line) | |
return is_dependent_sql | |
@staticmethod | |
def _header(owner, schema, set_role): | |
s0 = 'BEGIN;' | |
if set_role and owner: | |
s1 = 'SET LOCAL ROLE {};'.format(owner) | |
else: | |
s1 = '' | |
s2 = 'SET LOCAL check_function_bodies = false;' | |
s3 = 'SET SEARCH_PATH TO {}, pg_catalog, sys, dbo;'.format(schema) | |
return '{}\n{}\n{}\n{}\n'.format(s0, s1, s2, s3) | |
@staticmethod | |
def _footer(): | |
return 'COMMIT;' | |
def _drop_trigger(self, schema, table_name, trigger_name): | |
self.logger.debug('table_name: {}, trigger_name: {}'.format(table_name, trigger_name)) | |
s1 = 'DROP TRIGGER IF EXISTS {} ON {}.{};'.format(trigger_name.lower(), schema, table_name) | |
s2 = 'DROP TRIGGER IF EXISTS "{}" ON {}.{};'.format(trigger_name.upper(), schema, table_name) | |
return '{}\n{}\n'.format(s1, s2) | |
@staticmethod | |
def _drop_foreign_table(object_name): | |
s1 = 'DROP FOREIGN TABLE IF EXISTS {};'.format(object_name) | |
return s1 | |
@staticmethod | |
def _inject_create_or_replace(schema_type, sql): | |
if schema_type != 'type': | |
sql = re.sub(r'CREATE\s{}'.format(schema_type.upper()), | |
'CREATE OR REPLACE {}'.format(schema_type.upper()), | |
sql) | |
return sql | |
@staticmethod | |
def _fixup_quoted_create(schema, schema_type, object_name, sql): | |
sql = re.sub(r'CREATE(.*{}.*)(\"{}\")'.format(schema_type.upper(), object_name), | |
'CREATE\\1{}'.format(object_name.lower()), | |
sql) | |
return sql | |
@staticmethod | |
def _inject_postway_separator(schema_type, object_name, sql): | |
indices = [s.start() for s in re.finditer('CREATE OR REPLACE', sql)] | |
if len(indices) == 2: | |
sql = '{}/\n{}'.format(sql[:indices[1]], sql[indices[1]:]) | |
if schema_type == 'package': | |
indices = [s.start() for s in re.finditer('ALTER PACKAGE'.format(object_name), sql)] | |
if len(indices) == 1: | |
sql = '{}/\n{}'.format(sql[:indices[0]], sql[indices[0]:]) | |
return sql | |
@staticmethod | |
def _get_acl_parent_type(sql): | |
parent_types = ['aggregate', 'foreign_table', 'function', 'materialized_view', | |
'package', 'procedure', 'schema', 'server', 'table', 'trigger', 'type', 'view'] | |
acl_sql = ''.join([line for line in sql.splitlines(keepends=True) if not line.startswith('-- Name')]) | |
return next((pt for pt in parent_types if pt.upper() in acl_sql), None) | |
def _write_file(self, outdir, schema, schema_type, fname, sql): | |
self.logger.info('Schema: {}; Type: {}; Name: {}'.format(schema, schema_type, fname)) | |
sqlpath = os.path.join(outdir, schema, schema_type) | |
if not os.path.exists(sqlpath): | |
print('*** mkdir {}'.format(sqlpath)) | |
os.makedirs(sqlpath) | |
sqlf = os.path.join(sqlpath, fname) | |
self.logger.debug('sqlf: {}'.format(sqlf)) | |
sql = re.sub(r'\n{3,}', r'\n\n', sql.strip()) | |
open(sqlf, 'w').write(sql) | |
def _get_file_name(self, postway, version, versioned, name): | |
fname = '{}.sql'.format(name) | |
if versioned: | |
version += 1 | |
fname = 'V{}__{}.sql'.format(version, name) | |
elif postway: | |
fname = 'R__{}.sql'.format(name) | |
return fname, version | |
def parse_arguments(self): | |
""" | |
Parse command line arguments. | |
Sets self.args parameter for use throughout class/script. | |
""" | |
self.parser.add_argument('-i', '--infile', required=True, | |
help='The dump file created using pg_dump in text mode.') | |
self.parser.add_argument('-o', '--outdir', | |
help="The directory to use as the parent for the type'd directories. Created if it " | |
"doesn't exist. Defaults to the current directory.") | |
self.parser.add_argument('--include-parent-name', type=ParentName.from_string, choices=list(ParentName), | |
default=ParentName.none, | |
help='When naming files for objects that have parents (triggers, fk, pk, etc), use ' | |
'the parent name in the file name.') | |
self.parser.add_argument('--postway', action='store_true', | |
help='Create postway-compliant migration scripts.') | |
self.parser.add_argument('--postway-versioned-only', action='store_true', | |
help='Create postway-compliant VERSIONED migration scripts ONLY. Normally, postsplit ' | |
'creates both versioned and repeatable scripts, with ' | |
'functions/procedures/packages etc. being repeatable and ' | |
'tables/indexs/constraints etc. being versioned. However, EDB EPAS ' | |
'"check_function_bodies=false" DOES NOT WORK for packages. This means that if ' | |
'there are packages that have dependencies, they will fail to build unless they ' | |
'just happen to sort correctly.') | |
self.parser.add_argument('--postway-version', default=1, | |
help='Start with postway-version for migration scripts.') | |
self.parser.add_argument('--plot', action='store_true', | |
help='Show plot.') | |
self.parser.add_argument('-V', '--version', action='version', version='%(prog)s 1.0.0', | |
help='Print the version number of postsplit.') | |
self.parser.add_argument('-v', '--verbose', action='count', help='verbose level... repeat up to three times.') | |
self.args = self.parser.parse_args() | |
return self | |
def set_log_level_from_verbose(self): | |
if not self.args.verbose: | |
self.console_handler.setLevel('ERROR') | |
elif self.args.verbose == 1: | |
self.console_handler.setLevel('WARNING') | |
elif self.args.verbose == 2: | |
self.console_handler.setLevel('INFO') | |
elif self.args.verbose >= 3: | |
self.console_handler.setLevel('DEBUG') | |
else: | |
self.logger.critical('UNEXPLAINED NEGATIVE COUNT!') | |
return self | |
def build_graph(self): | |
outdir = self.args.outdir | |
if outdir == '': | |
outdir = os.path.dirname(self.args.infile) | |
graph = nx.DiGraph() | |
type_regex = re.compile( | |
r'-- Name: ([-\w\s\.\$\^]+)(?:\([-\w\s\[\],.\*\"]*\))?; Type: ([-\w\s]+); Schema: ([-\w]+); Owner: ([-\w]*)(?:; Tablespace: )?([-\w]*)\n', | |
flags=re.IGNORECASE) | |
toc_regex = re.compile(r'-- TOC entry (\d*) \(class (\d+) OID (\d+)\)\n', flags=re.IGNORECASE) | |
dep_regex = re.compile(r'-- Dependencies: (.*)') | |
name_regex = re.compile(r'(\w+)\s+(\w+)') | |
user_mapping_name_regex = re.compile(r'USER\sMAPPING\s([\w\s]+)') | |
name, schema_type, schema, owner, tablespace = [''] * 5 | |
toc_id = 0 | |
sql = '' | |
type_line = '' | |
# build the graph based on TOC | |
for sec in self._get_sections(): | |
# remove all lines with just `--` | |
section = [x for x in list(sec) if x != '--\n'] | |
toc_id = toc_regex.search(section[0]).group(1) | |
graph.add_node(toc_id, schema='', schema_type='', name='', sql='') | |
if section[1].startswith('-- Dependencies'): | |
dep_ids = dep_regex.search(section[1]).group(1).split() | |
# edges defined by TOC dependencies | |
for dep_id in set(dep_ids): | |
# dep_id is a dependency of toc_id | |
self.logger.debug('TOC edge: {}({}) <- {}({})'.format('unknown', | |
dep_id, | |
'unknown', | |
toc_id)) | |
graph.add_edge(toc_id, dep_id) | |
type_line = section[2] | |
else: | |
type_line = section[1] | |
name, schema_type, schema, owner, tablespace = type_regex.search(type_line).groups() | |
# ignore the schema_version table if building for postway | |
if self.args.postway and 'schema_version' in name: | |
continue | |
schema_type = schema_type.replace(' ', '_').lower() | |
if schema_type == 'user_mapping': | |
parent_name = '' | |
name = user_mapping_name_regex.search(name).group(1).replace(' ', '_').lower() | |
object_name = name | |
else: | |
name_match = name_regex.match(name) | |
if name_match: | |
parent_name = name_regex.search(name).group(1) | |
object_name = name_regex.search(name).group(2) | |
if self.args.include_parent_name == ParentName.none: | |
name = name_regex.search(name).group(2) | |
else: | |
if self.args.include_parent_name == ParentName.before: | |
name = '{}_{}'.format(name_regex.search(name).group(1), name_regex.search(name).group(2)) | |
else: | |
name = '{}_{}'.format(name_regex.search(name).group(2), name_regex.search(name).group(1)) | |
else: | |
parent_name = '' | |
object_name = name | |
name = name.lower() | |
if self.args.postway: | |
name = '{}_{}'.format(name, schema_type) | |
if schema == '-': | |
schema = 'public' | |
# section = anything that doesn't start with: | |
# 1. -- TOC | |
# 2. -- Dependencies | |
# 3. SET search_path | |
sql = [y for y in section if not (y.startswith('-- TOC') or | |
y.startswith('-- Dependencies') or | |
y.startswith('SET search_path'))] | |
self.logger.debug( | |
'Owner: {}; Schema: {}; Type: {}; Name: {}; Parent Name: {}; Object Name: {}'.format(owner, schema, | |
schema_type, name, | |
parent_name, | |
object_name)) | |
graph.nodes[toc_id]['schema'] = schema | |
graph.nodes[toc_id]['schema_type'] = schema_type | |
graph.nodes[toc_id]['name'] = name | |
graph.nodes[toc_id]['parent_name'] = parent_name | |
graph.nodes[toc_id]['object_name'] = object_name | |
graph.nodes[toc_id]['owner'] = owner | |
graph.nodes[toc_id]['sql'] = ''.join(sql) | |
# prune the graph - remove all nodes that have no name attribute | |
prune_list = [pr for pr, d in graph.nodes(data=True) if 'name' not in d or d['name'] is None or d['name'] == ''] | |
graph.remove_nodes_from(prune_list) | |
# find edges where type is a dependency of type/function/procedure/package | |
graph = self._add_edges_by_sql(['type'], ['type', 'function', 'procedure', 'package'], graph) | |
# find edges where table/view is a dependency of view | |
graph = self._add_edges_by_sql(['table', 'view'], ['view'], graph) | |
# find edges where table is a dependency of index | |
graph = self._add_edges_by_sql(['table'], ['index'], graph) | |
# find edges where table is a dependency of constraints/fk_constraints | |
graph = self._add_edges_by_parent_name(graph) | |
# merge acl into parent script | |
graph = self._merge_acl(graph) | |
# add public schema | |
# graph.add_node(0, schema='public', schema_type='schema', name='public_schema', | |
# sql=self.create_public_schema_sql) | |
return graph | |
def prepare_and_write(self, toc_id, postway, version, versioned_only, outdir, node): | |
self.logger.debug('{}: {}'.format(toc_id, node)) | |
owner, schema, schema_type, name, parent_name, object_name, sql = map(node.get, | |
('owner', 'schema', 'schema_type', 'name', | |
'parent_name', 'object_name', 'sql')) | |
if name is not None and name > '': | |
self.logger.info( | |
'Owner: {}; Schema: {}; Type: {}; Name: {}; Parent Name: {}; Object Name: {}'.format(owner, schema, | |
schema_type, name, | |
parent_name, | |
object_name)) | |
if postway and schema_type != 'schema': | |
# sql = self._inject_postway_separator(schema_type, object_name, sql) | |
set_role = True | |
sql = self._fixup_quoted_create(schema, schema_type, object_name, sql) | |
if schema_type in ['function', 'procedure', 'package', 'view', 'trigger', 'foreign_table']: | |
fname, version = self._get_file_name(postway, version, versioned_only, name) | |
if schema_type == 'trigger': | |
sql = '{}\n{}'.format(self._drop_trigger(schema, parent_name, object_name), sql) | |
elif schema_type == 'foreign_table': | |
sql = '{}\n{}'.format(self._drop_foreign_table(object_name), sql) | |
else: | |
sql = self._inject_create_or_replace(schema_type, sql) | |
else: | |
fname, version = self._get_file_name(postway, version, True, name) | |
sql = '{}\n{}\n{}'.format(self._header(owner, schema, set_role), sql, self._footer()) | |
else: | |
fname, version = self._get_file_name(False, version, False, name) | |
self._write_file(outdir, schema, schema_type, fname, sql) | |
return version | |
if __name__ == '__main__': | |
postsplit = PostSplit().parse_arguments().set_log_level_from_verbose() | |
outdir = postsplit.args.outdir | |
if outdir is None or outdir == '': | |
outdir = os.path.dirname(postsplit.args.infile) | |
print('Start parsing {}.'.format(postsplit.args.infile)) | |
postsplit.logger.debug('args: {}'.format(postsplit.args)) | |
g = postsplit.build_graph() | |
# if postsplit.args.plot: | |
# print('Plot the graph.') | |
# nx.draw(g, with_labels=True, font_weight='bold') | |
# plt.show() | |
print('Sort the DAG.') | |
try: | |
s = list(reversed(list(nx.topological_sort(g)))) | |
postsplit.logger.debug(s) | |
print('Write the files.') | |
p = postsplit.args.postway | |
p_v = postsplit.args.postway_version # postway version - used in version migration scripts | |
p_v_o = postsplit.args.postway_versioned_only | |
for n in s: | |
p_v = postsplit.prepare_and_write(n, p, p_v, p_v_o, outdir, g.nodes[n]) | |
print('Done processing {}.'.format(postsplit.args.infile)) | |
except Exception as err: | |
c = nx.find_cycle(g) | |
postsplit.logger.debug(c) | |
postsplit.logger.error(err) | |
exit(1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import argparse | |
import collections.abc | |
import logging | |
import ntpath | |
import os | |
import re | |
import subprocess | |
import zlib | |
from collections import namedtuple | |
from enum import Enum | |
from glob import glob | |
from timeit import default_timer as timer | |
import psycopg2 | |
from packaging import version | |
from psycopg2._psycopg import AsIs | |
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT | |
from psycopg2.extensions import make_dsn | |
from psycopg2.extras import LoggingConnection | |
class OrderedSet(collections.abc.MutableSet): | |
def __init__(self, iterable=None): | |
self.end = end = [] | |
end += [None, end, end] # sentinel node for doubly linked list | |
self.map = {} # key --> [key, prev, next] | |
if iterable is not None: | |
self |= iterable | |
def __len__(self): | |
return len(self.map) | |
def __contains__(self, key): | |
return key in self.map | |
def add(self, key): | |
if key not in self.map: | |
end = self.end | |
curr = end[1] | |
curr[2] = end[1] = self.map[key] = [key, curr, end] | |
def discard(self, key): | |
if key in self.map: | |
key, prev, next = self.map.pop(key) | |
prev[2] = next | |
next[1] = prev | |
def __iter__(self): | |
end = self.end | |
curr = end[2] | |
while curr is not end: | |
yield curr[0] | |
curr = curr[2] | |
def __reversed__(self): | |
end = self.end | |
curr = end[1] | |
while curr is not end: | |
yield curr[0] | |
curr = curr[1] | |
def pop(self, last=True): | |
if not self: | |
raise KeyError('set is empty') | |
key = self.end[1][0] if last else self.end[2][0] | |
self.discard(key) | |
return key | |
def __repr__(self): | |
if not self: | |
return '%s()' % (self.__class__.__name__,) | |
return '%s(%r)' % (self.__class__.__name__, list(self)) | |
def __eq__(self, other): | |
if isinstance(other, OrderedSet): | |
return len(self) == len(other) and list(self) == list(other) | |
return set(self) == set(other) | |
class PostWayCommand(Enum): | |
baseline = 1 | |
version = 2 | |
clean = 3 | |
migrate = 4 | |
validate = 5 | |
info = 6 | |
def __str__(self): | |
return self.name | |
@staticmethod | |
def from_string(s): | |
try: | |
return PostWayCommand[s] | |
except KeyError: | |
raise ValueError() | |
class PostWay: | |
""" | |
PostWay is a wholly new derivative of Flyway from BoxFuse. It's a derivative in that it derives it's workflow and | |
the schema_version table directly from Flyway. It's wholly new in that the code and focus are specific to | |
EDB/Postgres, and the method for executing SQL scripts uses native EDB/Postgres tools - specifically psql. | |
Management routines are managed via psycopg2. Migration scripts are run through psql. | |
""" | |
SchemaRecord = namedtuple('SchemaRecord', | |
'name, script, baseline_version, max_current_version, max_schema_version, ' | |
'repeatable_schema_migrations, versioned_schema_migrations, ' | |
'diff_repeatable_migrations, diff_versioned_migrations, changed_repeatable_migrations') | |
MigrationRecord = namedtuple('MigrationRecord', 'script, version, description, checksum') | |
ChangedMigrationRecord = namedtuple('ChangedMigrationRecord', | |
'script, version, description, old_checksum, new_checksum') | |
ExecutingMigrationRecord = namedtuple('ExecutingMigrationRecord', 'schema, script, version, description, checksum') | |
def __init__(self): | |
self.parser = argparse.ArgumentParser( | |
add_help=False, | |
description=''' | |
PostWay is a wholly new derivative of Flyway from BoxFuse. It's a derivative in that it derives it's workflow and | |
the schema_version table directly from Flyway. It's wholly new in that the code and focus are specific to | |
EDB/Postgres, and the method for executing SQL scripts uses native EDB/Postgres tools - specifically psql. | |
Management routines are managed via psycopg2. Migration scripts are run through psql. | |
''') | |
self.logger = logging.getLogger('postway') | |
self.logger.setLevel(logging.DEBUG) | |
self.console_handler = logging.StreamHandler() # sys.stderr | |
self.console_handler.setLevel( | |
logging.CRITICAL) # set later by set_log_level_from_verbose() in interactive sessions | |
self.console_handler.setFormatter(logging.Formatter('[%(levelname)s] %(message)s')) | |
self.logger.addHandler(self.console_handler) | |
self.args = None | |
self.conn = None | |
self.is_db = False | |
self.schema_records = None | |
self.is_max_schema_version = False | |
self.is_baselined = False | |
self.is_validated = False | |
self.versioned_prefix = 'V' | |
self.repeatable_prefix = 'R' | |
self.migration_separator = '__' | |
self.migration_suffix = '.sql' | |
self.baseline_version = 1 | |
self.baseline_description = '<< Flyway Baseline >>' | |
self.max_schema_version_sql = ''' | |
SELECT version | |
FROM %(nspname)s.schema_version | |
WHERE installed_rank = (SELECT MAX(installed_rank) FROM %(nspname)s.schema_version WHERE version IS NOT NULL); | |
''' | |
self.get_repeatable_migrations_sql = ''' | |
SELECT script, version, description, checksum | |
FROM %(nspname)s.schema_version | |
WHERE version IS NULL | |
ORDER BY installed_rank ASC | |
''' | |
self.get_versioned_migrations_sql = ''' | |
SELECT script, version, description, checksum | |
FROM %(nspname)s.schema_version | |
WHERE version IS NOT NULL AND script != %(description)s | |
ORDER BY installed_rank ASC | |
''' | |
self.schema_exists_sql = ''' | |
SELECT EXISTS( | |
SELECT 1 FROM pg_class c | |
JOIN pg_namespace n ON c.relnamespace = n.oid | |
WHERE n.nspname = %(schema)s | |
AND c.relkind = 'r' | |
AND c.oid NOT IN (SELECT inhrelid FROM pg_inherits) | |
); | |
''' | |
self.schema_version_exists_sql = ''' | |
SELECT EXISTS( | |
SELECT 1 FROM pg_class c | |
JOIN pg_namespace n ON c.relnamespace = n.oid | |
WHERE n.nspname = %(schema)s | |
AND c.relkind = 'r' | |
AND c.oid NOT IN (SELECT inhrelid FROM pg_inherits) | |
AND c.relname = 'schema_version' | |
); | |
''' | |
self.create_schema_version_sql = ''' | |
-- | |
-- Copyright 2010-2017 Boxfuse GmbH | |
-- | |
-- Licensed under the Apache License, Version 2.0 (the 'License'); | |
-- you may not use this file except in compliance with the License. | |
-- You may obtain a copy of the License at | |
-- | |
-- http://www.apache.org/licenses/LICENSE-2.0 | |
-- | |
-- Unless required by applicable law or agreed to in writing, software | |
-- distributed under the License is distributed on an 'AS IS' BASIS, | |
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
-- See the License for the specific language governing permissions and | |
-- limitations under the License. | |
-- | |
CREATE TABLE %(nspname)s.schema_version ( | |
installed_rank SERIAL PRIMARY KEY, | |
version VARCHAR(50), | |
description VARCHAR(200) NOT NULL, | |
type VARCHAR(20) NOT NULL, | |
script VARCHAR(1000) NOT NULL, | |
checksum NUMBER, | |
installed_by VARCHAR(100) NOT NULL, | |
installed_on TIMESTAMP NOT NULL DEFAULT now(), | |
execution_time INTEGER NOT NULL, | |
success BOOLEAN NOT NULL | |
) WITH ( | |
OIDS=FALSE | |
); | |
CREATE INDEX schema_version_s_idx ON %(nspname)s.schema_version (version, success); | |
''' | |
self.is_baselined_sql = 'SELECT EXISTS(SELECT 1 FROM %(nspname)s.schema_version WHERE script = %(script)s);' | |
self.baselined_version_sql = 'SELECT version FROM %(nspname)s.schema_version WHERE script = %(script)s;' | |
self.insert_schema_version_sql = ''' | |
INSERT INTO %(nspname)s.schema_version (version, description, "type", script, checksum, installed_by, execution_time, success) | |
VALUES (%(version)s, %(description)s, %(type)s, %(script)s, %(checksum)s, %(installed_by)s, %(execution_time)s, %(success)s); | |
''' | |
self.drop_schema_sql = 'DROP SCHEMA IF EXISTS %(nspname)s CASCADE;' | |
self.db_no_more_connections_sql = "UPDATE pg_database SET datallowconn = 'false' WHERE datname = %(dbname)s;" | |
self.db_disconnect_sql = ''' | |
SELECT pg_terminate_backend(pid) | |
FROM pg_stat_activity | |
WHERE datname = %(dbname)s; | |
''' | |
self.db_drop_sql = 'DROP DATABASE IF EXISTS %(dbname)s' | |
self.db_create_sql = ''' | |
CREATE DATABASE %(dbname)s WITH OWNER=%(owner)s LC_COLLATE='C' LC_CTYPE='C' TEMPLATE='template0'; | |
''' | |
@property | |
def _dsn(self): | |
return make_dsn(host=self.args.host, | |
port=self.args.port, | |
dbname=self.args.dbname, | |
user=self.args.user, | |
password=self.args.password) | |
def _get_search_path(self): | |
with self.conn: | |
with self.conn.cursor() as curs: | |
curs.execute('SHOW search_path;') | |
search_path = curs.fetchone() | |
return search_path | |
def _set_search_path(self, schema): | |
search_path = self._get_search_path() | |
self.logger.debug('old search_path: {}'.format(search_path[0])) | |
with self.conn: | |
with self.conn.cursor() as curs: | |
curs.execute('SET search_path=%(schema)s,%(search_path)s', | |
{'schema': schema, 'search_path': AsIs(', '.join(search_path))}) | |
self.logger.debug('new search_path: {}'.format(self._get_search_path()[0])) | |
def _is_schema(self, schema): | |
with self.conn: | |
with self.conn.cursor() as curs: | |
curs.execute(self.schema_exists_sql, {'schema': schema}) | |
res = curs.fetchone() | |
is_schema_version = res[0] | |
self.logger.warning('schema exists: {}'.format(is_schema_version)) | |
return is_schema_version | |
def _is_schema_version(self, schema): | |
with self.conn: | |
with self.conn.cursor() as curs: | |
curs.execute(self.schema_version_exists_sql, {'schema': schema}) | |
res = curs.fetchone() | |
is_schema_version = res[0] | |
self.logger.warning('schema_version exists: {}'.format(is_schema_version)) | |
return is_schema_version | |
def _schema_baseline(self, schema, script): | |
with self.conn: | |
with self.conn.cursor() as curs: | |
curs.execute(self.baselined_version_sql, {'nspname': AsIs(schema), 'script': script}) | |
res = curs.fetchone() | |
if res is not None: | |
baselined_version = res[0] | |
self.logger.warning('Schema is baselined') | |
return baselined_version | |
else: | |
self.logger.warning('Schema is not baselined') | |
return 0 | |
def _get_repeatable_schema_migrations(self, schema): | |
with self.conn: | |
with self.conn.cursor() as curs: | |
curs.execute(self.get_repeatable_migrations_sql, {'nspname': AsIs(schema)}) | |
res = curs.fetchall() | |
ret = list(map(self.MigrationRecord._make, res)) | |
return ret | |
def _get_versioned_schema_migrations(self, schema, baseline_description): | |
with self.conn: | |
with self.conn.cursor() as curs: | |
curs.execute(self.get_versioned_migrations_sql, | |
{'nspname': AsIs(schema), 'description': baseline_description}) | |
res = curs.fetchall() | |
# noinspection PyTypeChecker | |
ret = list(map(self.MigrationRecord._make, [(r[0], version.parse(r[1]), r[2], r[3]) for r in res])) | |
return ret | |
def _get_current_migrations(self, migration_base_directory, schema, repeatable_prefix, versioned_prefix, | |
migration_separator, migration_suffix): | |
max_current_version = 0 | |
repeatable_migrations = [] | |
versioned_migrations = [] | |
repeatable_migrations = postway._get_repeatable_migrations(migration_base_directory, schema, repeatable_prefix, | |
migration_separator, migration_suffix) | |
versioned_migrations = postway._get_versioned_migrations(migration_base_directory, schema, versioned_prefix, | |
migration_separator, migration_suffix) | |
if versioned_migrations: | |
max_current_version = max(versioned_migrations, key=lambda mi: mi.version).version | |
return max_current_version, repeatable_migrations, versioned_migrations | |
def _get_schema_migrations(self, schema, baseline_description, baseline_version, user): | |
repeatable_schema_migrations = self._get_repeatable_schema_migrations(schema) | |
versioned_schema_migrations = self._get_versioned_schema_migrations(schema, baseline_description) | |
return repeatable_schema_migrations, versioned_schema_migrations | |
def _get_repeatable_migrations(self, migration_base_directory, schema, prefix, separator, suffix): | |
files = self._get_files(os.path.join(migration_base_directory, schema), | |
'{}*{}*{}'.format(prefix, separator, suffix)) | |
migration_files = [self._repeatable_file_parts(schema, f, separator, prefix, suffix) for f in files] | |
return sorted(migration_files, key=lambda fi: self._path_leaf(fi.script)) | |
def _get_versioned_migrations(self, migration_base_directory, schema, prefix, separator, suffix): | |
files = self._get_files(os.path.join(migration_base_directory, schema), | |
'{}*{}*{}'.format(prefix, separator, suffix)) | |
migration_files = [self._versioned_file_parts(f, separator, prefix, suffix) for f in files] | |
return sorted(migration_files, key=lambda fi: fi.version) | |
def _repeatable_file_parts(self, schema, file_path, separator, prefix, suffix): | |
description, file_parts = self._migration_file_parts(file_path, separator, suffix) | |
checksum = self._crc(file_path) | |
return self.MigrationRecord(file_path, None, description, checksum) | |
def _versioned_file_parts(self, file_path, separator, prefix, suffix): | |
description, file_parts = self._migration_file_parts(file_path, separator, suffix) | |
ver = version.parse(file_parts[0].replace(prefix, '').replace('_', '.')) | |
checksum = self._crc(file_path) | |
return self.MigrationRecord(file_path, ver, description, checksum) | |
def _migration_file_parts(self, file_path, separator, suffix): | |
file_name = self._path_leaf(file_path) | |
file_parts = file_name.split(separator) | |
description = file_parts[1].replace(suffix, '').replace('_', ' ') | |
return description, file_parts | |
def _create_schema_version(self, schema): | |
if self._is_schema_version(schema): | |
return True | |
else: | |
with self.conn: | |
with self.conn.cursor() as curs: | |
self.logger.warning('Creating schema_version in {}'.format(schema)) | |
curs.execute(self.create_schema_version_sql, {'nspname': AsIs(schema)}) | |
return True | |
def _execute_migration(self, migration, user, bindir, host, port, dbname, password): | |
penv = os.environ.copy() | |
penv['PGPASSWORD'] = password | |
psql = ['{}/psql'.format(bindir), '-h', host, '-p', '{}'.format(port), '-d', dbname, '-U', user] | |
with self.conn: | |
with self.conn.cursor() as curs: | |
start, end, success = self._execute_psql(penv, psql, migration.script) | |
if not success: | |
exit(1) | |
execution_time = end - start | |
curs.execute(self.insert_schema_version_sql, | |
{'nspname': AsIs(migration.schema), | |
'version': None if migration.version is None else '{}'.format(migration.version), | |
'description': migration.description, | |
'type': 'SQL', | |
'script': migration.script, | |
'checksum': migration.checksum, | |
'installed_by': user, | |
'execution_time': int(round(execution_time * 1000)), | |
'success': True}) | |
def _execute_psql(self, penv, psql, script): | |
success = True | |
error_regex = re.compile(r'psql\.bin:(.*):([\d]+):\sERROR:\s+(.*)') | |
warning_regex = re.compile(r'psql\.bin:(.*):([\d]+):\sNOTICE:\s+(.*)') | |
if self.args.verbose is not None and self.args.verbose >= 3: | |
psql.extend(['-e']) | |
psql.extend(['-f', script]) | |
self.logger.warning(subprocess.list2cmdline(psql)) | |
start = timer() | |
end, output = self._execute_process(psql, penv) | |
error_matches = error_regex.search(output) | |
if error_matches: | |
success = False | |
self.logger.error( | |
'Error found at {start}-{end}: {match}'.format(start=error_matches.start(), | |
end=error_matches.end(), | |
match=error_matches.group())) | |
warning_matches = warning_regex.search(output) | |
if warning_matches: | |
self.logger.warning( | |
'Warning found at {start}-{end}: {match}'.format(start=warning_matches.start(), | |
end=warning_matches.end(), | |
match=warning_matches.group())) | |
return start, end, success | |
def _execute_clean_db(self, user, host, port, dbname, password): | |
self._execute_dropdb(user, host, port, dbname, password) | |
self._execute_createdb(user, host, port, dbname, password) | |
def _execute_dropdb(self, user, host, port, dbname, password): | |
start = timer() | |
db_dsn = make_dsn(host=host, | |
port=port, | |
dbname='postgres', | |
user=user, | |
password=password) | |
with psycopg2.connect(db_dsn, connection_factory=LoggingConnection) as conn: | |
conn.initialize(self.logger) | |
with conn.cursor() as curs: | |
self.logger.warning('Setting database {} to accept no more connections.'.format(dbname)) | |
curs.execute(self.db_no_more_connections_sql, {'dbname': dbname}) | |
with conn.cursor() as curs: | |
self.logger.warning('Disconnecting database {}.'.format(dbname)) | |
curs.execute(self.db_disconnect_sql, {'dbname': dbname}) | |
self.logger.warning('Dropping database {}.'.format(dbname)) | |
connect = psycopg2.connect(db_dsn, connection_factory=LoggingConnection) | |
connect.initialize(self.logger) | |
connect.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) | |
cursor = connect.cursor() | |
cursor.execute(self.db_drop_sql, {'dbname': AsIs(dbname)}) | |
end = timer() | |
return start, end | |
def _execute_createdb(self, user, host, port, dbname, password): | |
start = timer() | |
db_dsn = make_dsn(host=host, | |
port=port, | |
dbname='postgres', | |
user=user, | |
password=password) | |
self.logger.warning('Creating database {}.'.format(dbname)) | |
connect = psycopg2.connect(db_dsn, connection_factory=LoggingConnection) | |
connect.initialize(self.logger) | |
connect.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) | |
cursor = connect.cursor() | |
cursor.execute(self.db_create_sql, {'dbname': AsIs(dbname), 'owner': user}) | |
end = timer() | |
return start, end | |
def _execute_process(self, process, penv): | |
try: | |
stdout = subprocess.check_output(process, stderr=subprocess.STDOUT, env=penv) | |
end = timer() | |
output = stdout.decode('utf-8') | |
self.logger.debug(output) | |
except subprocess.CalledProcessError as db_exc: | |
output = db_exc.output.decode('utf-8') | |
self.logger.error(output) | |
raise | |
return end, output | |
# noinspection PyTypeChecker | |
def _find_changed_migrations(self, expected, difference): | |
changed = [] | |
for exp in expected: | |
for dif in difference: | |
if exp.script == dif.script: | |
changed.append(self.ChangedMigrationRecord._make(list(exp) + [dif.checksum])) | |
return OrderedSet(changed) | |
@staticmethod | |
def _crc(file_path): | |
crc = 0 | |
with open(file_path, "rb") as f: | |
data = f.read() | |
crc = zlib.crc32(data) & 0xffffffff | |
return crc | |
@staticmethod | |
def _path_leaf(path): | |
head, tail = ntpath.split(path) | |
return tail or ntpath.basename(head) | |
@staticmethod | |
def _get_files(path, glob_pattern): | |
files = [y for x in os.walk(path) for y in | |
glob(os.path.join(x[0], glob_pattern))] | |
return files | |
def set_log_level_from_verbose(self): | |
if not self.args.verbose: | |
self.console_handler.setLevel('ERROR') | |
elif self.args.verbose == 1: | |
self.console_handler.setLevel('WARNING') | |
elif self.args.verbose == 2: | |
self.console_handler.setLevel('INFO') | |
elif self.args.verbose >= 3: | |
self.console_handler.setLevel('DEBUG') | |
else: | |
self.logger.critical('UNEXPLAINED NEGATIVE COUNT!') | |
return self | |
def parse_arguments(self): | |
""" | |
Parse command line arguments. | |
Sets self.args parameter for use throughout class/script. | |
""" | |
self.parser.add_argument('-m', '--migration-base-directory', | |
default=os.getcwd(), | |
help='The base directory in which to look for schema-named directories for migration ' | |
'scripts.') | |
self.parser.add_argument('command', | |
default=PostWayCommand.validate, | |
type=PostWayCommand.from_string, | |
choices=list(PostWayCommand), | |
help='The PostWay command to execute.') | |
self.parser.add_argument('-h', '--host', default='localhost', help='The EPAS host.') | |
self.parser.add_argument('--port', default=5444, help='The EPAS port.') | |
self.parser.add_argument('-U', '--user', required=True, help='The EPAS username.') | |
self.parser.add_argument('-W', '--password', required=True, help='The EPAS password.') | |
self.parser.add_argument('-d', '--dbname', required=True, help='The EPAS database.') | |
self.parser.add_argument('-s', '--schema', dest='schemas', action='append', | |
help='The EPAS schema(s). If none, postway assumes a full DB run, iterating ' | |
'over each *_schema.sql file in the "{MIGRATION_DIRECTORY}/public/schema" ' | |
'directory. Each use appends a schema to the list.') | |
self.parser.add_argument('-b', '--bindir', | |
required=True, | |
help='The path to the EPAS bin directory where psql is located. Required to actually ' | |
'run the migration scripts.') | |
self.parser.add_argument('-V', '--version', action='version', version='%(prog)s 1.0.0', | |
help='Print the version number of postway.') | |
self.parser.add_argument('-v', '--verbose', action='count', help='verbose level... repeat up to three times.') | |
self.parser.add_argument('--help', action='help', default=argparse.SUPPRESS, | |
help=argparse._('show this help message and exit')) | |
self.args = self.parser.parse_args() | |
return self | |
def get_schemas(self): | |
if not self.schema_records: | |
schemas = self.args.schemas | |
directory = self.args.migration_base_directory | |
suffix = self.migration_suffix | |
files = self._get_files(os.path.join(directory, 'public', 'schema'), '*_schema{}'.format(suffix)) | |
public = self.SchemaRecord('public', None, None, None, None, None, None, None, None, None) | |
if not schemas: | |
self.schema_records = [self.SchemaRecord._make( | |
[self._path_leaf(f).replace('_schema.sql', ''), f, None, None, None, None, None, None, None, None]) | |
for f in files] | |
self.schema_records.extend([public]) | |
self.is_db = True | |
else: | |
self.schema_records = [self.SchemaRecord._make( | |
[self._path_leaf(f).replace('_schema.sql', ''), f, None, None, None, None, None, None, None, None]) | |
for s in schemas for f in files if s in f] | |
if 'public' in schemas: | |
self.schema_records.extend([public]) | |
return self | |
def connect(self): | |
if not self.conn: | |
self.conn = psycopg2.connect(self._dsn, connection_factory=LoggingConnection) | |
self.conn.initialize(self.logger) | |
return self | |
def do_clean(self, user, bindir, host, port, dbname, password): | |
if self.is_db: | |
self._execute_clean_db(user, host, port, dbname, password) | |
self.connect() | |
for sr in self.schema_records: | |
if sr.name != 'public': | |
with self.conn: | |
with self.conn.cursor() as curs: | |
self.logger.warning('Dropping {} schema'.format(sr.name)) | |
curs.execute(self.drop_schema_sql, {'nspname': AsIs(sr.name)}) | |
self.logger.warning('Creating {} schema'.format(sr.name)) | |
penv = os.environ.copy() | |
penv['PGPASSWD'] = password | |
psql = ['{}/psql'.format(bindir), '-h', host, '-p', '{}'.format(port), '-d', dbname, '-U', user] | |
start, end, success = self._execute_psql(penv, psql, sr.script) | |
print('{} cleaned'.format(sr.name)) | |
if not success: | |
exit(1) | |
def do_baseline(self, baseline_version, baseline_description, user): | |
self.connect() | |
if not self.is_baselined: | |
schema_records = [] | |
for sr in self.schema_records: | |
schema_baseline_version = -1 | |
if self._is_schema(sr.name) and self._create_schema_version(sr.name): | |
schema_baseline_version = self._schema_baseline(sr.name, baseline_description) | |
if schema_baseline_version == 0: | |
with self.conn: | |
with self.conn.cursor() as curs: | |
self.logger.warning( | |
'Inserting baseline version for : {}'.format(sr.name, baseline_version)) | |
schema_baseline_version = baseline_version | |
curs.execute(self.insert_schema_version_sql, | |
{'nspname': AsIs(sr.name), | |
'version': baseline_version, | |
'description': baseline_description, | |
'type': 'SQL', | |
'script': baseline_description, | |
'checksum': 0, | |
'installed_by': user, | |
'execution_time': 0, | |
'success': True}) | |
schema_records.extend(self.SchemaRecord(sr.name, sr.script, baseline_version, sr.max_current_version, | |
sr.max_schema_version, | |
sr.repeatable_schema_migrations, sr.versioned_schema_migrations, | |
sr.diff_repeatable_migrations, sr.diff_versioned_migrations, | |
sr.changed_repeatable_migrations)) | |
print('Baseline schema_version for {}: {}'.format(sr.name, schema_baseline_version)) | |
if schema_records: | |
self.schema_records = schema_records | |
self.is_baselined = True | |
def get_max_schema_version(self, baseline_version, baseline_description, user): | |
self.connect() | |
if not self.is_max_schema_version: | |
schema_records = [] | |
for sr in self.schema_records: | |
max_schema_version = 0 | |
self.do_baseline(baseline_version, baseline_description, user) | |
with self.conn: | |
with self.conn.cursor() as curs: | |
curs.execute(self.max_schema_version_sql, {'nspname': AsIs(sr.name)}) | |
res = curs.fetchone() | |
max_schema_version = version.parse(res[0]) | |
schema_records.append(self.SchemaRecord(sr.name, sr.script, sr.baseline_version, sr.max_current_version, | |
max_schema_version, | |
sr.repeatable_schema_migrations, sr.versioned_schema_migrations, | |
sr.diff_repeatable_migrations, sr.diff_versioned_migrations, | |
sr.changed_repeatable_migrations)) | |
print('Max schema_version for {}: {}'.format(sr.name, max_schema_version)) | |
if schema_records: | |
self.schema_records = schema_records | |
self.is_max_schema_version = True | |
def migrate(self, baseline_version, baseline_description, user, bindir, host, port, dbname, password, | |
migration_base_directory, repeatable_prefix, versioned_prefix, migration_separator, | |
migration_suffix): | |
self.connect() | |
print('Beginning migrations.') | |
self.validate(baseline_version, baseline_description, user, migration_base_directory, repeatable_prefix, | |
versioned_prefix, migration_separator, migration_suffix) | |
diff_versioned_migrations = [] | |
diff_repeatable_migrations = [] | |
for sr in self.schema_records: | |
if sr.diff_versioned_migrations: | |
# noinspection PyTypeChecker | |
diff_versioned_migrations.extend( | |
map(self.ExecutingMigrationRecord._make, | |
[(sr.name, d_v.script, d_v.version, d_v.description, d_v.checksum) | |
for d_v in sr.diff_versioned_migrations])) | |
if sr.diff_repeatable_migrations: | |
# noinspection PyTypeChecker | |
diff_repeatable_migrations.extend( | |
map(self.ExecutingMigrationRecord._make, | |
[(sr.name, d_r.script, d_r.version, d_r.description, d_r.checksum) | |
for d_r in sr.diff_repeatable_migrations])) | |
versioned_migrations = sorted(diff_versioned_migrations, key=lambda v: v.version) | |
repeatable_migrations = sorted(diff_repeatable_migrations, key=lambda r: r.script) | |
versioned = 0 | |
repeatable = 0 | |
try: | |
for m in versioned_migrations: | |
self._execute_migration(m, user, bindir, host, port, dbname, password) | |
versioned += 1 | |
for m in repeatable_migrations: | |
self._execute_migration(m, user, bindir, host, port, dbname, password) | |
repeatable += 1 | |
finally: | |
print('Executed - Versioned: {}, Repeatable: {}'.format(versioned, repeatable)) | |
def validate(self, baseline_version, baseline_description, user, migration_base_directory, repeatable_prefix, | |
versioned_prefix, migration_separator, migration_suffix): | |
""" | |
Determine if the migration is valid by checking the following metadata about the current migration: | |
1. Max applied version: the max version in the schema_version table. | |
2. Max current version: the max version of the scripts. | |
3. The number of versioned scripts that need to be applied: | |
a. new versioned scripts. | |
b. throw an error if a versioned script has been changed (checksum is different). | |
4. The number of repeatable migration scripts that need to be applied: | |
a. new repeatable scripts. | |
b. old scripts that have changed (checksum is different). | |
5. The number of unchanged versioned scripts. | |
6. The number of unchanged repeatable scripts. | |
:param sr: | |
:type sr: | |
:param baseline_version: | |
:type baseline_version: | |
:param baseline_description: | |
:type baseline_description: | |
:param user: | |
:type user: | |
:param migration_base_directory: | |
:type migration_base_directory: | |
:param repeatable_prefix: | |
:type repeatable_prefix: | |
:param versioned_prefix: | |
:type versioned_prefix: | |
:param migration_separator: | |
:type migration_separator: | |
:param migration_suffix: | |
:type migration_suffix: | |
:return: | |
:rtype: | |
""" | |
self.connect() | |
print('Validating migrations.') | |
if not self.is_validated: | |
self.get_max_schema_version(baseline_version, baseline_description, user) | |
schema_records = [] | |
for sr in self.schema_records: | |
max_current_version, \ | |
repeatable_migrations, \ | |
versioned_migrations = self._get_current_migrations(migration_base_directory, sr.name, | |
repeatable_prefix, | |
versioned_prefix, migration_separator, | |
migration_suffix) | |
repeatable_schema_migrations, \ | |
versioned_schema_migrations = self._get_schema_migrations(sr.name, baseline_description, | |
baseline_version, user) | |
repeatable_expected = OrderedSet(repeatable_schema_migrations) | |
repeatable_found = OrderedSet(repeatable_migrations) | |
versioned_expected = OrderedSet(versioned_schema_migrations) | |
versioned_found = OrderedSet(versioned_migrations) | |
diff_repeatable_migrations = OrderedSet( | |
sorted(repeatable_found - repeatable_expected, key=lambda r: r.script)) | |
diff_versioned_migrations = OrderedSet( | |
sorted(versioned_found - versioned_expected, key=lambda v: v.version)) | |
changed_repeatable_migrations = self._find_changed_migrations(repeatable_expected, | |
diff_repeatable_migrations) | |
changed_versioned_migrations = self._find_changed_migrations(versioned_expected, | |
diff_versioned_migrations) | |
schema_records.append(self.SchemaRecord(sr.name, sr.script, sr.baseline_version, max_current_version, | |
sr.max_schema_version, repeatable_schema_migrations, | |
versioned_schema_migrations, diff_repeatable_migrations, | |
diff_versioned_migrations, changed_repeatable_migrations)) | |
if bool(changed_versioned_migrations): | |
self.logger.error('Applied versioned migrations changed for {} - THIS SHOULD NEVER HAPPEN!!!') | |
self.logger.error( | |
'YOU MUST REVERT THE CHANGES TO THESE SCRIPTS, AND CREATE NEW VERSIONED MIGRATIONS.') | |
self.logger.error(sr.name, changed_versioned_migrations) | |
exit(1) | |
print('{} max applied: {}, current: {} version.' | |
.format(sr.name, sr.max_schema_version, max_current_version)) | |
print('versioned scripts ready to run for {}: {}' | |
.format(sr.name, len(diff_versioned_migrations))) | |
print('repeatable scripts ready to run for {}: {} ({} changed)' | |
.format(sr.name, len(diff_repeatable_migrations), len(changed_repeatable_migrations))) | |
self.logger.warning('{} versioned: {}'.format(sr.name, diff_versioned_migrations)) | |
self.logger.warning('{} repeatable: {}'.format(sr.name, diff_repeatable_migrations)) | |
self.logger.warning('{} changed repeatable: {}'.format(sr.name, changed_repeatable_migrations)) | |
if schema_records: | |
self.schema_records = schema_records | |
self.is_validated = True | |
def info(self, baseline_version, baseline_description, user, migration_base_directory, repeatable_prefix, | |
versioned_prefix, migration_separator, migration_suffix): | |
self.connect() | |
self.validate(baseline_version, baseline_description, user, migration_base_directory, repeatable_prefix, | |
versioned_prefix, migration_separator, migration_suffix) | |
versioned_schema_migrations_count = 0 | |
diff_versioned_migrations_count = 0 | |
repeatable_schema_migrations_count = 0 | |
changed_repeatable_migrations_count = 0 | |
diff_repeatable_migrations_count = 0 | |
for sr in self.schema_records: | |
versioned_schema_migrations_count += len(sr.versioned_schema_migrations) | |
print('{} VERSIONED MIGRATIONS - APPLIED: {}'.format(sr.name, len(sr.versioned_schema_migrations))) | |
for v_s in sr.versioned_schema_migrations: | |
self.logger.info(v_s) | |
diff_versioned_migrations_count += len(sr.diff_versioned_migrations) | |
print('{} VERSIONED MIGRATIONS - PENDING: {}'.format(sr.name, len(sr.diff_versioned_migrations))) | |
for d_v in sr.diff_versioned_migrations: | |
self.logger.info(d_v) | |
repeatable_schema_migrations_count += len(sr.repeatable_schema_migrations) | |
print('{} REPEATABLE MIGRATIONS - APPLIED: {}'.format(sr.name, len(sr.repeatable_schema_migrations))) | |
for r_s in sr.repeatable_schema_migrations: | |
self.logger.info(r_s) | |
changed_repeatable_migrations_count += len(sr.changed_repeatable_migrations) | |
print( | |
'{} REPEATABLE MIGRATIONS - CHANGED/PENDING: {}'.format(sr.name, len(sr.changed_repeatable_migrations))) | |
for c_r in sr.changed_repeatable_migrations: | |
self.logger.info(c_r) | |
diff_repeatable_migrations_count += len(sr.diff_repeatable_migrations) | |
print('{} REPEATABLE MIGRATIONS - PENDING: {}'.format(sr.name, len(sr.diff_repeatable_migrations))) | |
for d_r in sr.diff_repeatable_migrations: | |
self.logger.info(d_r) | |
if self.is_db: | |
print('TOTAL VERSIONED MIGRATIONS - APPLIED: {}'.format(versioned_schema_migrations_count)) | |
print('TOTAL VERSIONED MIGRATIONS - PENDING: {}'.format(diff_versioned_migrations_count)) | |
print('TOTAL REPEATABLE MIGRATIONS - APPLIED: {}'.format(repeatable_schema_migrations_count)) | |
print('TOTAL REPEATABLE MIGRATIONS - CHANGED/PENDING: {}'.format(changed_repeatable_migrations_count)) | |
print('TOTAL REPEATABLE MIGRATIONS - PENDING: {}'.format(diff_repeatable_migrations_count)) | |
if __name__ == '__main__': | |
postway = PostWay().parse_arguments().set_log_level_from_verbose().get_schemas() | |
postway.logger.debug('args: {}'.format(postway.args)) | |
m_dir = postway.args.migration_base_directory | |
b_dir = postway.args.bindir | |
u = postway.args.user | |
pwd = postway.args.password | |
h = postway.args.host | |
p = postway.args.port | |
d = postway.args.dbname | |
v_prefix = postway.versioned_prefix | |
r_prefix = postway.repeatable_prefix | |
sep = postway.migration_separator | |
suf = postway.migration_suffix | |
bs_ver = postway.baseline_version | |
bs_desc = postway.baseline_description | |
c = postway.args.command | |
if c == PostWayCommand.clean: | |
postway.do_clean(u, b_dir, h, p, d, pwd) | |
elif c == PostWayCommand.baseline: | |
postway.do_baseline(bs_ver, bs_desc, u) | |
elif c == PostWayCommand.version: | |
postway.get_max_schema_version(bs_ver, bs_desc, u) | |
elif c == PostWayCommand.migrate: | |
postway.migrate(bs_ver, bs_desc, u, b_dir, h, p, d, pwd, m_dir, r_prefix, v_prefix, sep, suf) | |
elif c == PostWayCommand.validate: | |
postway.validate(bs_ver, bs_desc, u, m_dir, r_prefix, v_prefix, sep, suf) | |
elif c == PostWayCommand.info: | |
postway.info(bs_ver, bs_desc, u, m_dir, r_prefix, v_prefix, sep, suf) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment