Skip to content

Instantly share code, notes, and snippets.

Last active March 28, 2024 20:35
Show Gist options
  • Save HDCharles/888bc5973198ca447046b974439dca03 to your computer and use it in GitHub Desktop.
Save HDCharles/888bc5973198ca447046b974439dca03 to your computer and use it in GitHub Desktop.
repro for subclass issue
import torch
import torch.nn as nn
from torch.utils._pytree import tree_flatten, tree_unflatten
class MultiTensor(torch.Tensor):
def __new__(cls, input, **kwargs):
if isinstance(input, (list, tuple)):
input = input[0]
kwargs["dtype"]=kwargs.get("dtype", input.dtype)
shape = kwargs.pop("shape", input.shape)
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
def __init__(self, input, **kwargs):
self.values = []
self.debug = True
def __repr__(self):
return (
def add_tensors(self, input):
if isinstance(input, (tuple, list)):
for inp in input:
assert isinstance(input, torch.Tensor), f"MultiTensor can only use add_input for Tensors or lists of tensors but got {type(input)}"
return self
def count(self):
return len(self.values)
def __torch_function__(cls, func, types, args=(), kwargs=None, skip_gptq=False):
with torch._C.DisableTorchFunctionSubclass():
is_set_item = str(func)=="<slot wrapper '__setitem__' of 'torch._C.TensorBase' objects>"
if is_set_item:
# breakpoint()
def flat_to_grouped(flat):
# size of biggest MultiTensor
multi_tensor_size = max(
[x.count() if isinstance(x, MultiTensor) else 1 for x in flat]
# convert [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] => [[A,b1,c1], [A,b2,c2] [A,b3,c3]]
grouped = list(
*[x.values if isinstance(x, MultiTensor) else [x] * multi_tensor_size for x in flat]
return grouped
# convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)]
# where A is nontensor, b's,c's are tensors
def grouped_to_flat(grouped):
# convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [(A,A,A), (b1,b2,b3), (c1,c2,c3)]
flat_tups = list(zip(*grouped))
# convert [(A,A,A), (b1,b2,b3), (c1,c2,c3)] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)]
flattened = [
cls(tup).cpu() if isinstance(tup[0], torch.Tensor) else tup[0] for tup in flat_tups
# need to check that getting rid of all but one from each nonTensor tuple is OK
min([True]+[ # handle situation where tuples have size 0
tup[0]==x for x in tup # check all elements match
]) for tup in flat_tups if not isinstance(tup[0], torch.Tensor) # look at tuples of nonTensors
return flattened, non_tensors_equal
kwargs = {} if kwargs is None else kwargs
# combine args and kwargs and remove lists and tuples
flat_args, spec = tree_flatten((args, kwargs))
# convert [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)] => [[A,b1,c1], [A,b2,c2] [A,b3,c3]]
grouped_args = flat_to_grouped(flat_args)
# run function for each of the multitensors and return a multitensor
outputs = []
with torch._C.DisableTorchFunctionSubclass():
for inp in grouped_args:
# inp = tensors_to_cuda(inp)
cur_args, cur_kwargs = tree_unflatten(inp, spec)
out = func(*cur_args, **cur_kwargs)
outputs.append(out.cpu() if isinstance(out, torch.Tensor) else out)
except Exception as e:
grouped_outputs = [tree_flatten(x)[0] for x in outputs]
out_spec = tree_flatten(outputs[0])[1]
# convert [[A,b1,c1], [A,b2,c2] [A,b3,c3]] => [A, MultiTensor(b1,b2,b3), MultiTensor(c1,c2,c3)]
flat_outputs, non_tensors_equal = grouped_to_flat(grouped_outputs)
assert non_tensors_equal, (
f"ERR: found a function in model: {func} which "
+"caused an error in GPTQMultiInput, the function dispatch only works for functions"
+" with Tensor outputs or that have the same non-Tensor output value for all across all inputs"
return tree_unflatten(flat_outputs, out_spec)
except Exception as e:
def __torch_dispatch__(cls, func, types, args, kwargs):
def __tensor_flatten__(self):
return ["values"], None
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
class mod(torch.nn.Module):
def __init__(self):
self.lin = torch.nn.Linear(10,10)
self.other = torch.randn(10,20)*0
def forward(self, x, indices):
print(f"initial model input types x: {type(x)}, indices: {type(indices)} (fine)")
y = self.lin(x)
print(f"result after linear y: {type(y)}, should be a multitensor (and is)")
z = self.other
z[:, indices]=y
print(f"after assigning y to index of z: {type(z)}, should be a multitensor (is torch.Tensor)")
return z
model = mod()
multi = [
MultiTensor([torch.randn(10,10), torch.randn(10,10)]),
MultiTensor([torch.tensor([0,1,2,3,4,5,6,7,8,9]), torch.tensor([0,1,2,3,4,5,6,7,8,9])])
with torch.no_grad():
Copy link

❯ python
initial model input types x: <class 'main.MultiTensor'>, indices: <class 'main.MultiTensor'> (fine)
result after linear y: <class 'main.MultiTensor'>, should be a multitensor (and is)
after assigning y to index of z: <class 'torch.Tensor'>, should be a multitensor (is torch.Tensor)

Copy link

high level question, is there a way to handle the "<slot wrapper 'setitem' of 'torch._C.TensorBase' objects>" in torch_function or torch_dispatch so that it works as you would expect (propagating the multitensor through the network)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment