Skip to content

Instantly share code, notes, and snippets.

@TomAugspurger
Created November 14, 2025 20:03
Show Gist options
  • Select an option

  • Save TomAugspurger/d5b0d3b0e5765e448aa07a4fcc706171 to your computer and use it in GitHub Desktop.

Select an option

Save TomAugspurger/d5b0d3b0e5765e448aa07a4fcc706171 to your computer and use it in GitHub Desktop.
import rapidsmpf
import pylibcudf as plc
import pyarrow as pa
import numpy as np
import rmm.mr
import rmm.pylibrmm.stream
import rapidsmpf.communicator.single
import rapidsmpf.shuffler
import rapidsmpf.buffer.resource
import rapidsmpf.buffer.buffer
import rapidsmpf.rmm_resource_adaptor
import rapidsmpf.statistics
import rapidsmpf.progress_thread
import rapidsmpf.integrations.cudf.partition
import nvtx
def make_table() -> plc.Table:
return plc.Table.from_arrow(
pa.Table.from_pydict({"a": (np.arange(10 * (1024 * 1024) // 8) % 12)})
)
def main():
opts = rapidsmpf.config.Options()
comm = rapidsmpf.communicator.single.new_communicator(opts)
stats = rapidsmpf.statistics.Statistics(enable=True)
tables = [make_table() for _ in range(8)]
streams = [rmm.pylibrmm.stream.Stream() for _ in range(len(tables))]
mr = rapidsmpf.rmm_resource_adaptor.RmmResourceAdaptor(
rmm.mr.CudaAsyncMemoryResource()
)
br = rapidsmpf.buffer.resource.BufferResource(
mr,
memory_available={
rapidsmpf.buffer.buffer.MemoryType.DEVICE: rapidsmpf.buffer.resource.LimitAvailableMemory(
mr, 1024 * 1024
),
},
)
progress_thread = rapidsmpf.progress_thread.ProgressThread(comm, stats)
shuffler = rapidsmpf.shuffler.Shuffler(
comm=comm,
progress_thread=progress_thread,
op_id=0,
total_num_partitions=len(tables),
br=br,
statistics=stats,
)
for i, (table, stream) in enumerate(zip(tables, streams)):
partitioned_and_packed = (
rapidsmpf.integrations.cudf.partition.partition_and_pack(
table,
columns_to_hash=[0],
num_partitions=len(tables),
br=br,
stream=stream,
)
)
with nvtx.annotate("insert_chunks", payload=i):
shuffler.insert_chunks(partitioned_and_packed)
for i in range(len(tables)):
shuffler.insert_finished(i)
for i in range(len(tables)):
shuffler.extract(i)
shuffler.shutdown()
print(stats.report())
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment