Created
May 5, 2022 07:22
-
-
Save pashu123/35d80539aa5e292a70283a0682d5cc7e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from functorch import make_fx | |
from torch.nn.utils import _stateless | |
import torch_mlir | |
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend | |
class Foo(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.fc = torch.nn.Linear(3, 3) | |
def forward(self, x): | |
return self.fc(x) | |
mod = Foo() | |
def get_sorted_params(named_params): | |
return [i[1] for i in sorted(named_params.items())] | |
inp = (torch.randn(3, 3), ) | |
def forward(params, buffers, args): | |
params_and_buffers = {**params, **buffers} | |
_stateless.functional_call(mod, params_and_buffers, args, | |
{}).sum().backward() | |
optim = torch.optim.SGD(get_sorted_params(params), lr=0.01) | |
# optim.load_state_dict(optim_state) | |
optim.step() | |
return params, buffers | |
# Obtain torch.fx graph and convert it into torchscipt graph. | |
fx_graph = make_fx(forward)(dict(mod.named_parameters()), | |
dict(mod.named_buffers()), inp) | |
fx_graph.graph.set_codegen(torch.fx.graph.CodeGen()) | |
fx_graph.recompile() | |
ts_graph = torch.jit.script(fx_graph) | |
print(ts_graph.graph) | |
module = torch_mlir.compile( | |
ts_graph, (torch.ones(3), torch.ones(3, 3), torch.ones(3, 3)), | |
output_type=torch_mlir.OutputType.TORCH) | |
module.dump() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment