Last active
March 8, 2022 04:30
-
-
Save randerzander/93ddbe1092ad2d722c0e0e8a71861f6d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# determines whether a GPU is available in your instance and starts an appropriate Dask Cluster | |
from dask.distributed import Client, LocalCluster | |
from dask_cuda import LocalCUDACluster | |
from dask_sql import Context | |
from pyngrok import ngrok | |
import dask | |
import time | |
import warnings | |
from IPython.core.magic import needs_local_scope, register_cell_magic | |
try: | |
from cuda import cuda | |
cuda.cuInit(0) | |
GPU = True | |
except: | |
GPU = False | |
if GPU: | |
import dask_cudf as dd | |
import cudf as pd | |
cluster = LocalCUDACluster( | |
protocol="tcp", | |
jit_unspill=True, | |
rmm_pool_size="12GB", | |
rmm_maximum_pool_size="15GB", | |
) | |
else: | |
import dask.dataframe as dd | |
import pandas as pd | |
cluster = LocalCluster() | |
client = Client(cluster) | |
client.amm.start() | |
dash_port = client.dashboard_link.split(':')[2].split('/')[0] | |
dashboard_tunnel = ngrok.connect(dash_port) | |
print(dashboard_tunnel) | |
warnings.filterwarnings("ignore") | |
@register_cell_magic | |
@needs_local_scope | |
def sql(line, cell, local_ns): | |
sql_statement = cell.format(**local_ns) | |
t0 = time.time() | |
res = c.sql(sql_statement) | |
if ( | |
"CREATE OR REPLACE TABLE" in sql_statement | |
or "CREATE OR REPLACE VIEW" in sql_statement | |
): | |
table = sql_statement.split("CREATE OR REPLACE")[1] | |
table = table.replace("TABLE", "").replace("VIEW", "").split()[0].strip() | |
res = c.sql(f"SELECT * FROM {table} LIMIT 5").compute() | |
elif "CREATE OR REPLACE MODEL" in sql_statement: | |
res = c.sql(sql_statement) | |
else: | |
res = res.compute() | |
# since we want to pass DF results to Plotly, anytime we pull query results to the client | |
# we go ahead and convert any cuDF DF to Pandas so Plotly can graph it | |
if GPU: | |
res = res.to_pandas() | |
print(f"Execution time: {time.time() - t0:.2f}s") | |
return res | |
# avoid name conflicts for automagic to work on line magics. | |
del sql | |
# Use of Dask-SQL from within Python centers around a Context | |
c = Context() | |
# can set SQL session configurations | |
# dask-sql is case-sensitive by default. I prefer to turn case sensitivity off | |
dask.config.set({"sql.identifier.case_sensitive":False}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment