Created
February 4, 2023 13:57
-
-
Save astanziola/565681777ed2a6231bbd7184bfb4b80e to your computer and use it in GitHub Desktop.
Differentiable approximate plane wave Ultrasound PSF in JAX
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This code is a quick reproduction of eq. (9) of | |
# "Mathematical Analysis of Ultrafast Ultrasound Imaging" by Alberti ed at. 2016. | |
# https://arxiv.org/pdf/1604.04604.pdf | |
# | |
# It represents an approximate point spread function for Plane Wave Imaging that can be | |
# used to write a simple, yet powerful, 2D Plane Wave ultrasound simulator using spatially-variant | |
# convolutions. | |
# | |
# The function is fully differentiable. | |
import jax | |
from jax import numpy as jnp | |
def f(t, *, v0, tau): | |
X = lambda u : jnp.exp(-(u**2)/(tau**2)) | |
return jnp.exp(2*jnp.pi*1j*v0*t) * X(v0*t) | |
def f_prime(t, v0, tau): | |
f_set = partial(f, v0=v0, tau=tau) | |
primals, f_vjp = jax.vjp(f_set, t) | |
return f_vjp(1.0 + 0*1j)[0] + f_vjp(1j)[0]*1j | |
def plane_wave_psf( | |
x, | |
z, | |
theta, # Transmit angle | |
c0, # Background sound speed | |
F, # Aperture size | |
v0, # Base frequency | |
tau, # Pulse width | |
): | |
# https://arxiv.org/pdf/1604.04604.pdf Eq. (9) | |
prefact = c0/(4*jnp.pi*x) | |
aperture_prefact = 1/(c0 * jnp.sqrt(1 + F**2)) | |
z_component = (1 + jnp.sqrt(1 + F**2)*jnp.cos(theta))*z | |
x_component_left = (jnp.sqrt(1 + F**2)*jnp.sin(theta) - F)*x | |
x_component_right = (jnp.sqrt(1 + F**2)*jnp.sin(theta) + F)*x | |
f_1 = partial(f_prime, v0=v0, tau=tau) | |
square_bracket = f_1(aperture_prefact * (z_component + x_component_left)) - f_1(aperture_prefact * (z_component + x_component_right)) | |
return prefact * square_bracket | |
psf_line = jax.vmap(plane_wave_psf, in_axes=(0,None,None,None,None,None,None)) | |
psf_fun = jax.vmap(psf_line, in_axes=(None,0,None,None,None,None,None)) | |
# Example usage. | |
# The following code obtains a 2D PSF | |
x = jnp.linspace(-0.003, 0.003, 1000) # Spatial coordinates | |
z = jnp.linspace(-0.003, 0.003, 1000) | |
aperture = 0.05 # Aperture size | |
z = 0.15 # Depth | |
F = aperture / z # 1/f-number (as defined in Alberti et al.) | |
theta = 10*jnp.pi/180 # Steering angle in radiants | |
c0 = 1540 # Background sound sspeed | |
v0 = 5e6 # Center frequency | |
tau = 2 # ~ number of cycles | |
result = psf_fun(x, z, -theta, c0, F, v0, tau) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment