Skip to content

Instantly share code, notes, and snippets.

@apivovarov
Created January 24, 2025 22:42
Show Gist options
  • Save apivovarov/83698d079369f479de964290c40b044c to your computer and use it in GitHub Desktop.
Save apivovarov/83698d079369f479de964290c40b044c to your computer and use it in GitHub Desktop.
JAX matmul accuracy
import jax
import jax.numpy as jnp
import numpy as np
print("JAX devices:", jax.devices())
shape = (16,16)
a_np = np.random.uniform(low=-1.0, high=1.0, size=shape).astype(np.float32)
b_np = np.random.uniform(low=-1.0, high=1.0, size=shape).astype(np.float32)
c_np = a_np @ b_np
dtype = jnp.float32
a = jnp.array(a_np, dtype=dtype)
b = jnp.array(b_np, dtype=dtype)
@jax.jit
def myfunc(a, b):
result = jnp.matmul(a, b)
return result
y = myfunc(a, b)
y_np = np.asarray(y)
arr1 = c_np
arr2 = y_np
abs_diff = np.abs(arr1 - arr2)
mean_abs_diff = np.mean(abs_diff)
# Relative difference
epsilon = 1e-12 # To avoid division by zero
rel_diff = abs_diff / (np.abs(arr1) + epsilon)
mean_rel_diff = np.mean(rel_diff)
# Output results
print("dtype:", arr2.dtype)
print("Mean Absolute Difference: %.3e" % mean_abs_diff)
print("Mean Relative Difference: %.3e" % mean_rel_diff)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment