Skip to content

Instantly share code, notes, and snippets.

@tombasche
Created September 6, 2019 02:22
Show Gist options
  • Save tombasche/102d2053f15395757726f9c8567cb4c4 to your computer and use it in GitHub Desktop.
Save tombasche/102d2053f15395757726f9c8567cb4c4 to your computer and use it in GitHub Desktop.
Given a source and destination database, restore any records that may have occurred concurrently with a pg_dump
"""
Usage:
python sync_db.py
"""
import os
import sys
import json
import psycopg2
TABLE_MAX_ID_MAP = {}
def db_load(conn, file):
with open(file) as f:
loaded_data = json.load(f)
table_name = file.split('-')[0]
for id, obj in loaded_data.iteritems():
columns = ",".join(obj.keys())
values_list = ",".join(["%s" for _ in obj.values()])
query = "INSERT INTO {table_name} ({columns}) VALUES ({values})".format(
table_name=table_name,
columns=columns,
values=values_list
)
cur = conn.cursor()
try:
cur.execute(query, obj.values())
except Exception as ex:
# could be a conflict
print("Error inserting {} into {}: {}".format(obj, table_name, ex))
print("Inserted {} records into {}".format(len(loaded_data), table_name))
def set_max_ids_for_tables(source_conn, dest_conn, tables):
for table in tables:
query = "SELECT max(id) from {}".format(table)
cur = dest_conn.cursor()
try:
cur.execute(query)
except Exception:
cur.execute("ROLLBACK")
# table doesn't exist
continue
dest_count = cur.fetchone()[0]
if dest_count is None:
continue
cur = source_conn.cursor()
try:
cur.execute(query)
except Exception:
cur.execute("ROLLBACK")
# tables don't match between dbs
continue
source_count = cur.fetchone()[0]
# don't try updating if it's already up to date
if dest_count == source_count:
print("Skipping {} as it's already up to date at id: {}".format(table, source_count))
continue
TABLE_MAX_ID_MAP[table] = dest_count
if not TABLE_MAX_ID_MAP:
print("nothing to do: everything's up to date!")
return
print("Set max ids to {}".format(TABLE_MAX_ID_MAP))
def get_tables(source_conn):
query = "SELECT * FROM pg_catalog.pg_tables"
cur = source_conn.cursor()
cur.execute(query)
tables = cur.fetchall()
return [t[1] for t in tables if not t[1].startswith(('pg', 'sql'))]
def db_backup(conn, tables):
filenames = []
for table_name in tables:
if not TABLE_MAX_ID_MAP.get(table_name):
continue
table_json = "{}-from {}.json".format(table_name, TABLE_MAX_ID_MAP[table_name])
base_query = "FROM {table} WHERE id > %s".format(table=table_name)
vars_ = (TABLE_MAX_ID_MAP[table_name], )
query = "SELECT * {0}".format(base_query)
cur = conn.cursor()
cur.execute(query, vars=vars_)
if cur.rowcount < 0:
print("No results for {}".format(table_json))
cur.execute("ROLLBACK")
continue
results = cur.fetchall()
json_dict = {}
columns = [desc[0] for desc in cur.description]
for result in results:
json_dict[result[0]] = {}
for i, _ in enumerate(result):
json_dict[result[0]].update({columns[i]: result[i]})
with open(table_json, 'w+') as f:
f.write(json.dumps(json_dict))
print("Wrote {} records to to {}".format(len(json_dict), table_json))
filenames.append(table_json)
return filenames
def main():
source_conn = psycopg2.connect(os.environ["SOURCE_DATABASE"])
tables_to_sync = get_tables(source_conn)
dest_conn = psycopg2.connect(os.environ["DEST_DATABASE"])
set_max_ids_for_tables(source_conn, dest_conn, tables_to_sync)
dump_files = db_backup(source_conn, tables_to_sync)
for file in dump_files:
db_load(dest_conn, file)
dest_conn.commit()
source_conn.close()
dest_conn.close()
if __name__ == "__main__":
assert os.environ['SOURCE_DATABASE']
assert os.environ['DEST_DATABASE']
ret = main()
sys.exit(ret)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment