Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created April 29, 2022 14:30
Show Gist options
  • Save pashu123/f1b726b7cec0fdda9b4ba140d41bac64 to your computer and use it in GitHub Desktop.
Save pashu123/f1b726b7cec0fdda9b4ba140d41bac64 to your computer and use it in GitHub Desktop.
import torch
from functorch.compile import aot_function, nop
from functorch import make_fx
from torch.nn.utils import _stateless
from torchvision.models import resnet18
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc = torch.nn.Linear(3, 3)
def forward(self, x):
return self.fc(x)
mod = Foo()
# mod = resnet18()
def get_sorted_params(named_params):
return [i[1] for i in sorted(named_params.items())]
inp = (torch.randn(3, 3),)
# inp = (torch.randn(1, 3, 228, 228),)
mod(*inp).sum().backward()
optim = torch.optim.Adam(get_sorted_params(dict(mod.named_parameters())), lr=0.01)
# optim.step()
def f(params, buffers, optim_state, args):
params_and_buffers = {**params, **buffers}
_stateless.functional_call(mod, params_and_buffers, args, {}).sum().backward()
optim = torch.optim.Adam(get_sorted_params(params), lr=0.01)
optim.load_state_dict(optim_state)
optim.step()
return params, buffers, optim.state_dict()
print(make_fx(f)(dict(mod.named_parameters()), dict(mod.named_buffers()), optim.state_dict(), inp).graph)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment