Skip to content

Instantly share code, notes, and snippets.

@chavarera
Created December 13, 2023 10:13
Show Gist options
  • Save chavarera/b41ab336a888935c3541332f4310119c to your computer and use it in GitHub Desktop.
Save chavarera/b41ab336a888935c3541332f4310119c to your computer and use it in GitHub Desktop.
import sqlglot
import time
import os
import pandas as pd
from sqlglot import parse_one, exp, optimizer
from sqlglot.optimizer.scope import traverse_scope
from sqlglot.errors import OptimizeError
# Define the SQL query string
sql_query = """
WITH SalesSummary AS (
SELECT p.Product_Name, SUM(s.Quantity) AS Total_Sales
FROM Products p
JOIN Sales s ON p.Product_ID = s.Product_ID
WHERE s.Sale_Date BETWEEN TO_DATE('2021-01-01', 'YYYY-MM-DD') AND TO_DATE('2021-12-31', 'YYYY-MM-DD')
GROUP BY p.Product_Name
)
SELECT ss.Product_Name, ss.Total_Sales, c.Category_Name
FROM SalesSummary ss
JOIN Categories c ON ss.Product_Name = c.Product_Name
ORDER BY ss.Total_Sales DESC;
"""
def qualify_columns(expression, schema):
"""
Attempts to qualify tables and columns in an SQL expression.
Args:
expression: The SQL expression to qualify.
schema: The schema to use for qualification (optional).
Returns:
The qualified SQL expression.
"""
try:
expression = optimizer.qualify_tables.qualify_tables(expression)
expression = optimizer.isolate_table_selects.isolate_table_selects(expression)
expression = optimizer.qualify_columns.qualify_columns(expression, schema)
except OptimizeError:
pass
return expression
def parse_statement(sql_query, dialect):
"""
Parses an SQL statement and extracts table, column, and condition information.
Args:
sql_query: The SQL statement string.
dialect: The dialect of the SQL statement (e.g., "mysql").
Returns:
A list of dictionaries containing table, column, and condition data.
"""
ast = parse_one(sql_query, read=dialect)
ast = qualify_columns(ast, schema=None)
# Extract WHERE conditions using `find_all` and format them
conditions = [ele.sql(pretty=True) for ele in ast.find_all(exp.Where)]
# Extract physical columns
physical_columns = []
for scope in traverse_scope(ast):
for c in scope.columns:
if isinstance(scope.sources.get(c.table), exp.Table):
physical_columns.append({
"table_name": scope.sources.get(c.table).name,
"Column Name": c.name,
"conditions": conditions
})
return physical_columns
# Parse the SQL query and convert to a Pandas DataFrame
result = parse_statement(sql_query, 'mysql')
df = pd.DataFrame(result)
# Save the DataFrame to a CSV file
df.to_csv('sample.csv', index=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment