This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from time import perf_counter | |
| def benchmark_grid_sampler_3d(device, input_shape, grid_shape, interp=0, padding=0, align_corners=False, num_warmup=5, num_runs=20): | |
| input = torch.randn(input_shape, dtype=torch.float32).to(device) | |
| grid = torch.randn(grid_shape, dtype=torch.float32).to(device) | |
| for _ in range(num_warmup): | |
| _ = torch.grid_sampler_3d(input, grid, interp, padding, align_corners) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import sys | |
| import torch | |
| torch.manual_seed(42) | |
| def test_grid_sampler3d_output(input_shape, grid_shape, interp, padding, align_corners) -> tuple: | |
| """Test grid_sampler_3d output accuracy between CPU and MPS.""" | |
| # CPU computation | |
| device = "cpu" | |
| t = torch.randn(input_shape).to(device) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| torch.manual_seed(42) | |
| def test_grid_sampler3d_grad_input(input_shape, grid_shape, interp, padding, align_corners) -> bool: | |
| device = "cpu" | |
| t = torch.randn(input_shape).to(device).requires_grad_(True) | |
| g = torch.randn(grid_shape).to(device).tanh().requires_grad_(False) | |
| l = torch.grid_sampler_3d(t, g, interp, padding, align_corners).mean() | |
| l.backward() |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import socket | |
| import numpy as np | |
| HOST = '127.0.0.1' # The server's hostname or IP address | |
| PORT = 65432 # The port used by the server | |
| data = np.random.rand(1000000) | |
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
| print(f"Sent {np.prod(data.shape)} elements.") |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
NewerOlder