Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Last active March 19, 2023 04:42
Show Gist options
  • Save AmosLewis/b53dfecbe4918618031cac01c8a88fb9 to your computer and use it in GitHub Desktop.
Save AmosLewis/b53dfecbe4918618031cac01c8a88fb9 to your computer and use it in GitHub Desktop.
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
class Test(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, input_ids, decoder_input_ids):
shifted_input_ids = decoder_input_ids.new_zeros(decoder_input_ids.shape) # torch.tensor([[6536, 504, 24, 1]]) -> tensor([[0, 0, 0, 0]])
shifted_input_ids[..., 1:] = decoder_input_ids[..., :-1].clone() # tensor([[0, 0, 0]]) = tensor([[6536, 504, 24]])
shifted_input_ids[..., 0] = 0 # tensor([[ 0, 6536, 504, 24]])
return shifted_input_ids
model = Test()
input_ids = torch.tensor([[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14]])
decoder_input_ids = torch.tensor([[6536, 504, 24, 1]])
test_inputs = (input_ids, decoder_input_ids)
outputs = model(*test_inputs)
print("model(test_input): ")
print(outputs)
fx_g = make_fx(
model,
decomposition_table=get_decompositions(
[
]
),
)(*test_inputs)
print("fx_g.graph: ")
print(fx_g.graph)
# graph():
# %arg0_1 : [#users=0] = placeholder[target=arg0_1]
# %arg1_1 : [#users=4] = placeholder[target=arg1_1]
# %new_zeros : [#users=5] = call_function[target=torch.ops.aten.new_zeros.default](args = (%arg1_1, [1, 4]), kwargs = {dtype: torch.int64, layout: torch.strided, device: cpu, pin_memory: False})
#
# %slice_1 : [#users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%arg1_1, 1, 0, -1), kwargs = {})
# %clone : [#users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_1,), kwargs = {})
# %slice_2 : [#users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%new_zeros, 1, 1, 9223372036854775807), kwargs = {})
# %copy_ : [#users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%slice_2, %clone), kwargs = {})
# %_tensor_constant0 : [#users=1] = get_attr[target=_tensor_constant0]
# %lift_fresh_copy : [#users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%_tensor_constant0,), kwargs = {})
# %select : [#users=1] = call_function[target=torch.ops.aten.select.int](args = (%new_zeros, 1, 0), kwargs = {})
# %fill_ : [#users=0] = call_function[target=torch.ops.aten.fill_.Tensor](args = (%select, %lift_fresh_copy), kwargs = {})
# %slice_3 : [#users=0] = call_function[target=torch.ops.aten.slice.Tensor](args = (%arg1_1, 1, 0, -1), kwargs = {})
# %slice_4 : [#users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%arg1_1, 1, 0, -1), kwargs = {})
# %clone_1 : [#users=0] = call_function[target=torch.ops.aten.clone.default](args = (%slice_4,), kwargs = {})
# %slice_5 : [#users=0] = call_function[target=torch.ops.aten.slice.Tensor](args = (%new_zeros, 1, 1, 9223372036854775807), kwargs = {})
# %select_1 : [#users=0] = call_function[target=torch.ops.aten.select.int](args = (%new_zeros, 1, 0), kwargs = {})
# return new_zeros
@AmosLewis
Copy link
Author

AmosLewis commented Mar 18, 2023

https://pytorch.org/cppdocs/notes/tensor_indexing.html

python == C++(using namespace torch::indexing)
1:  == Slice(1, None)

:3  == Slice(None, 3)

... == "..."

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment