Skip to content

Instantly share code, notes, and snippets.

@dumkydewilde
Last active November 2, 2023 12:01
Show Gist options
  • Save dumkydewilde/d7051ad3175856a3234792390b4927c5 to your computer and use it in GitHub Desktop.
Save dumkydewilde/d7051ad3175856a3234792390b4927c5 to your computer and use it in GitHub Desktop.
dbt databricks: Auto generate model files (.sql) with source and column names
# You can run this script to generate the SQL files for the models and fetch columns names from Databricks.
# Run it on the command line, like this:
# python utils/generate_models.py \
# -yml models/staging/salesforce/_source_salesforce.yml \
# -o models/staging/salesforce
# Or see all options with python utils/generate_models.py --help
import yaml
import os
import requests
from databricks_cli.workspace.api import WorkspaceApi
from databricks_cli.sdk.api_client import ApiClient
import argparse
# Initialize the argument parser
parser = argparse.ArgumentParser(description="Process CLI arguments for the script.")
# Add arguments
parser.add_argument("--input_yaml", "-yml", help="Path to the input YAML file.")
parser.add_argument("--prefix", default="", help="Prefix to remove from the source table.")
parser.add_argument("--source_name", default="", help="Name of the source to generate files for only a single source.")
parser.add_argument("--output_path", "-o", help="Folder path for the output files.")
# Parse arguments
args = parser.parse_args()
# Use the arguments in the script
INPUT_YAML = args.input_yaml
PREFIX = args.prefix
SOURCE_NAME = args.source_name
OUTPUT_PATH = args.output_path
print(f"INPUT_YAML: {INPUT_YAML}")
print(f"PREFIX: {PREFIX}")
print(f"SOURCE_NAME: {SOURCE_NAME}")
print(f"OUTPUT_PATH: {OUTPUT_PATH}")
DATABRICKS_HOST = f"https://{os.environ.get('DATABRICKS_HOST')}"
DATABRICKS_TOKEN = os.environ.get("DATABRICKS_TOKEN")
DATABRICKS_WAREHOUSE = os.environ.get("DATABRICKS_HTTP_PATH").split("/")[-1]
def get_columns_from_databricks(schema, table_name):
# Initialize the Databricks API client
client = ApiClient(host=DATABRICKS_HOST, token=DATABRICKS_TOKEN)
# Execute the "show columns" SQL command
data = {
"warehouse_id": DATABRICKS_WAREHOUSE,
"statement": f"show columns from {schema}.{table_name}",
"wait_timeout": "30s",
"on_wait_timeout": "CONTINUE",
}
print(
f"Fetching columns for {table_name}. Params: {data}"
)
resp = client.perform_query("POST", "/sql/statements", data=data)
# Wait for the query to complete
while resp["status"]["state"] not in [
"FINISHED",
"FAILED",
"CANCELED",
"SUCCEEDED",
]:
time.sleep(2) # Wait for a short duration before polling again
resp = client.perform_query("GET", f'/sql/statements/{resp["id"]}')
# Extract column names from the result
if resp["status"]["state"] in ["FINISHED", "SUCCEEDED"]:
columns = [row[0] for row in resp["result"]["data_array"]]
return columns
else:
print(
f"Error fetching columns for {table_name}. State: {resp['status']['state']}"
)
return []
# Read the YAML file
with open(INPUT_YAML) as file:
data = yaml.load(file, Loader=yaml.FullLoader)
# Check if sources exist in the YAML data
if "sources" in data:
for source in data["sources"]:
# Check if a source name is provided
if source["name"] == SOURCE_NAME or SOURCE_NAME == "":
# Iterate over the tables
for table in source["tables"]:
# Extract table name and format it
formatted_table_name = (
table["name"].replace(PREFIX, "").replace("_", "__")
)
file_name = (
f"{OUTPUT_PATH}/std_{SOURCE_NAME}__{formatted_table_name}.sql"
)
# Fetch columns from Databricks
columns = get_columns_from_databricks(source["schema"], table["name"])
columns_string = ",\n".join([f" {column}" for column in columns])
# Construct the SQL content
sql_content = (
""
+ "with source as (\n"
+ f" select * from {{{{ source('{source['name']}', '{table['name']}') }}}}\n"
+ "),\n\n"
+ "renamed as (\n"
+ " select\n"
+ columns_string
+ "\n\n from source\n"
+ ")\n\n"
+ "select * from renamed"
)
# Write to the SQL file
with open(file_name, "w") as file:
file.write(sql_content.strip())
print(f"File '{file_name}' created with columns from Databricks.")
print("Script completed.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment