Skip to content

Instantly share code, notes, and snippets.

@corporatepiyush
Last active August 8, 2024 09:12
Show Gist options
  • Save corporatepiyush/6a6909c4a4b3fee29edfbef41be6b8b2 to your computer and use it in GitHub Desktop.
Save corporatepiyush/6a6909c4a4b3fee29edfbef41be6b8b2 to your computer and use it in GitHub Desktop.
Migrate schema and data from mysql/mariadb to other mysql/mariadb server
import pymysql
import traceback
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import contextmanager
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
BATCH_SIZE = 1000
MAX_RETRIES = 3
# Define a top-level ThreadPoolExecutor
executor = ThreadPoolExecutor()
@contextmanager
def get_connection(config):
"""Context manager for database connections."""
connection = pymysql.connect(**config)
try:
yield connection
finally:
connection.close()
def execute_query(cursor, query, params=None):
"""Execute a query with error handling and retries."""
for attempt in range(MAX_RETRIES):
try:
if params:
cursor.execute(query, params)
else:
cursor.execute(query)
return cursor.fetchall()
except pymysql.Error as e:
logger.error(f"Error executing query (attempt {attempt + 1}/{MAX_RETRIES}): {e}")
if attempt == MAX_RETRIES - 1:
raise
def get_foreign_key_dependencies(cursor, database_name):
query = """
SELECT TABLE_NAME, COLUMN_NAME, REFERENCED_TABLE_NAME, REFERENCED_COLUMN_NAME
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
WHERE REFERENCED_TABLE_SCHEMA = %s AND REFERENCED_TABLE_NAME IS NOT NULL;
"""
return execute_query(cursor, query, (database_name,))
def get_insertion_order(dependencies):
from collections import defaultdict, deque
graph = defaultdict(list)
in_degree = defaultdict(int)
for table, _, ref_table, _ in dependencies:
graph[ref_table].append(table)
in_degree[table] += 1
order = [table for table in graph if in_degree[table] == 0]
queue = deque(order)
while queue:
current = queue.popleft()
for dependent in graph[current]:
in_degree[dependent] -= 1
if in_degree[dependent] == 0:
queue.append(dependent)
order.append(dependent)
return order
def read_data_in_batches(cursor, query, batch_size=BATCH_SIZE):
try:
cursor.execute(query)
while True:
batch = cursor.fetchmany(batch_size)
if not batch:
break
yield batch
except Exception as e:
logger.error(f"Error reading data in batches: {e}")
traceback.print_exc()
def create_temporary_table(cursor):
query = """
CREATE TABLE IF NOT EXISTS `_row_count_temp` (
`table_name` VARCHAR(255) PRIMARY KEY,
`row_count` INT DEFAULT 0
)
"""
execute_query(cursor, query)
def initialize_row_count(cursor, table_name):
query = """
INSERT INTO `_row_count_temp` (`table_name`, `row_count`)
VALUES (%s, 0)
ON DUPLICATE KEY UPDATE `row_count` = 0
"""
execute_query(cursor, query, (table_name,))
def update_row_count(cursor, table_name, count):
query = """
UPDATE `_row_count_temp`
SET `row_count` = `row_count` + %s
WHERE `table_name` = %s
"""
execute_query(cursor, query, (count, table_name))
def create_tables_and_constraints(remote_config, local_config, database_name):
with get_connection(remote_config) as remote_conn, get_connection(local_config) as local_conn:
with remote_conn.cursor() as remote_cursor, local_conn.cursor() as local_cursor:
try:
execute_query(remote_cursor, f"SHOW TABLES FROM `{database_name}`")
tables = [table[0] for table in remote_cursor.fetchall()]
dependencies = get_foreign_key_dependencies(remote_cursor, database_name)
insertion_order = get_insertion_order(dependencies)
execute_query(local_cursor, f"CREATE DATABASE IF NOT EXISTS `{database_name}`")
execute_query(local_cursor, f"USE `{database_name}`")
create_temporary_table(local_cursor)
local_conn.commit()
for table_name in insertion_order:
logger.info(f"Creating table: {table_name}")
execute_query(local_cursor, f"SHOW TABLES LIKE '{table_name}'")
table_exists = local_cursor.fetchone()
if table_exists:
logger.info(f"Dropping existing table: {table_name}")
execute_query(local_cursor, f"DROP TABLE `{table_name}`")
local_conn.commit()
create_table_query = execute_query(remote_cursor, f"SHOW CREATE TABLE `{database_name}`.`{table_name}`")[0][1]
execute_query(local_cursor, create_table_query)
local_conn.commit()
recreate_indexes(local_cursor, database_name, table_name, remote_cursor)
initialize_row_count(local_cursor, table_name)
local_conn.commit()
logger.info(f"Initialized row count for table: {table_name}")
logger.info("All tables created successfully.")
except Exception as e:
logger.error(f"An error occurred while creating tables: {e}")
traceback.print_exc()
raise
def recreate_indexes(local_cursor, database_name, table_name, remote_cursor):
try:
logger.info(f"Recreating indexes for {table_name} on local server.")
indexes = execute_query(remote_cursor, f"SHOW INDEX FROM `{database_name}`.`{table_name}`")
index_queries = {}
primary_key_columns = []
primary_key_exists = False
for index in indexes:
index_name = index[2]
column_name = index[4]
non_unique = index[1]
if index_name == 'PRIMARY':
primary_key_columns.append(column_name)
primary_key_exists = True
else:
index_type = 'UNIQUE' if non_unique == 0 else 'INDEX'
if index_name not in index_queries:
index_queries[index_name] = f"CREATE {index_type} INDEX `{index_name}` ON `{table_name}` (`{column_name}`)"
else:
index_queries[index_name] = index_queries[index_name][:-1] + f", `{column_name}`)"
if primary_key_columns and not primary_key_exists:
primary_key_query = f"ALTER TABLE `{table_name}` ADD PRIMARY KEY ({', '.join(f'`{col}`' for col in primary_key_columns)})"
execute_query(local_cursor, primary_key_query)
for query in index_queries.values():
execute_query(local_cursor, query)
except Exception as e:
logger.error(f"An error occurred while recreating indexes: {e}")
traceback.print_exc()
raise
def insert_batch(local_config, database_name, table_name, insert_query, rows):
with get_connection(local_config) as local_conn:
with local_conn.cursor() as local_cursor:
try:
execute_query(local_cursor, f"USE `{database_name}`")
local_cursor.executemany(insert_query, rows)
count = len(rows)
update_row_count(local_cursor, table_name, count)
local_conn.commit()
logger.info(f"Successfully inserted batch of {count} rows into {table_name}")
return count
except pymysql.err.IntegrityError as e:
logger.error(f"Error inserting batch into {table_name}: {e}")
local_conn.rollback()
return 0
except Exception as e:
logger.error(f"Unexpected error inserting batch into {table_name}: {e}")
traceback.print_exc()
local_conn.rollback()
return 0
def insert_table_data(remote_config, local_config, database_name, table_name):
with get_connection(remote_config) as remote_conn:
with remote_conn.cursor() as remote_cursor:
try:
columns = [column[0] for column in execute_query(remote_cursor, f"SHOW COLUMNS FROM `{database_name}`.`{table_name}`")]
columns_list = ", ".join([f"`{col}`" for col in columns])
placeholders = ", ".join(["%s"] * len(columns))
insert_query = f"INSERT INTO `{table_name}` ({columns_list}) VALUES ({placeholders})"
futures = []
for rows in read_data_in_batches(remote_cursor, f"SELECT * FROM `{database_name}`.`{table_name}`"):
futures.append(executor.submit(insert_batch, local_config, database_name, table_name, insert_query, rows))
total_inserted = sum(future.result() for future in as_completed(futures))
logger.info(f"Total rows inserted into {table_name}: {total_inserted}")
except Exception as e:
logger.error(f"An error occurred while inserting data into {table_name}: {e}")
traceback.print_exc()
def copy_database(remote_config, local_config, database_name):
create_tables_and_constraints(remote_config, local_config, database_name)
with get_connection(remote_config) as remote_conn:
with remote_conn.cursor() as remote_cursor:
tables = [table[0] for table in execute_query(remote_cursor, f"SELECT TABLE_NAME FROM information_schema.tables WHERE table_schema = '{database_name}' AND TABLE_TYPE = 'BASE TABLE'")]
for table_name in tables:
logger.info(f"Inserting data into {table_name}")
insert_table_data(remote_config, local_config, database_name, table_name)
def main():
remote_config = {
'host': '<host_ip>',
'user': '<user>',
'password': '<password>',
'database': '<database>',
'charset': 'utf8mb4',
'cursorclass': pymysql.cursors.Cursor,
'port': 2546,
}
local_config = {
'host': '<host_ip>',
'user': '<user>',
'password': '<password>',
'charset': 'utf8mb4',
'cursorclass': pymysql.cursors.Cursor,
}
with get_connection(remote_config) as remote_conn:
with remote_conn.cursor() as cursor:
databases = execute_query(cursor, "SHOW DATABASES")
for (database_name,) in databases:
if database_name not in ['information_schema', 'mysql', 'performance_schema', 'sys']:
logger.info(f"Copying database: {database_name}")
copy_database(remote_config, local_config, database_name)
# Print total inserted rows per table
with get_connection(local_config) as local_conn:
with local_conn.cursor() as local_cursor:
results = execute_query(local_cursor, "SELECT `table_name`, `row_count` FROM `_row_count_temp`")
for table_name, row_count in results:
logger.info(f"Total rows inserted into {table_name}: {row_count}")
if __name__ == "__main__":
try:
main()
except Exception as e:
logger.error(f"An unexpected error occurred: {e}")
traceback.print_exc()
finally:
# Shutdown the ThreadPoolExecutor after use
executor.shutdown(wait=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment