Skip to content

Instantly share code, notes, and snippets.

@mengdong
Created February 8, 2022 01:30
Show Gist options
  • Save mengdong/d6a24fc266d9806ccd74cd9890b67c6a to your computer and use it in GitHub Desktop.
Save mengdong/d6a24fc266d9806ccd74cd9890b67c6a to your computer and use it in GitHub Desktop.
nvt gcs test
import os
import time
from dask_cuda import LocalCUDACluster
from dask.distributed import Client, performance_report
import nvtabular as nvt
from nvtabular.utils import device_mem_size
from nvtabular.ops import Categorify, FillMissing, Clip, Normalize
def create_criteo_nvt_workflow(client):
"""Create a nvt.Workflow definition with transformation all the steps."""
# Columns definition
cont_names = ['I' + str(x) for x in range(1, 14)]
cat_names = ['C' + str(x) for x in range(1, 27)]
# Transformation pipeline
num_buckets = 10000000
categorify_op = Categorify(max_size=num_buckets)
cat_features = cat_names >> categorify_op
cont_features = cont_names >> FillMissing() >> Clip(
min_value=0) >> Normalize()
features = cat_features + cont_features + ['label']
# Create and save workflow
return nvt.Workflow(features, client)
def analyze_dataset(
workflow,
dataset,
):
"""Calculate statistics for a given workflow."""
workflow.fit(dataset)
return workflow
def transform_dataset(
dataset,
workflow
):
"""Apply the transformations to the dataset."""
workflow.transform(dataset)
return dataset
def load_workflow(
workflow_path,
client,
):
"""Load a workflow definition from a path."""
return nvt.Workflow.load(workflow_path, client)
def main():
data_path = 'gs://criteo-datasets/criteo_raw_parquet'
n_workers = 4
frac_size = 0.10
num_buckets = 10_000_000
memory_limit = 100_000_000_000
device_size = device_mem_size()
device_limit_frac, device_pool_frac = 0.60, 0.90
device_limit = int(device_limit_frac * device_size)
device_pool_size = int(device_pool_frac * device_size)
rmm_pool_size = (device_pool_size // 256) * 256
# Spin up local cluster
cluster = LocalCUDACluster(
n_workers=n_workers,
memory_limit=memory_limit,
device_memory_limit=device_limit,
rmm_pool_size=rmm_pool_size,
)
client = Client(cluster)
print(client)
# Define dataset
dataset = nvt.Dataset(
data_path,
engine="parquet",
part_mem_fraction=frac_size,
client=client,
)
# Create Workflow
criteo_workflow = create_criteo_nvt_workflow(client=client)
criteo_workflow = analyze_dataset(criteo_workflow, dataset)
# Call workflow.fit()
timei = time.time()
with performance_report(filename="workflow_fit_profile.html"):
criteo_workflow.fit(dataset)
timef = time.time()
print(f"\nWorkflow fit in {timef-timei} seconds.\n")
if __name__ == "__main__":
main()
@mengdong
Copy link
Author

mengdong commented Feb 8, 2022

works on 4XT4 + n1-highmem-64

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment