Skip to content

Instantly share code, notes, and snippets.

@mlaves
Created July 29, 2025 21:14
Show Gist options
  • Save mlaves/8a127b1134564dc5361b9353c3b801bf to your computer and use it in GitHub Desktop.
Save mlaves/8a127b1134564dc5361b9353c3b801bf to your computer and use it in GitHub Desktop.
Benchmark grid_sampler_3d on MPS vs. CPU
import torch
from time import perf_counter
def benchmark_grid_sampler_3d(device, input_shape, grid_shape, interp=0, padding=0, align_corners=False, num_warmup=5, num_runs=20):
input = torch.randn(input_shape, dtype=torch.float32).to(device)
grid = torch.randn(grid_shape, dtype=torch.float32).to(device)
for _ in range(num_warmup):
_ = torch.grid_sampler_3d(input, grid, interp, padding, align_corners)
start_time = perf_counter()
torch.mps.synchronize()
for _ in range(num_runs):
_ = torch.grid_sampler_3d(input, grid, interp, padding, align_corners)
torch.mps.synchronize()
end_time = perf_counter()
run_time = (end_time - start_time) * 1000
return run_time / num_runs
def main():
for batch_size in [1, 2, 4, 8, 16]:
input_shape = (batch_size, 3, 128, 128, 128)
grid_shape = (batch_size, 64, 64, 64, 3)
run_time_cpu = benchmark_grid_sampler_3d("cpu", input_shape, grid_shape)
run_time_mps = benchmark_grid_sampler_3d("mps", input_shape, grid_shape)
print(f"Batch size: {batch_size:2}, CPU time: {run_time_cpu:.6f} ms, MPS time: {run_time_mps:.6f} ms, speedup: {run_time_cpu / run_time_mps:.2f}x")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment