Last active
May 19, 2023 00:15
-
-
Save naren-dremio/31b4a8065a3c53a458e2a5103a485106 to your computer and use it in GitHub Desktop.
for dremio cloud
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import certifi | |
import sys | |
from http.cookies import SimpleCookie | |
from pyarrow import flight | |
class DremioClientAuthMiddlewareFactory(flight.ClientMiddlewareFactory): | |
"""A factory that creates DremioClientAuthMiddleware(s).""" | |
def __init__(self): | |
self.call_credential = [] | |
def start_call(self, info): | |
return DremioClientAuthMiddleware(self) | |
def set_call_credential(self, call_credential): | |
self.call_credential = call_credential | |
class DremioClientAuthMiddleware(flight.ClientMiddleware): | |
def __init__(self, factory): | |
self.factory = factory | |
def received_headers(self, headers): | |
auth_header_key = 'authorization' | |
authorization_header = [] | |
for key in headers: | |
if key.lower() == auth_header_key: | |
authorization_header = headers.get(auth_header_key) | |
if not authorization_header: | |
raise Exception('Did not receive authorization header back from server.') | |
self.factory.set_call_credential([ | |
b'authorization', authorization_header[0].encode('utf-8')]) | |
class CookieMiddlewareFactory(flight.ClientMiddlewareFactory): | |
"""A factory that creates CookieMiddleware(s).""" | |
def __init__(self): | |
self.cookies = {} | |
def start_call(self, info): | |
return CookieMiddleware(self) | |
class CookieMiddleware(flight.ClientMiddleware): | |
def __init__(self, factory): | |
self.factory = factory | |
def received_headers(self, headers): | |
for key in headers: | |
if key.lower() == 'set-cookie': | |
cookie = SimpleCookie() | |
for item in headers.get(key): | |
cookie.load(item) | |
self.factory.cookies.update(cookie.items()) | |
def sending_headers(self): | |
if self.factory.cookies: | |
cookie_string = '; '.join("{!s}={!s}".format(key, val.value) for (key, val) in self.factory.cookies.items()) | |
return {b'cookie': cookie_string.encode('utf-8')} | |
return {} | |
def connect_to_dremio_flight_server_endpoint(host, port, query, | |
tls, pat_or_auth_token): | |
""" | |
Connects to Dremio Flight server endpoint with the provided credentials. | |
It also runs the query and retrieves the result set. | |
""" | |
try: | |
# Default to use an unencrypted TCP connection. | |
scheme = "grpc+tcp" | |
connection_args = {} | |
if tls: | |
scheme = "grpc+tls" | |
disable_server_verification=True | |
if disable_server_verification: | |
connection_args['disable_server_verification'] = disable_server_verification | |
elif certs: | |
print('[INFO] Trusted certificates provided') | |
# TLS certificates are provided in a list of connection arguments. | |
with open(certs, "rb") as root_certs: | |
connection_args["tls_root_certs"] = root_certs.read() | |
else: | |
print('[ERROR] Trusted certificates must be provided to establish a TLS connection') | |
sys.exit() | |
session_properties='' | |
headers = session_properties | |
if not headers: | |
headers = [] | |
engine='' | |
if engine: | |
headers.append((b'routing_engine', engine.encode('utf-8'))) | |
# Two WLM settings can be provided upon initial authentication with the Dremio Server Flight Endpoint: | |
# routing_tag | |
# routing_queue | |
headers.append((b'routing_tag', b'test-routing-tag')) | |
headers.append((b'routing_queue', b'Low Cost User Queries')) | |
client_cookie_middleware = CookieMiddlewareFactory() | |
if pat_or_auth_token: | |
client = flight.FlightClient("{}://{}:{}".format(scheme, host, port), | |
middleware=[client_cookie_middleware], **connection_args) | |
headers.append((b'authorization', "Bearer {}".format(pat_or_auth_token).encode('utf-8'))) | |
elif username and password: | |
client_auth_middleware = DremioClientAuthMiddlewareFactory() | |
client = flight.FlightClient("{}://{}:{}".format(scheme, host, port), | |
middleware=[client_auth_middleware, client_cookie_middleware], | |
**connection_args) | |
# Authenticate with the server endpoint. | |
bearer_token = client.authenticate_basic_token(username, password, | |
flight.FlightCallOptions(headers=headers)) | |
headers.append(bearer_token) | |
else: | |
print('[ERROR] Username/password or PAT/Auth token must be supplied.') | |
sys.exit() | |
if query: | |
# Construct FlightDescriptor for the query result set. | |
flight_desc = flight.FlightDescriptor.for_command(query) | |
# In addition to the bearer token, a query context can also | |
# be provided as an entry of FlightCallOptions. | |
# options = flight.FlightCallOptions(headers=[ | |
# bearer_token, | |
# (b'schema', b'test.schema') | |
# ]) | |
# Retrieve the schema of the result set. | |
options = flight.FlightCallOptions(headers=headers) | |
schema = client.get_schema(flight_desc, options) | |
# Get the FlightInfo message to retrieve the Ticket corresponding | |
# to the query result set. | |
flight_info = client.get_flight_info(flight.FlightDescriptor.for_command(query), options) | |
# Retrieve the result set as a stream of Arrow record batches. | |
reader = client.do_get(flight_info.endpoints[0].ticket, options) | |
print(reader.read_pandas()) | |
except Exception as exception: | |
print("[ERROR] Exception: {}".format(repr(exception))) | |
raise |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment