Created
June 14, 2024 08:53
-
-
Save pashu123/898636a138e41e1db2443acd1248d6d4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def compare_arrays(expected, computed): | |
# Check if the shapes of the arrays match | |
if expected.shape != computed.shape: | |
print("Arrays have different shapes.") | |
return | |
# Find where mismatches occur (including handling NaNs) | |
mismatches = np.where((abs(expected - computed) > 1.0) & ~np.isnan(expected) & ~np.isnan(computed)) | |
# Print the mismatched values | |
for index in zip(*mismatches): | |
print(f"Mismatch at index {index}: golden={expected[index]}, iree={computed[index]}") | |
# Handle NaNs separately | |
nan_mismatches = np.where(np.isnan(expected) != np.isnan(computed)) | |
for index in zip(*nan_mismatches): | |
print(f"NaN mismatch at index {index}: golden={expected[index]}, iree={computed[index]}") | |
def softmax(x, axis=None): | |
# Subtract the max value along the specified axis to prevent overflow | |
x_max = np.max(x, axis=axis, keepdims=False) | |
return x_max | |
sub_in = x - x_max | |
return sub_in | |
e_x = np.exp(sub_in) | |
return e_x | |
np.save('e_x_inp.npy', e_x) | |
sum_e_x = np.sum(e_x, axis=axis, keepdims=False) | |
return sum_e_x | |
return e_x / sum_e_x | |
inp_tensor = np.load('42_inputs.npy') | |
print(np.isnan(inp_tensor).any()) | |
golden_out = softmax(inp_tensor, axis=3) | |
print(np.isnan(golden_out).any()) | |
iree_out = np.load('42_out_repro.npy') | |
print(np.isnan(iree_out).any()) | |
compare_arrays(golden_out, iree_out) | |
np.testing.assert_allclose(golden_out, iree_out, rtol=1e-5, atol=1e-5) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment