Created
May 18, 2025 05:11
-
-
Save youkaichao/cc37500079dcc57db7a98b44cd17698a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.utils.cpp_extension import load_inline | |
src = { | |
"cuda": r""" | |
#include <cuda_runtime.h> | |
#include <torch/all.h> | |
#include <c10/cuda/CUDAStream.h> | |
__global__ void computation_kernel(unsigned long long total_nanosec) { | |
const unsigned long long max_sleep = 1'000'000; // 1 ms per nanosleep | |
unsigned long long slept = 0; | |
while (slept + max_sleep <= total_nanosec) { | |
__nanosleep(max_sleep); | |
slept += max_sleep; | |
} | |
// Sleep the remainder, if any | |
if (slept < total_nanosec) { | |
__nanosleep(total_nanosec - slept); | |
} | |
} | |
__global__ void communication_kernel(unsigned long long total_nanosec) { | |
const unsigned long long max_sleep = 1'000'000; // 1 ms per nanosleep | |
unsigned long long slept = 0; | |
while (slept + max_sleep <= total_nanosec) { | |
__nanosleep(max_sleep); | |
slept += max_sleep; | |
} | |
// Sleep the remainder, if any | |
if (slept < total_nanosec) { | |
__nanosleep(total_nanosec - slept); | |
} | |
} | |
void computation(unsigned long long nanosec) { | |
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); | |
computation_kernel<<<1, 1, 0, stream>>>(nanosec); | |
} | |
void communication(unsigned long long nanosec) { | |
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); | |
communication_kernel<<<1, 1, 0, stream>>>(nanosec); | |
} | |
""", | |
"cpp": r""" | |
void computation(unsigned long long nanosec); | |
void communication(unsigned long long nanosec); | |
""" | |
} | |
# Load the inline extension with stream-aware kernel launcher | |
nanosleep_module = load_inline( | |
name="nanosleep_ext", | |
cpp_sources=src["cpp"], | |
cuda_sources=src["cuda"], | |
functions=["computation", "communication"], | |
extra_cuda_cflags=["-arch=sm_70"], | |
verbose=False, | |
with_cuda=True, | |
) | |
def one_giant_batch(): | |
nanosleep_module.computation(200_000_000) | |
nanosleep_module.communication(200_000_000) | |
nanosleep_module.computation(200_000_000) | |
nanosleep_module.communication(200_000_000) | |
stream1 = torch.cuda.Stream() | |
stream2 = torch.cuda.Stream() | |
def microbatch_overlapping(): | |
# attach stream1 and steam2 to the current stream | |
stream = torch.cuda.current_stream() | |
event = torch.cuda.Event() | |
event.record(stream) | |
stream1.wait_event(event) | |
stream2.wait_event(event) | |
# stream 2 will wait for stream1 | |
stream1_event = None | |
funcs = [ | |
nanosleep_module.computation, | |
nanosleep_module.communication, | |
nanosleep_module.computation, | |
nanosleep_module.communication, | |
] | |
for func in funcs: | |
with torch.cuda.stream(stream1): | |
# microbatch 0 | |
func(100_000_000) | |
stream1_event = torch.cuda.Event() | |
stream1_event.record(stream1) | |
with torch.cuda.stream(stream2): | |
# microbatch 1 | |
stream2.wait_event(stream1_event) | |
func(100_000_000) | |
# sync the streams back to the main stream | |
e = torch.cuda.Event() | |
e.record(stream1) | |
stream.wait_event(e) | |
e = torch.cuda.Event() | |
e.record(stream2) | |
stream.wait_event(e) | |
import time | |
one_giant_batch() | |
torch.cuda.current_stream().synchronize() | |
with torch.cuda.nvtx.range("one_giant_batch"): | |
start = time.time() | |
one_giant_batch() | |
torch.cuda.current_stream().synchronize() | |
end = time.time() | |
print(f"one_giant_batch takes: {end - start:.3f} sec") | |
microbatch_overlapping() | |
torch.cuda.current_stream().synchronize() | |
with torch.cuda.nvtx.range("microbatch_overlapping"): | |
start = time.time() | |
microbatch_overlapping() | |
torch.cuda.current_stream().synchronize() | |
end = time.time() | |
print(f"microbatch_overlapping takes: {end - start:.3f} sec") | |
graph = torch.cuda.CUDAGraph() | |
graph.enable_debug_mode() | |
with torch.cuda.graph(graph): | |
microbatch_overlapping() | |
with torch.cuda.nvtx.range("microbatch_overlapping (cudagraph mode)"): | |
start = time.time() | |
graph.replay() | |
torch.cuda.current_stream().synchronize() | |
end = time.time() | |
print(f"microbatch_overlapping (cudagraph mode) takes: {end - start:.3f} sec") | |
graph.debug_dump("arch.dot") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
run with
nsys profile -o overlap python test.py
, and use nsight compute to analyze theoverlap.nsys-rep
, we can see:one giant batch on the default stream
two microbatch overlapping with 2 streams (eager mode)
two microbatch overlapping with the default stream (cudagraph mode)
the overlapping implementation is compatible with cudagraph, and we can capture it and replay on the default stream
We can also inspect the content of the cudagraph by displaying the
arch.dot
file (a dot graph) using https://dreampuf.github.io/GraphvizOnline :