Created
August 8, 2024 21:19
-
-
Save angusdev/9881e5822fbcc068c3dbac798c249ae1 to your computer and use it in GitHub Desktop.
SQL Warehouse Data Source
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Example of how to use the SqlWarehouseDataSource class | |
if __name__ == "__main__": | |
host = "your-databricks-workspace-host" | |
http_path = "your-http-path" | |
client_id = "your-client-id" | |
client_secret = "your-client-secret" | |
tenant_id = "your-tenant-id" | |
# Use default 300 seconds (5 minutes) threshold | |
data_source = SqlWarehouseDataSource(host, http_path, client_id, client_secret, tenant_id) | |
# Use a custom 600 seconds (10 minutes) threshold | |
# data_source = SqlWarehouseDataSource(host, http_path, client_id, client_secret, tenant_id, token_expiry_threshold=600) | |
query = "SELECT * FROM your_table LIMIT 10" | |
df = data_source.execute_query(query) | |
print(df) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import requests | |
from datetime import datetime, timedelta | |
from databricks import sql | |
class SqlWarehouseDataSource: | |
def __init__(self, host: str, http_path: str, client_id: str, client_secret: str, tenant_id: str, token_expiry_threshold: int = 300): | |
""" | |
Initialize the SqlWarehouseDataSource class with Databricks SQL connection details. | |
Parameters: | |
- host (str): Databricks workspace hostname. | |
- http_path (str): HTTP path for the SQL warehouse. | |
- client_id (str): Azure AD client ID. | |
- client_secret (str): Azure AD client secret. | |
- tenant_id (str): Azure AD tenant ID. | |
- token_expiry_threshold (int): Threshold in seconds for token expiration check. Default is 300 seconds (5 minutes). | |
""" | |
self.host = host | |
self.http_path = http_path | |
self.client_id = client_id | |
self.client_secret = client_secret | |
self.tenant_id = tenant_id | |
self.token_expiry_threshold = token_expiry_threshold | |
self.access_token = None | |
self.token_expiration = None | |
def __get_access_token(self) -> str: | |
"""Private method to get access token from Azure AD.""" | |
token_url = f"https://login.microsoftonline.com/{self.tenant_id}/oauth2/v2.0/token" | |
headers = { | |
"Content-Type": "application/x-www-form-urlencoded" | |
} | |
payload = { | |
"grant_type": "client_credentials", | |
"client_id": self.client_id, | |
"client_secret": self.client_secret, | |
"scope": "https://databricks.azure.net/.default" | |
} | |
response = requests.post(token_url, headers=headers, data=payload) | |
if response.status_code != 200: | |
raise Exception(f"Failed to obtain access token: {response.text}") | |
data = response.json() | |
self.access_token = data.get("access_token") | |
expires_in = data.get("expires_in", 3600) # Default to 3600 seconds if not provided | |
self.token_expiration = datetime.utcnow() + timedelta(seconds=expires_in) | |
return self.access_token | |
def __refresh_token_if_needed(self): | |
"""Private method to refresh the access token if it is expired or about to expire.""" | |
if (self.access_token is None or | |
datetime.utcnow() >= self.token_expiration - timedelta(seconds=self.token_expiry_threshold)): | |
self.__get_access_token() | |
def execute_query(self, query: str) -> pd.DataFrame: | |
"""Public method to execute the SQL query and return the results as a DataFrame.""" | |
# Refresh token if needed | |
self.__refresh_token_if_needed() | |
# Establish a connection to the Databricks SQL warehouse and ensure cleanup | |
with sql.connect( | |
server_hostname=self.host, | |
http_path=self.http_path, | |
access_token=self.access_token | |
) as connection: | |
# Execute the query and fetch results | |
with connection.cursor() as cursor: | |
cursor.execute(query) | |
result = cursor.fetchall() | |
column_names = [desc[0] for desc in cursor.description] | |
# Convert result to a DataFrame | |
df = pd.DataFrame(result, columns=column_names) | |
return df |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment