Created
July 29, 2025 20:27
-
-
Save mlaves/a8055c04367a0907b829b2c115223113 to your computer and use it in GitHub Desktop.
Test grid_sampler_3d output accuracy between CPU and MPS implementations
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import sys | |
| import torch | |
| torch.manual_seed(42) | |
| def test_grid_sampler3d_output(input_shape, grid_shape, interp, padding, align_corners) -> tuple: | |
| """Test grid_sampler_3d output accuracy between CPU and MPS.""" | |
| # CPU computation | |
| device = "cpu" | |
| t = torch.randn(input_shape).to(device) | |
| g = torch.randn(grid_shape).to(device).tanh() | |
| output_cpu = torch.grid_sampler_3d(t, g, interp, padding, align_corners) | |
| # MPS computation | |
| device = "mps" | |
| t_mps = t.clone().detach().to(device) | |
| g_mps = g.clone().detach().to(device) | |
| output_mps = torch.grid_sampler_3d(t_mps, g_mps, interp, padding, align_corners) | |
| # Compare outputs | |
| diff = (output_cpu - output_mps.cpu()).abs() | |
| max_diff = diff.max() | |
| mean_diff = diff.mean() | |
| sum_diff = diff.sum() | |
| print(f"{str(input_shape):21} {str(grid_shape):20} {interp} {padding:8} {align_corners:9} {max_diff:12.8f} {mean_diff:12.8f} {sum_diff:12.8f}") | |
| return torch.allclose(output_cpu, output_mps.cpu()), max_diff, mean_diff, sum_diff | |
| def test_output_accuracy(): | |
| """Test output accuracy across different configurations.""" | |
| print("Grid Sampler 3D Output Accuracy Test (MPS vs CPU)") | |
| print("input_shape grid_shape interp padding align_corners max_diff mean_diff sum_diff") | |
| print("-" * 120) | |
| max_diffs = [] | |
| mean_diffs = [] | |
| sum_diffs = [] | |
| failed_tests = [] | |
| # Test different input and grid shapes | |
| input_shapes = [ | |
| (1, 3, 128, 128, 128), # Large 3D volume | |
| (2, 1, 32, 32, 32), # Smaller volume, batch size 2 | |
| (1, 1, 1, 1, 1), # Minimal case | |
| (1, 4, 16, 16, 16), # Medium volume with 4 channels | |
| ] | |
| grid_shapes = [ | |
| (1, 64, 64, 64, 3), # Large grid | |
| (2, 16, 16, 16, 3), # Medium grid, batch size 2 | |
| (1, 1, 1, 1, 3), # Minimal grid | |
| (1, 8, 8, 8, 3), # Small grid | |
| ] | |
| for input_shape in input_shapes: | |
| for grid_shape in grid_shapes: | |
| # Skip mismatched batch sizes | |
| if input_shape[0] != grid_shape[0]: | |
| continue | |
| for interp in [0, 1]: # 0: nearest, 1: bilinear | |
| for padding in [0, 1, 2]: # 0: zeros, 1: border, 2: reflection | |
| for align_corners in [False, True]: | |
| try: | |
| ok, max_diff, mean_diff, sum_diff = test_grid_sampler3d_output( | |
| input_shape, grid_shape, interp, padding, align_corners | |
| ) | |
| if not ok: | |
| failed_tests.append({ | |
| 'input_shape': input_shape, | |
| 'grid_shape': grid_shape, | |
| 'interp': interp, | |
| 'padding': padding, | |
| 'align_corners': align_corners, | |
| 'max_diff': max_diff.item() | |
| }) | |
| max_diffs.append(max_diff.item()) | |
| mean_diffs.append(mean_diff.item()) | |
| sum_diffs.append(sum_diff.item()) | |
| except Exception as e: | |
| print(f"ERROR: {input_shape} {grid_shape} {interp} {padding} {align_corners} - {str(e)}") | |
| print("-" * 120) | |
| print(f"Summary:") | |
| print(f" Total tests: {len(max_diffs)}") | |
| print(f" Failed tests: {len(failed_tests)}") | |
| print(f" Max difference overall: {max(max_diffs) if max_diffs else 0:.8f}") | |
| print(f" Mean difference overall: {sum(mean_diffs) / len(mean_diffs) if mean_diffs else 0:.8f}") | |
| print(f" Sum difference overall: {max(sum_diffs) if sum_diffs else 0:.8f}") | |
| if failed_tests: | |
| print(f"\nFailed tests (tolerance not met):") | |
| for test in failed_tests: | |
| print(f" {test}") | |
| return len(failed_tests) == 0 | |
| def test_edge_cases(): | |
| """Test edge cases and boundary conditions.""" | |
| print("\nEdge Cases Test") | |
| print("input_shape grid_shape interp padding align_corners max_diff mean_diff sum_diff") | |
| print("-" * 120) | |
| edge_cases = [ | |
| # Very small tensors | |
| ((1, 1, 2, 2, 2), (1, 1, 1, 1, 3)), | |
| # Single depth slice | |
| ((1, 1, 1, 10, 10), (1, 1, 5, 5, 3)), | |
| # Large channel count | |
| ((1, 16, 8, 8, 8), (1, 4, 4, 4, 3)), | |
| ] | |
| for input_shape, grid_shape in edge_cases: | |
| for interp in [0, 1]: | |
| for padding in [0, 1]: | |
| for align_corners in [False, True]: | |
| try: | |
| ok, max_diff, mean_diff, sum_diff = test_grid_sampler3d_output( | |
| input_shape, grid_shape, interp, padding, align_corners | |
| ) | |
| except Exception as e: | |
| print(f"ERROR: {input_shape} {grid_shape} {interp} {padding} {align_corners} - {str(e)}") | |
| if __name__ == "__main__": | |
| # Test if MPS is available | |
| if not torch.backends.mps.is_available(): | |
| print("MPS not available. This test requires MPS support.") | |
| sys.exit(1) | |
| success = test_output_accuracy() | |
| test_edge_cases() | |
| if success: | |
| print("\nAll tests passed!") | |
| else: | |
| print("\nSome tests failed!") | |
| sys.exit(1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment