Last active
December 29, 2023 21:03
-
-
Save jgomezdans/6560e2971904794d91298f741acbbfc7 to your computer and use it in GitHub Desktop.
testing_frameworks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import numba | |
import jax.numpy as jnp | |
import jax | |
import time | |
import torch | |
import pandas as pd | |
from functools import partial | |
def spatial_regularisation_naive(x: np.ndarray, dy: int = 1, dx: int = 1) -> float: | |
"""Calculate spatial regularisation. Assume that a 2D array of size `ny` x `nx` | |
is stored in 1D array `x`. We want to calculate the difference between the centre | |
pixel and a neighbourhood given by `dy` and `dx`, square it and sum it. | |
Parameters | |
---------- | |
x (np.ndarray): a `ny`*`nx` vector | |
dy (int): neighbourhood around pixel from -dy to (dy+1) | |
dx (int): neighbourhood around pixel from -dx to (dx+1) | |
Returns | |
------- | |
Associated cost | |
""" | |
ny, nx = x.shape | |
total_cost = 0.0 | |
for i in range(ny): | |
for j in range(nx): | |
for m in range(-dy, dy + 1): | |
for n in range(-dx, x + 1): | |
if (0 <= (i + m) <= ny) and (0 <= (j + m) <= nx): | |
total_cost += (x[i, j] - x[i + m, j + n]) ** 2 | |
return 0.5 * total_cost | |
def spatial_regularisation_numpy(x: np.ndarray, dy: int = 1, dx: int = 1) -> float: | |
"""Calculate spatial regularisation using numpy for vectorization. | |
Parameters | |
---------- | |
x : np.ndarray | |
a `ny`*`nx` vector. | |
dy : int | |
neighbourhood around pixel vertically. | |
dx : int | |
neighbourhood around pixel horizontally. | |
Returns | |
------- | |
float | |
Associated cost. | |
""" | |
ny, nx = x.shape | |
total_cost = 0.0 | |
for m in range(-dy, dy + 1): | |
for n in range(-dx, dx + 1): | |
if m == 0 and n == 0: | |
continue | |
# Create shifted versions of the array and calculate the squared difference | |
shifted_x = np.roll(x, shift=m, axis=0) | |
shifted_y = np.roll(shifted_x, shift=n, axis=1) | |
squared_diff = (x - shifted_y) ** 2 | |
total_cost += np.sum(squared_diff) | |
return 0.5 * total_cost | |
@numba.jit(nopython=True, parallel=True, fastmath=True) | |
def spatial_regularisation_numba(x: np.ndarray, dy: int = 1, dx: int = 1) -> float: | |
"""Calculate spatial regularisation using Numba with parallelization. | |
Parameters | |
---------- | |
x : np.ndarray | |
a `ny`*`nx` vector. | |
dy : int | |
neighbourhood around pixel vertically. | |
dx : int | |
neighbourhood around pixel horizontally. | |
Returns | |
------- | |
float | |
Associated cost. | |
""" | |
ny, nx = x.shape | |
total_cost = 0.0 | |
for i in numba.prange(ny): | |
for j in numba.prange(nx): | |
for m in range(-dy, dy + 1): | |
for n in range(-dx, dx + 1): | |
if 0 <= i + m < ny and 0 <= j + n < nx: | |
if not (m == 0 and n == 0): | |
total_cost += (x[i, j] - x[i + m, j + n]) ** 2 | |
return 0.5 * total_cost | |
@partial(jax.jit, static_argnums=(1, 2)) | |
def spatial_regularisation_jax(x: jnp.ndarray, dy: int = 1, dx: int = 1) -> float: | |
"""Calculate spatial regularisation using JAX for efficient computation. | |
Parameters | |
---------- | |
x : jnp.ndarray | |
a `ny`*`nx` vector. | |
dy : int | |
neighbourhood around pixel vertically. | |
dx : int | |
neighbourhood around pixel horizontally. | |
Returns | |
------- | |
float | |
Associated cost. | |
""" | |
total_cost = 0.0 | |
# Iterate over the neighborhood offsets | |
for m in range(-dy, dy + 1): | |
for n in range(-dx, dx + 1): | |
if m == 0 and n == 0: | |
continue | |
# Shift the array and compute the squared difference | |
shifted_x = jnp.roll(x, shift=m, axis=0) | |
shifted_y = jnp.roll(shifted_x, shift=n, axis=1) | |
squared_diff = (x - shifted_y) ** 2 | |
# Ensure we only sum valid comparisons (ignoring the padded edges) | |
valid_mask = jnp.ones_like(x, dtype=bool) | |
if m > 0: | |
valid_mask = valid_mask.at[:m, :].set(False) | |
elif m < 0: | |
valid_mask = valid_mask.at[m:, :].set(False) | |
if n > 0: | |
valid_mask = valid_mask.at[:, :n].set(False) | |
elif n < 0: | |
valid_mask = valid_mask.at[:, n:].set(False) | |
total_cost += jnp.sum(squared_diff * valid_mask) | |
return 0.5 * total_cost | |
@torch.jit.script | |
def spatial_regularisation_torch( | |
x: torch.Tensor, dy: int = 1, dx: int = 1 | |
) -> torch.Tensor: | |
"""Calculate spatial regularisation using PyTorch. | |
Parameters | |
---------- | |
x : torch.Tensor | |
a 2D tensor. | |
dy : int | |
neighbourhood around pixel vertically. | |
dx : int | |
neighbourhood around pixel horizontally. | |
Returns | |
------- | |
torch.Tensor | |
Associated cost. | |
""" | |
ny, nx = x.shape | |
total_cost = torch.tensor(0.0, device=x.device) | |
for m in range(-dy, dy + 1): | |
for n in range(-dx, dx + 1): | |
if m == 0 and n == 0: | |
continue | |
# Shift the tensor and compute the squared difference | |
shifted_x = torch.roll(x, shifts=m, dims=0) | |
shifted_y = torch.roll(shifted_x, shifts=n, dims=1) | |
squared_diff = (x - shifted_y) ** 2 | |
# Ensure we only sum valid comparisons (ignoring the padded edges) | |
valid_mask = torch.ones_like(x, dtype=torch.bool) | |
if m > 0: | |
valid_mask[:m, :] = False | |
elif m < 0: | |
valid_mask[m:, :] = False | |
if n > 0: | |
valid_mask[:, :n] = False | |
elif n < 0: | |
valid_mask[:, n:] = False | |
total_cost += torch.sum(squared_diff * valid_mask) | |
return 0.5 * total_cost | |
if __name__ == "__main__": | |
sizes = [128, 256, 1024, 2048] | |
neighbourhood = [1, 3, 5, 7, 9] | |
functions = [ | |
spatial_regularisation_numpy, | |
spatial_regularisation_numba, | |
spatial_regularisation_jax, | |
spatial_regularisation_torch, | |
] | |
results = [] | |
num_runs = 5 | |
for npix in sizes: | |
for n_neighs in neighbourhood: | |
x = np.random.rand(npix, npix) | |
xx = jnp.array(x) | |
xxx = torch.from_numpy(x) | |
for func, array_in in zip(functions, [x, x, xx, xxx]): | |
try: | |
func_name = func.__name__ | |
print(npix, n_neighs, func.__name__) | |
except AttributeError: | |
func_name = func.name | |
print(npix, n_neighs, func.name) | |
total_time = [] | |
# Dry run | |
_ = func(array_in, dx=n_neighs, dy=n_neighs) | |
for _ in range(num_runs): | |
start_time = time.perf_counter() | |
_ = func(array_in, dx=n_neighs, dy=n_neighs) | |
end_time = time.perf_counter() | |
total_time.append(end_time - start_time) | |
avg_time = np.mean(total_time) | |
std_time = np.std(total_time) | |
func(array_in, dx=n_neighs, dy=n_neighs) | |
results.append([npix, n_neighs, func_name, avg_time, std_time]) | |
df = pd.DataFrame(results) | |
df.columns=["n_size", "dx/dy", "function", "time_mean", "time_std"] | |
df["library"] = pd.DataFrame(df.function.str.split("_").to_list()).iloc[:,-1] | |
df.groupby('library')['time_mean'].plot(x="n_size", legend=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment