Last active
November 27, 2023 22:20
-
-
Save AmosLewis/c0d551bfdd11b24eb10a5f8de10809d1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import shark_turbine.aot as aot | |
import torch | |
import torch.nn as nn | |
class ExMod(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.m = nn.BatchNorm2d(100) | |
def forward(self,x): | |
return self.m(x) | |
x = torch.zeros(20,100,35,45) | |
mod = ExMod() | |
mod.eval() | |
export_output = aot.export(mod,x) | |
export_output.save_mlir('bnex.mlir') | |
# BUG1: | |
# iree version: https://github.com/AmosLewis/SHARK-Turbine/commits/main | |
# commit ID: 4117974833a7c5b35c81f30a8c9d523f483b514f | |
# TypeError: print(): incompatible function arguments. The following argument types are supported: | |
# 1. (self: iree.compiler._mlir_libs._mlir.ir._OperationBase, state: mlir::python::PyAsmState, file: object = None, binary: bool = False) -> None | |
# 2. (self: iree.compiler._mlir_libs._mlir.ir._OperationBase, large_elements_limit: Optional[int] = None, enable_debug_info: bool = False, pretty_debug_info: bool = False, print_generic_op_form: bool = False, use_local_scope: bool = False, assume_verified: bool = False, file: object = None, binary: bool = False) -> None | |
# Invoked with: <iree.compiler._mlir_libs._mlir.ir.Operation object at 0x7f9fea26e030>, <_io.BufferedWriter name='bnex.mlir'>; kwargs: binary=True | |
binary = export_output.compile(save_to = None) | |
# BUG2: | |
# (iree_venv) ➜ src cd /nodclouddata/chi/src ; /usr/bin/env /nodclouddata/chi/s | |
# rc/SHARK-Turbine/.venv/bin/python3.11 /home/chi/.vscode-server/extensions/ms-py | |
# thon.python-2023.20.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/laun | |
# cher 38747 -- /nodclouddata/chi/src/SHARK-Turbine/tests/aot/batchnorm2d_test.py | |
# loc("<eval_with_key>.0 from /nodclouddata/chi/src/SHARK-Turbine/.venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py:477 in wrapped":5:0): | |
# error: 'tensor.cast' op operand type 'tensor<?xui8>' and result type 'tensor<0xi8>' are cast incompatible |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
bnex.mlir