Created
December 8, 2022 02:39
-
-
Save Guitaricet/1aca93323b0d3f94a35a9001aa736467 to your computer and use it in GitHub Desktop.
add NamedShape property to torch.Tensor
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class NamedShape: | |
"""A convenience class to make beautifully named shapes.""" | |
def __init__(self, tensor) -> None: | |
self.names = tensor.names | |
self.shape = tensor.shape | |
def __repr__(self) -> str: | |
_named_shape = {name: size for name, size in zip(self.names, self.shape)} | |
_named_shape = "".join(f"{n1}={n2}, " for n1, n2 in _named_shape.items()) | |
_named_shape = "NamedShape[" + _named_shape[:-2] + "]" | |
return _named_shape | |
torch.Tensor.named_shape = property(NamedShape) | |
x = torch.rand(5, 3, names=("batch", "features")) | |
W = torch.randn(3, 7, names=("features", "neurons")) | |
y = x @ W | |
print(y.named_shape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment