Skip to content

Instantly share code, notes, and snippets.

@crhea93
Created February 19, 2025 13:52
Show Gist options
  • Save crhea93/f3ac84a570d96c0c3763d6b294d7f263 to your computer and use it in GitHub Desktop.
Save crhea93/f3ac84a570d96c0c3763d6b294d7f263 to your computer and use it in GitHub Desktop.
Benchmarking test for pytorch implementation of the pixel_to_pixel algorithm from the reproject packag.e
import time
import torch
import numpy as np
from astropy.wcs import WCS
from astropy.wcs.wcsapi import BaseHighLevelWCS
def pixel_to_pixel(wcs_in: BaseHighLevelWCS, wcs_out: BaseHighLevelWCS, *inputs):
"""
CPU version: Transform pixel coordinates using NumPy.
"""
if np.isscalar(inputs[0]):
world_outputs = wcs_in.pixel_to_world(*inputs)
if not isinstance(world_outputs, (tuple, list)):
world_outputs = (world_outputs,)
return wcs_out.world_to_pixel(*world_outputs)
original_shape = inputs[0].shape
outputs = [None] * wcs_out.pixel_n_dim
for i in range(wcs_out.pixel_n_dim):
pixel_inputs = np.broadcast_arrays(*inputs)
world_outputs = wcs_in.pixel_to_world(*pixel_inputs)
if not isinstance(world_outputs, (tuple, list)):
world_outputs = (world_outputs,)
pixel_outputs = wcs_out.world_to_pixel(*world_outputs)
if wcs_out.pixel_n_dim == 1:
pixel_outputs = (pixel_outputs,)
outputs[i] = np.broadcast_to(pixel_outputs[i], original_shape)
return outputs[0] if wcs_out.pixel_n_dim == 1 else outputs
def pixel_to_pixel_gpu(wcs_in: BaseHighLevelWCS, wcs_out: BaseHighLevelWCS, *inputs):
"""
GPU version: Transform pixel coordinates using PyTorch, optimized to reduce transfer overhead.
"""
# Automatically select device (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if np.isscalar(inputs[0]):
world_outputs = wcs_in.pixel_to_world(*inputs)
if not isinstance(world_outputs, (tuple, list)):
world_outputs = (world_outputs,)
return wcs_out.world_to_pixel(*world_outputs)
original_shape = inputs[0].shape
outputs = [None] * wcs_out.pixel_n_dim
# Ensure inputs are torch tensors and move to selected device
pixel_inputs = [
torch.tensor(arr).to(device) if not isinstance(arr, torch.Tensor) else arr.to(device) for arr in inputs
]
pixel_inputs = torch.broadcast_tensors(*pixel_inputs)
# Compute world outputs on the CPU using the WCS functions
world_outputs_cpu = wcs_in.pixel_to_world(*[arr.cpu() for arr in pixel_inputs])
if not isinstance(world_outputs_cpu, (tuple, list)):
world_outputs_cpu = (world_outputs_cpu,)
pixel_outputs_cpu = wcs_out.world_to_pixel(*world_outputs_cpu)
if wcs_out.pixel_n_dim == 1:
pixel_outputs_cpu = (pixel_outputs_cpu,)
for i in range(wcs_out.pixel_n_dim):
outputs[i] = pixel_outputs_cpu[i]
# Convert torch tensors back to NumPy arrays
outputs = [output for output in outputs]
return outputs[0] if wcs_out.pixel_n_dim == 1 else outputs
def benchmark_pixel_to_pixel(wcs_in, wcs_out, inputs, use_gpu=False):
"""Benchmark the pixel_to_pixel function with GPU (CuPy) or CPU (NumPy)."""
if use_gpu:
print("Running GPU version...")
start_time = time.time()
outputs = pixel_to_pixel_gpu(wcs_in, wcs_out, *inputs)
duration = time.time() - start_time
print(f"GPU execution time: {duration:.6f} seconds")
else:
print("Running CPU version...")
start_time = time.time()
outputs = pixel_to_pixel(wcs_in, wcs_out, *inputs)
duration = time.time() - start_time
print(f"CPU execution time: {duration:.6f} seconds")
return outputs, duration
def main():
# Mock WCS objects for the test
wcs_in = WCS(naxis=2)
wcs_out = WCS(naxis=2)
# Simulate large pixel arrays for benchmarking
shape = (4000, 6000) # Increase size for noticeable GPU acceleration
x_coords = np.random.random(shape) * 500
y_coords = np.random.random(shape) * 500
# Compare CPU vs GPU
inputs_cpu = (x_coords, y_coords)
# Test both implementations
cpu_results, cpu_time = benchmark_pixel_to_pixel(
wcs_in, wcs_out, inputs_cpu, use_gpu=False
)
gpu_results, gpu_time = benchmark_pixel_to_pixel(
wcs_in, wcs_out, inputs_cpu, use_gpu=True
)
# Performance comparison
speedup = (cpu_time - gpu_time) / cpu_time * 100
print(f"GPU speedup: {speedup:.2f}%")
# Verify accuracy
for cpu_result, gpu_result in zip(cpu_results, gpu_results):
assert np.allclose(
cpu_result, gpu_result, atol=1e-6
), "Mismatch between CPU and GPU results!"
print("Results match within tolerance.")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment