Skip to content

Instantly share code, notes, and snippets.

@mlaves
Created July 29, 2025 20:27
Show Gist options
  • Save mlaves/a8055c04367a0907b829b2c115223113 to your computer and use it in GitHub Desktop.
Save mlaves/a8055c04367a0907b829b2c115223113 to your computer and use it in GitHub Desktop.
Test grid_sampler_3d output accuracy between CPU and MPS implementations
import sys
import torch
torch.manual_seed(42)
def test_grid_sampler3d_output(input_shape, grid_shape, interp, padding, align_corners) -> tuple:
"""Test grid_sampler_3d output accuracy between CPU and MPS."""
# CPU computation
device = "cpu"
t = torch.randn(input_shape).to(device)
g = torch.randn(grid_shape).to(device).tanh()
output_cpu = torch.grid_sampler_3d(t, g, interp, padding, align_corners)
# MPS computation
device = "mps"
t_mps = t.clone().detach().to(device)
g_mps = g.clone().detach().to(device)
output_mps = torch.grid_sampler_3d(t_mps, g_mps, interp, padding, align_corners)
# Compare outputs
diff = (output_cpu - output_mps.cpu()).abs()
max_diff = diff.max()
mean_diff = diff.mean()
sum_diff = diff.sum()
print(f"{str(input_shape):21} {str(grid_shape):20} {interp} {padding:8} {align_corners:9} {max_diff:12.8f} {mean_diff:12.8f} {sum_diff:12.8f}")
return torch.allclose(output_cpu, output_mps.cpu()), max_diff, mean_diff, sum_diff
def test_output_accuracy():
"""Test output accuracy across different configurations."""
print("Grid Sampler 3D Output Accuracy Test (MPS vs CPU)")
print("input_shape grid_shape interp padding align_corners max_diff mean_diff sum_diff")
print("-" * 120)
max_diffs = []
mean_diffs = []
sum_diffs = []
failed_tests = []
# Test different input and grid shapes
input_shapes = [
(1, 3, 128, 128, 128), # Large 3D volume
(2, 1, 32, 32, 32), # Smaller volume, batch size 2
(1, 1, 1, 1, 1), # Minimal case
(1, 4, 16, 16, 16), # Medium volume with 4 channels
]
grid_shapes = [
(1, 64, 64, 64, 3), # Large grid
(2, 16, 16, 16, 3), # Medium grid, batch size 2
(1, 1, 1, 1, 3), # Minimal grid
(1, 8, 8, 8, 3), # Small grid
]
for input_shape in input_shapes:
for grid_shape in grid_shapes:
# Skip mismatched batch sizes
if input_shape[0] != grid_shape[0]:
continue
for interp in [0, 1]: # 0: nearest, 1: bilinear
for padding in [0, 1, 2]: # 0: zeros, 1: border, 2: reflection
for align_corners in [False, True]:
try:
ok, max_diff, mean_diff, sum_diff = test_grid_sampler3d_output(
input_shape, grid_shape, interp, padding, align_corners
)
if not ok:
failed_tests.append({
'input_shape': input_shape,
'grid_shape': grid_shape,
'interp': interp,
'padding': padding,
'align_corners': align_corners,
'max_diff': max_diff.item()
})
max_diffs.append(max_diff.item())
mean_diffs.append(mean_diff.item())
sum_diffs.append(sum_diff.item())
except Exception as e:
print(f"ERROR: {input_shape} {grid_shape} {interp} {padding} {align_corners} - {str(e)}")
print("-" * 120)
print(f"Summary:")
print(f" Total tests: {len(max_diffs)}")
print(f" Failed tests: {len(failed_tests)}")
print(f" Max difference overall: {max(max_diffs) if max_diffs else 0:.8f}")
print(f" Mean difference overall: {sum(mean_diffs) / len(mean_diffs) if mean_diffs else 0:.8f}")
print(f" Sum difference overall: {max(sum_diffs) if sum_diffs else 0:.8f}")
if failed_tests:
print(f"\nFailed tests (tolerance not met):")
for test in failed_tests:
print(f" {test}")
return len(failed_tests) == 0
def test_edge_cases():
"""Test edge cases and boundary conditions."""
print("\nEdge Cases Test")
print("input_shape grid_shape interp padding align_corners max_diff mean_diff sum_diff")
print("-" * 120)
edge_cases = [
# Very small tensors
((1, 1, 2, 2, 2), (1, 1, 1, 1, 3)),
# Single depth slice
((1, 1, 1, 10, 10), (1, 1, 5, 5, 3)),
# Large channel count
((1, 16, 8, 8, 8), (1, 4, 4, 4, 3)),
]
for input_shape, grid_shape in edge_cases:
for interp in [0, 1]:
for padding in [0, 1]:
for align_corners in [False, True]:
try:
ok, max_diff, mean_diff, sum_diff = test_grid_sampler3d_output(
input_shape, grid_shape, interp, padding, align_corners
)
except Exception as e:
print(f"ERROR: {input_shape} {grid_shape} {interp} {padding} {align_corners} - {str(e)}")
if __name__ == "__main__":
# Test if MPS is available
if not torch.backends.mps.is_available():
print("MPS not available. This test requires MPS support.")
sys.exit(1)
success = test_output_accuracy()
test_edge_cases()
if success:
print("\nAll tests passed!")
else:
print("\nSome tests failed!")
sys.exit(1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment