Last active
November 2, 2023 12:01
-
-
Save dumkydewilde/d7051ad3175856a3234792390b4927c5 to your computer and use it in GitHub Desktop.
dbt databricks: Auto generate model files (.sql) with source and column names
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# You can run this script to generate the SQL files for the models and fetch columns names from Databricks. | |
# Run it on the command line, like this: | |
# python utils/generate_models.py \ | |
# -yml models/staging/salesforce/_source_salesforce.yml \ | |
# -o models/staging/salesforce | |
# Or see all options with python utils/generate_models.py --help | |
import yaml | |
import os | |
import requests | |
from databricks_cli.workspace.api import WorkspaceApi | |
from databricks_cli.sdk.api_client import ApiClient | |
import argparse | |
# Initialize the argument parser | |
parser = argparse.ArgumentParser(description="Process CLI arguments for the script.") | |
# Add arguments | |
parser.add_argument("--input_yaml", "-yml", help="Path to the input YAML file.") | |
parser.add_argument("--prefix", default="", help="Prefix to remove from the source table.") | |
parser.add_argument("--source_name", default="", help="Name of the source to generate files for only a single source.") | |
parser.add_argument("--output_path", "-o", help="Folder path for the output files.") | |
# Parse arguments | |
args = parser.parse_args() | |
# Use the arguments in the script | |
INPUT_YAML = args.input_yaml | |
PREFIX = args.prefix | |
SOURCE_NAME = args.source_name | |
OUTPUT_PATH = args.output_path | |
print(f"INPUT_YAML: {INPUT_YAML}") | |
print(f"PREFIX: {PREFIX}") | |
print(f"SOURCE_NAME: {SOURCE_NAME}") | |
print(f"OUTPUT_PATH: {OUTPUT_PATH}") | |
DATABRICKS_HOST = f"https://{os.environ.get('DATABRICKS_HOST')}" | |
DATABRICKS_TOKEN = os.environ.get("DATABRICKS_TOKEN") | |
DATABRICKS_WAREHOUSE = os.environ.get("DATABRICKS_HTTP_PATH").split("/")[-1] | |
def get_columns_from_databricks(schema, table_name): | |
# Initialize the Databricks API client | |
client = ApiClient(host=DATABRICKS_HOST, token=DATABRICKS_TOKEN) | |
# Execute the "show columns" SQL command | |
data = { | |
"warehouse_id": DATABRICKS_WAREHOUSE, | |
"statement": f"show columns from {schema}.{table_name}", | |
"wait_timeout": "30s", | |
"on_wait_timeout": "CONTINUE", | |
} | |
print( | |
f"Fetching columns for {table_name}. Params: {data}" | |
) | |
resp = client.perform_query("POST", "/sql/statements", data=data) | |
# Wait for the query to complete | |
while resp["status"]["state"] not in [ | |
"FINISHED", | |
"FAILED", | |
"CANCELED", | |
"SUCCEEDED", | |
]: | |
time.sleep(2) # Wait for a short duration before polling again | |
resp = client.perform_query("GET", f'/sql/statements/{resp["id"]}') | |
# Extract column names from the result | |
if resp["status"]["state"] in ["FINISHED", "SUCCEEDED"]: | |
columns = [row[0] for row in resp["result"]["data_array"]] | |
return columns | |
else: | |
print( | |
f"Error fetching columns for {table_name}. State: {resp['status']['state']}" | |
) | |
return [] | |
# Read the YAML file | |
with open(INPUT_YAML) as file: | |
data = yaml.load(file, Loader=yaml.FullLoader) | |
# Check if sources exist in the YAML data | |
if "sources" in data: | |
for source in data["sources"]: | |
# Check if a source name is provided | |
if source["name"] == SOURCE_NAME or SOURCE_NAME == "": | |
# Iterate over the tables | |
for table in source["tables"]: | |
# Extract table name and format it | |
formatted_table_name = ( | |
table["name"].replace(PREFIX, "").replace("_", "__") | |
) | |
file_name = ( | |
f"{OUTPUT_PATH}/std_{SOURCE_NAME}__{formatted_table_name}.sql" | |
) | |
# Fetch columns from Databricks | |
columns = get_columns_from_databricks(source["schema"], table["name"]) | |
columns_string = ",\n".join([f" {column}" for column in columns]) | |
# Construct the SQL content | |
sql_content = ( | |
"" | |
+ "with source as (\n" | |
+ f" select * from {{{{ source('{source['name']}', '{table['name']}') }}}}\n" | |
+ "),\n\n" | |
+ "renamed as (\n" | |
+ " select\n" | |
+ columns_string | |
+ "\n\n from source\n" | |
+ ")\n\n" | |
+ "select * from renamed" | |
) | |
# Write to the SQL file | |
with open(file_name, "w") as file: | |
file.write(sql_content.strip()) | |
print(f"File '{file_name}' created with columns from Databricks.") | |
print("Script completed.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment