Last active
November 21, 2022 08:38
-
-
Save arquolo/bf595ad5067b45162f8c01cca23398b6 to your computer and use it in GitHub Desktop.
Fork of https://github.com/gruns/icecream with NumPy support. Python 3.6+
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# | |
# IceCream - Never use print() to debug again | |
# | |
# Ansgar Grunseid | |
# grunseid.com | |
# [email protected] | |
# | |
# Pavel Maevskikh | |
# [email protected] | |
# | |
# License: MIT | |
# | |
# pip install asttokens colorama executing numpy pygments | |
# | |
__all__ = ['ic'] | |
import ast | |
import inspect | |
import pprint | |
import shutil | |
import sys | |
from collections.abc import Iterable, Iterator, Mapping | |
from dataclasses import is_dataclass, replace | |
from datetime import datetime | |
from os.path import basename | |
from textwrap import dedent | |
from typing import Dict, List, NamedTuple, Tuple | |
import colorama | |
import executing | |
import numpy as np | |
from pygments import highlight | |
from pygments.formatters import TerminalFormatter | |
from pygments.lexers.python import PythonLexer | |
colorama.init() | |
PREFIX = 'ic| ' | |
LINE_WRAP_WIDTH = 70 # Characters | |
FORMATTER = TerminalFormatter(bg='dark') | |
LEXER = PythonLexer(ensurenl=False) | |
def is_literal(s) -> bool: | |
try: | |
ast.literal_eval(s) | |
return True | |
except Exception: # noqa: PIE786 | |
return False | |
class NoSourceAvailableError(OSError): | |
""" | |
Raised when icecream fails to find or access source code that's | |
required to parse and analyze. This can happen, for example, when | |
- ic() is invoked inside a REPL or interactive shell, e.g. from the | |
command line (CLI) or with python -i. | |
- The source code is mangled and/or packaged, e.g. with a project | |
freezer like PyInstaller. | |
- The underlying source code changed during execution. See | |
https://stackoverflow.com/a/33175832. | |
""" | |
info_message = ( | |
'Failed to access the underlying source code for analysis. Was ic() ' | |
'invoked in a REPL (e.g. from the command line), a frozen application ' | |
'(e.g. packaged with PyInstaller), or did the underlying source code ' | |
'change during execution?') | |
class Source(executing.Source): | |
def get_text_with_indentation(self, node) -> str: | |
result = self.asttokens().get_text(node) | |
if '\n' in result: | |
result = ' ' * node.first_token.start[1] + result | |
result = dedent(result) | |
return result.strip() | |
def indented_lines(prefix: str, lines: str) -> List[str]: | |
space = ' ' * len(prefix) | |
first, *rest = lines.splitlines() | |
return [prefix + first] + [space + line for line in rest] | |
def format_pair(prefix: str, arg: str, value: str) -> str: | |
# Align the start of multiline strings. | |
if value[0] + value[-1] in ["''", '""']: | |
first, *rest = value.splitlines(keepends=True) | |
value = first + ''.join(' ' + line for line in rest) | |
*lines, tail = indented_lines(prefix, arg) | |
return '\n'.join(lines + indented_lines(tail + ': ', value)) | |
def _get_nd_grad(arr: np.ndarray) -> 'Iterator[str]': | |
# A bit sophisticated way to compute gradients by all directions, | |
# but a fastest one. | |
# Split tensor by all axes, do mean for each cell, | |
# and then aggregate means to mean for the each axis split. | |
arr_f4 = arr.astype('f4') | |
# Pyramid of splits | |
splits: Dict[Tuple[int, ...], np.ndarray] = {(): arr_f4} | |
for axis, size in enumerate(arr.shape): | |
if size == 1: | |
splits = {(*k, 0): s for k, s in splits.items()} | |
else: | |
half = size // 2 | |
splits = { | |
(*k, k2): ss for k, s in splits.items() | |
for k2, ss in enumerate(np.split(s, [half, -half], axis)) | |
} | |
# Tensor of means | |
s_shape = [1 if size == 1 else 3 for size in arr.shape] | |
means = np.zeros(s_shape) | |
weights = np.zeros(s_shape, int) | |
for loc, s in splits.items(): | |
means[loc] = s.mean() if s.size else 0 | |
weights[loc] = s.size | |
# Aggregate and do grads | |
grad = np.zeros(arr.ndim) | |
for axis, size in enumerate(arr.shape): | |
if size == 1: | |
continue | |
axes = *(a for a in range(arr.ndim) if a != axis), | |
means_ = np.take(means, [0, 2], axis) | |
weights_ = np.take(weights, [0, 2], axis) | |
head, tail = np.average(means_, axes, weights_) | |
grad[axis] = tail - head | |
if grad.any(): | |
yield f'grad={grad.round(8)}' | |
def _get_properties(arr: np.ndarray) -> 'Iterator[str]': | |
yield f'{arr.shape}, dtype={arr.dtype}' | |
if not arr.size: | |
return | |
# Small array, print contents as is | |
if arr.size < 40: | |
yield f'data={arr.ravel()}' | |
return | |
# Small enough binary array, hexify | |
if arr.size < 500 and arr.dtype == bool: | |
data = np.packbits(arr.flat).tobytes() | |
line = ''.join(f'{v:02x}' for v in data).replace('0', '_') | |
yield f'data={line!r}' | |
return | |
# Too much data, use statistics | |
lo = arr.min() | |
hi = arr.max() | |
if arr.dtype.kind == 'f': | |
yield f'x∈[{lo:.8f}, {hi:.8f}]' | |
yield f'μ={arr.mean():.8f}, σ={arr.std():.8f}' | |
# Wide range, only low/high | |
elif int(hi) - int(lo) > 100: | |
yield f'x∈[{lo}, {hi}]' | |
# Medium range or zero crossing, low/high + nuniq | |
elif int(lo) < 0 or int(hi) > 10: | |
nuniq = np.unique(arr.ravel()).size | |
yield f'x∈[{lo}, {hi}], nuniq={nuniq}' | |
# Narrow range, raw distribution | |
else: | |
weights = np.bincount(arr.ravel()).astype('f8') / arr.size | |
yield f'weights={weights}' | |
yield from _get_nd_grad(arr) | |
class _ReprArray(NamedTuple): | |
data: np.ndarray | |
def __str__(self) -> str: | |
return str(self.data) | |
def __repr__(self) -> str: | |
return 'np.ndarray(' + ', '.join(_get_properties(self.data)) + ')' | |
def _patch_repr_types(obj): | |
if isinstance(obj, np.ndarray): | |
return _ReprArray(obj) | |
if isinstance(obj, (str, bytes, bytearray, range)): | |
return obj | |
if is_dataclass(obj): | |
return replace(obj, **_patch_repr_types(vars(obj))) | |
# namedtuple | |
if isinstance(obj, tuple) and hasattr(obj, '_fields'): | |
return type(obj)(*(_patch_repr_types(x) for x in obj)) | |
if isinstance(obj, Mapping): | |
return dict(_patch_repr_types(kv) for kv in obj.items()) | |
if isinstance(obj, Iterable) and not isinstance(obj, Iterator): | |
return type(obj)(_patch_repr_types(x) for x in obj) | |
return obj | |
def argument_to_string(obj) -> str: | |
obj = _patch_repr_types(obj) | |
# Preserve string newlines in output. | |
width = shutil.get_terminal_size().columns | |
return pprint.pformat(obj, width=width).replace('\\n', '\n') | |
def _format_time() -> str: | |
now = f'{datetime.now():%H:%M:%S.%f}'[:-3] | |
return f' at {now}' | |
def _format_context(frame, call_node) -> str: | |
info = inspect.getframeinfo(frame) | |
parent_fn = info.function | |
if parent_fn != '<module>': | |
parent_fn = f'{parent_fn}()' | |
return f'{basename(info.filename)}:{call_node.lineno} in {parent_fn}' | |
def _construct_argument_output(context, pairs) -> str: | |
pairs = [(arg, argument_to_string(val)) for arg, val in pairs] | |
# For cleaner output, if <arg> is a literal, eg 3, "string", b'bytes', | |
# etc, only output the value, not the argument and the value, as the | |
# argument and the value will be identical or nigh identical. Ex: with | |
# ic("hello"), just output | |
# | |
# ic| 'hello', | |
# | |
# instead of | |
# | |
# ic| "hello": 'hello'. | |
# | |
all_args_on_one_line = ', '.join( | |
val if is_literal(arg) else f'{arg}: {val}' for arg, val in pairs) | |
context_delimiter = f'{context}- ' if context else '' | |
all_pairs = PREFIX + context_delimiter + all_args_on_one_line | |
if len(all_args_on_one_line.splitlines()) <= 1 \ | |
and len(all_pairs.splitlines()[0]) <= LINE_WRAP_WIDTH: | |
# ic| foo.py:11 in foo()- a: 1, b: 2 | |
# ic| a: 1, b: 2, c: 3 | |
return PREFIX + context_delimiter + all_args_on_one_line | |
# ic| foo.py:11 in foo() | |
# multilineStr: 'line1 | |
# line2' | |
# | |
# ic| foo.py:11 in foo() | |
# a: 11111111111111111111 | |
# b: 22222222222222222222 | |
if context: | |
space = len(PREFIX) * ' ' | |
return '\n'.join( | |
[PREFIX + context] + | |
[format_pair(space, arg, value) for arg, value in pairs]) | |
# ic| multilineStr: 'line1 | |
# line2' | |
# | |
# ic| a: 11111111111111111111 | |
# b: 22222222222222222222 | |
lines = '\n'.join(format_pair('', arg, value) for arg, value in pairs) | |
return '\n'.join(indented_lines(PREFIX, lines)) | |
def _format(frame, *args) -> str: | |
call_node = Source.executing(frame).node | |
if call_node is None: | |
raise NoSourceAvailableError() | |
context = _format_context(frame, call_node) | |
if not args: | |
return PREFIX + context + _format_time() | |
source = Source.for_frame(frame) | |
sanitized_arg_strs = [ | |
source.get_text_with_indentation(arg) for arg in call_node.args | |
] | |
pairs = zip(sanitized_arg_strs, args) | |
return _construct_argument_output(context, pairs) | |
def ic(*args): | |
frame = inspect.currentframe() | |
assert frame | |
try: | |
out = _format(frame.f_back, *args) | |
except NoSourceAvailableError as err: | |
out = f'{PREFIX}Error: {err.info_message}' | |
s = highlight(out, LEXER, FORMATTER) | |
print(s, file=sys.stderr) | |
if not args: | |
return None # E.g. ic(). | |
if len(args) == 1: | |
return args[0] # E.g. ic(1). | |
return args # E.g. ic(1, 2, 3). |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment