Created
February 19, 2025 13:52
-
-
Save crhea93/f3ac84a570d96c0c3763d6b294d7f263 to your computer and use it in GitHub Desktop.
Benchmarking test for pytorch implementation of the pixel_to_pixel algorithm from the reproject packag.e
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import torch | |
import numpy as np | |
from astropy.wcs import WCS | |
from astropy.wcs.wcsapi import BaseHighLevelWCS | |
def pixel_to_pixel(wcs_in: BaseHighLevelWCS, wcs_out: BaseHighLevelWCS, *inputs): | |
""" | |
CPU version: Transform pixel coordinates using NumPy. | |
""" | |
if np.isscalar(inputs[0]): | |
world_outputs = wcs_in.pixel_to_world(*inputs) | |
if not isinstance(world_outputs, (tuple, list)): | |
world_outputs = (world_outputs,) | |
return wcs_out.world_to_pixel(*world_outputs) | |
original_shape = inputs[0].shape | |
outputs = [None] * wcs_out.pixel_n_dim | |
for i in range(wcs_out.pixel_n_dim): | |
pixel_inputs = np.broadcast_arrays(*inputs) | |
world_outputs = wcs_in.pixel_to_world(*pixel_inputs) | |
if not isinstance(world_outputs, (tuple, list)): | |
world_outputs = (world_outputs,) | |
pixel_outputs = wcs_out.world_to_pixel(*world_outputs) | |
if wcs_out.pixel_n_dim == 1: | |
pixel_outputs = (pixel_outputs,) | |
outputs[i] = np.broadcast_to(pixel_outputs[i], original_shape) | |
return outputs[0] if wcs_out.pixel_n_dim == 1 else outputs | |
def pixel_to_pixel_gpu(wcs_in: BaseHighLevelWCS, wcs_out: BaseHighLevelWCS, *inputs): | |
""" | |
GPU version: Transform pixel coordinates using PyTorch, optimized to reduce transfer overhead. | |
""" | |
# Automatically select device (GPU if available, otherwise CPU) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
if np.isscalar(inputs[0]): | |
world_outputs = wcs_in.pixel_to_world(*inputs) | |
if not isinstance(world_outputs, (tuple, list)): | |
world_outputs = (world_outputs,) | |
return wcs_out.world_to_pixel(*world_outputs) | |
original_shape = inputs[0].shape | |
outputs = [None] * wcs_out.pixel_n_dim | |
# Ensure inputs are torch tensors and move to selected device | |
pixel_inputs = [ | |
torch.tensor(arr).to(device) if not isinstance(arr, torch.Tensor) else arr.to(device) for arr in inputs | |
] | |
pixel_inputs = torch.broadcast_tensors(*pixel_inputs) | |
# Compute world outputs on the CPU using the WCS functions | |
world_outputs_cpu = wcs_in.pixel_to_world(*[arr.cpu() for arr in pixel_inputs]) | |
if not isinstance(world_outputs_cpu, (tuple, list)): | |
world_outputs_cpu = (world_outputs_cpu,) | |
pixel_outputs_cpu = wcs_out.world_to_pixel(*world_outputs_cpu) | |
if wcs_out.pixel_n_dim == 1: | |
pixel_outputs_cpu = (pixel_outputs_cpu,) | |
for i in range(wcs_out.pixel_n_dim): | |
outputs[i] = pixel_outputs_cpu[i] | |
# Convert torch tensors back to NumPy arrays | |
outputs = [output for output in outputs] | |
return outputs[0] if wcs_out.pixel_n_dim == 1 else outputs | |
def benchmark_pixel_to_pixel(wcs_in, wcs_out, inputs, use_gpu=False): | |
"""Benchmark the pixel_to_pixel function with GPU (CuPy) or CPU (NumPy).""" | |
if use_gpu: | |
print("Running GPU version...") | |
start_time = time.time() | |
outputs = pixel_to_pixel_gpu(wcs_in, wcs_out, *inputs) | |
duration = time.time() - start_time | |
print(f"GPU execution time: {duration:.6f} seconds") | |
else: | |
print("Running CPU version...") | |
start_time = time.time() | |
outputs = pixel_to_pixel(wcs_in, wcs_out, *inputs) | |
duration = time.time() - start_time | |
print(f"CPU execution time: {duration:.6f} seconds") | |
return outputs, duration | |
def main(): | |
# Mock WCS objects for the test | |
wcs_in = WCS(naxis=2) | |
wcs_out = WCS(naxis=2) | |
# Simulate large pixel arrays for benchmarking | |
shape = (4000, 6000) # Increase size for noticeable GPU acceleration | |
x_coords = np.random.random(shape) * 500 | |
y_coords = np.random.random(shape) * 500 | |
# Compare CPU vs GPU | |
inputs_cpu = (x_coords, y_coords) | |
# Test both implementations | |
cpu_results, cpu_time = benchmark_pixel_to_pixel( | |
wcs_in, wcs_out, inputs_cpu, use_gpu=False | |
) | |
gpu_results, gpu_time = benchmark_pixel_to_pixel( | |
wcs_in, wcs_out, inputs_cpu, use_gpu=True | |
) | |
# Performance comparison | |
speedup = (cpu_time - gpu_time) / cpu_time * 100 | |
print(f"GPU speedup: {speedup:.2f}%") | |
# Verify accuracy | |
for cpu_result, gpu_result in zip(cpu_results, gpu_results): | |
assert np.allclose( | |
cpu_result, gpu_result, atol=1e-6 | |
), "Mismatch between CPU and GPU results!" | |
print("Results match within tolerance.") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment