Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created June 20, 2024 11:25
Show Gist options
  • Save pashu123/757c6eea57eff8e411eba8960aba818a to your computer and use it in GitHub Desktop.
Save pashu123/757c6eea57eff8e411eba8960aba818a to your computer and use it in GitHub Desktop.
module attributes {torch.debug_module_name = "SumModule"} {
ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
func.func @forward(%arg0: tensor<1048576xf32>) -> tensor<f32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<f32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<f32>) -> tensor<f32>
%2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], iterator_types = ["reduction"]} ins(%arg0 : tensor<1048576xf32>) outs(%1 : tensor<f32>) {
^bb0(%in: f32, %out: f32):
%3 = arith.addf %in, %out : f32
linalg.yield %3 : f32
} -> tensor<f32>
return %2 : tensor<f32>
}
}
import torch
import torch.nn as nn
from torch_mlir import torchscript
class SumModule(nn.Module):
def __init__(self):
super(SumModule, self).__init__()
def forward(self, x):
return torch.sum(x)
# Example usage:
if __name__ == "__main__":
sum_module = SumModule()
input_tensor = torch.randn(1024*1024)
output = sum_module(input_tensor)
print("Sum of elements:", output.item())
module = torchscript.compile(
sum_module, input_tensor, output_type="linalg-on-tensors"
)
module.dump()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment