Skip to content

Instantly share code, notes, and snippets.

@afspies
Forked from nmwsharp/printarr
Last active February 26, 2023 11:23
Show Gist options
  • Save afspies/64671e3a9822363df1129a38ceae16fe to your computer and use it in GitHub Desktop.
Save afspies/64671e3a9822363df1129a38ceae16fe to your computer and use it in GitHub Desktop.
Pretty print tables summarizing properties of tensor arrays in numpy, pytorch, jax, etc.
Pretty print tables summarizing properties of tensor arrays in numpy, pytorch, jax, etc.
name | dtype | shape | type | device | min | max | mean
--------------------------------------------------------------------------------------------------------------
[None] | None | N/A | NoneType | | N/A | N/A | N/A
intval1 | int | scalar | int | | 7 | 7 | 7
intval2 | int | scalar | int | | -3 | -3 | -3
floatval0 | float | scalar | float | | 42 | 42 | 42
floatval1 | float | scalar | float | | 5.5e-12 | 5.5e-12 | 5.5e-12
floatval2 | float | scalar | float | | 7.72324e+44 | 7.72324e+44 | 7.72324e+44
npval1 | int64 | [100] | numpy.ndarray | | 0 | 99 | 49.5
npval2 | int64 | [10000] | numpy.ndarray | | 0 | 9999 | 4999.5
npval3 | uint64 | [10000] | numpy.ndarray | | 0 | 9999 | 4999.5
npval4 | float32 | [100, 10, 10] | numpy.ndarray | | 0 | 9999 | 4999.5
[temporary] | float32 | [10, 8] | numpy.ndarray | | 2 | 99 | 50.5
npval5 | int64 | [] | numpy.int64 | | 9999 | 9999 | 9999
torchval1 | torch.float32 | [1000, 12, 3] | torch.Tensor | cpu | -4.08445 | 3.90982 | 0.00404567
torchval2 | torch.float32 | [1000, 12, 3] | torch.Tensor | cuda:0 | -3.87309 | 3.90342 | 0.00339224
torchval3 | torch.int64 | [1000] | torch.Tensor | cpu | 0 | 999 | N/A
torchval4 | torch.int64 | [] | torch.Tensor | cpu | 0 | 0 | N/A
def printarr(*arrs, float_width=6):
"""
Print a pretty table giving name, shape, dtype, type, and content information for input tensors or scalars.
Call like: printarr(my_arr, some_other_arr, maybe_a_scalar). Accepts a variable number of arguments.
Inputs can be:
- Numpy tensor arrays
- Pytorch tensor arrays
- Jax tensor arrays
- Python ints / floats
- None
It may also work with other array-like types, but they have not been tested.
Use the `float_width` option specify the precision to which floating point types are printed.
Author: Nicholas Sharp (nmwsharp.com)
Canonical source: https://gist.github.com/nmwsharp/54d04af87872a4988809f128e1a1d233
License: This snippet may be used under an MIT license, and it is also released into the public domain.
Please retain this docstring as a reference.
"""
frame = inspect.currentframe().f_back
default_name = "[temporary]"
## helpers to gather data about each array
def name_from_outer_scope(a):
if a is None:
return '[None]'
name = default_name
for k, v in frame.f_locals.items():
if v is a:
name = k
break
return name
def dtype_str(a):
if a is None:
return 'None'
if isinstance(a, int):
return 'int'
if isinstance(a, float):
return 'float'
return str(a.dtype)
def shape_str(a):
if a is None:
return 'N/A'
if isinstance(a, int):
return 'scalar'
if isinstance(a, float):
return 'scalar'
return str(list(a.shape))
def type_str(a):
return str(type(a))[8:-2] # TODO this is is weird... what's the better way?
def device_str(a):
if hasattr(a, 'device'):
device_str = str(a.device)
if len(device_str) < 10:
# heuristic: jax returns some goofy long string we don't want, ignore it
return device_str
return ""
def format_float(x):
return f"{x:{float_width}g}"
def minmaxmean_str(a):
if a is None:
return ('N/A', 'N/A', 'N/A')
if isinstance(a, int) or isinstance(a, float):
return (format_float(a), format_float(a), format_float(a))
# compute min/max/mean. if anything goes wrong, just print 'N/A'
min_str = "N/A"
try: min_str = format_float(a.min())
except: pass
max_str = "N/A"
try: max_str = format_float(a.max())
except: pass
mean_str = "N/A"
try: mean_str = format_float(a.mean())
except: pass
return (min_str, max_str, mean_str)
try:
props = ['name', 'dtype', 'shape', 'type', 'device', 'min', 'max', 'mean']
# precompute all of the properties for each input
str_props = []
for a in arrs:
minmaxmean = minmaxmean_str(a)
str_props.append({
'name' : name_from_outer_scope(a),
'dtype' : dtype_str(a),
'shape' : shape_str(a),
'type' : type_str(a),
'device' : device_str(a),
'min' : minmaxmean[0],
'max' : minmaxmean[1],
'mean' : minmaxmean[2],
})
# for each property, compute its length
maxlen = {}
for p in props: maxlen[p] = 0
for sp in str_props:
for p in props:
maxlen[p] = max(maxlen[p], len(sp[p]))
# if any property got all empty strings, don't bother printing it, remove if from the list
props = [p for p in props if maxlen[p] > 0]
# Account for possibility that header is longer than any of the values
maxlen = {p: max(maxlen[p], len(p)) for p in props}
# print a header
header_str = ""
for p in props:
prefix = "" if p == 'name' else " | "
fmt_key = ">" if p == 'name' else "<"
header_str += f"{prefix}{p:{fmt_key}{maxlen[p]}}"
print(header_str)
print("-"*len(header_str))
# now print the acual arrays
for strp in str_props:
for p in props:
prefix = "" if p == 'name' else " | "
fmt_key = ">" if p == 'name' else "<"
print(f"{prefix}{strp[p]:{fmt_key}{maxlen[p]}}", end='')
print("")
finally:
del frame
if __name__ == "__main__":
## test it!
# plain python vlaues
noneval = None
intval1 = 7
intval2 = -3
floatval0 = 42.0
floatval1 = 5.5 * 1e-12
floatval2 = 7.7232412351231231234 * 1e44
# numpy values
import numpy as np
npval1 = np.arange(100)
npval2 = np.arange(10000)
npval3 = np.arange(10000).astype(np.uint64)
npval4 = np.arange(10000).astype(np.float32).reshape(100,10,10)
npval5 = np.arange(10000)[-1]
# torch values
torchval1 = None
torchval2 = None
torchval3 = None
torchval4 = None
try:
import torch
torchval1 = torch.randn((1000,12,3))
torchval2 = torch.randn((1000,12,3)).cuda()
torchval3 = torch.arange(1000)
torchval4 = torch.arange(1000)[0]
except ModuleNotFoundError:
pass
# jax values
jaxval1 = None
jaxval2 = None
jaxval3 = None
jaxval4 = None
try:
import jax
import jax.numpy as jnp
jaxval1 = jnp.linspace(0,1,10000)
jaxval2 = jnp.linspace(0,1,10000).reshape(100,10,10)
jaxval3 = jnp.arange(1000)
jaxval4 = jnp.arange(1000)[0]
except ModuleNotFoundError:
pass
printarr(noneval,
intval1, intval2, \
floatval0, floatval1, floatval2, \
npval1, npval2, npval3, npval4, npval4[0,:,2:], npval5, \
torchval1, torchval2, torchval3, torchval4, \
jaxval1, jaxval2, jaxval3, jaxval4, \
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment