Skip to content

Instantly share code, notes, and snippets.

@zhongxiao37
Last active August 24, 2023 09:30
Show Gist options
  • Save zhongxiao37/b5949fa1927beaafc8db45e371e26de3 to your computer and use it in GitHub Desktop.
Save zhongxiao37/b5949fa1927beaafc8db45e371e26de3 to your computer and use it in GitHub Desktop.
Parse the SQL files to draw the table flows with DOT
import re
import os
from collections import defaultdict
import json
import pydot
def parse_relationships_from_sql_files(sql_folder_path):
full_relationships = []
for root, _, files in os.walk(sql_folder_path):
for file_name in files:
if file_name.endswith('.sql'):
full_relationships += parse_relationships_from_sql_file(sql_file_path=f"{root}/{file_name}")
return full_relationships
def parse_relationships_from_sql_file(sql_file_path):
"""
return:
[
{'staging': {'dim_customers': [('app', 'account')]}},
{'data': {'dim_customers': [('staging', 'dim_customers'),
('data', 'dim_customers')]}},
{'data': {'dim_customers': [('staging', 'dim_customers')]}}
]
"""
sql = ''
sql_raw_relationships = []
with open(sql_file_path) as file:
sql = file.read()
sql_statements = parse_sql_statements(sql)
for statement in sql_statements:
relationship = parse_insert_statement(statement)
sql_raw_relationships.append(relationship) if relationship is not None else None
return sql_raw_relationships
def parse_sql_statements(sql):
sql_match = re.search(r"CREATE.+FUNCTION[\s\S]*(?<=BEGIN)([\s\S]*)(?=END)", sql)
sql_statements = [x.strip() for x in sql_match.group(1).split(';')] if sql_match else []
return sql_statements
def parse_insert_statement(sql):
"""
return:
{'data': {'dim_customers': [('staging', 'dim_customers')]}}
"""
# Extract target table and columns
target_table_match = re.search(r"INSERT INTO\s*(\w+\.\w+)\s*\(([^)]*)\)", sql, re.IGNORECASE)
target_table = None
target_columns = None
if target_table_match:
target_table = target_table_match.group(1)
target_columns = [col.strip() for col in target_table_match.group(2).split(',')]
else:
return None
# Extract source table and columns
source_table_match = re.search(r"FROM\s+(\w+\.\w+)\s(\w+)", sql, re.IGNORECASE)
source_tables = list()
if source_table_match:
source_tables.append((source_table_match.group(1), source_table_match.group(2)))
join_tables = re.findall(r"JOIN\s+(\w+\.\w+)\s(\w+)", sql)
if join_tables:
source_tables += join_tables
source_columns_match = re.search(r"SELECT\s+(.+)\s+FROM", sql, flags=re.DOTALL | re.IGNORECASE)
source_columns_str = source_columns_match.group(1).strip() if source_columns_match else ""
source_columns = [col.strip() for col in source_columns_str.split(',')]
if target_columns and source_columns:
pass
schema, table = target_table.split('.')
relationship = {
schema: {
table: [tuple(tbl.split('.')) for tbl, _ in source_tables]
}
}
return relationship
def merge_relationships(relationships):
"""
return:
{
'app': {'account': []},
'staging': {'dim_customers': [('app', 'account')]},
'data': {'dim_customers': [('staging', 'dim_customers')]}
}
"""
merged_relationships = defaultdict(lambda: defaultdict(list))
for relationship in relationships:
for schema, tables in relationship.items():
for table, sources in tables.items():
for source in sources:
src_schema, src_table = source
# skip self table to self table scenario
if (src_schema == schema) and (src_table == table):
continue
# just initialize this node
if merged_relationships[src_schema][src_table]:
pass
if source not in merged_relationships[schema][table]:
merged_relationships[schema][table].append(source)
return merged_relationships
def generate_dot_for_relationships(relationships):
dot_content = "digraph dw { rankdir=LR;\n colorscheme=dark28;\n "
relationship_dot = []
# draw the nodes for each schema as subgraph
for index, (schema, tables) in enumerate(relationships.items()):
schema_dot = f"subgraph cluster_{schema} {{ label=\"{schema}\"; shape=box; color={index+1}; "
for table, sources in tables.items():
schema_dot += f" {schema}_{table}[shape=box,label=\"{schema}.{table}\"]; "
for src_schema, src_table in sources:
relationship_dot.append((src_schema, src_table, schema, table))
schema_dot += "}\n"
dot_content += schema_dot
# draw the arrow between nodes
for idx, r_dot in enumerate(relationship_dot):
src_schema, src_table, schema, table = r_dot
# idx = list(relationships.keys()).index(src_schema)
dot_content += f"\n{src_schema}_{src_table} -> {schema}_{table}[color=\"/dark28/{(idx % 8)+1}\"];"
dot_content += "}"
return dot_content
def export_dot(dot_content, output_file):
graphs = pydot.graph_from_dot_data(dot_content)
graph = graphs[0]
graph.write_png(output_file)
folder_path = 'sqls'
relationships = merge_relationships(parse_relationships_from_sql_files(folder_path))
print(json.dumps(relationships, indent=4))
dot_content = generate_dot_for_relationships(relationships)
export_dot(dot_content, 'table_data_flow.png')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment