Skip to content

Instantly share code, notes, and snippets.

@youkaichao
Created May 18, 2025 05:11
Show Gist options
  • Save youkaichao/cc37500079dcc57db7a98b44cd17698a to your computer and use it in GitHub Desktop.
Save youkaichao/cc37500079dcc57db7a98b44cd17698a to your computer and use it in GitHub Desktop.
import torch
from torch.utils.cpp_extension import load_inline
src = {
"cuda": r"""
#include <cuda_runtime.h>
#include <torch/all.h>
#include <c10/cuda/CUDAStream.h>
__global__ void computation_kernel(unsigned long long total_nanosec) {
const unsigned long long max_sleep = 1'000'000; // 1 ms per nanosleep
unsigned long long slept = 0;
while (slept + max_sleep <= total_nanosec) {
__nanosleep(max_sleep);
slept += max_sleep;
}
// Sleep the remainder, if any
if (slept < total_nanosec) {
__nanosleep(total_nanosec - slept);
}
}
__global__ void communication_kernel(unsigned long long total_nanosec) {
const unsigned long long max_sleep = 1'000'000; // 1 ms per nanosleep
unsigned long long slept = 0;
while (slept + max_sleep <= total_nanosec) {
__nanosleep(max_sleep);
slept += max_sleep;
}
// Sleep the remainder, if any
if (slept < total_nanosec) {
__nanosleep(total_nanosec - slept);
}
}
void computation(unsigned long long nanosec) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
computation_kernel<<<1, 1, 0, stream>>>(nanosec);
}
void communication(unsigned long long nanosec) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
communication_kernel<<<1, 1, 0, stream>>>(nanosec);
}
""",
"cpp": r"""
void computation(unsigned long long nanosec);
void communication(unsigned long long nanosec);
"""
}
# Load the inline extension with stream-aware kernel launcher
nanosleep_module = load_inline(
name="nanosleep_ext",
cpp_sources=src["cpp"],
cuda_sources=src["cuda"],
functions=["computation", "communication"],
extra_cuda_cflags=["-arch=sm_70"],
verbose=False,
with_cuda=True,
)
def one_giant_batch():
nanosleep_module.computation(200_000_000)
nanosleep_module.communication(200_000_000)
nanosleep_module.computation(200_000_000)
nanosleep_module.communication(200_000_000)
stream1 = torch.cuda.Stream()
stream2 = torch.cuda.Stream()
def microbatch_overlapping():
# attach stream1 and steam2 to the current stream
stream = torch.cuda.current_stream()
event = torch.cuda.Event()
event.record(stream)
stream1.wait_event(event)
stream2.wait_event(event)
# stream 2 will wait for stream1
stream1_event = None
funcs = [
nanosleep_module.computation,
nanosleep_module.communication,
nanosleep_module.computation,
nanosleep_module.communication,
]
for func in funcs:
with torch.cuda.stream(stream1):
# microbatch 0
func(100_000_000)
stream1_event = torch.cuda.Event()
stream1_event.record(stream1)
with torch.cuda.stream(stream2):
# microbatch 1
stream2.wait_event(stream1_event)
func(100_000_000)
# sync the streams back to the main stream
e = torch.cuda.Event()
e.record(stream1)
stream.wait_event(e)
e = torch.cuda.Event()
e.record(stream2)
stream.wait_event(e)
import time
one_giant_batch()
torch.cuda.current_stream().synchronize()
with torch.cuda.nvtx.range("one_giant_batch"):
start = time.time()
one_giant_batch()
torch.cuda.current_stream().synchronize()
end = time.time()
print(f"one_giant_batch takes: {end - start:.3f} sec")
microbatch_overlapping()
torch.cuda.current_stream().synchronize()
with torch.cuda.nvtx.range("microbatch_overlapping"):
start = time.time()
microbatch_overlapping()
torch.cuda.current_stream().synchronize()
end = time.time()
print(f"microbatch_overlapping takes: {end - start:.3f} sec")
graph = torch.cuda.CUDAGraph()
graph.enable_debug_mode()
with torch.cuda.graph(graph):
microbatch_overlapping()
with torch.cuda.nvtx.range("microbatch_overlapping (cudagraph mode)"):
start = time.time()
graph.replay()
torch.cuda.current_stream().synchronize()
end = time.time()
print(f"microbatch_overlapping (cudagraph mode) takes: {end - start:.3f} sec")
graph.debug_dump("arch.dot")
@youkaichao
Copy link
Author

youkaichao commented May 18, 2025

run with nsys profile -o overlap python test.py , and use nsight compute to analyze the overlap.nsys-rep , we can see:

one giant batch on the default stream

image

two microbatch overlapping with 2 streams (eager mode)

image

two microbatch overlapping with the default stream (cudagraph mode)

the overlapping implementation is compatible with cudagraph, and we can capture it and replay on the default stream

image

We can also inspect the content of the cudagraph by displaying the arch.dot file (a dot graph) using https://dreampuf.github.io/GraphvizOnline :

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment