Created
September 12, 2023 15:56
-
-
Save smsharma/dccc4a1a1f2ca8e9434a96bfa8f0057b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import jax | |
import jax.numpy as jnp | |
# Global flag to set a specific platform, must be used at startup. | |
jax.config.update("jax_platform_name", "gpu") | |
# from diffrax import diffeqsolve, ODETerm, Dopri5, PIDController, SaveAt | |
from jax.experimental.ode import odeint | |
def get_t(rho_g_vec, rho_nu_vec, rho_NP_vec, P_NP_vec): | |
rho_tot_vec = rho_g_vec + rho_nu_vec + rho_NP_vec | |
rho_tot_flip = jnp.flip(rho_tot_vec) | |
P_tot_flip = jnp.flip(P_NP_vec) | |
# returning 0 here cuts us down to 0.05s on GPU. This is because interpolation | |
# is slow--using dot_interp from https://github.com/google/jax/issues/16182 speeds | |
# this step up. | |
def P_tot(rho_tot): | |
return jnp.interp(rho_tot, rho_tot_flip, P_tot_flip) | |
# return 0.0 | |
def dt_prime(rho_tot, t, args): | |
# hcb.id_print((rho_tot,t)) | |
return 1.0 / (-3.0 * (rho_tot + P_tot(rho_tot))) | |
rho_tot_init = rho_tot_vec[0] | |
rho_tot_fin = rho_tot_vec[-1] | |
# sol_t = diffeqsolve( | |
# ODETerm(dt_prime), | |
# Dopri5(), | |
# t0=rho_tot_init, | |
# t1=rho_tot_fin, | |
# y0=1, | |
# dt0=None, | |
# max_steps=4096, | |
# saveat=SaveAt(ts=rho_tot_vec), | |
# stepsize_controller=PIDController(rtol=1e-4, atol=1e-4), | |
# ) | |
# return sol_t.ys | |
sol = odeint( | |
dt_prime, | |
1.0, | |
jnp.linspace(rho_tot_init, rho_tot_fin, 1000), | |
(), | |
rtol=1e-4, | |
atol=1e-4, | |
mxstep=4096, | |
) | |
return sol | |
@jax.jit | |
def get_abundances( | |
rho_g_vec, | |
rho_nu_vec, | |
rho_NP_vec, | |
P_NP_vec, | |
): | |
t_vec = get_t(rho_g_vec, rho_nu_vec, rho_NP_vec, P_NP_vec) | |
return jnp.array([t_vec[3], 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) | |
@jax.jit | |
def rho_gam(T): | |
return 2 * jnp.pi**2 / 30.0 * T**4 | |
############################################################################################## | |
############################################################################################## | |
############################################################################################## | |
############################################################################################## | |
T_gamma_array = jnp.logspace(jnp.log10(8.617), jnp.log10(3.83e-4), num=418) | |
test_array_CPU = jnp.logspace( | |
jnp.log10(3.17473950e03), jnp.log10(1.68692195e-15), num=418 | |
) | |
rho_extra_array_CPU = jnp.concatenate( | |
( | |
jnp.logspace(jnp.log10(6.34941711e03), jnp.log10(1.63508741e-20), num=342), | |
jnp.zeros(76), | |
) | |
) | |
T_gamma_array_gp = T_gamma_array | |
rho_gamma_array = rho_gam(T_gamma_array_gp) | |
test_array_gp = test_array_CPU | |
rho_extra_array_gp = rho_extra_array_CPU | |
n_batch = 1024 | |
# Batched versions of these 3 arrays | |
rho_gamma_array = jnp.tile(rho_gamma_array, (n_batch, 1)) | |
test_array_gp = jnp.tile(test_array_gp, (n_batch, 1)) | |
rho_extra_array_gp = jnp.tile(rho_extra_array_gp, (n_batch, 1)) | |
############################################################################################## | |
############################################################################################## | |
# compilation run | |
start_time = time.time() | |
res = jax.vmap(get_abundances)( | |
rho_gamma_array, test_array_gp, rho_extra_array_gp, rho_extra_array_gp / 3 | |
) | |
Neff_vec = res[:, 0] | |
print("finished in %s seconds" % (time.time() - start_time)) | |
# timing runs | |
for i in range(10): | |
start_time = time.time() | |
res = jax.vmap(get_abundances)( | |
rho_gamma_array, test_array_gp, rho_extra_array_gp, rho_extra_array_gp / 3 | |
) | |
Neff_vec = res[:, 0] | |
print("finished in %s seconds" % (time.time() - start_time)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment