Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created December 13, 2024 06:21
Show Gist options
  • Save AmosLewis/7cbf94ef1f18daa35631c3fc08b314db to your computer and use it in GitHub Desktop.
Save AmosLewis/7cbf94ef1f18daa35631c3fc08b314db to your computer and use it in GitHub Desktop.
class NonzeroDecomposeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1], torch.bool, True),
]
)
def forward(self, t):
t_flat = t.flatten(0, 0)
nonzero_mask = t_flat != 0
nonzero_mask = nonzero_mask.long()
destination_indices = torch.cumsum(nonzero_mask, 0) - 1
destination_indices_clamp = torch.clamp(destination_indices, min=0)
iota = torch.arange(t_flat.size(0)) * nonzero_mask
scatter_self = torch.zeros_like(t_flat, dtype=torch.int64)
# compacted = scatter_self.scatter_(
# dim=0,
# index=destination_indices_clamp,
# src=iota,
# reduce='add'
# )
compacted = torch.scatter_add(
scatter_self, dim=0, index=destination_indices_clamp, src=iota
)
result_flat = compacted[: torch.sum(nonzero_mask)]
# multi dim
original_shape = t.shape
input_shape_tensor = torch.tensor(original_shape)
strides = torch.cumprod(torch.flip(input_shape_tensor, [0]), 0).flip(0)
one = torch.tensor([1])
if(t.dim() > 1):
slicedStrides = strides[1:-1]
strides = torch.cat([slicedStrides, one])
# return strides
else:
strides = one
a = result_flat.unsqueeze(1) # tensor([[2], [3]]) torch.Size([2, 1])
b = strides.unsqueeze(0) # tensor([[1]]) torch.Size([1, 1])
c = a // b
multi_indices = c % input_shape_tensor
return multi_indices
@register_test_case(module_factory=lambda: NonzeroDecomposeModule())
def NonzeroDecomposeModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.bool))
@AmosLewis
Copy link
Author

AmosLewis commented Dec 13, 2024

@register_test_case(module_factory=lambda: ScatterAddDynamicModule())
def ScatterAddDynamicModule_basic(module, tu: TestUtils):
    module.forward(
        torch.tensor([0, 0, 0, 0, 0, 0]),
        torch.tensor([0, 0, 0, 0, 0, 0]),
        torch.tensor([0, 0, 0, 3, 0, 0]),
    )

class NonzeroFlattenDynamicModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    @export
    @annotate_args(
        [
            None,
            ([-1], torch.int, True),
        ]
    )
    def forward(self, x):
        return x.flatten()


@register_test_case(module_factory=lambda: NonzeroFlattenDynamicModule())
def NonzeroFlattenDynamicModule_basic(module, tu: TestUtils):
    module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.int))


class NonzeroCatModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    @export
    @annotate_args(
        [
            None,
            ([-1], torch.int, True),
        ]
    )
    def forward(self, a):
        a = a[1:1]
        b = torch.tensor([1])
        return torch.cat([a, b])


@register_test_case(module_factory=lambda: NonzeroCatModule())
def NonzeroCatModule_basic(module, tu: TestUtils):
    module.forward(torch.tensor([6], dtype=torch.int))


class NonzeroCumsumModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    @export
    @annotate_args(
        [
            None,
            ([-1], torch.int64, True),
        ]
    )
    def forward(self, x):
        return torch.cumsum(x, 0)


@register_test_case(module_factory=lambda: NonzeroCumsumModule())
def NonzeroCumsumModule_basic(module, tu: TestUtils):
    module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.int64))


class NonzeroCumsumBoolModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    @export
    @annotate_args(
        [
            None,
            ([-1], torch.bool, True),
        ]
    )
    def forward(self, x):
        return torch.cumsum(x.long(), 0)


@register_test_case(module_factory=lambda: NonzeroCumsumBoolModule())
def NonzeroCumsumBoolModule_basic(module, tu: TestUtils):
    module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.bool))


class NonzeroLongModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    @export
    @annotate_args(
        [
            None,
            ([-1], torch.bool, True),
        ]
    )
    def forward(self, x):
        return x.long()


@register_test_case(module_factory=lambda: NonzeroLongModule())
def NonzeroLongModule_basic(module, tu: TestUtils):
    module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.bool))


class ScatterAddDynamicModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    @export
    @annotate_args(
        [
            None,
            ([-1], torch.int64, True),
            ([-1], torch.int64, True),
            ([-1], torch.int64, True),
        ]
    )
    def forward(self, input, index, src):
        return torch.ops.aten.scatter_add(input, 0, index, src)


@register_test_case(module_factory=lambda: ScatterAddDynamicModule())
def ScatterAddDynamicModule_basic(module, tu: TestUtils):
    module.forward(
        torch.tensor([0, 0, 0, 0, 0, 0]),
        torch.tensor([0, 0, 0, 0, 0, 0]),
        torch.tensor([0, 0, 0, 3, 0, 0]),
    )

class ScatterAddDynamicModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    @export
    @annotate_args(
        [
            None,
            ([-1], torch.int64, True),
            ([-1], torch.int64, True),
            ([-1], torch.int64, True),
        ]
    )
    def forward(self, input, index, src):
        return torch.ops.aten.scatter_add(input, 0, index, src)

@AmosLewis
Copy link
Author

python -m e2e_testing.main --config=onnx -v --filter AtenNonzero1DModule_one_nonzero # fail
# python -m e2e_testing.main --config=linalg -v --filter NonzeroDecomposeModule_basic # Passed: 1

# python -m e2e_testing.main --config=onnx -v --filter NonzeroDecomposeModule_basic # Failed: 1

# python -m e2e_testing.main --config=linalg -v --filter NonzeroFlattenDynamicModule # Passed: 1

# python -m e2e_testing.main --config=onnx -v --filter ScatterAddDynamicModule_basic #  Passed: 1

# python -m e2e_testing.main --config=onnx -v --filter NonzeroCatModule # Passed: 1
# python -m e2e_testing.main --config=linalg -v --filter NonzeroCatModule # Failed: 1
# tensor with unknown dtype "torch.aten.cat"(%31, %4) : (!torch.list<vtensor>, !torch.int) -> !torch.vtensor<[1],unk>

# python -m e2e_testing.main --config=linalg -v --filter NonzeroCatModule # Failed: 1

# python -m e2e_testing.main --config=linalg -v --filter NonzeroCumsumModule
# python -m e2e_testing.main --config=onnx -v --filter NonzeroCumsumModule # pass
# python -m e2e_testing.main --config=onnx -v --filter NonzeroCumsumBoolModule # pass in torch-mlir, failed in iree

# python -m e2e_testing.main --config=onnx -v --filter NonzeroLongModule

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment