Skip to content

Instantly share code, notes, and snippets.

@mengdong
Forked from leiterenato/analyze_dataset_op.py
Last active October 28, 2021 20:44
Show Gist options
  • Save mengdong/8526f2765707be492b1b697c3fb8c687 to your computer and use it in GitHub Desktop.
Save mengdong/8526f2765707be492b1b697c3fb8c687 to your computer and use it in GitHub Desktop.
Fit 1 parquet file from Criteo
if __name__ == "__main__":
parquet_dataset = "dongm-debug"
n_workers = 4
device_limit_frac = 0.8
device_pool_frac = 0.9
part_mem_frac = 0.125
import etl
import logging
import os
workflow = './workflow'
logging.basicConfig(level=logging.INFO)
split = 'fsspec_parquet'
# Create Dask cluster
logging.info('Creating Dask cluster.')
client = etl.create_cluster(
n_workers = n_workers,
device_limit_frac = device_limit_frac,
device_pool_frac = device_pool_frac
)
# Create data transformation workflow. This step will only
# calculate statistics based on the transformations
logging.info('Creating transformation workflow.')
criteo_workflow = etl.create_criteo_nvt_workflow(client=client)
# Create dataset to be fitted
logging.info(f'Creating dataset to be analysed.')
dataset = etl.create_parquet_dataset(
client=client,
data_path=os.path.join(
parquet_dataset,
split
),
part_mem_frac=part_mem_frac
)
logging.info(f'Starting workflow fitting for {split} split.')
criteo_workflow = etl.analyze_dataset(criteo_workflow, dataset)
logging.info('Finished generating statistics for dataset.')
etl.save_workflow(criteo_workflow, workflow)
logging.info('Workflow saved to GCS')
# Copyright 2021 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
"""Data preprocessing."""
from typing import Dict, Union
import numpy as np
import fsspec
import os
import nvtabular as nvt
from nvtabular.utils import device_mem_size
from nvtabular.io.shuffle import Shuffle
from nvtabular.ops import (
Categorify,
Clip,
FillMissing,
Normalize,
)
from dask_cuda import LocalCUDACluster
from dask.distributed import Client
def create_csv_dataset(
data_paths,
sep,
recursive,
col_dtypes,
part_mem_frac,
client
):
'''Create nvt.Dataset definition for CSV files'''
fs_spec = fsspec.filesystem('gs')
rec_symbol = '**' if recursive else '*'
valid_paths = []
for path in data_paths:
try:
if fs_spec.isfile(path):
valid_paths.append(path)
else:
path = os.path.join(path, rec_symbol)
for i in fs_spec.glob(path):
if fs_spec.isfile(i):
valid_paths.append(f'gs://{i}')
except FileNotFoundError as fnf_expt:
print(fnf_expt)
print('Incorrect path: {path}.')
except OSError as os_err:
print(os_err)
print(f'Verify access to the bucket.')
return nvt.Dataset(
path_or_source = valid_paths,
engine='csv',
names=list(col_dtypes.keys()),
sep=sep,
dtypes=col_dtypes,
part_size=int(part_mem_frac * device_mem_size()),
client=client,
assume_missing=True
)
def convert_csv_to_parquet(
output_path,
dataset,
shuffle = None
):
'''Convert CSV file to parquet and write to GCS'''
if shuffle:
shuffle = getattr(Shuffle, shuffle)
dataset.to_parquet(
output_path,
preserve_files=True,
shuffle=shuffle
)
def create_criteo_nvt_workflow(client):
'''Create a nvt.Workflow definition with transformation all the steps'''
# Columns definition
cont_names = ["I" + str(x) for x in range(1, 14)]
cat_names = ["C" + str(x) for x in range(1, 27)]
# Transformation pipeline
num_buckets = 10000000
categorify_op = Categorify(max_size=num_buckets)
cat_features = cat_names >> categorify_op
cont_features = cont_names >> FillMissing() >> \
Clip(min_value=0) >> Normalize()
features = cat_features + cont_features + ['label']
# Create and save workflow
return nvt.Workflow(features, client)
def create_cluster(
n_workers,
device_limit_frac,
device_pool_frac,
):
'''
Create a Dask cluster to apply the transformations steps to the Dataset
'''
device_size = device_mem_size()
device_limit = int(device_limit_frac * device_size)
device_pool_size = int(device_pool_frac * device_size)
rmm_pool_size = (device_pool_size // 256) * 256
cluster = LocalCUDACluster(
n_workers=n_workers,
device_memory_limit=device_limit,
rmm_pool_size=rmm_pool_size
)
return Client(cluster)
def create_parquet_dataset(
client,
data_path,
part_mem_frac
):
'''Create a nvt.Dataset definition for the parquet files.'''
fs = fsspec.filesystem('gs')
file_list = fs.glob(
os.path.join(data_path, '*.parquet')
)
if not file_list:
raise FileNotFoundError('Parquet file(s) not found')
file_list = [os.path.join('gs://', i) for i in file_list]
return nvt.Dataset(
file_list,
engine="parquet",
part_size=int(part_mem_frac * device_mem_size()),
client=client
)
def analyze_dataset(
workflow,
dataset,
):
'''Calculate statistics for a given workflow'''
workflow.fit(dataset)
return workflow
def transform_dataset(
dataset,
workflow
):
'''Apply the transformations to the dataset.'''
workflow.transform(dataset)
return dataset
def load_workflow(
workflow_path,
client,
):
'''Load a workflow definition from a path'''
return nvt.Workflow.load(workflow_path, client)
def save_workflow(
workflow,
output_path
):
'''Save workflow to a path'''
workflow.save(output_path)
def save_dataset(
dataset,
output_path,
shuffle = None
):
'''Save dataset to parquet files to path.'''
if shuffle:
shuffle = getattr(Shuffle, shuffle)
dataset.to_parquet(
output_path=output_path,
shuffle=shuffle
)
INFO:root:Creating Dask cluster.
INFO:numba.cuda.cudadrv.driver:init
distributed.preloading - INFO - Import preload module: dask_cuda.initialize
distributed.preloading - INFO - Import preload module: dask_cuda.initialize
distributed.preloading - INFO - Import preload module: dask_cuda.initialize
distributed.preloading - INFO - Import preload module: dask_cuda.initialize
INFO:root:Creating transformation workflow.
INFO:root:Creating dataset to be analysed.
distributed.nanny - WARNING - Worker exceeded 95% memory budget. Restarting
distributed.nanny - WARNING - Restarting worker
distributed.preloading - INFO - Import preload module: dask_cuda.initialize
distributed.nanny - WARNING - Worker exceeded 95% memory budget. Restarting
distributed.nanny - WARNING - Restarting worker
distributed.preloading - INFO - Import preload module: dask_cuda.initialize
distributed.nanny - WARNING - Worker exceeded 95% memory budget. Restarting
distributed.nanny - WARNING - Restarting worker
distributed.preloading - INFO - Import preload module: dask_cuda.initialize
distributed.nanny - WARNING - Worker exceeded 95% memory budget. Restarting
distributed.core - ERROR - None
Traceback (most recent call last):
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/core.py", line 575, in handle_stream
handler(**merge(extra, msg))
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/scheduler.py", line 4961, in client_releases_keys
self.transitions(recommendations)
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/scheduler.py", line 6975, in transitions
self.send_all(client_msgs, worker_msgs)
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/scheduler.py", line 5478, in send_all
w = stream_comms[worker]
KeyError: None
distributed.core - ERROR - Exception while handling op register-client
Traceback (most recent call last):
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/core.py", line 502, in handle_comm
result = await result
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/scheduler.py", line 5192, in add_client
await self.handle_stream(comm=comm, extra={"client": client})
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/core.py", line 575, in handle_stream
handler(**merge(extra, msg))
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/scheduler.py", line 4961, in client_releases_keys
self.transitions(recommendations)
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/scheduler.py", line 6975, in transitions
self.send_all(client_msgs, worker_msgs)
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/scheduler.py", line 5478, in send_all
w = stream_comms[worker]
KeyError: None
tornado.application - ERROR - Exception in callback functools.partial(<function TCPServer._handle_connection.<locals>.<lambda> at 0x7f4cafeaf550>, <Task finished name='Task-211' coro=<BaseTCPListener._handle_stream() done, defined at /usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/comm/tcp.py:478> exception=KeyError(None)>)
Traceback (most recent call last):
File "/usr/local/lib/python3.8/dist-packages/tornado/ioloop.py", line 741, in _run_callback
ret = callback()
File "/usr/local/lib/python3.8/dist-packages/tornado/tcpserver.py", line 331, in <lambda>
gen.convert_yielded(future), lambda f: f.result()
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/comm/tcp.py", line 495, in _handle_stream
await self.comm_handler(comm)
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/core.py", line 502, in handle_comm
result = await result
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/scheduler.py", line 5192, in add_client
await self.handle_stream(comm=comm, extra={"client": client})
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/core.py", line 575, in handle_stream
handler(**merge(extra, msg))
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/scheduler.py", line 4961, in client_releases_keys
self.transitions(recommendations)
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/scheduler.py", line 6975, in transitions
self.send_all(client_msgs, worker_msgs)
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/scheduler.py", line 5478, in send_all
w = stream_comms[worker]
KeyError: None
distributed.nanny - WARNING - Restarting worker
---------------------------------------------------------------------------
KilledWorker Traceback (most recent call last)
<ipython-input-4-bc185e579267> in <module>
----> 1 analyze_dataset_op(
2 parquet_dataset=parquet_dataset,
3 n_workers=4
4 )
<ipython-input-1-ac68e5aa5289> in analyze_dataset_op(parquet_dataset, n_workers, device_limit_frac, device_pool_frac, part_mem_frac)
31 # Create dataset to be fitted
32 logging.info(f'Creating dataset to be analysed.')
---> 33 dataset = etl.create_parquet_dataset(
34 client=client,
35 data_path=os.path.join(
/src/preprocessing/etl.py in create_parquet_dataset(client, data_path, part_mem_frac)
149 file_list = [os.path.join('gs://', i) for i in file_list]
150
--> 151 return nvt.Dataset(
152 file_list,
153 engine="parquet",
/nvtabular/nvtabular/io/dataset.py in __init__(self, path_or_source, engine, npartitions, part_size, part_mem_fraction, storage_options, dtypes, client, cpu, base_dataset, schema, **kwargs)
295 if isinstance(engine, str):
296 if engine == "parquet":
--> 297 self.engine = ParquetDatasetEngine(
298 paths, part_size, storage_options=storage_options, cpu=self.cpu, **kwargs
299 )
/nvtabular/nvtabular/io/parquet.py in __init__(self, paths, part_size, storage_options, row_groups_per_part, legacy, batch_size, cpu, **kwargs)
225
226 if row_groups_per_part is None:
--> 227 self._real_meta, rg_byte_size_0 = run_on_worker(
228 _sample_row_group, self._path0, self.fs, cpu=self.cpu, memory_usage=True
229 )
/nvtabular/nvtabular/utils.py in run_on_worker(func, *args, **kwargs)
176 if global_dask_client(kwargs.get("client", None)):
177 # There is a specified or global Dask client. Use it
--> 178 return dask.delayed(func)(*args, **kwargs).compute()
179 # No Dask client - Use simple function call
180 return func(*args, **kwargs)
/usr/local/lib/python3.8/dist-packages/dask-2021.7.1-py3.8.egg/dask/base.py in compute(self, **kwargs)
284 dask.base.compute
285 """
--> 286 (result,) = compute(self, traverse=False, **kwargs)
287 return result
288
/usr/local/lib/python3.8/dist-packages/dask-2021.7.1-py3.8.egg/dask/base.py in compute(*args, **kwargs)
566 postcomputes.append(x.__dask_postcompute__())
567
--> 568 results = schedule(dsk, keys, **kwargs)
569 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
570
/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/client.py in get(self, dsk, keys, workers, allow_other_workers, resources, sync, asynchronous, direct, retries, priority, fifo_timeout, actors, **kwargs)
2741 should_rejoin = False
2742 try:
-> 2743 results = self.gather(packed, asynchronous=asynchronous, direct=direct)
2744 finally:
2745 for f in futures.values():
/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/client.py in gather(self, futures, errors, direct, asynchronous)
2018 else:
2019 local_worker = None
-> 2020 return self.sync(
2021 self._gather,
2022 futures,
/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/client.py in sync(self, func, asynchronous, callback_timeout, *args, **kwargs)
859 return future
860 else:
--> 861 return sync(
862 self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
863 )
/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/utils.py in sync(loop, func, callback_timeout, *args, **kwargs)
324 if error[0]:
325 typ, exc, tb = error[0]
--> 326 raise exc.with_traceback(tb)
327 else:
328 return result[0]
/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/utils.py in f()
307 if callback_timeout is not None:
308 future = asyncio.wait_for(future, callback_timeout)
--> 309 result[0] = yield future
310 except Exception:
311 error[0] = sys.exc_info()
/usr/local/lib/python3.8/dist-packages/tornado/gen.py in run(self)
760
761 try:
--> 762 value = future.result()
763 except Exception:
764 exc_info = sys.exc_info()
/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/client.py in _gather(self, futures, errors, direct, local_worker)
1883 exc = CancelledError(key)
1884 else:
-> 1885 raise exception.with_traceback(traceback)
1886 raise exc
1887 if errors == "skip":
KilledWorker: ('_sample_row_group-73d451ad-5c12-41eb-b6e1-e5b606350eac', <WorkerState 'tcp://127.0.0.1:41009', name: 3, memory: 0, processing: 1>)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment