Skip to content

Instantly share code, notes, and snippets.

@jamesr66a
Created May 13, 2021 18:27
Show Gist options
  • Save jamesr66a/5623e9f92300dd513189270bcf4f8b90 to your computer and use it in GitHub Desktop.
Save jamesr66a/5623e9f92300dd513189270bcf4f8b90 to your computer and use it in GitHub Desktop.
# myclass.py
import torch
import torch.fx
class Foo(metaclass=torch.fx.ProxyableClassMeta):
def __init__(self, x):
self.x = x
# ser.py
import torch
import torch.fx
from myclass import Foo
def use_foo(x):
return Foo(x)
traced = torch.fx.symbolic_trace(use_foo)
import pickle
with open('file.pkl', 'wb') as f:
pickle.dump(traced, f)
# des.py
import pickle
with open('file.pkl', 'wb') as f:
pickle.dump(traced, f)
(base) 11:27:27 [devbig354.ftw3.facebook.com] ~ $ cat des.py
import pickle
with open('file.pkl', 'rb') as f:
loaded = pickle.load(f)
print(loaded)
"""
GraphModule()
def forward(self, x):
foo = myclass_Foo(x); x = None
return foo
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment