Created
July 29, 2025 20:25
-
-
Save mlaves/7343c1a445a7c2b0bd23f23c655f9ab9 to your computer and use it in GitHub Desktop.
Test gradient of grid_sampler_3d for MPS
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| torch.manual_seed(42) | |
| def test_grid_sampler3d_grad_input(input_shape, grid_shape, interp, padding, align_corners) -> bool: | |
| device = "cpu" | |
| t = torch.randn(input_shape).to(device).requires_grad_(True) | |
| g = torch.randn(grid_shape).to(device).tanh().requires_grad_(False) | |
| l = torch.grid_sampler_3d(t, g, interp, padding, align_corners).mean() | |
| l.backward() | |
| device = "mps" | |
| t_mps = t.clone().detach().to(device).requires_grad_(True) | |
| g_mps = g.clone().detach().to(device).requires_grad_(False) | |
| l_mps = torch.grid_sampler_3d(t_mps, g_mps, interp, padding, align_corners).mean() | |
| l_mps.backward() | |
| diff = (t.grad - t_mps.grad.cpu()).abs().sum() | |
| print(f"{str(input_shape):21} {str(grid_shape):20} {interp} {padding:8} {align_corners:9} {diff:24}") | |
| return torch.allclose(t.grad, t_mps.grad.cpu()), diff | |
| def test_grid_sampler3d_grad_grid(input_shape, grid_shape, interp, padding, align_corners) -> bool: | |
| device = "cpu" | |
| t = torch.randn(input_shape).to(device).requires_grad_(False) | |
| g = torch.randn(grid_shape).to(device).tanh().requires_grad_(True) | |
| l = torch.grid_sampler_3d(t, g, interp, padding, align_corners).mean() | |
| l.backward() | |
| device = "mps" | |
| t_mps = t.clone().detach().to(device).requires_grad_(False) | |
| g_mps = g.clone().detach().to(device).requires_grad_(True) | |
| l_mps = torch.grid_sampler_3d(t_mps, g_mps, interp, padding, align_corners).mean() | |
| l_mps.backward() | |
| diff = (g.grad - g_mps.grad.cpu()).abs().sum() | |
| print(f"{str(input_shape):21} {str(grid_shape):20} {interp} {padding:8} {align_corners:9} {diff:24}") | |
| return torch.allclose(g.grad, g_mps.grad.cpu()), diff | |
| def inp(): | |
| print("grad wrt. input") | |
| print("input_shape grid_shape interp padding align_corners diff") | |
| diffs = [] | |
| for input_shape in [(1, 3, 128, 128, 128), (1, 1, 1, 1, 1)]: | |
| for grid_shape in [(1, 64, 64, 64, 3), (1, 1, 1, 1, 3)]: | |
| for interp in [0, 1]: | |
| for padding in [0, 1]: | |
| for align_corners in [False, True]: | |
| ok, diff = test_grid_sampler3d_grad_input(input_shape, grid_shape, interp, padding, align_corners) | |
| if not ok: | |
| assert False | |
| diffs.append(diff.item()) | |
| print("max diff for input gradients: ", max(diffs)) | |
| def grid(): | |
| print("grad wrt. grid") | |
| print("input_shape grid_shape interp padding align_corners diff") | |
| diffs = [] | |
| for input_shape in [(1, 3, 128, 128, 128), (1, 1, 1, 1, 1)]: | |
| for grid_shape in [(1, 64, 64, 64, 3), (1, 1, 1, 1, 3)]: | |
| for interp in [0, 1]: | |
| for padding in [0, 1]: | |
| for align_corners in [False, True]: | |
| ok, diff = test_grid_sampler3d_grad_grid(input_shape, grid_shape, interp, padding, align_corners) | |
| if not ok: | |
| assert False | |
| diffs.append(diff.item()) | |
| print("max diff for grid gradients: ", max(diffs)) | |
| if __name__ == "__main__": | |
| inp() | |
| grid() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment