Skip to content

Instantly share code, notes, and snippets.

@TomAugspurger
Created June 8, 2026 20:32
Show Gist options
  • Select an option

  • Save TomAugspurger/cebf449f7621f69ff70a5452f1ca5706 to your computer and use it in GitHub Desktop.

Select an option

Save TomAugspurger/cebf449f7621f69ff70a5452f1ca5706 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
"""Small benchmark for shuffle-based joins in cudf-polars."""
from __future__ import annotations
import argparse
import concurrent.futures
import contextlib
import dataclasses
import json
import tempfile
import textwrap
import time
from pathlib import Path
from typing import Any, TYPE_CHECKING
import numpy as np
import polars as pl
from cudf_polars import Translator
from cudf_polars.dsl.traversal import traversal
from cudf_polars.engine.options import StreamingOptions
from cudf_polars.streaming.parallel import lower_ir_graph
from cudf_polars.streaming.shuffle import Shuffle
from cudf_polars.streaming.statistics import collect_statistics
from cudf_polars.utils.config import ConfigOptions
if TYPE_CHECKING:
from collections.abc import Iterator
try:
import nvtx
except ImportError: # pragma: no cover
nvtx = None
def build_parser() -> argparse.ArgumentParser:
"""Build command-line parser."""
parser = argparse.ArgumentParser(
description="Benchmark a cudf-polars join configured to prefer shuffle joins."
)
parser.add_argument("--n-left", type=int, default=8_000_000)
parser.add_argument("--n-right", type=int, default=2_000_000)
parser.add_argument(
"--distinct-keys",
type=int,
default=2_000_000,
help="Number of key values sampled by the left side.",
)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument(
"--how",
type=str,
default="inner",
choices=["inner", "left", "right", "full", "semi", "anti"],
)
parser.add_argument("--iterations", type=int, default=3)
parser.add_argument(
"--data-mode",
type=str,
default="in-memory",
choices=["in-memory", "parquet"],
help="Use DataFrameScan inputs or write+scan temporary parquet files.",
)
parser.add_argument(
"--frontend",
required=True,
type=str,
choices=["dask", "duckdb", "in-memory", "polars-cpu", "ray", "spmd"],
help=textwrap.dedent("""\
Execution frontend:
- dask : Dask distributed multi-GPU execution
- duckdb : DuckDB CPU execution
- in-memory : Single-process GPU, in-memory evaluation
- polars-cpu : Polars CPU streaming engine (no GPU)
- ray : Ray actor-based multi-GPU execution
- spmd : SPMD execution via rrun launcher"""),
)
parser.add_argument(
"--connect",
dest="connect",
default=None,
type=str,
help=textwrap.dedent("""\
Connect to an existing cluster instead of creating a local one.
Only supported with --frontend dask or ray:
- dask : a TCP address (e.g. tcp://host:8786) or a scheduler file path
- ray : a Ray address (e.g. ray://host:10001 or "auto")"""),
)
parser.add_argument(
"--num-gpus",
dest="num_gpus",
default=None,
type=int,
help="Number of GPUs for local cluster creation (--frontend ray/dask only). "
"Cannot be used with --connect. Defaults to all visible GPUs.",
)
parser.add_argument(
"--verify-plan",
action=argparse.BooleanOptionalAction,
default=True,
help="Count Shuffle nodes in lowered IR (static planning only).",
)
parser.add_argument(
"--check-broadcast-sensitivity",
action=argparse.BooleanOptionalAction,
default=False,
help="Also inspect shuffle count with a high broadcast limit.",
)
parser.add_argument(
"--sensitivity-broadcast-limit",
type=int,
default=1_000_000_000,
help="Broadcast limit to use for sensitivity check.",
)
StreamingOptions._add_cli_args(parser)
parser.set_defaults(
dynamic_planning=False,
max_rows_per_partition=250_000,
target_partition_size=1,
broadcast_limit=1,
)
return parser
def make_tables(args: argparse.Namespace) -> tuple[pl.DataFrame, pl.DataFrame]:
"""Create synthetic left/right tables for join benchmarking."""
if args.n_left < 1 or args.n_right < 1:
raise ValueError("Both --n-left and --n-right must be >= 1.")
if args.distinct_keys < 1:
raise ValueError("--distinct-keys must be >= 1.")
rng = np.random.default_rng(args.seed)
key_domain = max(args.distinct_keys, args.n_right)
left = pl.DataFrame(
{
"key": rng.integers(
0, key_domain, size=args.n_left, dtype=np.int64
),
"left_payload": rng.integers(
0, 1_000_000, size=args.n_left, dtype=np.int64
),
"left_row_id": np.arange(args.n_left, dtype=np.int64),
}
)
right = pl.DataFrame(
{
"key": np.arange(args.n_right, dtype=np.int64),
"right_payload": rng.integers(
0, 1_000_000, size=args.n_right, dtype=np.int64
),
"right_row_id": np.arange(args.n_right, dtype=np.int64),
}
)
return left, right
@contextlib.contextmanager
def build_query(args: argparse.Namespace) -> Iterator[pl.LazyFrame]:
"""Yield a join query over synthetic inputs."""
left, right = make_tables(args)
if args.data_mode == "in-memory":
left_lf = left.lazy()
right_lf = right.lazy()
yield left_lf.join(right_lf, on="key", how=args.how)
return
with tempfile.TemporaryDirectory(
prefix="cudf_polars_shuffle_join_"
) as tmp:
tmp_path = Path(tmp)
left_path = tmp_path / "left.parquet"
right_path = tmp_path / "right.parquet"
left.write_parquet(left_path)
right.write_parquet(right_path)
left_lf = pl.scan_parquet(left_path)
right_lf = pl.scan_parquet(right_path)
yield left_lf.join(right_lf, on="key", how=args.how)
def create_engine(args: argparse.Namespace) -> pl.GPUEngine:
"""Create a streaming GPUEngine using StreamingOptions CLI args."""
streaming_options = StreamingOptions._from_argparse(args)
executor_options = streaming_options.to_executor_options()
engine_options = streaming_options.to_engine_options()
engine_options.setdefault("raise_on_fail", True)
return pl.GPUEngine(
executor="streaming",
executor_options=executor_options,
**engine_options,
)
def _executor_summary(engine: pl.GPUEngine) -> dict[str, object]:
config_options = ConfigOptions.from_polars_engine(engine)
executor = config_options.executor
dynamic_planning = executor.dynamic_planning
if dynamic_planning is not None:
dynamic_planning = dataclasses.asdict(dynamic_planning)
return {
"cluster": executor.cluster,
"max_rows_per_partition": executor.max_rows_per_partition,
"target_partition_size": executor.target_partition_size,
"broadcast_limit": executor.broadcast_limit,
"dynamic_planning": dynamic_planning,
"min_device_size": executor.min_device_size,
}
def count_shuffle_nodes(
query: pl.LazyFrame, engine: pl.GPUEngine
) -> int | None:
"""Return the number of lowered Shuffle nodes for static planning."""
config_options = ConfigOptions.from_polars_engine(engine)
if config_options.executor.dynamic_planning is not None:
return None
ir = Translator(query._ldf.visit(), engine).translate_ir()
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as io_pool:
lowered_ir, _ = lower_ir_graph(
ir,
config_options,
collect_statistics(ir, config_options, io_pool),
)
return sum(
1 for node in traversal([lowered_ir]) if isinstance(node, Shuffle)
)
def maybe_verify_plan(
args: argparse.Namespace, query: pl.LazyFrame, engine: pl.GPUEngine
) -> None:
if not args.verify_plan:
return
shuffle_count = count_shuffle_nodes(query, engine)
if shuffle_count is None:
print("Plan verification skipped: dynamic planning is enabled.")
else:
print(f"Lowered Shuffle node count: {shuffle_count}")
if shuffle_count > 0:
print("Plan check: shuffle/hash join path detected.")
else:
print(
"Plan check: no shuffle nodes detected (likely broadcast join)."
)
if not args.check_broadcast_sensitivity:
return
streaming_options = StreamingOptions._from_argparse(args)
executor_options = streaming_options.to_executor_options()
executor_options["dynamic_planning"] = None
executor_options["broadcast_limit"] = args.sensitivity_broadcast_limit
engine_options = streaming_options.to_engine_options()
engine_options.setdefault("raise_on_fail", True)
sensitivity_engine = pl.GPUEngine(
executor="streaming",
executor_options=executor_options,
**engine_options,
)
sensitivity_shuffle_count = count_shuffle_nodes(query, sensitivity_engine)
print(
"Sensitivity check shuffle count "
f"(broadcast_limit={args.sensitivity_broadcast_limit}): "
f"{sensitivity_shuffle_count}"
)
def run_benchmark(args: argparse.Namespace) -> None:
planning_engine = create_engine(args)
print("Resolved streaming executor options:")
print(
json.dumps(
_executor_summary(planning_engine),
default=str,
indent=2,
sort_keys=True,
)
)
with build_query(args) as query:
maybe_verify_plan(args, query, planning_engine)
durations: list[float] = []
result_rows: int | None = None
stream_options = StreamingOptions._from_argparse(args)
executor_options = stream_options.to_executor_options()
engine_options = stream_options.to_engine_options()
engine_options.setdefault("raise_on_fail", True)
if args.frontend in {"dask", "duckdb", "polars-cpu"}:
raise NotImplementedError(
f"--frontend {args.frontend!r} is not implemented in this microbenchmark. "
"Use in-memory, ray, or spmd."
)
@contextlib.contextmanager
def engine_cm() -> Iterator[Any]:
if args.frontend == "in-memory":
yield planning_engine
return
if args.frontend == "ray":
from cudf_polars.engine.ray import RayEngine
ray_executor_options = executor_options.copy()
ray_executor_options.pop("cluster", None)
ray_init_options: dict[str, object] = {}
if args.connect is not None:
ray_init_options["address"] = args.connect
if args.num_gpus is not None:
ray_init_options["num_gpus"] = args.num_gpus
ray_init_options["runtime_env"] = {
"nsight": {
"python-backtrace": "cuda",
"python-sampling": "true",
"trace": "cuda,osrt,nvtx,python-gil,ucx",
}
}
with RayEngine(
rapidsmpf_options=stream_options.to_rapidsmpf_options(),
executor_options=ray_executor_options,
engine_options=engine_options,
ray_init_options=ray_init_options,
) as engine:
yield engine
return
if args.frontend == "spmd":
from cudf_polars.engine.spmd import SPMDEngine
spmd_executor_options = executor_options.copy()
spmd_executor_options.pop("cluster", None)
with SPMDEngine(
rapidsmpf_options=stream_options.to_rapidsmpf_options(),
executor_options=spmd_executor_options,
engine_options=engine_options,
) as engine:
yield engine
return
raise AssertionError(f"Unexpected frontend {args.frontend!r}")
with engine_cm() as engine:
if args.frontend == "spmd" and getattr(engine, "rank", 0) != 0:
# Non-root ranks still execute query but don't print benchmark output.
for _ in range(args.iterations):
query.collect(engine=engine)
return
for i in range(args.iterations):
if nvtx is None:
annotation = contextlib.nullcontext()
else:
annotation = nvtx.annotate(
message=f"shuffle-join iteration {i}",
domain="cudf_polars",
color="green",
)
with annotation:
start = time.monotonic()
result = query.collect(engine=engine)
durations.append(time.monotonic() - start)
result_rows = result.height
print(
f"Iteration {i}: {durations[-1]:.4f}s"
+ (
f", rows={result_rows}"
if result_rows is not None
else ""
)
)
mean_duration = sum(durations) / len(durations)
print(
"Timing summary (seconds): "
f"min={min(durations):.4f}, max={max(durations):.4f}, mean={mean_duration:.4f}"
)
def main() -> None:
args = build_parser().parse_args()
run_benchmark(args)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment