Created
November 2, 2023 06:43
-
-
Save ryancollingwood/1bc64520adac3f6c53edfaf330ff4922 to your computer and use it in GitHub Desktop.
For column that is a composite of other columns in SQL query (sub queries or CTEs), get the sql that makes up that column
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Dict | |
import re | |
from dataclasses import dataclass, field | |
import sqlparse | |
@dataclass | |
class SQLColumnLineage(): | |
sql: str | |
target_column: str | |
aliased_columns: Dict[str, str] = field(default_factory=lambda: dict()) | |
def get_columns_referenced(self, token): | |
# Extract column references from the parsed SQL | |
column_references = set() | |
if token.is_whitespace: | |
return column_references | |
if isinstance(token, sqlparse.sql.IdentifierList): | |
# Find column references in the identifier list | |
for identifier in token.get_identifiers(): | |
column_references.update(self.extract_column_references(identifier)) | |
elif isinstance(token, sqlparse.sql.Identifier): | |
# Find column references in a single identifier | |
column_references.update(self.extract_column_references(token)) | |
for t in token.tokens: | |
column_references.update(self.get_columns_referenced(t)) | |
elif isinstance(token, sqlparse.sql.Parenthesis): | |
# Handle subqueries within parentheses | |
for t in token.tokens: | |
column_references.update(self.get_columns_referenced(t)) | |
elif isinstance(token, sqlparse.sql.Statement): | |
# Recursively handle subqueries | |
for t in token.tokens: | |
column_references.update(self.get_columns_referenced(t)) | |
else: | |
# handle the case that this token has child tokens that | |
# are potentially relevant | |
try: | |
sub_tokens = [x for x in token.tokens] | |
except AttributeError: | |
sub_tokens = set() | |
for t in sub_tokens: | |
if t.is_whitespace: | |
continue | |
column_references.update(self.get_columns_referenced(t)) | |
return column_references | |
def extract_column_references(self, token): | |
column_name = self.target_column.lower() | |
result = set() | |
try: | |
referenced_column = token.get_real_name().lower() | |
except AttributeError: | |
return result | |
alias = token.get_alias() | |
if alias is None: | |
alias = "" | |
else: | |
alias = alias | |
if alias.lower() != referenced_column: | |
# to avoid the case of assigning a function or keyword | |
# as an alias, we only want Name token types | |
# To be fair this is a little sketchy | |
if str(token.tokens[0].ttype) == "Token.Name": | |
self.aliased_columns[referenced_column] = alias | |
else: | |
#print("not aliasing", token.value) | |
pass | |
if column_name in (referenced_column, alias.lower()): | |
try: | |
sub_tokens = [str(x) for x in token.flatten()] | |
# to handle the case of a subquery using aliased columns | |
# we want the actual source columns | |
replacement_sub_tokens = [self.aliased_columns[x.lower()] if x.lower() in self.aliased_columns else x for x in sub_tokens] | |
# may have to do this replacement a few times | |
# depending on how nested the aliasing has been | |
while replacement_sub_tokens != sub_tokens: | |
sub_tokens = replacement_sub_tokens | |
replacement_sub_tokens = [self.aliased_columns[x.lower()] if x.lower() in self.aliased_columns else x for x in sub_tokens] | |
result.add("".join(sub_tokens)) | |
except AttributeError: | |
# if the token has no child just use the sql | |
# that describes the current token | |
result.add(token.value) | |
return result | |
def find_columns_in_query(self): | |
# Parse the SQL query | |
parsed = sqlparse.parse(self.sql) | |
# Extract column references from the parsed SQL | |
column_references = set() | |
for stmt in parsed: | |
for token in stmt.tokens: | |
column_references.update(self.get_columns_referenced(token)) | |
column_name = self.target_column.lower() | |
# filter results | |
result = [re.sub(r"\s+", " ", x).strip() for x in column_references] | |
# remove self references | |
result = [x for x in result if x.strip().lower() != f"{column_name} as {column_name}"] | |
if len(result) == 1: | |
return result[0] | |
return result | |
if __name__ == "__main__": | |
sql = """SELECT *, CONCAT( COALESCE( cast(trim(SiteID) as string) , '') ,'--' ,COALESCE( cast(trim(SystemOfOriginID) as string), '') ) as zz_meta_key FROM reference_sites""" | |
target_column = "zz_meta_key" | |
finder = SQLColumnLineage( | |
sql = sql, | |
target_column = target_column, | |
) | |
# Get columns referenced in the SQL query for the specified column | |
referenced_columns = finder.find_columns_in_query() | |
print(f"Columns that make up {target_column}: {', '.join(referenced_columns)}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Leverages https://pypi.org/project/sqlparse/ to do the heavy lifting