Skip to content

Instantly share code, notes, and snippets.

@angusdev
Created August 8, 2024 21:19
Show Gist options
  • Save angusdev/9881e5822fbcc068c3dbac798c249ae1 to your computer and use it in GitHub Desktop.
Save angusdev/9881e5822fbcc068c3dbac798c249ae1 to your computer and use it in GitHub Desktop.
SQL Warehouse Data Source
# Example of how to use the SqlWarehouseDataSource class
if __name__ == "__main__":
host = "your-databricks-workspace-host"
http_path = "your-http-path"
client_id = "your-client-id"
client_secret = "your-client-secret"
tenant_id = "your-tenant-id"
# Use default 300 seconds (5 minutes) threshold
data_source = SqlWarehouseDataSource(host, http_path, client_id, client_secret, tenant_id)
# Use a custom 600 seconds (10 minutes) threshold
# data_source = SqlWarehouseDataSource(host, http_path, client_id, client_secret, tenant_id, token_expiry_threshold=600)
query = "SELECT * FROM your_table LIMIT 10"
df = data_source.execute_query(query)
print(df)
import pandas as pd
import requests
from datetime import datetime, timedelta
from databricks import sql
class SqlWarehouseDataSource:
def __init__(self, host: str, http_path: str, client_id: str, client_secret: str, tenant_id: str, token_expiry_threshold: int = 300):
"""
Initialize the SqlWarehouseDataSource class with Databricks SQL connection details.
Parameters:
- host (str): Databricks workspace hostname.
- http_path (str): HTTP path for the SQL warehouse.
- client_id (str): Azure AD client ID.
- client_secret (str): Azure AD client secret.
- tenant_id (str): Azure AD tenant ID.
- token_expiry_threshold (int): Threshold in seconds for token expiration check. Default is 300 seconds (5 minutes).
"""
self.host = host
self.http_path = http_path
self.client_id = client_id
self.client_secret = client_secret
self.tenant_id = tenant_id
self.token_expiry_threshold = token_expiry_threshold
self.access_token = None
self.token_expiration = None
def __get_access_token(self) -> str:
"""Private method to get access token from Azure AD."""
token_url = f"https://login.microsoftonline.com/{self.tenant_id}/oauth2/v2.0/token"
headers = {
"Content-Type": "application/x-www-form-urlencoded"
}
payload = {
"grant_type": "client_credentials",
"client_id": self.client_id,
"client_secret": self.client_secret,
"scope": "https://databricks.azure.net/.default"
}
response = requests.post(token_url, headers=headers, data=payload)
if response.status_code != 200:
raise Exception(f"Failed to obtain access token: {response.text}")
data = response.json()
self.access_token = data.get("access_token")
expires_in = data.get("expires_in", 3600) # Default to 3600 seconds if not provided
self.token_expiration = datetime.utcnow() + timedelta(seconds=expires_in)
return self.access_token
def __refresh_token_if_needed(self):
"""Private method to refresh the access token if it is expired or about to expire."""
if (self.access_token is None or
datetime.utcnow() >= self.token_expiration - timedelta(seconds=self.token_expiry_threshold)):
self.__get_access_token()
def execute_query(self, query: str) -> pd.DataFrame:
"""Public method to execute the SQL query and return the results as a DataFrame."""
# Refresh token if needed
self.__refresh_token_if_needed()
# Establish a connection to the Databricks SQL warehouse and ensure cleanup
with sql.connect(
server_hostname=self.host,
http_path=self.http_path,
access_token=self.access_token
) as connection:
# Execute the query and fetch results
with connection.cursor() as cursor:
cursor.execute(query)
result = cursor.fetchall()
column_names = [desc[0] for desc in cursor.description]
# Convert result to a DataFrame
df = pd.DataFrame(result, columns=column_names)
return df
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment