Skip to content

Instantly share code, notes, and snippets.

@mlaves
Created July 29, 2025 20:25
Show Gist options
  • Save mlaves/7343c1a445a7c2b0bd23f23c655f9ab9 to your computer and use it in GitHub Desktop.
Save mlaves/7343c1a445a7c2b0bd23f23c655f9ab9 to your computer and use it in GitHub Desktop.
Test gradient of grid_sampler_3d for MPS
import torch
torch.manual_seed(42)
def test_grid_sampler3d_grad_input(input_shape, grid_shape, interp, padding, align_corners) -> bool:
device = "cpu"
t = torch.randn(input_shape).to(device).requires_grad_(True)
g = torch.randn(grid_shape).to(device).tanh().requires_grad_(False)
l = torch.grid_sampler_3d(t, g, interp, padding, align_corners).mean()
l.backward()
device = "mps"
t_mps = t.clone().detach().to(device).requires_grad_(True)
g_mps = g.clone().detach().to(device).requires_grad_(False)
l_mps = torch.grid_sampler_3d(t_mps, g_mps, interp, padding, align_corners).mean()
l_mps.backward()
diff = (t.grad - t_mps.grad.cpu()).abs().sum()
print(f"{str(input_shape):21} {str(grid_shape):20} {interp} {padding:8} {align_corners:9} {diff:24}")
return torch.allclose(t.grad, t_mps.grad.cpu()), diff
def test_grid_sampler3d_grad_grid(input_shape, grid_shape, interp, padding, align_corners) -> bool:
device = "cpu"
t = torch.randn(input_shape).to(device).requires_grad_(False)
g = torch.randn(grid_shape).to(device).tanh().requires_grad_(True)
l = torch.grid_sampler_3d(t, g, interp, padding, align_corners).mean()
l.backward()
device = "mps"
t_mps = t.clone().detach().to(device).requires_grad_(False)
g_mps = g.clone().detach().to(device).requires_grad_(True)
l_mps = torch.grid_sampler_3d(t_mps, g_mps, interp, padding, align_corners).mean()
l_mps.backward()
diff = (g.grad - g_mps.grad.cpu()).abs().sum()
print(f"{str(input_shape):21} {str(grid_shape):20} {interp} {padding:8} {align_corners:9} {diff:24}")
return torch.allclose(g.grad, g_mps.grad.cpu()), diff
def inp():
print("grad wrt. input")
print("input_shape grid_shape interp padding align_corners diff")
diffs = []
for input_shape in [(1, 3, 128, 128, 128), (1, 1, 1, 1, 1)]:
for grid_shape in [(1, 64, 64, 64, 3), (1, 1, 1, 1, 3)]:
for interp in [0, 1]:
for padding in [0, 1]:
for align_corners in [False, True]:
ok, diff = test_grid_sampler3d_grad_input(input_shape, grid_shape, interp, padding, align_corners)
if not ok:
assert False
diffs.append(diff.item())
print("max diff for input gradients: ", max(diffs))
def grid():
print("grad wrt. grid")
print("input_shape grid_shape interp padding align_corners diff")
diffs = []
for input_shape in [(1, 3, 128, 128, 128), (1, 1, 1, 1, 1)]:
for grid_shape in [(1, 64, 64, 64, 3), (1, 1, 1, 1, 3)]:
for interp in [0, 1]:
for padding in [0, 1]:
for align_corners in [False, True]:
ok, diff = test_grid_sampler3d_grad_grid(input_shape, grid_shape, interp, padding, align_corners)
if not ok:
assert False
diffs.append(diff.item())
print("max diff for grid gradients: ", max(diffs))
if __name__ == "__main__":
inp()
grid()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment