Last active
January 29, 2025 01:22
-
-
Save apivovarov/18f196a502a24cd1a24da4438b91728b to your computer and use it in GitHub Desktop.
jax rsqrt accuracy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# RSQRT | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
print("JAX devices:", jax.devices()) | |
a_np=np.array([[[1048684900], [1052291356], [1049963963], [1050007938], [1051252185], [1050317382]], [[1045717137], [1050494007], [1050815620], [1049559979], [1051598875], [1051539171]]], dtype=np.uint32).view(np.float32) | |
c_np = 1.0 / np.sqrt(a_np) | |
for dtype in [jnp.float32, jnp.float16, jnp.bfloat16]: | |
a = jnp.array(a_np, dtype=dtype) | |
@jax.jit | |
def myfunc(a): | |
result = jax.lax.rsqrt(a) | |
return result | |
y = myfunc(a) | |
y_np = np.asarray(y) | |
arr1 = c_np | |
arr2 = y_np | |
abs_diff = np.abs(arr1 - arr2) | |
max_abs_diff = np.max(abs_diff) | |
# Relative difference | |
epsilon = 1e-12 # To avoid division by zero | |
rel_diff = abs_diff / (np.abs(arr1) + epsilon) | |
max_rel_diff = np.max(rel_diff) | |
# Output results | |
print("dtype:", arr2.dtype) | |
print("Max Absolute Difference: %.3e" % max_abs_diff) | |
print("Max Relative Difference: %.3e" % max_rel_diff) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment