-
-
Save afspies/64671e3a9822363df1129a38ceae16fe to your computer and use it in GitHub Desktop.
Pretty print tables summarizing properties of tensor arrays in numpy, pytorch, jax, etc.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Pretty print tables summarizing properties of tensor arrays in numpy, pytorch, jax, etc. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name | dtype | shape | type | device | min | max | mean | |
-------------------------------------------------------------------------------------------------------------- | |
[None] | None | N/A | NoneType | | N/A | N/A | N/A | |
intval1 | int | scalar | int | | 7 | 7 | 7 | |
intval2 | int | scalar | int | | -3 | -3 | -3 | |
floatval0 | float | scalar | float | | 42 | 42 | 42 | |
floatval1 | float | scalar | float | | 5.5e-12 | 5.5e-12 | 5.5e-12 | |
floatval2 | float | scalar | float | | 7.72324e+44 | 7.72324e+44 | 7.72324e+44 | |
npval1 | int64 | [100] | numpy.ndarray | | 0 | 99 | 49.5 | |
npval2 | int64 | [10000] | numpy.ndarray | | 0 | 9999 | 4999.5 | |
npval3 | uint64 | [10000] | numpy.ndarray | | 0 | 9999 | 4999.5 | |
npval4 | float32 | [100, 10, 10] | numpy.ndarray | | 0 | 9999 | 4999.5 | |
[temporary] | float32 | [10, 8] | numpy.ndarray | | 2 | 99 | 50.5 | |
npval5 | int64 | [] | numpy.int64 | | 9999 | 9999 | 9999 | |
torchval1 | torch.float32 | [1000, 12, 3] | torch.Tensor | cpu | -4.08445 | 3.90982 | 0.00404567 | |
torchval2 | torch.float32 | [1000, 12, 3] | torch.Tensor | cuda:0 | -3.87309 | 3.90342 | 0.00339224 | |
torchval3 | torch.int64 | [1000] | torch.Tensor | cpu | 0 | 999 | N/A | |
torchval4 | torch.int64 | [] | torch.Tensor | cpu | 0 | 0 | N/A |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def printarr(*arrs, float_width=6): | |
""" | |
Print a pretty table giving name, shape, dtype, type, and content information for input tensors or scalars. | |
Call like: printarr(my_arr, some_other_arr, maybe_a_scalar). Accepts a variable number of arguments. | |
Inputs can be: | |
- Numpy tensor arrays | |
- Pytorch tensor arrays | |
- Jax tensor arrays | |
- Python ints / floats | |
- None | |
It may also work with other array-like types, but they have not been tested. | |
Use the `float_width` option specify the precision to which floating point types are printed. | |
Author: Nicholas Sharp (nmwsharp.com) | |
Canonical source: https://gist.github.com/nmwsharp/54d04af87872a4988809f128e1a1d233 | |
License: This snippet may be used under an MIT license, and it is also released into the public domain. | |
Please retain this docstring as a reference. | |
""" | |
frame = inspect.currentframe().f_back | |
default_name = "[temporary]" | |
## helpers to gather data about each array | |
def name_from_outer_scope(a): | |
if a is None: | |
return '[None]' | |
name = default_name | |
for k, v in frame.f_locals.items(): | |
if v is a: | |
name = k | |
break | |
return name | |
def dtype_str(a): | |
if a is None: | |
return 'None' | |
if isinstance(a, int): | |
return 'int' | |
if isinstance(a, float): | |
return 'float' | |
return str(a.dtype) | |
def shape_str(a): | |
if a is None: | |
return 'N/A' | |
if isinstance(a, int): | |
return 'scalar' | |
if isinstance(a, float): | |
return 'scalar' | |
return str(list(a.shape)) | |
def type_str(a): | |
return str(type(a))[8:-2] # TODO this is is weird... what's the better way? | |
def device_str(a): | |
if hasattr(a, 'device'): | |
device_str = str(a.device) | |
if len(device_str) < 10: | |
# heuristic: jax returns some goofy long string we don't want, ignore it | |
return device_str | |
return "" | |
def format_float(x): | |
return f"{x:{float_width}g}" | |
def minmaxmean_str(a): | |
if a is None: | |
return ('N/A', 'N/A', 'N/A') | |
if isinstance(a, int) or isinstance(a, float): | |
return (format_float(a), format_float(a), format_float(a)) | |
# compute min/max/mean. if anything goes wrong, just print 'N/A' | |
min_str = "N/A" | |
try: min_str = format_float(a.min()) | |
except: pass | |
max_str = "N/A" | |
try: max_str = format_float(a.max()) | |
except: pass | |
mean_str = "N/A" | |
try: mean_str = format_float(a.mean()) | |
except: pass | |
return (min_str, max_str, mean_str) | |
try: | |
props = ['name', 'dtype', 'shape', 'type', 'device', 'min', 'max', 'mean'] | |
# precompute all of the properties for each input | |
str_props = [] | |
for a in arrs: | |
minmaxmean = minmaxmean_str(a) | |
str_props.append({ | |
'name' : name_from_outer_scope(a), | |
'dtype' : dtype_str(a), | |
'shape' : shape_str(a), | |
'type' : type_str(a), | |
'device' : device_str(a), | |
'min' : minmaxmean[0], | |
'max' : minmaxmean[1], | |
'mean' : minmaxmean[2], | |
}) | |
# for each property, compute its length | |
maxlen = {} | |
for p in props: maxlen[p] = 0 | |
for sp in str_props: | |
for p in props: | |
maxlen[p] = max(maxlen[p], len(sp[p])) | |
# if any property got all empty strings, don't bother printing it, remove if from the list | |
props = [p for p in props if maxlen[p] > 0] | |
# Account for possibility that header is longer than any of the values | |
maxlen = {p: max(maxlen[p], len(p)) for p in props} | |
# print a header | |
header_str = "" | |
for p in props: | |
prefix = "" if p == 'name' else " | " | |
fmt_key = ">" if p == 'name' else "<" | |
header_str += f"{prefix}{p:{fmt_key}{maxlen[p]}}" | |
print(header_str) | |
print("-"*len(header_str)) | |
# now print the acual arrays | |
for strp in str_props: | |
for p in props: | |
prefix = "" if p == 'name' else " | " | |
fmt_key = ">" if p == 'name' else "<" | |
print(f"{prefix}{strp[p]:{fmt_key}{maxlen[p]}}", end='') | |
print("") | |
finally: | |
del frame |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if __name__ == "__main__": | |
## test it! | |
# plain python vlaues | |
noneval = None | |
intval1 = 7 | |
intval2 = -3 | |
floatval0 = 42.0 | |
floatval1 = 5.5 * 1e-12 | |
floatval2 = 7.7232412351231231234 * 1e44 | |
# numpy values | |
import numpy as np | |
npval1 = np.arange(100) | |
npval2 = np.arange(10000) | |
npval3 = np.arange(10000).astype(np.uint64) | |
npval4 = np.arange(10000).astype(np.float32).reshape(100,10,10) | |
npval5 = np.arange(10000)[-1] | |
# torch values | |
torchval1 = None | |
torchval2 = None | |
torchval3 = None | |
torchval4 = None | |
try: | |
import torch | |
torchval1 = torch.randn((1000,12,3)) | |
torchval2 = torch.randn((1000,12,3)).cuda() | |
torchval3 = torch.arange(1000) | |
torchval4 = torch.arange(1000)[0] | |
except ModuleNotFoundError: | |
pass | |
# jax values | |
jaxval1 = None | |
jaxval2 = None | |
jaxval3 = None | |
jaxval4 = None | |
try: | |
import jax | |
import jax.numpy as jnp | |
jaxval1 = jnp.linspace(0,1,10000) | |
jaxval2 = jnp.linspace(0,1,10000).reshape(100,10,10) | |
jaxval3 = jnp.arange(1000) | |
jaxval4 = jnp.arange(1000)[0] | |
except ModuleNotFoundError: | |
pass | |
printarr(noneval, | |
intval1, intval2, \ | |
floatval0, floatval1, floatval2, \ | |
npval1, npval2, npval3, npval4, npval4[0,:,2:], npval5, \ | |
torchval1, torchval2, torchval3, torchval4, \ | |
jaxval1, jaxval2, jaxval3, jaxval4, \ | |
) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment