Created
October 7, 2024 19:08
-
-
Save Hugoberry/d4e362a65678b62365800f634685e59a to your computer and use it in GitHub Desktop.
Use `sqlfluff` to build a lineage tree in D2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from typing import Any, Dict, Iterator, Union, List, Tuple | |
import sqlfluff | |
# Parses the given SQL file and returns the parsing result | |
def parse_sql_file(file_path: str) -> Dict[str, Any]: | |
try: | |
with open(file_path, 'r', encoding='utf-8-sig') as file: | |
sql_content = file.read() | |
parsed = sqlfluff.parse(sql_content, dialect="tsql") | |
print(f"Successfully parsed file: {file_path}") | |
return parsed | |
except sqlfluff.api.simple.APIParsingError as e: | |
print(f"Warning: Unable to parse the entire SQL file {file_path}. Proceeding with partial parsing.") | |
return getattr(e, 'partial_result', {}) | |
# Recursively searches for segments of the given type in the parse result | |
def get_json_segment(parse_result: Union[Dict[str, Any], List[Any]], segment_type: str) -> Iterator[Union[str, Dict[str, Any], List[Dict[str, Any]]]]: | |
if isinstance(parse_result, dict): | |
for k, v in parse_result.items(): | |
if k == segment_type: | |
yield v | |
elif isinstance(v, (dict, list)): | |
yield from get_json_segment(v, segment_type) | |
elif isinstance(parse_result, list): | |
for item in parse_result: | |
yield from get_json_segment(item, segment_type) | |
# Helper function to extract aliases from FROM clauses | |
def extract_aliases(from_clause: Union[Dict[str, Any], List[Any]]) -> List[str]: | |
aliases = [] | |
for from_element in get_json_segment(from_clause, "from_expression_element"): | |
alias_expression = from_element.get('alias_expression', {}).get('naked_identifier', '') | |
table_reference = from_element.get('table_expression', {}).get('table_reference') | |
if isinstance(table_reference, list): | |
table_name = ''.join([value for part in table_reference for value in part.values()]) | |
else: | |
table_name = ''.join(table_reference.values()) | |
aliases.append(f"{alias_expression}: {table_name}") | |
return aliases | |
# Helper function to extract comparisons from expressions | |
def extract_comparisons(expression: Union[Dict[str, Any], List[Any]]) -> List[str]: | |
comparisons = [] | |
for exp in get_json_segment(expression, "expression"): | |
current_comparison = [] | |
for item in exp: | |
if 'column_reference' in item: | |
column = '.'.join(ref['naked_identifier'] for ref in item['column_reference'] if 'naked_identifier' in ref) | |
current_comparison.append(column) | |
elif 'comparison_operator' in item: | |
current_comparison.append(item['comparison_operator']['raw_comparison_operator']) | |
elif 'binary_operator' in item: | |
# If a complete comparison has been found, store it in the desired format | |
if len(current_comparison) == 3 and current_comparison[1] == '=': | |
comparisons.append(f"{current_comparison[0]} -> {current_comparison[2]}") | |
current_comparison = [] | |
# Append the last comparison if it is complete | |
if len(current_comparison) == 3: | |
comparisons.append(f"{current_comparison[0]} -> {current_comparison[2]}") | |
return comparisons | |
if __name__ == "__main__": | |
file_path = 'test/vwMarket.sql' | |
print(f"Parsing file: {file_path}") | |
parse_result = parse_sql_file(file_path) | |
# Extract and print aliases | |
for from_clause in get_json_segment(parse_result, "from_clause"): | |
aliases = extract_aliases(from_clause) | |
for alias in aliases: | |
print(f"{alias} {{shape: sql_table}}") | |
# Extract and print comparisons | |
for from_clause in get_json_segment(parse_result, "from_clause"): | |
comparisons = extract_comparisons(from_clause) | |
for comparison in comparisons: | |
print(comparison) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment