Created
December 13, 2023 10:13
-
-
Save chavarera/b41ab336a888935c3541332f4310119c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sqlglot | |
import time | |
import os | |
import pandas as pd | |
from sqlglot import parse_one, exp, optimizer | |
from sqlglot.optimizer.scope import traverse_scope | |
from sqlglot.errors import OptimizeError | |
# Define the SQL query string | |
sql_query = """ | |
WITH SalesSummary AS ( | |
SELECT p.Product_Name, SUM(s.Quantity) AS Total_Sales | |
FROM Products p | |
JOIN Sales s ON p.Product_ID = s.Product_ID | |
WHERE s.Sale_Date BETWEEN TO_DATE('2021-01-01', 'YYYY-MM-DD') AND TO_DATE('2021-12-31', 'YYYY-MM-DD') | |
GROUP BY p.Product_Name | |
) | |
SELECT ss.Product_Name, ss.Total_Sales, c.Category_Name | |
FROM SalesSummary ss | |
JOIN Categories c ON ss.Product_Name = c.Product_Name | |
ORDER BY ss.Total_Sales DESC; | |
""" | |
def qualify_columns(expression, schema): | |
""" | |
Attempts to qualify tables and columns in an SQL expression. | |
Args: | |
expression: The SQL expression to qualify. | |
schema: The schema to use for qualification (optional). | |
Returns: | |
The qualified SQL expression. | |
""" | |
try: | |
expression = optimizer.qualify_tables.qualify_tables(expression) | |
expression = optimizer.isolate_table_selects.isolate_table_selects(expression) | |
expression = optimizer.qualify_columns.qualify_columns(expression, schema) | |
except OptimizeError: | |
pass | |
return expression | |
def parse_statement(sql_query, dialect): | |
""" | |
Parses an SQL statement and extracts table, column, and condition information. | |
Args: | |
sql_query: The SQL statement string. | |
dialect: The dialect of the SQL statement (e.g., "mysql"). | |
Returns: | |
A list of dictionaries containing table, column, and condition data. | |
""" | |
ast = parse_one(sql_query, read=dialect) | |
ast = qualify_columns(ast, schema=None) | |
# Extract WHERE conditions using `find_all` and format them | |
conditions = [ele.sql(pretty=True) for ele in ast.find_all(exp.Where)] | |
# Extract physical columns | |
physical_columns = [] | |
for scope in traverse_scope(ast): | |
for c in scope.columns: | |
if isinstance(scope.sources.get(c.table), exp.Table): | |
physical_columns.append({ | |
"table_name": scope.sources.get(c.table).name, | |
"Column Name": c.name, | |
"conditions": conditions | |
}) | |
return physical_columns | |
# Parse the SQL query and convert to a Pandas DataFrame | |
result = parse_statement(sql_query, 'mysql') | |
df = pd.DataFrame(result) | |
# Save the DataFrame to a CSV file | |
df.to_csv('sample.csv', index=False) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment