Created
February 8, 2022 01:30
-
-
Save mengdong/d6a24fc266d9806ccd74cd9890b67c6a to your computer and use it in GitHub Desktop.
nvt gcs test
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import time | |
from dask_cuda import LocalCUDACluster | |
from dask.distributed import Client, performance_report | |
import nvtabular as nvt | |
from nvtabular.utils import device_mem_size | |
from nvtabular.ops import Categorify, FillMissing, Clip, Normalize | |
def create_criteo_nvt_workflow(client): | |
"""Create a nvt.Workflow definition with transformation all the steps.""" | |
# Columns definition | |
cont_names = ['I' + str(x) for x in range(1, 14)] | |
cat_names = ['C' + str(x) for x in range(1, 27)] | |
# Transformation pipeline | |
num_buckets = 10000000 | |
categorify_op = Categorify(max_size=num_buckets) | |
cat_features = cat_names >> categorify_op | |
cont_features = cont_names >> FillMissing() >> Clip( | |
min_value=0) >> Normalize() | |
features = cat_features + cont_features + ['label'] | |
# Create and save workflow | |
return nvt.Workflow(features, client) | |
def analyze_dataset( | |
workflow, | |
dataset, | |
): | |
"""Calculate statistics for a given workflow.""" | |
workflow.fit(dataset) | |
return workflow | |
def transform_dataset( | |
dataset, | |
workflow | |
): | |
"""Apply the transformations to the dataset.""" | |
workflow.transform(dataset) | |
return dataset | |
def load_workflow( | |
workflow_path, | |
client, | |
): | |
"""Load a workflow definition from a path.""" | |
return nvt.Workflow.load(workflow_path, client) | |
def main(): | |
data_path = 'gs://criteo-datasets/criteo_raw_parquet' | |
n_workers = 4 | |
frac_size = 0.10 | |
num_buckets = 10_000_000 | |
memory_limit = 100_000_000_000 | |
device_size = device_mem_size() | |
device_limit_frac, device_pool_frac = 0.60, 0.90 | |
device_limit = int(device_limit_frac * device_size) | |
device_pool_size = int(device_pool_frac * device_size) | |
rmm_pool_size = (device_pool_size // 256) * 256 | |
# Spin up local cluster | |
cluster = LocalCUDACluster( | |
n_workers=n_workers, | |
memory_limit=memory_limit, | |
device_memory_limit=device_limit, | |
rmm_pool_size=rmm_pool_size, | |
) | |
client = Client(cluster) | |
print(client) | |
# Define dataset | |
dataset = nvt.Dataset( | |
data_path, | |
engine="parquet", | |
part_mem_fraction=frac_size, | |
client=client, | |
) | |
# Create Workflow | |
criteo_workflow = create_criteo_nvt_workflow(client=client) | |
criteo_workflow = analyze_dataset(criteo_workflow, dataset) | |
# Call workflow.fit() | |
timei = time.time() | |
with performance_report(filename="workflow_fit_profile.html"): | |
criteo_workflow.fit(dataset) | |
timef = time.time() | |
print(f"\nWorkflow fit in {timef-timei} seconds.\n") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
works on 4XT4 + n1-highmem-64