Last active
November 25, 2019 14:54
-
-
Save temporaer/317e6df339125431c3531019ad72ad9d to your computer and use it in GitHub Desktop.
pdbhelpers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
alias deltrace ipdb.set_trace = lambda: None | |
alias pshape (%1).shape | |
alias plen len(%1) | |
alias plshape for i in (%1): print(i.shape) | |
alias embed import IPython; IPython.embed() | |
alias ts from pprint import pprint; from pdbhelpers import tensor_shapes; A=tensor_shapes(%1); pprint(A) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# print shapes/dtypes of a complex datastructure containing tensors/ndarrays | |
# to use, put this in your ~/.pdbrc: | |
# alias ts from pprint import pprint; from pdbhelpers import tensor_shapes; A=tensor_shapes(%1); pprint(A) | |
try: | |
import torch | |
from collections import OrderedDict | |
except ImportError: | |
pass | |
def tensor_shapes(t): | |
import numpy as np | |
def fmt_tensor(t: torch.Tensor): | |
ss = str(t.dtype) | |
if t.is_cuda: | |
ss += "_cuda" | |
ss += "(" + ",".join("%d" % i for i in t.shape) + ")" | |
return ss | |
def fmt_np(t: np.ndarray): | |
ss = str(t.dtype) | |
ss += "(" + ",".join("%d" % i for i in t.shape) + ")" | |
return ss | |
if t is None: | |
return None | |
if isinstance(t, torch.Tensor): | |
return fmt_tensor(t) | |
if isinstance(t, tuple): | |
return tuple(tensor_shapes(tt) for tt in t) | |
if isinstance(t, list): | |
return list(tensor_shapes(tt) for tt in t) | |
if isinstance(t, (dict, OrderedDict)): | |
return {tensor_shapes(k): tensor_shapes(v) for k, v in t.items()} | |
if isinstance(t, np.ndarray): | |
return fmt_np(t) | |
if isinstance(t, str): | |
return t | |
if isinstance(t, (float, int)): | |
return t | |
return t.__class__.__name__ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment