Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created December 13, 2024 06:21
Show Gist options
  • Save AmosLewis/7cbf94ef1f18daa35631c3fc08b314db to your computer and use it in GitHub Desktop.
Save AmosLewis/7cbf94ef1f18daa35631c3fc08b314db to your computer and use it in GitHub Desktop.
class NonzeroDecomposeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1], torch.bool, True),
]
)
def forward(self, t):
t_flat = t.flatten(0, 0)
nonzero_mask = t_flat != 0
nonzero_mask = nonzero_mask.long()
destination_indices = torch.cumsum(nonzero_mask, 0) - 1
destination_indices_clamp = torch.clamp(destination_indices, min=0)
iota = torch.arange(t_flat.size(0)) * nonzero_mask
scatter_self = torch.zeros_like(t_flat, dtype=torch.int64)
# compacted = scatter_self.scatter_(
# dim=0,
# index=destination_indices_clamp,
# src=iota,
# reduce='add'
# )
compacted = torch.scatter_add(
scatter_self, dim=0, index=destination_indices_clamp, src=iota
)
result_flat = compacted[: torch.sum(nonzero_mask)]
# multi dim
original_shape = t.shape
input_shape_tensor = torch.tensor(original_shape)
strides = torch.cumprod(torch.flip(input_shape_tensor, [0]), 0).flip(0)
one = torch.tensor([1])
if(t.dim() > 1):
slicedStrides = strides[1:-1]
strides = torch.cat([slicedStrides, one])
# return strides
else:
strides = one
a = result_flat.unsqueeze(1) # tensor([[2], [3]]) torch.Size([2, 1])
b = strides.unsqueeze(0) # tensor([[1]]) torch.Size([1, 1])
c = a // b
multi_indices = c % input_shape_tensor
return multi_indices
@register_test_case(module_factory=lambda: NonzeroDecomposeModule())
def NonzeroDecomposeModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.bool))
@AmosLewis
Copy link
Author

python -m e2e_testing.main --config=onnx -v --filter AtenNonzero1DModule_one_nonzero # fail
# python -m e2e_testing.main --config=linalg -v --filter NonzeroDecomposeModule_basic # Passed: 1

# python -m e2e_testing.main --config=onnx -v --filter NonzeroDecomposeModule_basic # Failed: 1

# python -m e2e_testing.main --config=linalg -v --filter NonzeroFlattenDynamicModule # Passed: 1

# python -m e2e_testing.main --config=onnx -v --filter ScatterAddDynamicModule_basic #  Passed: 1

# python -m e2e_testing.main --config=onnx -v --filter NonzeroCatModule # Passed: 1
# python -m e2e_testing.main --config=linalg -v --filter NonzeroCatModule # Failed: 1
# tensor with unknown dtype "torch.aten.cat"(%31, %4) : (!torch.list<vtensor>, !torch.int) -> !torch.vtensor<[1],unk>

# python -m e2e_testing.main --config=linalg -v --filter NonzeroCatModule # Failed: 1

# python -m e2e_testing.main --config=linalg -v --filter NonzeroCumsumModule
# python -m e2e_testing.main --config=onnx -v --filter NonzeroCumsumModule # pass
# python -m e2e_testing.main --config=onnx -v --filter NonzeroCumsumBoolModule # pass in torch-mlir, failed in iree

# python -m e2e_testing.main --config=onnx -v --filter NonzeroLongModule

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment