Created
April 24, 2024 21:31
-
-
Save andylolu2/629098cf041108a77e37d2c8b6e91467 to your computer and use it in GitHub Desktop.
Better indexing with PyTorch that doesn't work
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import inspect | |
from typing import Set | |
from functools import partial | |
import torch | |
from torch import Tensor | |
class ConstraintTrackingTensor(Tensor): | |
_constraints: Set[int] | |
@staticmethod | |
def add_constraint(tensor, size): | |
if isinstance(tensor, ConstraintTrackingTensor): | |
if hasattr(tensor, "_constraints"): | |
tensor._constraints.add(size) | |
else: | |
tensor._constraints = {size} | |
@classmethod | |
def __torch_function__(cls, func, types, args=(), kwargs=None): | |
args_l = list(args) | |
if func.__name__ == "__getitem__": | |
if isinstance(args_l[1], ConstraintTrackingTensor): | |
ConstraintTrackingTensor.add_constraint(args_l[1], args_l[0].shape[0]) | |
elif isinstance(args_l[1], tuple) and any( | |
isinstance(i, ConstraintTrackingTensor) for i in args_l[1] | |
): | |
for i, (size, index) in enumerate(zip(args_l[0].shape, args_l[1])): | |
ConstraintTrackingTensor.add_constraint(index, size) | |
if isinstance(args_l[0], ConstraintTrackingTensor): | |
args_l[0] = torch.tensor(args_l[0]) | |
return torch.tensor( | |
super().__torch_function__(func, types, tuple(args_l), kwargs) | |
) | |
if kwargs is None: | |
kwargs = {} | |
return super().__torch_function__(func, types, tuple(args_l), kwargs) | |
def infer_output_shape(f): | |
n_args = len(inspect.signature(f).parameters) | |
dummy_indices = [ConstraintTrackingTensor(torch.tensor(0)) for _ in range(n_args)] | |
out = f(*dummy_indices) | |
assert out.ndim == 0 | |
constraints = [getattr(idx, "_constraints", set()) for idx in dummy_indices] | |
assert all(len(constraint) == 1 for constraint in constraints) | |
return tuple(next(iter(constraint)) for constraint in constraints) | |
def ein_arr(f): | |
output_shape = infer_output_shape(f) | |
indices = [] | |
for i, size in enumerate(output_shape): | |
index = torch.arange(size) | |
# (size,) -> (1, 1, ..., size, ..., 1, 1) | |
index = index.view(*((1,)*i), size, *((1,)*(len(output_shape)-1-i))).broadcast_to(*output_shape) | |
indices.append(index) | |
for _ in output_shape: # tensor up the function | |
f = torch.vmap(f) | |
return f(*indices) | |
x = torch.randn(5, 10) | |
y = torch.randn(5, 10) | |
def f(i, j): | |
return torch.dot(x[i], y[j]) | |
print(infer_output_shape(f)) | |
# (5, 5) | |
print(ein_arr(f)) | |
# RuntimeError: vmap: It looks like you're calling .item() on a Tensor |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment