Skip to content

Instantly share code, notes, and snippets.

@jamesr66a
Created March 8, 2022 00:31
Show Gist options
  • Save jamesr66a/7304d8818c04abd49df7a70a2ae51c02 to your computer and use it in GitHub Desktop.
Save jamesr66a/7304d8818c04abd49df7a70a2ae51c02 to your computer and use it in GitHub Desktop.
commit b0703a2d968ccc91760ad738e9a50b3a913969a9
Author: James Reed <[email protected]>
Date: Wed Mar 2 01:03:05 2022 +0000
Serialization fixes for HF tracer
diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py
index b88ae4ae7..aeb345ad3 100644
--- a/src/transformers/utils/fx.py
+++ b/src/transformers/utils/fx.py
@@ -270,6 +270,7 @@ class HFTracer(Tracer):
self.prev_module = None
self.recorded_methods = None
+ self.input_vals = None
def _register_leaf_function(self, module: ModuleType, name: str):
"""Registers the function called name in module as a leaf function."""
@@ -404,18 +405,21 @@ class HFTracer(Tracer):
if method_names is None:
method_names = self._DEFAULT_METHODS_TO_RECORD
- # Creating a random input shape to generate dummy inputs.
- batch_size = _generate_random_int()
- sequence_length = _generate_random_int()
- shape = [batch_size, sequence_length]
+ if self.input_vals is None:
+ # Creating a random input shape to generate dummy inputs.
+ batch_size = _generate_random_int()
+ sequence_length = _generate_random_int()
+ shape = [batch_size, sequence_length]
- if model.__class__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
- num_choices = _generate_random_int(low=2, high=5)
- shape.insert(1, num_choices)
+ if model.__class__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
+ num_choices = _generate_random_int(low=2, high=5)
+ shape.insert(1, num_choices)
- inputs = {}
- for input_name in input_names:
- inputs.update(self._generate_dummy_input(model, input_name, shape))
+ inputs = {}
+ for input_name in input_names:
+ inputs.update(self._generate_dummy_input(model, input_name, shape))
+ else:
+ inputs = self.input_vals
cache_names, original_methods = self._monkey_patch_tensor_methods_for_model_recording(model, method_names)
self.original_methods = original_methods
@@ -427,6 +431,7 @@ class HFTracer(Tracer):
self.recorded_methods = {
method_name: cache_name for method_name, cache_name in cache_names.items() if hasattr(model, cache_name)
}
+ return inputs
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
if isinstance(attr_val, torch.nn.Parameter):
@@ -464,7 +469,7 @@ class HFTracer(Tracer):
sig = inspect.signature(root.forward)
input_names = sig.parameters.keys() - concrete_args.keys()
- self.record(root, input_names, method_names=method_names)
+ self.input_vals = self.record(root, input_names, method_names=method_names)
# TODO: adapt the way leaf function are wrapped with the "autowrap function" feature from Tracer.
autowrap_functions = [patched for (_, _, patched) in self._leaf_functions_register.values()]
@@ -548,6 +553,86 @@ class HFTracer(Tracer):
return super().create_arg(list(a))
return super().create_arg(a)
+def reduce_graph_module(body: Dict[Any, Any], import_block: str, metadata : Any) -> torch.nn.Module:
+ # BC: attribute name was changed from `code` to `_code` to facilitate
+ # making `code` into a property and adding a docstring to it
+ fn_src = body.get('_code') or body['code']
+ forward = torch.fx.graph_module._forward_from_src(import_block + fn_src, {})
+ return _deserialize_graph_module(forward, body, metadata)
+
+def _deserialize_graph_module(forward, body: Dict[Any, Any], metadata : Any) -> torch.nn.Module:
+ """
+ Deserialize a GraphModule given the dictionary of the original module,
+ using the code to reconstruct the graph. We delete the actual graph before
+ saving the dictionary so that changes to the in-memory graph format do not
+ get serialized.
+ """
+ # We create a dummy class here because symbolic_trace pulls the forward()
+ # function off of the class, rather than the instance
+ class CodeOnlyModule(torch.nn.Module):
+ def __init__(self, body):
+ super().__init__()
+ self.__dict__ = body
+
+ # Try to retrieve the forward source in a backward-compatible way
+ CodeOnlyModule.forward = forward
+
+ tracer_cls = body.get('_tracer_cls')
+ if tracer_cls is None:
+ from torch.fx._symbolic_trace import Tracer
+ tracer_cls = Tracer
+
+ graphmodule_cls_name = body.get('_graphmodule_cls_name', 'GraphModule')
+
+ # This is a workaround for a mypy linter issue related to
+ # passing base class as an argument - https://github.com/python/mypy/issues/5865.
+ cls_tracer : Any = tracer_cls
+
+ class KeepModules(cls_tracer):
+ # we shouldn't trace into any of the submodules,
+ # because they were not traced in the original GraphModule
+ def is_leaf_module(self, _: torch.nn.Module, __: str) -> bool:
+ return True
+
+ com = CodeOnlyModule(body)
+
+ assert isinstance(metadata, dict)
+ com.device = metadata['device']
+
+ km = KeepModules()
+ km.input_vals = metadata['input_vals']
+
+ graph = km.trace(com, concrete_args=metadata['concrete_args'])
+
+ # Manually set Tracer class on the reconstructed Graph, to avoid
+ # referencing the private local subclass KeepModules.
+ graph._tracer_cls = tracer_cls
+ gm = GraphModule(com, graph, class_name=graphmodule_cls_name)
+
+ # The GraphModule constructor only retains attributes referenced by the graph.
+ # In this case, our goal is return a GraphModule as close to identical as the one
+ # put into the package. If any additional attributes were present in body,
+ # we should keep them.
+ for k, v in body.items():
+ if not hasattr(gm, k):
+ setattr(gm, k, v)
+ return gm
+
+class MetadataGraphModule(torch.fx.GraphModule):
+ def __init__(self, root, graph : torch.fx.Graph, metadata : Any, class_name : str = 'GraphModule'):
+ super().__init__(root, graph, class_name)
+ self.metadata = metadata
+
+ def __reduce__(self):
+ """
+ Serialization of GraphModule. We serialize only the generated code, not
+ the underlying ``Graph``. This is because ``Graph`` does not have on-disk
+ backward-compatibility guarantees, whereas Python source code does.
+ On the deserialization side, we symbolically trace through the generated
+ code to regenerate the underlying ``Graph``
+ """
+ (_, (dict_without_graph, import_block)) = super().__reduce__()
+ return (reduce_graph_module, (dict_without_graph, import_block, self.metadata))
def symbolic_trace(
model: PreTrainedModel,
@@ -589,6 +674,7 @@ def symbolic_trace(
# Tracing.
tracer = HFTracer()
traced_graph = tracer.trace(model, concrete_args=concrete_args)
- traced = torch.fx.GraphModule(model, traced_graph)
+ traced = MetadataGraphModule(
+ model, traced_graph, {'device' : model.device, 'concrete_args': concrete_args, 'input_vals' : tracer.input_vals})
return traced
diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py
index b6ec0eae8..2ee7bdf80 100755
--- a/tests/test_modeling_common.py
+++ b/tests/test_modeling_common.py
@@ -728,6 +728,24 @@ class ModelTesterMixin:
except RuntimeError:
self.fail("Couldn't trace module.")
+ # Test serialization
+ import pickle
+
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
+ pickle_file_name = os.path.join(tmp_dir_name, "traced_model.pkl")
+
+ try:
+ with open(pickle_file_name, 'wb') as f:
+ pickle.dump(traced_model, f)
+ except Exception:
+ self.fail("Couldn't save module.")
+
+ try:
+ with open(pickle_file_name, 'rb') as f:
+ loaded = pickle.load(f)
+ except Exception:
+ self.fail("Couldn't load module.")
+
def flatten_output(output):
flatten = []
for x in output:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment