Skip to content

Instantly share code, notes, and snippets.

Last active January 29, 2025 01:22
Show Gist options
  • Save apivovarov/18f196a502a24cd1a24da4438b91728b to your computer and use it in GitHub Desktop.
Save apivovarov/18f196a502a24cd1a24da4438b91728b to your computer and use it in GitHub Desktop.
jax rsqrt accuracy
import jax
import jax.numpy as jnp
import numpy as np
print("JAX devices:", jax.devices())
a_np=np.array([[[1048684900], [1052291356], [1049963963], [1050007938], [1051252185], [1050317382]], [[1045717137], [1050494007], [1050815620], [1049559979], [1051598875], [1051539171]]], dtype=np.uint32).view(np.float32)
c_np = 1.0 / np.sqrt(a_np)
for dtype in [jnp.float32, jnp.float16, jnp.bfloat16]:
a = jnp.array(a_np, dtype=dtype)
def myfunc(a):
result = jax.lax.rsqrt(a)
return result
y = myfunc(a)
y_np = np.asarray(y)
arr1 = c_np
arr2 = y_np
abs_diff = np.abs(arr1 - arr2)
max_abs_diff = np.max(abs_diff)
# Relative difference
epsilon = 1e-12 # To avoid division by zero
rel_diff = abs_diff / (np.abs(arr1) + epsilon)
max_rel_diff = np.max(rel_diff)
# Output results
print("dtype:", arr2.dtype)
print("Max Absolute Difference: %.3e" % max_abs_diff)
print("Max Relative Difference: %.3e" % max_rel_diff)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment