Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created May 5, 2022 07:22
Show Gist options
  • Save pashu123/35d80539aa5e292a70283a0682d5cc7e to your computer and use it in GitHub Desktop.
Save pashu123/35d80539aa5e292a70283a0682d5cc7e to your computer and use it in GitHub Desktop.
import torch
from functorch import make_fx
from torch.nn.utils import _stateless
import torch_mlir
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc = torch.nn.Linear(3, 3)
def forward(self, x):
return self.fc(x)
mod = Foo()
def get_sorted_params(named_params):
return [i[1] for i in sorted(named_params.items())]
inp = (torch.randn(3, 3), )
def forward(params, buffers, args):
params_and_buffers = {**params, **buffers}
_stateless.functional_call(mod, params_and_buffers, args,
{}).sum().backward()
optim = torch.optim.SGD(get_sorted_params(params), lr=0.01)
# optim.load_state_dict(optim_state)
optim.step()
return params, buffers
# Obtain torch.fx graph and convert it into torchscipt graph.
fx_graph = make_fx(forward)(dict(mod.named_parameters()),
dict(mod.named_buffers()), inp)
fx_graph.graph.set_codegen(torch.fx.graph.CodeGen())
fx_graph.recompile()
ts_graph = torch.jit.script(fx_graph)
print(ts_graph.graph)
module = torch_mlir.compile(
ts_graph, (torch.ones(3), torch.ones(3, 3), torch.ones(3, 3)),
output_type=torch_mlir.OutputType.TORCH)
module.dump()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment