Skip to content

Instantly share code, notes, and snippets.

@apivovarov
Created January 29, 2025 01:50
Show Gist options
  • Save apivovarov/50b766f2ccb82cbdf0a04e0084f59f46 to your computer and use it in GitHub Desktop.
Save apivovarov/50b766f2ccb82cbdf0a04e0084f59f46 to your computer and use it in GitHub Desktop.
JAX reduce op accuracy
# REDUCE
import jax
import jax.numpy as jnp
import numpy as np
print("JAX devices:", jax.devices())
a_np=np.array([[[[3214685668, 1050640488, 1060743252, 3209803584], [3204519310, 1067654368, 1067817699, 1067875232], [1056212128, 3212040969, 3205718709, 1065846737], [3212748857, 3210953055, 3206425550, 3214376535]], [[3216393659, 3204671589, 1046392801, 3210937971], [3212871310, 1011922854, 3201903270, 1056981194], [1057317906, 1057615558, 1049853029, 1054672679], [3212476770, 3221471932, 3220222283, 3214302880]], [[3223731609, 3214052862, 3180226583, 3214602181], [1058005824, 1066321194, 3196840352, 3205731707], [1063641844, 1058202109, 3199602305, 1062816921], [3213035287, 3205352409, 3207120713, 3215062456]], [[3213243313, 3198190149, 3200959705, 3220727198], [3216848691, 1065951977, 1058746486, 3187463331], [3208759739, 3209999898, 3201950053, 1057709270], [3215410801, 3220972337, 3217900520, 3205565316]], [[3204515228, 3200171121, 1036747732, 3212008346], [3212994805, 1064014337, 1045871186, 1049358571], [3206003701, 3199328665, 3197640218, 1054410096], [3213258490, 3223234703, 3214133538, 3205696757]], [[3208849732, 3197060701, 1068358080, 3206184405], [3192432672, 3175462240, 3205612694, 1056886796], [3206545803, 3194521889, 1045375233, 3198940504], [3216814735, 3208711995, 3186193045, 3214037638]]], [[[3216360967, 1036684125, 1063106641, 3207920868], [3210912356, 1057406131, 1061889454, 1070389906], [1063614043, 1011093775, 1051630319, 1042937344], [3218021979, 3211821692, 3213480277, 3218852962]], [[3198755067, 1042856514, 1062589330, 3218310721], [3209203833, 1052299815, 3206946607, 1042539670], [3211616419, 3213324222, 3174058236, 1065812446], [3212593428, 3222880238, 3215602142, 3211626657]], [[3220552116, 3214872403, 3197154275, 3215200023], [3209910984, 1060965356, 3183351687, 1035694192], [3197118079, 3200897243, 3192881157, 1047471646], [3212466763, 3215646388, 3216358017, 3215608310]], [[3213116663, 3201643522, 1064224623, 3211875058], [3211266780, 1058085186, 3189642974, 1054486153], [1044948679, 3202277747, 3202345174, 1060848778], [3210467392, 3213487717, 3211134876, 3214338282]], [[3211677223, 1048864005, 1063947028, 3209123399], [3210500909, 1061565305, 1060888820, 1067545326], [1057679257, 1037987315, 3177818048, 3197682771], [3216272100, 3189838367, 3209647833, 3218242390]], [[3221011316, 3215432750, 1033295809, 3207701561], [1053112007, 1064012590, 3208772933, 3210342449], [3180445954, 1058780285, 1044248884, 1062992136], [3211900119, 3220486882, 3214684104, 3215066638]]]], dtype=np.uint32).view(np.float32)
c_np = np.sum(a_np)
for dtype in [jnp.float32, jnp.float16, jnp.bfloat16]:
a = jnp.array(a_np, dtype=dtype)
@jax.jit
def myfunc(a):
result = jnp.sum(a)
return result
y = myfunc(a)
y_np = np.asarray(y)
arr1 = c_np
arr2 = y_np
abs_diff = np.abs(arr1 - arr2)
max_abs_diff = np.max(abs_diff)
# Relative difference
epsilon = 1e-12 # To avoid division by zero
rel_diff = abs_diff / (np.abs(arr1) + epsilon)
max_rel_diff = np.max(rel_diff)
# Output results
print("dtype:", arr2.dtype)
print("Max Absolute Difference: %.3e" % max_abs_diff)
print("Max Relative Difference: %.3e" % max_rel_diff)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment