Skip to content

Instantly share code, notes, and snippets.

@renxida
Created October 18, 2024 21:31
Show Gist options
  • Save renxida/fde763bbdd659f158f83cddfa6e61dae to your computer and use it in GitHub Desktop.
Save renxida/fde763bbdd659f158f83cddfa6e61dae to your computer and use it in GitHub Desktop.
function for dumping a bunch of info about any shortfin tensor (host, device, array.array, np.ndarray)
def log_tensor_stats(tensor, name="Tensor"):
logger.info(f"{name} stats:")
logger.info(f" type :{type(tensor)}")
logger.info(f" dtype:{tensor.dtype}")
from scipy import stats
import array
import numpy as np
fp16hack = False
if isinstance(tensor, sfnp.device_array):
host_tensor = tensor.for_transfer()
host_tensor.copy_from(tensor)
tensor = host_tensor
if isinstance(tensor, sfnp.base_array):
if tensor.dtype == sfnp.float16:
fp16hack = True
tensor = tensor.items
if isinstance(tensor, array.array):
dtype = tensor.typecode
if fp16hack:
dtype = np.float16
tensor = np.frombuffer(tensor, dtype=dtype)
assert isinstance(tensor, np.ndarray)
# Count NaN values
nan_count = np.isnan(tensor).sum()
# Remove NaN values for calculations
tensor_no_nan = tensor[~np.isnan(tensor)]
logger.info(f" NaN count: {nan_count} / {tensor.size}")
logger.info(f" Shape: {tensor.shape}, dtype: {tensor.dtype}")
if len(tensor_no_nan) > 0:
logger.info(f" Min (excluding NaN): {tensor_no_nan.min()}")
logger.info(f" Max (excluding NaN): {tensor_no_nan.max()}")
logger.info(f" Mean (excluding NaN): {tensor_no_nan.mean()}")
logger.info(
f" Mode (excluding NaN): {stats.mode(tensor_no_nan)[0]}"
)
logger.info(
f" First 10 elements: {tensor_no_nan.flatten()[:10]}"
f" Last 10 elements: {tensor_no_nan.flatten()[-10:]}"
)
else:
logger.warning(f" All values are NaN in {name}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment