Last active
August 24, 2023 09:30
-
-
Save zhongxiao37/b5949fa1927beaafc8db45e371e26de3 to your computer and use it in GitHub Desktop.
Parse the SQL files to draw the table flows with DOT
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
import os | |
from collections import defaultdict | |
import json | |
import pydot | |
def parse_relationships_from_sql_files(sql_folder_path): | |
full_relationships = [] | |
for root, _, files in os.walk(sql_folder_path): | |
for file_name in files: | |
if file_name.endswith('.sql'): | |
full_relationships += parse_relationships_from_sql_file(sql_file_path=f"{root}/{file_name}") | |
return full_relationships | |
def parse_relationships_from_sql_file(sql_file_path): | |
""" | |
return: | |
[ | |
{'staging': {'dim_customers': [('app', 'account')]}}, | |
{'data': {'dim_customers': [('staging', 'dim_customers'), | |
('data', 'dim_customers')]}}, | |
{'data': {'dim_customers': [('staging', 'dim_customers')]}} | |
] | |
""" | |
sql = '' | |
sql_raw_relationships = [] | |
with open(sql_file_path) as file: | |
sql = file.read() | |
sql_statements = parse_sql_statements(sql) | |
for statement in sql_statements: | |
relationship = parse_insert_statement(statement) | |
sql_raw_relationships.append(relationship) if relationship is not None else None | |
return sql_raw_relationships | |
def parse_sql_statements(sql): | |
sql_match = re.search(r"CREATE.+FUNCTION[\s\S]*(?<=BEGIN)([\s\S]*)(?=END)", sql) | |
sql_statements = [x.strip() for x in sql_match.group(1).split(';')] if sql_match else [] | |
return sql_statements | |
def parse_insert_statement(sql): | |
""" | |
return: | |
{'data': {'dim_customers': [('staging', 'dim_customers')]}} | |
""" | |
# Extract target table and columns | |
target_table_match = re.search(r"INSERT INTO\s*(\w+\.\w+)\s*\(([^)]*)\)", sql, re.IGNORECASE) | |
target_table = None | |
target_columns = None | |
if target_table_match: | |
target_table = target_table_match.group(1) | |
target_columns = [col.strip() for col in target_table_match.group(2).split(',')] | |
else: | |
return None | |
# Extract source table and columns | |
source_table_match = re.search(r"FROM\s+(\w+\.\w+)\s(\w+)", sql, re.IGNORECASE) | |
source_tables = list() | |
if source_table_match: | |
source_tables.append((source_table_match.group(1), source_table_match.group(2))) | |
join_tables = re.findall(r"JOIN\s+(\w+\.\w+)\s(\w+)", sql) | |
if join_tables: | |
source_tables += join_tables | |
source_columns_match = re.search(r"SELECT\s+(.+)\s+FROM", sql, flags=re.DOTALL | re.IGNORECASE) | |
source_columns_str = source_columns_match.group(1).strip() if source_columns_match else "" | |
source_columns = [col.strip() for col in source_columns_str.split(',')] | |
if target_columns and source_columns: | |
pass | |
schema, table = target_table.split('.') | |
relationship = { | |
schema: { | |
table: [tuple(tbl.split('.')) for tbl, _ in source_tables] | |
} | |
} | |
return relationship | |
def merge_relationships(relationships): | |
""" | |
return: | |
{ | |
'app': {'account': []}, | |
'staging': {'dim_customers': [('app', 'account')]}, | |
'data': {'dim_customers': [('staging', 'dim_customers')]} | |
} | |
""" | |
merged_relationships = defaultdict(lambda: defaultdict(list)) | |
for relationship in relationships: | |
for schema, tables in relationship.items(): | |
for table, sources in tables.items(): | |
for source in sources: | |
src_schema, src_table = source | |
# skip self table to self table scenario | |
if (src_schema == schema) and (src_table == table): | |
continue | |
# just initialize this node | |
if merged_relationships[src_schema][src_table]: | |
pass | |
if source not in merged_relationships[schema][table]: | |
merged_relationships[schema][table].append(source) | |
return merged_relationships | |
def generate_dot_for_relationships(relationships): | |
dot_content = "digraph dw { rankdir=LR;\n colorscheme=dark28;\n " | |
relationship_dot = [] | |
# draw the nodes for each schema as subgraph | |
for index, (schema, tables) in enumerate(relationships.items()): | |
schema_dot = f"subgraph cluster_{schema} {{ label=\"{schema}\"; shape=box; color={index+1}; " | |
for table, sources in tables.items(): | |
schema_dot += f" {schema}_{table}[shape=box,label=\"{schema}.{table}\"]; " | |
for src_schema, src_table in sources: | |
relationship_dot.append((src_schema, src_table, schema, table)) | |
schema_dot += "}\n" | |
dot_content += schema_dot | |
# draw the arrow between nodes | |
for idx, r_dot in enumerate(relationship_dot): | |
src_schema, src_table, schema, table = r_dot | |
# idx = list(relationships.keys()).index(src_schema) | |
dot_content += f"\n{src_schema}_{src_table} -> {schema}_{table}[color=\"/dark28/{(idx % 8)+1}\"];" | |
dot_content += "}" | |
return dot_content | |
def export_dot(dot_content, output_file): | |
graphs = pydot.graph_from_dot_data(dot_content) | |
graph = graphs[0] | |
graph.write_png(output_file) | |
folder_path = 'sqls' | |
relationships = merge_relationships(parse_relationships_from_sql_files(folder_path)) | |
print(json.dumps(relationships, indent=4)) | |
dot_content = generate_dot_for_relationships(relationships) | |
export_dot(dot_content, 'table_data_flow.png') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment