Skip to content

Instantly share code, notes, and snippets.

@Hugoberry
Created October 7, 2024 19:08
Show Gist options
  • Save Hugoberry/d4e362a65678b62365800f634685e59a to your computer and use it in GitHub Desktop.
Save Hugoberry/d4e362a65678b62365800f634685e59a to your computer and use it in GitHub Desktop.
Use `sqlfluff` to build a lineage tree in D2
import os
from typing import Any, Dict, Iterator, Union, List, Tuple
import sqlfluff
# Parses the given SQL file and returns the parsing result
def parse_sql_file(file_path: str) -> Dict[str, Any]:
try:
with open(file_path, 'r', encoding='utf-8-sig') as file:
sql_content = file.read()
parsed = sqlfluff.parse(sql_content, dialect="tsql")
print(f"Successfully parsed file: {file_path}")
return parsed
except sqlfluff.api.simple.APIParsingError as e:
print(f"Warning: Unable to parse the entire SQL file {file_path}. Proceeding with partial parsing.")
return getattr(e, 'partial_result', {})
# Recursively searches for segments of the given type in the parse result
def get_json_segment(parse_result: Union[Dict[str, Any], List[Any]], segment_type: str) -> Iterator[Union[str, Dict[str, Any], List[Dict[str, Any]]]]:
if isinstance(parse_result, dict):
for k, v in parse_result.items():
if k == segment_type:
yield v
elif isinstance(v, (dict, list)):
yield from get_json_segment(v, segment_type)
elif isinstance(parse_result, list):
for item in parse_result:
yield from get_json_segment(item, segment_type)
# Helper function to extract aliases from FROM clauses
def extract_aliases(from_clause: Union[Dict[str, Any], List[Any]]) -> List[str]:
aliases = []
for from_element in get_json_segment(from_clause, "from_expression_element"):
alias_expression = from_element.get('alias_expression', {}).get('naked_identifier', '')
table_reference = from_element.get('table_expression', {}).get('table_reference')
if isinstance(table_reference, list):
table_name = ''.join([value for part in table_reference for value in part.values()])
else:
table_name = ''.join(table_reference.values())
aliases.append(f"{alias_expression}: {table_name}")
return aliases
# Helper function to extract comparisons from expressions
def extract_comparisons(expression: Union[Dict[str, Any], List[Any]]) -> List[str]:
comparisons = []
for exp in get_json_segment(expression, "expression"):
current_comparison = []
for item in exp:
if 'column_reference' in item:
column = '.'.join(ref['naked_identifier'] for ref in item['column_reference'] if 'naked_identifier' in ref)
current_comparison.append(column)
elif 'comparison_operator' in item:
current_comparison.append(item['comparison_operator']['raw_comparison_operator'])
elif 'binary_operator' in item:
# If a complete comparison has been found, store it in the desired format
if len(current_comparison) == 3 and current_comparison[1] == '=':
comparisons.append(f"{current_comparison[0]} -> {current_comparison[2]}")
current_comparison = []
# Append the last comparison if it is complete
if len(current_comparison) == 3:
comparisons.append(f"{current_comparison[0]} -> {current_comparison[2]}")
return comparisons
if __name__ == "__main__":
file_path = 'test/vwMarket.sql'
print(f"Parsing file: {file_path}")
parse_result = parse_sql_file(file_path)
# Extract and print aliases
for from_clause in get_json_segment(parse_result, "from_clause"):
aliases = extract_aliases(from_clause)
for alias in aliases:
print(f"{alias} {{shape: sql_table}}")
# Extract and print comparisons
for from_clause in get_json_segment(parse_result, "from_clause"):
comparisons = extract_comparisons(from_clause)
for comparison in comparisons:
print(comparison)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment