Created
December 13, 2024 06:21
-
-
Save AmosLewis/7cbf94ef1f18daa35631c3fc08b314db to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class NonzeroDecomposeModule(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
@export | |
@annotate_args( | |
[ | |
None, | |
([-1], torch.bool, True), | |
] | |
) | |
def forward(self, t): | |
t_flat = t.flatten(0, 0) | |
nonzero_mask = t_flat != 0 | |
nonzero_mask = nonzero_mask.long() | |
destination_indices = torch.cumsum(nonzero_mask, 0) - 1 | |
destination_indices_clamp = torch.clamp(destination_indices, min=0) | |
iota = torch.arange(t_flat.size(0)) * nonzero_mask | |
scatter_self = torch.zeros_like(t_flat, dtype=torch.int64) | |
# compacted = scatter_self.scatter_( | |
# dim=0, | |
# index=destination_indices_clamp, | |
# src=iota, | |
# reduce='add' | |
# ) | |
compacted = torch.scatter_add( | |
scatter_self, dim=0, index=destination_indices_clamp, src=iota | |
) | |
result_flat = compacted[: torch.sum(nonzero_mask)] | |
# multi dim | |
original_shape = t.shape | |
input_shape_tensor = torch.tensor(original_shape) | |
strides = torch.cumprod(torch.flip(input_shape_tensor, [0]), 0).flip(0) | |
one = torch.tensor([1]) | |
if(t.dim() > 1): | |
slicedStrides = strides[1:-1] | |
strides = torch.cat([slicedStrides, one]) | |
# return strides | |
else: | |
strides = one | |
a = result_flat.unsqueeze(1) # tensor([[2], [3]]) torch.Size([2, 1]) | |
b = strides.unsqueeze(0) # tensor([[1]]) torch.Size([1, 1]) | |
c = a // b | |
multi_indices = c % input_shape_tensor | |
return multi_indices | |
@register_test_case(module_factory=lambda: NonzeroDecomposeModule()) | |
def NonzeroDecomposeModule_basic(module, tu: TestUtils): | |
module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.bool)) |
Author
AmosLewis
commented
Dec 13, 2024
•
python -m e2e_testing.main --config=onnx -v --filter AtenNonzero1DModule_one_nonzero # fail
# python -m e2e_testing.main --config=linalg -v --filter NonzeroDecomposeModule_basic # Passed: 1
# python -m e2e_testing.main --config=onnx -v --filter NonzeroDecomposeModule_basic # Failed: 1
# python -m e2e_testing.main --config=linalg -v --filter NonzeroFlattenDynamicModule # Passed: 1
# python -m e2e_testing.main --config=onnx -v --filter ScatterAddDynamicModule_basic # Passed: 1
# python -m e2e_testing.main --config=onnx -v --filter NonzeroCatModule # Passed: 1
# python -m e2e_testing.main --config=linalg -v --filter NonzeroCatModule # Failed: 1
# tensor with unknown dtype "torch.aten.cat"(%31, %4) : (!torch.list<vtensor>, !torch.int) -> !torch.vtensor<[1],unk>
# python -m e2e_testing.main --config=linalg -v --filter NonzeroCatModule # Failed: 1
# python -m e2e_testing.main --config=linalg -v --filter NonzeroCumsumModule
# python -m e2e_testing.main --config=onnx -v --filter NonzeroCumsumModule # pass
# python -m e2e_testing.main --config=onnx -v --filter NonzeroCumsumBoolModule # pass in torch-mlir, failed in iree
# python -m e2e_testing.main --config=onnx -v --filter NonzeroLongModule
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment