Skip to content

Instantly share code, notes, and snippets.

@cbassa
Created February 3, 2024 10:36
Show Gist options
  • Save cbassa/272f7baa3450c8e6c3f37cd3d8d6f1c9 to your computer and use it in GitHub Desktop.
Save cbassa/272f7baa3450c8e6c3f37cd3d8d6f1c9 to your computer and use it in GitHub Desktop.
Different implementations of sky imagers for speed testing
#!/usr/bin/env python3
import time
import numpy as np
import matplotlib.pyplot as plt
import numba
import cupy as cp
import warnings
warnings.filterwarnings('ignore')
SPEED_OF_LIGHT = 299792458.0
def sky_imager_simple(visibilities, baselines, freq, npix_l, npix_m):
img = np.zeros((npix_m, npix_l), dtype=np.complex128)
for m_ix in range(npix_m):
m = -1 + m_ix * 2 / npix_m
for l_ix in range(npix_l):
l = 1 - l_ix * 2 / npix_l
n = np.sqrt(1 - l * l - m * m) - 1
img[m_ix, l_ix] = np.mean(visibilities * np.exp(-2j * np.pi * freq *
(baselines[:, :, 0] * l +
baselines[:, :, 1] * m +
baselines[:, :, 2] * n) /
SPEED_OF_LIGHT))
return np.real(img)
@numba.jit(parallel=True, fastmath=True)
def sky_imager_numba(visibilities, baselines, freq, npix_l, npix_m):
img = np.zeros((npix_m, npix_l), dtype=np.complex128)
for m_ix in range(npix_m):
m = -1 + m_ix * 2 / npix_m
for l_ix in range(npix_l):
l = 1 - l_ix * 2 / npix_l
n = np.sqrt(1 - l * l - m * m) - 1
img[m_ix, l_ix] = np.mean(visibilities * np.exp(-2j * np.pi * freq *
(baselines[:, :, 0] * l +
baselines[:, :, 1] * m +
baselines[:, :, 2] * n) /
SPEED_OF_LIGHT))
return np.real(img)
def sky_imager_cupy(visibilities, baselines, freq, npix_l, npix_m):
l, m = cp.meshgrid(cp.linspace(-1, 1, npix_l, dtype="float32"), cp.linspace(1, -1, npix_m, dtype="float32"))
n = cp.sqrt(1 - l**2 - m**2) - 1
vis = cp.array(visibilities)
u, v, w = cp.array(baselines.astype("float32")).T
prod = (u[:, :, cp.newaxis, cp.newaxis] * l +
v[:, :, cp.newaxis, cp.newaxis] * m +
w[:, :, cp.newaxis, cp.newaxis] * n).astype("complex64")
phasor = cp.exp(-2j * cp.pi * freq * prod / SPEED_OF_LIGHT)
prod = None
img = cp.real(cp.mean(vis[:, :, cp.newaxis, cp.newaxis] * phasor, axis=(0, 1)))
return cp.asnumpy(img)
def sky_imager_numpy(visibilities, baselines, freq, npix_l, npix_m):
l, m = np.meshgrid(np.linspace(-1, 1, npix_l), np.linspace(1, -1, npix_m))
n = np.sqrt(1 - l**2 - m**2) - 1
u, v, w = baselines.T
prod = (u[:, :, np.newaxis, np.newaxis] * l +
v[:, :, np.newaxis, np.newaxis] * m +
w[:, :, np.newaxis, np.newaxis] * n)
phasor = np.exp(-2j * np.pi * freq * prod / SPEED_OF_LIGHT)
img = np.real(np.mean(visibilities[:, :, np.newaxis, np.newaxis] * phasor, axis=(0, 1)))
return img
def sky_imager_numpy_real(visibilities, baselines, freq, npix_l, npix_m):
l, m = np.meshgrid(np.linspace(-1, 1, npix_l), np.linspace(1, -1, npix_m))
n = np.sqrt(1 - l**2 - m**2) - 1
u, v, w = baselines.T
prod = (u[:, :, np.newaxis, np.newaxis] * l +
v[:, :, np.newaxis, np.newaxis] * m +
w[:, :, np.newaxis, np.newaxis] * n)
phase = -2 * np.pi * freq * prod / SPEED_OF_LIGHT
pr, pi = np.cos(phase), np.sin(phase)
vr, vi = np.real(visibilities), np.imag(visibilities)
img = np.mean(vr[:, :, np.newaxis, np.newaxis] * pr - vi[:, :, np.newaxis, np.newaxis] * pi, axis=(0, 1))
return img
def sky_imager_numpy_float32(visibilities, baselines, freq, npix_l, npix_m):
l, m = np.meshgrid(np.linspace(-1, 1, npix_l).astype("float32"), np.linspace(1, -1, npix_m).astype("float32"))
n = np.sqrt(1 - l**2 - m**2) - 1
u, v, w = baselines.astype("float32").T
prod = (u[:, :, np.newaxis, np.newaxis] * l +
v[:, :, np.newaxis, np.newaxis] * m +
w[:, :, np.newaxis, np.newaxis] * n)
phasor = np.exp(-2j * np.pi * freq * prod / SPEED_OF_LIGHT)
img = np.real(np.mean(visibilities[:, :, np.newaxis, np.newaxis] * phasor, axis=(0, 1)))
return img
def sky_imager_numpy_float32_ravel(visibilities, baselines, freq, npix_l, npix_m):
l, m = np.meshgrid(np.linspace(-1, 1, npix_l).astype("float32"), np.linspace(1, -1, npix_m).astype("float32"))
# Select and ravel
c = l**2 + m**2 < 1
l, m = l[c].ravel(), m[c].ravel()
n = np.sqrt(1 - l**2 - m**2) - 1
u, v, w = baselines.astype("float32").T
prod = (u[:, :, np.newaxis] * l +
v[:, :, np.newaxis] * m +
w[:, :, np.newaxis] * n)
phasor = np.exp(-2j * np.pi * freq * prod / SPEED_OF_LIGHT)
img = np.full(npix_l * npix_m, np.nan)
img[c.ravel()] = np.real(np.mean(visibilities[:, :, np.newaxis] * phasor, axis=(0, 1)))
return img.reshape(npix_l, npix_m)
if __name__ == "__main__":
vis = np.load("vis.npy")
baselines = np.load("baselines.npy")
freq = np.load("freq.npy")
npix_l, npix_m = 151, 151
tstart = time.time()
img = sky_imager_numpy_real(vis[0], baselines, freq[0], npix_l, npix_m)
dt = time.time() - tstart
print(f"numpy_real: {dt:.3f}s")
tstart = time.time()
img = sky_imager_numpy(vis[0], baselines, freq[0], npix_l, npix_m)
dt = time.time() - tstart
print(f"numpy: {dt:.3f}s")
tstart = time.time()
img = sky_imager_numpy_float32_ravel(vis[0], baselines, freq[0], npix_l, npix_m)
dt = time.time() - tstart
print(f"numpy_float32_ravel: {dt:.3f}s")
tstart = time.time()
img = sky_imager_numpy_float32(vis[0], baselines, freq[0], npix_l, npix_m)
dt = time.time() - tstart
print(f"numpy_float32: {dt:.3f}s")
tstart = time.time()
img = sky_imager_cupy(vis[0], baselines, freq[0], npix_l, npix_m)
dt = time.time() - tstart
print(f"cupy: {dt:.3f}s")
tstart = time.time()
img = sky_imager_simple(vis[0], baselines, freq[0], npix_l, npix_m)
dt = time.time() - tstart
print(f"simple: {dt:.3f}s")
tstart = time.time()
img = sky_imager_numba(vis[0], baselines, freq[0], npix_l, npix_m)
dt = time.time() - tstart
print(f"numba: {dt:.3f}s")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment