Skip to content

Instantly share code, notes, and snippets.

@naren-dremio
Last active May 19, 2023 00:15
Show Gist options
  • Save naren-dremio/31b4a8065a3c53a458e2a5103a485106 to your computer and use it in GitHub Desktop.
Save naren-dremio/31b4a8065a3c53a458e2a5103a485106 to your computer and use it in GitHub Desktop.
for dremio cloud
import certifi
import sys
from http.cookies import SimpleCookie
from pyarrow import flight
class DremioClientAuthMiddlewareFactory(flight.ClientMiddlewareFactory):
"""A factory that creates DremioClientAuthMiddleware(s)."""
def __init__(self):
self.call_credential = []
def start_call(self, info):
return DremioClientAuthMiddleware(self)
def set_call_credential(self, call_credential):
self.call_credential = call_credential
class DremioClientAuthMiddleware(flight.ClientMiddleware):
def __init__(self, factory):
self.factory = factory
def received_headers(self, headers):
auth_header_key = 'authorization'
authorization_header = []
for key in headers:
if key.lower() == auth_header_key:
authorization_header = headers.get(auth_header_key)
if not authorization_header:
raise Exception('Did not receive authorization header back from server.')
self.factory.set_call_credential([
b'authorization', authorization_header[0].encode('utf-8')])
class CookieMiddlewareFactory(flight.ClientMiddlewareFactory):
"""A factory that creates CookieMiddleware(s)."""
def __init__(self):
self.cookies = {}
def start_call(self, info):
return CookieMiddleware(self)
class CookieMiddleware(flight.ClientMiddleware):
def __init__(self, factory):
self.factory = factory
def received_headers(self, headers):
for key in headers:
if key.lower() == 'set-cookie':
cookie = SimpleCookie()
for item in headers.get(key):
cookie.load(item)
self.factory.cookies.update(cookie.items())
def sending_headers(self):
if self.factory.cookies:
cookie_string = '; '.join("{!s}={!s}".format(key, val.value) for (key, val) in self.factory.cookies.items())
return {b'cookie': cookie_string.encode('utf-8')}
return {}
def connect_to_dremio_flight_server_endpoint(host, port, query,
tls, pat_or_auth_token):
"""
Connects to Dremio Flight server endpoint with the provided credentials.
It also runs the query and retrieves the result set.
"""
try:
# Default to use an unencrypted TCP connection.
scheme = "grpc+tcp"
connection_args = {}
if tls:
scheme = "grpc+tls"
disable_server_verification=True
if disable_server_verification:
connection_args['disable_server_verification'] = disable_server_verification
elif certs:
print('[INFO] Trusted certificates provided')
# TLS certificates are provided in a list of connection arguments.
with open(certs, "rb") as root_certs:
connection_args["tls_root_certs"] = root_certs.read()
else:
print('[ERROR] Trusted certificates must be provided to establish a TLS connection')
sys.exit()
session_properties=''
headers = session_properties
if not headers:
headers = []
engine=''
if engine:
headers.append((b'routing_engine', engine.encode('utf-8')))
# Two WLM settings can be provided upon initial authentication with the Dremio Server Flight Endpoint:
# routing_tag
# routing_queue
headers.append((b'routing_tag', b'test-routing-tag'))
headers.append((b'routing_queue', b'Low Cost User Queries'))
client_cookie_middleware = CookieMiddlewareFactory()
if pat_or_auth_token:
client = flight.FlightClient("{}://{}:{}".format(scheme, host, port),
middleware=[client_cookie_middleware], **connection_args)
headers.append((b'authorization', "Bearer {}".format(pat_or_auth_token).encode('utf-8')))
elif username and password:
client_auth_middleware = DremioClientAuthMiddlewareFactory()
client = flight.FlightClient("{}://{}:{}".format(scheme, host, port),
middleware=[client_auth_middleware, client_cookie_middleware],
**connection_args)
# Authenticate with the server endpoint.
bearer_token = client.authenticate_basic_token(username, password,
flight.FlightCallOptions(headers=headers))
headers.append(bearer_token)
else:
print('[ERROR] Username/password or PAT/Auth token must be supplied.')
sys.exit()
if query:
# Construct FlightDescriptor for the query result set.
flight_desc = flight.FlightDescriptor.for_command(query)
# In addition to the bearer token, a query context can also
# be provided as an entry of FlightCallOptions.
# options = flight.FlightCallOptions(headers=[
# bearer_token,
# (b'schema', b'test.schema')
# ])
# Retrieve the schema of the result set.
options = flight.FlightCallOptions(headers=headers)
schema = client.get_schema(flight_desc, options)
# Get the FlightInfo message to retrieve the Ticket corresponding
# to the query result set.
flight_info = client.get_flight_info(flight.FlightDescriptor.for_command(query), options)
# Retrieve the result set as a stream of Arrow record batches.
reader = client.do_get(flight_info.endpoints[0].ticket, options)
print(reader.read_pandas())
except Exception as exception:
print("[ERROR] Exception: {}".format(repr(exception)))
raise
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment