Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Last active November 27, 2023 22:20
Show Gist options
  • Save AmosLewis/c0d551bfdd11b24eb10a5f8de10809d1 to your computer and use it in GitHub Desktop.
Save AmosLewis/c0d551bfdd11b24eb10a5f8de10809d1 to your computer and use it in GitHub Desktop.
import shark_turbine.aot as aot
import torch
import torch.nn as nn
class ExMod(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.BatchNorm2d(100)
def forward(self,x):
return self.m(x)
x = torch.zeros(20,100,35,45)
mod = ExMod()
mod.eval()
export_output = aot.export(mod,x)
export_output.save_mlir('bnex.mlir')
# BUG1:
# iree version: https://github.com/AmosLewis/SHARK-Turbine/commits/main
# commit ID: 4117974833a7c5b35c81f30a8c9d523f483b514f
# TypeError: print(): incompatible function arguments. The following argument types are supported:
# 1. (self: iree.compiler._mlir_libs._mlir.ir._OperationBase, state: mlir::python::PyAsmState, file: object = None, binary: bool = False) -> None
# 2. (self: iree.compiler._mlir_libs._mlir.ir._OperationBase, large_elements_limit: Optional[int] = None, enable_debug_info: bool = False, pretty_debug_info: bool = False, print_generic_op_form: bool = False, use_local_scope: bool = False, assume_verified: bool = False, file: object = None, binary: bool = False) -> None
# Invoked with: <iree.compiler._mlir_libs._mlir.ir.Operation object at 0x7f9fea26e030>, <_io.BufferedWriter name='bnex.mlir'>; kwargs: binary=True
binary = export_output.compile(save_to = None)
# BUG2:
# (iree_venv) ➜ src cd /nodclouddata/chi/src ; /usr/bin/env /nodclouddata/chi/s
# rc/SHARK-Turbine/.venv/bin/python3.11 /home/chi/.vscode-server/extensions/ms-py
# thon.python-2023.20.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/laun
# cher 38747 -- /nodclouddata/chi/src/SHARK-Turbine/tests/aot/batchnorm2d_test.py
# loc("<eval_with_key>.0 from /nodclouddata/chi/src/SHARK-Turbine/.venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped":5:0):
# error: 'tensor.cast' op operand type 'tensor<?xui8>' and result type 'tensor<0xi8>' are cast incompatible
@AmosLewis
Copy link
Author

bnex.mlir

module @ExMod {
  util.global private @_params.m.weight {noinline} = dense<1.000000e+00> : tensor<100xf32>
  util.global private @_params.m.bias {noinline} = dense<0.000000e+00> : tensor<100xf32>
  func.func @main(%arg0: tensor<20x100x35x45xf32>) -> tensor<20x100x35x45xf32> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"} {
    %0 = torch_c.from_builtin_tensor %arg0 : tensor<20x100x35x45xf32> -> !torch.vtensor<[20,100,35,45],f32>
    %1 = call @forward(%0) : (!torch.vtensor<[20,100,35,45],f32>) -> !torch.vtensor<[20,100,35,45],f32>
    %2 = torch_c.to_builtin_tensor %1 : !torch.vtensor<[20,100,35,45],f32> -> tensor<20x100x35x45xf32>
    return %2 : tensor<20x100x35x45xf32>
  }
  func.func private @forward(%arg0: !torch.vtensor<[20,100,35,45],f32>) -> !torch.vtensor<[20,100,35,45],f32> {
    %int0 = torch.constant.int 0
    %0 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
    %int0_0 = torch.constant.int 0
    %int0_1 = torch.constant.int 0
    %cpu = torch.constant.device "cpu"
    %none = torch.constant.none
    %none_2 = torch.constant.none
    %1 = torch.aten.empty.memory_format %0, %int0_0, %int0_1, %cpu, %none, %none_2 : !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.none, !torch.none -> !torch.vtensor<[0],ui8>
    %2 = torch.vtensor.literal(dense<0.000000e+00> : tensor<100xf32>) : !torch.vtensor<[100],f32>
    %int6 = torch.constant.int 6
    %3 = torch.prims.convert_element_type %2, %int6 : !torch.vtensor<[100],f32>, !torch.int -> !torch.vtensor<[100],f32>
    %4 = torch.vtensor.literal(dense<1.000000e+00> : tensor<100xf32>) : !torch.vtensor<[100],f32>
    %int6_3 = torch.constant.int 6
    %5 = torch.prims.convert_element_type %4, %int6_3 : !torch.vtensor<[100],f32>, !torch.int -> !torch.vtensor<[100],f32>
    %float1.000000e-05 = torch.constant.float 1.000000e-05
    %int1 = torch.constant.int 1
    %6 = torch.aten.add.Scalar %5, %float1.000000e-05, %int1 : !torch.vtensor<[100],f32>, !torch.float, !torch.int -> !torch.vtensor<[100],f32>
    %7 = torch.aten.sqrt %6 : !torch.vtensor<[100],f32> -> !torch.vtensor<[100],f32>
    %8 = torch.aten.reciprocal %7 : !torch.vtensor<[100],f32> -> !torch.vtensor<[100],f32>
    %int1_4 = torch.constant.int 1
    %9 = torch.aten.mul.Scalar %8, %int1_4 : !torch.vtensor<[100],f32>, !torch.int -> !torch.vtensor<[100],f32>
    %int0_5 = torch.constant.int 0
    %10 = torch.prim.ListConstruct %int0_5 : (!torch.int) -> !torch.list<int>
    %none_6 = torch.constant.none
    %none_7 = torch.constant.none
    %none_8 = torch.constant.none
    %false = torch.constant.bool false
    %11 = torch.aten.new_zeros %arg0, %10, %none_6, %none_7, %none_8, %false : !torch.vtensor<[20,100,35,45],f32>, !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[0],f32>
    %int0_9 = torch.constant.int 0
    %12 = torch.prim.ListConstruct %int0_9 : (!torch.int) -> !torch.list<int>
    %none_10 = torch.constant.none
    %none_11 = torch.constant.none
    %none_12 = torch.constant.none
    %false_13 = torch.constant.bool false
    %13 = torch.aten.new_zeros %arg0, %12, %none_10, %none_11, %none_12, %false_13 : !torch.vtensor<[20,100,35,45],f32>, !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[0],f32>
    %int-1 = torch.constant.int -1
    %14 = torch.aten.unsqueeze %3, %int-1 : !torch.vtensor<[100],f32>, !torch.int -> !torch.vtensor<[100,1],f32>
    %int-1_14 = torch.constant.int -1
    %15 = torch.aten.unsqueeze %14, %int-1_14 : !torch.vtensor<[100,1],f32>, !torch.int -> !torch.vtensor<[100,1,1],f32>
    %int-1_15 = torch.constant.int -1
    %16 = torch.aten.unsqueeze %9, %int-1_15 : !torch.vtensor<[100],f32>, !torch.int -> !torch.vtensor<[100,1],f32>
    %int-1_16 = torch.constant.int -1
    %17 = torch.aten.unsqueeze %16, %int-1_16 : !torch.vtensor<[100,1],f32>, !torch.int -> !torch.vtensor<[100,1,1],f32>
    %int1_17 = torch.constant.int 1
    %18 = torch.aten.sub.Tensor %arg0, %15, %int1_17 : !torch.vtensor<[20,100,35,45],f32>, !torch.vtensor<[100,1,1],f32>, !torch.int -> !torch.vtensor<[20,100,35,45],f32>
    %19 = torch.aten.mul.Tensor %18, %17 : !torch.vtensor<[20,100,35,45],f32>, !torch.vtensor<[100,1,1],f32> -> !torch.vtensor<[20,100,35,45],f32>
    %_params.m.weight = util.global.load @_params.m.weight : tensor<100xf32>
    %20 = torch_c.from_builtin_tensor %_params.m.weight : tensor<100xf32> -> !torch.vtensor<[100],f32>
    %int-1_18 = torch.constant.int -1
    %21 = torch.aten.unsqueeze %20, %int-1_18 : !torch.vtensor<[100],f32>, !torch.int -> !torch.vtensor<[100,1],f32>
    %int-1_19 = torch.constant.int -1
    %22 = torch.aten.unsqueeze %21, %int-1_19 : !torch.vtensor<[100,1],f32>, !torch.int -> !torch.vtensor<[100,1,1],f32>
    %23 = torch.aten.mul.Tensor %19, %22 : !torch.vtensor<[20,100,35,45],f32>, !torch.vtensor<[100,1,1],f32> -> !torch.vtensor<[20,100,35,45],f32>
    %_params.m.bias = util.global.load @_params.m.bias : tensor<100xf32>
    %24 = torch_c.from_builtin_tensor %_params.m.bias : tensor<100xf32> -> !torch.vtensor<[100],f32>
    %int-1_20 = torch.constant.int -1
    %25 = torch.aten.unsqueeze %24, %int-1_20 : !torch.vtensor<[100],f32>, !torch.int -> !torch.vtensor<[100,1],f32>
    %int-1_21 = torch.constant.int -1
    %26 = torch.aten.unsqueeze %25, %int-1_21 : !torch.vtensor<[100,1],f32>, !torch.int -> !torch.vtensor<[100,1,1],f32>
    %int1_22 = torch.constant.int 1
    %27 = torch.aten.add.Tensor %23, %26, %int1_22 : !torch.vtensor<[20,100,35,45],f32>, !torch.vtensor<[100,1,1],f32>, !torch.int -> !torch.vtensor<[20,100,35,45],f32>
    return %27 : !torch.vtensor<[20,100,35,45],f32>
  }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment