Last active
August 8, 2024 09:12
-
-
Save corporatepiyush/6a6909c4a4b3fee29edfbef41be6b8b2 to your computer and use it in GitHub Desktop.
Migrate schema and data from mysql/mariadb to other mysql/mariadb server
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pymysql | |
import traceback | |
import logging | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
from contextlib import contextmanager | |
# Set up logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
BATCH_SIZE = 1000 | |
MAX_RETRIES = 3 | |
# Define a top-level ThreadPoolExecutor | |
executor = ThreadPoolExecutor() | |
@contextmanager | |
def get_connection(config): | |
"""Context manager for database connections.""" | |
connection = pymysql.connect(**config) | |
try: | |
yield connection | |
finally: | |
connection.close() | |
def execute_query(cursor, query, params=None): | |
"""Execute a query with error handling and retries.""" | |
for attempt in range(MAX_RETRIES): | |
try: | |
if params: | |
cursor.execute(query, params) | |
else: | |
cursor.execute(query) | |
return cursor.fetchall() | |
except pymysql.Error as e: | |
logger.error(f"Error executing query (attempt {attempt + 1}/{MAX_RETRIES}): {e}") | |
if attempt == MAX_RETRIES - 1: | |
raise | |
def get_foreign_key_dependencies(cursor, database_name): | |
query = """ | |
SELECT TABLE_NAME, COLUMN_NAME, REFERENCED_TABLE_NAME, REFERENCED_COLUMN_NAME | |
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE | |
WHERE REFERENCED_TABLE_SCHEMA = %s AND REFERENCED_TABLE_NAME IS NOT NULL; | |
""" | |
return execute_query(cursor, query, (database_name,)) | |
def get_insertion_order(dependencies): | |
from collections import defaultdict, deque | |
graph = defaultdict(list) | |
in_degree = defaultdict(int) | |
for table, _, ref_table, _ in dependencies: | |
graph[ref_table].append(table) | |
in_degree[table] += 1 | |
order = [table for table in graph if in_degree[table] == 0] | |
queue = deque(order) | |
while queue: | |
current = queue.popleft() | |
for dependent in graph[current]: | |
in_degree[dependent] -= 1 | |
if in_degree[dependent] == 0: | |
queue.append(dependent) | |
order.append(dependent) | |
return order | |
def read_data_in_batches(cursor, query, batch_size=BATCH_SIZE): | |
try: | |
cursor.execute(query) | |
while True: | |
batch = cursor.fetchmany(batch_size) | |
if not batch: | |
break | |
yield batch | |
except Exception as e: | |
logger.error(f"Error reading data in batches: {e}") | |
traceback.print_exc() | |
def create_temporary_table(cursor): | |
query = """ | |
CREATE TABLE IF NOT EXISTS `_row_count_temp` ( | |
`table_name` VARCHAR(255) PRIMARY KEY, | |
`row_count` INT DEFAULT 0 | |
) | |
""" | |
execute_query(cursor, query) | |
def initialize_row_count(cursor, table_name): | |
query = """ | |
INSERT INTO `_row_count_temp` (`table_name`, `row_count`) | |
VALUES (%s, 0) | |
ON DUPLICATE KEY UPDATE `row_count` = 0 | |
""" | |
execute_query(cursor, query, (table_name,)) | |
def update_row_count(cursor, table_name, count): | |
query = """ | |
UPDATE `_row_count_temp` | |
SET `row_count` = `row_count` + %s | |
WHERE `table_name` = %s | |
""" | |
execute_query(cursor, query, (count, table_name)) | |
def create_tables_and_constraints(remote_config, local_config, database_name): | |
with get_connection(remote_config) as remote_conn, get_connection(local_config) as local_conn: | |
with remote_conn.cursor() as remote_cursor, local_conn.cursor() as local_cursor: | |
try: | |
execute_query(remote_cursor, f"SHOW TABLES FROM `{database_name}`") | |
tables = [table[0] for table in remote_cursor.fetchall()] | |
dependencies = get_foreign_key_dependencies(remote_cursor, database_name) | |
insertion_order = get_insertion_order(dependencies) | |
execute_query(local_cursor, f"CREATE DATABASE IF NOT EXISTS `{database_name}`") | |
execute_query(local_cursor, f"USE `{database_name}`") | |
create_temporary_table(local_cursor) | |
local_conn.commit() | |
for table_name in insertion_order: | |
logger.info(f"Creating table: {table_name}") | |
execute_query(local_cursor, f"SHOW TABLES LIKE '{table_name}'") | |
table_exists = local_cursor.fetchone() | |
if table_exists: | |
logger.info(f"Dropping existing table: {table_name}") | |
execute_query(local_cursor, f"DROP TABLE `{table_name}`") | |
local_conn.commit() | |
create_table_query = execute_query(remote_cursor, f"SHOW CREATE TABLE `{database_name}`.`{table_name}`")[0][1] | |
execute_query(local_cursor, create_table_query) | |
local_conn.commit() | |
recreate_indexes(local_cursor, database_name, table_name, remote_cursor) | |
initialize_row_count(local_cursor, table_name) | |
local_conn.commit() | |
logger.info(f"Initialized row count for table: {table_name}") | |
logger.info("All tables created successfully.") | |
except Exception as e: | |
logger.error(f"An error occurred while creating tables: {e}") | |
traceback.print_exc() | |
raise | |
def recreate_indexes(local_cursor, database_name, table_name, remote_cursor): | |
try: | |
logger.info(f"Recreating indexes for {table_name} on local server.") | |
indexes = execute_query(remote_cursor, f"SHOW INDEX FROM `{database_name}`.`{table_name}`") | |
index_queries = {} | |
primary_key_columns = [] | |
primary_key_exists = False | |
for index in indexes: | |
index_name = index[2] | |
column_name = index[4] | |
non_unique = index[1] | |
if index_name == 'PRIMARY': | |
primary_key_columns.append(column_name) | |
primary_key_exists = True | |
else: | |
index_type = 'UNIQUE' if non_unique == 0 else 'INDEX' | |
if index_name not in index_queries: | |
index_queries[index_name] = f"CREATE {index_type} INDEX `{index_name}` ON `{table_name}` (`{column_name}`)" | |
else: | |
index_queries[index_name] = index_queries[index_name][:-1] + f", `{column_name}`)" | |
if primary_key_columns and not primary_key_exists: | |
primary_key_query = f"ALTER TABLE `{table_name}` ADD PRIMARY KEY ({', '.join(f'`{col}`' for col in primary_key_columns)})" | |
execute_query(local_cursor, primary_key_query) | |
for query in index_queries.values(): | |
execute_query(local_cursor, query) | |
except Exception as e: | |
logger.error(f"An error occurred while recreating indexes: {e}") | |
traceback.print_exc() | |
raise | |
def insert_batch(local_config, database_name, table_name, insert_query, rows): | |
with get_connection(local_config) as local_conn: | |
with local_conn.cursor() as local_cursor: | |
try: | |
execute_query(local_cursor, f"USE `{database_name}`") | |
local_cursor.executemany(insert_query, rows) | |
count = len(rows) | |
update_row_count(local_cursor, table_name, count) | |
local_conn.commit() | |
logger.info(f"Successfully inserted batch of {count} rows into {table_name}") | |
return count | |
except pymysql.err.IntegrityError as e: | |
logger.error(f"Error inserting batch into {table_name}: {e}") | |
local_conn.rollback() | |
return 0 | |
except Exception as e: | |
logger.error(f"Unexpected error inserting batch into {table_name}: {e}") | |
traceback.print_exc() | |
local_conn.rollback() | |
return 0 | |
def insert_table_data(remote_config, local_config, database_name, table_name): | |
with get_connection(remote_config) as remote_conn: | |
with remote_conn.cursor() as remote_cursor: | |
try: | |
columns = [column[0] for column in execute_query(remote_cursor, f"SHOW COLUMNS FROM `{database_name}`.`{table_name}`")] | |
columns_list = ", ".join([f"`{col}`" for col in columns]) | |
placeholders = ", ".join(["%s"] * len(columns)) | |
insert_query = f"INSERT INTO `{table_name}` ({columns_list}) VALUES ({placeholders})" | |
futures = [] | |
for rows in read_data_in_batches(remote_cursor, f"SELECT * FROM `{database_name}`.`{table_name}`"): | |
futures.append(executor.submit(insert_batch, local_config, database_name, table_name, insert_query, rows)) | |
total_inserted = sum(future.result() for future in as_completed(futures)) | |
logger.info(f"Total rows inserted into {table_name}: {total_inserted}") | |
except Exception as e: | |
logger.error(f"An error occurred while inserting data into {table_name}: {e}") | |
traceback.print_exc() | |
def copy_database(remote_config, local_config, database_name): | |
create_tables_and_constraints(remote_config, local_config, database_name) | |
with get_connection(remote_config) as remote_conn: | |
with remote_conn.cursor() as remote_cursor: | |
tables = [table[0] for table in execute_query(remote_cursor, f"SELECT TABLE_NAME FROM information_schema.tables WHERE table_schema = '{database_name}' AND TABLE_TYPE = 'BASE TABLE'")] | |
for table_name in tables: | |
logger.info(f"Inserting data into {table_name}") | |
insert_table_data(remote_config, local_config, database_name, table_name) | |
def main(): | |
remote_config = { | |
'host': '<host_ip>', | |
'user': '<user>', | |
'password': '<password>', | |
'database': '<database>', | |
'charset': 'utf8mb4', | |
'cursorclass': pymysql.cursors.Cursor, | |
'port': 2546, | |
} | |
local_config = { | |
'host': '<host_ip>', | |
'user': '<user>', | |
'password': '<password>', | |
'charset': 'utf8mb4', | |
'cursorclass': pymysql.cursors.Cursor, | |
} | |
with get_connection(remote_config) as remote_conn: | |
with remote_conn.cursor() as cursor: | |
databases = execute_query(cursor, "SHOW DATABASES") | |
for (database_name,) in databases: | |
if database_name not in ['information_schema', 'mysql', 'performance_schema', 'sys']: | |
logger.info(f"Copying database: {database_name}") | |
copy_database(remote_config, local_config, database_name) | |
# Print total inserted rows per table | |
with get_connection(local_config) as local_conn: | |
with local_conn.cursor() as local_cursor: | |
results = execute_query(local_cursor, "SELECT `table_name`, `row_count` FROM `_row_count_temp`") | |
for table_name, row_count in results: | |
logger.info(f"Total rows inserted into {table_name}: {row_count}") | |
if __name__ == "__main__": | |
try: | |
main() | |
except Exception as e: | |
logger.error(f"An unexpected error occurred: {e}") | |
traceback.print_exc() | |
finally: | |
# Shutdown the ThreadPoolExecutor after use | |
executor.shutdown(wait=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment