Created
January 24, 2025 22:42
-
-
Save apivovarov/83698d079369f479de964290c40b044c to your computer and use it in GitHub Desktop.
JAX matmul accuracy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
print("JAX devices:", jax.devices()) | |
shape = (16,16) | |
a_np = np.random.uniform(low=-1.0, high=1.0, size=shape).astype(np.float32) | |
b_np = np.random.uniform(low=-1.0, high=1.0, size=shape).astype(np.float32) | |
c_np = a_np @ b_np | |
dtype = jnp.float32 | |
a = jnp.array(a_np, dtype=dtype) | |
b = jnp.array(b_np, dtype=dtype) | |
@jax.jit | |
def myfunc(a, b): | |
result = jnp.matmul(a, b) | |
return result | |
y = myfunc(a, b) | |
y_np = np.asarray(y) | |
arr1 = c_np | |
arr2 = y_np | |
abs_diff = np.abs(arr1 - arr2) | |
mean_abs_diff = np.mean(abs_diff) | |
# Relative difference | |
epsilon = 1e-12 # To avoid division by zero | |
rel_diff = abs_diff / (np.abs(arr1) + epsilon) | |
mean_rel_diff = np.mean(rel_diff) | |
# Output results | |
print("dtype:", arr2.dtype) | |
print("Mean Absolute Difference: %.3e" % mean_abs_diff) | |
print("Mean Relative Difference: %.3e" % mean_rel_diff) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment