-
-
Save mengdong/8526f2765707be492b1b697c3fb8c687 to your computer and use it in GitHub Desktop.
Fit 1 parquet file from Criteo
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if __name__ == "__main__": | |
parquet_dataset = "dongm-debug" | |
n_workers = 4 | |
device_limit_frac = 0.8 | |
device_pool_frac = 0.9 | |
part_mem_frac = 0.125 | |
import etl | |
import logging | |
import os | |
workflow = './workflow' | |
logging.basicConfig(level=logging.INFO) | |
split = 'fsspec_parquet' | |
# Create Dask cluster | |
logging.info('Creating Dask cluster.') | |
client = etl.create_cluster( | |
n_workers = n_workers, | |
device_limit_frac = device_limit_frac, | |
device_pool_frac = device_pool_frac | |
) | |
# Create data transformation workflow. This step will only | |
# calculate statistics based on the transformations | |
logging.info('Creating transformation workflow.') | |
criteo_workflow = etl.create_criteo_nvt_workflow(client=client) | |
# Create dataset to be fitted | |
logging.info(f'Creating dataset to be analysed.') | |
dataset = etl.create_parquet_dataset( | |
client=client, | |
data_path=os.path.join( | |
parquet_dataset, | |
split | |
), | |
part_mem_frac=part_mem_frac | |
) | |
logging.info(f'Starting workflow fitting for {split} split.') | |
criteo_workflow = etl.analyze_dataset(criteo_workflow, dataset) | |
logging.info('Finished generating statistics for dataset.') | |
etl.save_workflow(criteo_workflow, workflow) | |
logging.info('Workflow saved to GCS') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2021 Google Inc. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
"""Data preprocessing.""" | |
from typing import Dict, Union | |
import numpy as np | |
import fsspec | |
import os | |
import nvtabular as nvt | |
from nvtabular.utils import device_mem_size | |
from nvtabular.io.shuffle import Shuffle | |
from nvtabular.ops import ( | |
Categorify, | |
Clip, | |
FillMissing, | |
Normalize, | |
) | |
from dask_cuda import LocalCUDACluster | |
from dask.distributed import Client | |
def create_csv_dataset( | |
data_paths, | |
sep, | |
recursive, | |
col_dtypes, | |
part_mem_frac, | |
client | |
): | |
'''Create nvt.Dataset definition for CSV files''' | |
fs_spec = fsspec.filesystem('gs') | |
rec_symbol = '**' if recursive else '*' | |
valid_paths = [] | |
for path in data_paths: | |
try: | |
if fs_spec.isfile(path): | |
valid_paths.append(path) | |
else: | |
path = os.path.join(path, rec_symbol) | |
for i in fs_spec.glob(path): | |
if fs_spec.isfile(i): | |
valid_paths.append(f'gs://{i}') | |
except FileNotFoundError as fnf_expt: | |
print(fnf_expt) | |
print('Incorrect path: {path}.') | |
except OSError as os_err: | |
print(os_err) | |
print(f'Verify access to the bucket.') | |
return nvt.Dataset( | |
path_or_source = valid_paths, | |
engine='csv', | |
names=list(col_dtypes.keys()), | |
sep=sep, | |
dtypes=col_dtypes, | |
part_size=int(part_mem_frac * device_mem_size()), | |
client=client, | |
assume_missing=True | |
) | |
def convert_csv_to_parquet( | |
output_path, | |
dataset, | |
shuffle = None | |
): | |
'''Convert CSV file to parquet and write to GCS''' | |
if shuffle: | |
shuffle = getattr(Shuffle, shuffle) | |
dataset.to_parquet( | |
output_path, | |
preserve_files=True, | |
shuffle=shuffle | |
) | |
def create_criteo_nvt_workflow(client): | |
'''Create a nvt.Workflow definition with transformation all the steps''' | |
# Columns definition | |
cont_names = ["I" + str(x) for x in range(1, 14)] | |
cat_names = ["C" + str(x) for x in range(1, 27)] | |
# Transformation pipeline | |
num_buckets = 10000000 | |
categorify_op = Categorify(max_size=num_buckets) | |
cat_features = cat_names >> categorify_op | |
cont_features = cont_names >> FillMissing() >> \ | |
Clip(min_value=0) >> Normalize() | |
features = cat_features + cont_features + ['label'] | |
# Create and save workflow | |
return nvt.Workflow(features, client) | |
def create_cluster( | |
n_workers, | |
device_limit_frac, | |
device_pool_frac, | |
): | |
''' | |
Create a Dask cluster to apply the transformations steps to the Dataset | |
''' | |
device_size = device_mem_size() | |
device_limit = int(device_limit_frac * device_size) | |
device_pool_size = int(device_pool_frac * device_size) | |
rmm_pool_size = (device_pool_size // 256) * 256 | |
cluster = LocalCUDACluster( | |
n_workers=n_workers, | |
device_memory_limit=device_limit, | |
rmm_pool_size=rmm_pool_size | |
) | |
return Client(cluster) | |
def create_parquet_dataset( | |
client, | |
data_path, | |
part_mem_frac | |
): | |
'''Create a nvt.Dataset definition for the parquet files.''' | |
fs = fsspec.filesystem('gs') | |
file_list = fs.glob( | |
os.path.join(data_path, '*.parquet') | |
) | |
if not file_list: | |
raise FileNotFoundError('Parquet file(s) not found') | |
file_list = [os.path.join('gs://', i) for i in file_list] | |
return nvt.Dataset( | |
file_list, | |
engine="parquet", | |
part_size=int(part_mem_frac * device_mem_size()), | |
client=client | |
) | |
def analyze_dataset( | |
workflow, | |
dataset, | |
): | |
'''Calculate statistics for a given workflow''' | |
workflow.fit(dataset) | |
return workflow | |
def transform_dataset( | |
dataset, | |
workflow | |
): | |
'''Apply the transformations to the dataset.''' | |
workflow.transform(dataset) | |
return dataset | |
def load_workflow( | |
workflow_path, | |
client, | |
): | |
'''Load a workflow definition from a path''' | |
return nvt.Workflow.load(workflow_path, client) | |
def save_workflow( | |
workflow, | |
output_path | |
): | |
'''Save workflow to a path''' | |
workflow.save(output_path) | |
def save_dataset( | |
dataset, | |
output_path, | |
shuffle = None | |
): | |
'''Save dataset to parquet files to path.''' | |
if shuffle: | |
shuffle = getattr(Shuffle, shuffle) | |
dataset.to_parquet( | |
output_path=output_path, | |
shuffle=shuffle | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
INFO:root:Creating Dask cluster. | |
INFO:numba.cuda.cudadrv.driver:init | |
distributed.preloading - INFO - Import preload module: dask_cuda.initialize | |
distributed.preloading - INFO - Import preload module: dask_cuda.initialize | |
distributed.preloading - INFO - Import preload module: dask_cuda.initialize | |
distributed.preloading - INFO - Import preload module: dask_cuda.initialize | |
INFO:root:Creating transformation workflow. | |
INFO:root:Creating dataset to be analysed. | |
distributed.nanny - WARNING - Worker exceeded 95% memory budget. Restarting | |
distributed.nanny - WARNING - Restarting worker | |
distributed.preloading - INFO - Import preload module: dask_cuda.initialize | |
distributed.nanny - WARNING - Worker exceeded 95% memory budget. Restarting | |
distributed.nanny - WARNING - Restarting worker | |
distributed.preloading - INFO - Import preload module: dask_cuda.initialize | |
distributed.nanny - WARNING - Worker exceeded 95% memory budget. Restarting | |
distributed.nanny - WARNING - Restarting worker | |
distributed.preloading - INFO - Import preload module: dask_cuda.initialize | |
distributed.nanny - WARNING - Worker exceeded 95% memory budget. Restarting | |
distributed.core - ERROR - None | |
Traceback (most recent call last): | |
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/core.py", line 575, in handle_stream | |
handler(**merge(extra, msg)) | |
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/scheduler.py", line 4961, in client_releases_keys | |
self.transitions(recommendations) | |
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/scheduler.py", line 6975, in transitions | |
self.send_all(client_msgs, worker_msgs) | |
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/scheduler.py", line 5478, in send_all | |
w = stream_comms[worker] | |
KeyError: None | |
distributed.core - ERROR - Exception while handling op register-client | |
Traceback (most recent call last): | |
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/core.py", line 502, in handle_comm | |
result = await result | |
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/scheduler.py", line 5192, in add_client | |
await self.handle_stream(comm=comm, extra={"client": client}) | |
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/core.py", line 575, in handle_stream | |
handler(**merge(extra, msg)) | |
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/scheduler.py", line 4961, in client_releases_keys | |
self.transitions(recommendations) | |
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/scheduler.py", line 6975, in transitions | |
self.send_all(client_msgs, worker_msgs) | |
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/scheduler.py", line 5478, in send_all | |
w = stream_comms[worker] | |
KeyError: None | |
tornado.application - ERROR - Exception in callback functools.partial(<function TCPServer._handle_connection.<locals>.<lambda> at 0x7f4cafeaf550>, <Task finished name='Task-211' coro=<BaseTCPListener._handle_stream() done, defined at /usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/comm/tcp.py:478> exception=KeyError(None)>) | |
Traceback (most recent call last): | |
File "/usr/local/lib/python3.8/dist-packages/tornado/ioloop.py", line 741, in _run_callback | |
ret = callback() | |
File "/usr/local/lib/python3.8/dist-packages/tornado/tcpserver.py", line 331, in <lambda> | |
gen.convert_yielded(future), lambda f: f.result() | |
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/comm/tcp.py", line 495, in _handle_stream | |
await self.comm_handler(comm) | |
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/core.py", line 502, in handle_comm | |
result = await result | |
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/scheduler.py", line 5192, in add_client | |
await self.handle_stream(comm=comm, extra={"client": client}) | |
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/core.py", line 575, in handle_stream | |
handler(**merge(extra, msg)) | |
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/scheduler.py", line 4961, in client_releases_keys | |
self.transitions(recommendations) | |
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/scheduler.py", line 6975, in transitions | |
self.send_all(client_msgs, worker_msgs) | |
File "/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/scheduler.py", line 5478, in send_all | |
w = stream_comms[worker] | |
KeyError: None | |
distributed.nanny - WARNING - Restarting worker | |
--------------------------------------------------------------------------- | |
KilledWorker Traceback (most recent call last) | |
<ipython-input-4-bc185e579267> in <module> | |
----> 1 analyze_dataset_op( | |
2 parquet_dataset=parquet_dataset, | |
3 n_workers=4 | |
4 ) | |
<ipython-input-1-ac68e5aa5289> in analyze_dataset_op(parquet_dataset, n_workers, device_limit_frac, device_pool_frac, part_mem_frac) | |
31 # Create dataset to be fitted | |
32 logging.info(f'Creating dataset to be analysed.') | |
---> 33 dataset = etl.create_parquet_dataset( | |
34 client=client, | |
35 data_path=os.path.join( | |
/src/preprocessing/etl.py in create_parquet_dataset(client, data_path, part_mem_frac) | |
149 file_list = [os.path.join('gs://', i) for i in file_list] | |
150 | |
--> 151 return nvt.Dataset( | |
152 file_list, | |
153 engine="parquet", | |
/nvtabular/nvtabular/io/dataset.py in __init__(self, path_or_source, engine, npartitions, part_size, part_mem_fraction, storage_options, dtypes, client, cpu, base_dataset, schema, **kwargs) | |
295 if isinstance(engine, str): | |
296 if engine == "parquet": | |
--> 297 self.engine = ParquetDatasetEngine( | |
298 paths, part_size, storage_options=storage_options, cpu=self.cpu, **kwargs | |
299 ) | |
/nvtabular/nvtabular/io/parquet.py in __init__(self, paths, part_size, storage_options, row_groups_per_part, legacy, batch_size, cpu, **kwargs) | |
225 | |
226 if row_groups_per_part is None: | |
--> 227 self._real_meta, rg_byte_size_0 = run_on_worker( | |
228 _sample_row_group, self._path0, self.fs, cpu=self.cpu, memory_usage=True | |
229 ) | |
/nvtabular/nvtabular/utils.py in run_on_worker(func, *args, **kwargs) | |
176 if global_dask_client(kwargs.get("client", None)): | |
177 # There is a specified or global Dask client. Use it | |
--> 178 return dask.delayed(func)(*args, **kwargs).compute() | |
179 # No Dask client - Use simple function call | |
180 return func(*args, **kwargs) | |
/usr/local/lib/python3.8/dist-packages/dask-2021.7.1-py3.8.egg/dask/base.py in compute(self, **kwargs) | |
284 dask.base.compute | |
285 """ | |
--> 286 (result,) = compute(self, traverse=False, **kwargs) | |
287 return result | |
288 | |
/usr/local/lib/python3.8/dist-packages/dask-2021.7.1-py3.8.egg/dask/base.py in compute(*args, **kwargs) | |
566 postcomputes.append(x.__dask_postcompute__()) | |
567 | |
--> 568 results = schedule(dsk, keys, **kwargs) | |
569 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)]) | |
570 | |
/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/client.py in get(self, dsk, keys, workers, allow_other_workers, resources, sync, asynchronous, direct, retries, priority, fifo_timeout, actors, **kwargs) | |
2741 should_rejoin = False | |
2742 try: | |
-> 2743 results = self.gather(packed, asynchronous=asynchronous, direct=direct) | |
2744 finally: | |
2745 for f in futures.values(): | |
/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/client.py in gather(self, futures, errors, direct, asynchronous) | |
2018 else: | |
2019 local_worker = None | |
-> 2020 return self.sync( | |
2021 self._gather, | |
2022 futures, | |
/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/client.py in sync(self, func, asynchronous, callback_timeout, *args, **kwargs) | |
859 return future | |
860 else: | |
--> 861 return sync( | |
862 self.loop, func, *args, callback_timeout=callback_timeout, **kwargs | |
863 ) | |
/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/utils.py in sync(loop, func, callback_timeout, *args, **kwargs) | |
324 if error[0]: | |
325 typ, exc, tb = error[0] | |
--> 326 raise exc.with_traceback(tb) | |
327 else: | |
328 return result[0] | |
/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/utils.py in f() | |
307 if callback_timeout is not None: | |
308 future = asyncio.wait_for(future, callback_timeout) | |
--> 309 result[0] = yield future | |
310 except Exception: | |
311 error[0] = sys.exc_info() | |
/usr/local/lib/python3.8/dist-packages/tornado/gen.py in run(self) | |
760 | |
761 try: | |
--> 762 value = future.result() | |
763 except Exception: | |
764 exc_info = sys.exc_info() | |
/usr/local/lib/python3.8/dist-packages/distributed-2021.7.1-py3.8.egg/distributed/client.py in _gather(self, futures, errors, direct, local_worker) | |
1883 exc = CancelledError(key) | |
1884 else: | |
-> 1885 raise exception.with_traceback(traceback) | |
1886 raise exc | |
1887 if errors == "skip": | |
KilledWorker: ('_sample_row_group-73d451ad-5c12-41eb-b6e1-e5b606350eac', <WorkerState 'tcp://127.0.0.1:41009', name: 3, memory: 0, processing: 1>) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment