Created
March 8, 2022 00:31
-
-
Save jamesr66a/7304d8818c04abd49df7a70a2ae51c02 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
commit b0703a2d968ccc91760ad738e9a50b3a913969a9 | |
Author: James Reed <[email protected]> | |
Date: Wed Mar 2 01:03:05 2022 +0000 | |
Serialization fixes for HF tracer | |
diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py | |
index b88ae4ae7..aeb345ad3 100644 | |
--- a/src/transformers/utils/fx.py | |
+++ b/src/transformers/utils/fx.py | |
@@ -270,6 +270,7 @@ class HFTracer(Tracer): | |
self.prev_module = None | |
self.recorded_methods = None | |
+ self.input_vals = None | |
def _register_leaf_function(self, module: ModuleType, name: str): | |
"""Registers the function called name in module as a leaf function.""" | |
@@ -404,18 +405,21 @@ class HFTracer(Tracer): | |
if method_names is None: | |
method_names = self._DEFAULT_METHODS_TO_RECORD | |
- # Creating a random input shape to generate dummy inputs. | |
- batch_size = _generate_random_int() | |
- sequence_length = _generate_random_int() | |
- shape = [batch_size, sequence_length] | |
+ if self.input_vals is None: | |
+ # Creating a random input shape to generate dummy inputs. | |
+ batch_size = _generate_random_int() | |
+ sequence_length = _generate_random_int() | |
+ shape = [batch_size, sequence_length] | |
- if model.__class__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING): | |
- num_choices = _generate_random_int(low=2, high=5) | |
- shape.insert(1, num_choices) | |
+ if model.__class__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING): | |
+ num_choices = _generate_random_int(low=2, high=5) | |
+ shape.insert(1, num_choices) | |
- inputs = {} | |
- for input_name in input_names: | |
- inputs.update(self._generate_dummy_input(model, input_name, shape)) | |
+ inputs = {} | |
+ for input_name in input_names: | |
+ inputs.update(self._generate_dummy_input(model, input_name, shape)) | |
+ else: | |
+ inputs = self.input_vals | |
cache_names, original_methods = self._monkey_patch_tensor_methods_for_model_recording(model, method_names) | |
self.original_methods = original_methods | |
@@ -427,6 +431,7 @@ class HFTracer(Tracer): | |
self.recorded_methods = { | |
method_name: cache_name for method_name, cache_name in cache_names.items() if hasattr(model, cache_name) | |
} | |
+ return inputs | |
def _module_getattr(self, attr, attr_val, parameter_proxy_cache): | |
if isinstance(attr_val, torch.nn.Parameter): | |
@@ -464,7 +469,7 @@ class HFTracer(Tracer): | |
sig = inspect.signature(root.forward) | |
input_names = sig.parameters.keys() - concrete_args.keys() | |
- self.record(root, input_names, method_names=method_names) | |
+ self.input_vals = self.record(root, input_names, method_names=method_names) | |
# TODO: adapt the way leaf function are wrapped with the "autowrap function" feature from Tracer. | |
autowrap_functions = [patched for (_, _, patched) in self._leaf_functions_register.values()] | |
@@ -548,6 +553,86 @@ class HFTracer(Tracer): | |
return super().create_arg(list(a)) | |
return super().create_arg(a) | |
+def reduce_graph_module(body: Dict[Any, Any], import_block: str, metadata : Any) -> torch.nn.Module: | |
+ # BC: attribute name was changed from `code` to `_code` to facilitate | |
+ # making `code` into a property and adding a docstring to it | |
+ fn_src = body.get('_code') or body['code'] | |
+ forward = torch.fx.graph_module._forward_from_src(import_block + fn_src, {}) | |
+ return _deserialize_graph_module(forward, body, metadata) | |
+ | |
+def _deserialize_graph_module(forward, body: Dict[Any, Any], metadata : Any) -> torch.nn.Module: | |
+ """ | |
+ Deserialize a GraphModule given the dictionary of the original module, | |
+ using the code to reconstruct the graph. We delete the actual graph before | |
+ saving the dictionary so that changes to the in-memory graph format do not | |
+ get serialized. | |
+ """ | |
+ # We create a dummy class here because symbolic_trace pulls the forward() | |
+ # function off of the class, rather than the instance | |
+ class CodeOnlyModule(torch.nn.Module): | |
+ def __init__(self, body): | |
+ super().__init__() | |
+ self.__dict__ = body | |
+ | |
+ # Try to retrieve the forward source in a backward-compatible way | |
+ CodeOnlyModule.forward = forward | |
+ | |
+ tracer_cls = body.get('_tracer_cls') | |
+ if tracer_cls is None: | |
+ from torch.fx._symbolic_trace import Tracer | |
+ tracer_cls = Tracer | |
+ | |
+ graphmodule_cls_name = body.get('_graphmodule_cls_name', 'GraphModule') | |
+ | |
+ # This is a workaround for a mypy linter issue related to | |
+ # passing base class as an argument - https://github.com/python/mypy/issues/5865. | |
+ cls_tracer : Any = tracer_cls | |
+ | |
+ class KeepModules(cls_tracer): | |
+ # we shouldn't trace into any of the submodules, | |
+ # because they were not traced in the original GraphModule | |
+ def is_leaf_module(self, _: torch.nn.Module, __: str) -> bool: | |
+ return True | |
+ | |
+ com = CodeOnlyModule(body) | |
+ | |
+ assert isinstance(metadata, dict) | |
+ com.device = metadata['device'] | |
+ | |
+ km = KeepModules() | |
+ km.input_vals = metadata['input_vals'] | |
+ | |
+ graph = km.trace(com, concrete_args=metadata['concrete_args']) | |
+ | |
+ # Manually set Tracer class on the reconstructed Graph, to avoid | |
+ # referencing the private local subclass KeepModules. | |
+ graph._tracer_cls = tracer_cls | |
+ gm = GraphModule(com, graph, class_name=graphmodule_cls_name) | |
+ | |
+ # The GraphModule constructor only retains attributes referenced by the graph. | |
+ # In this case, our goal is return a GraphModule as close to identical as the one | |
+ # put into the package. If any additional attributes were present in body, | |
+ # we should keep them. | |
+ for k, v in body.items(): | |
+ if not hasattr(gm, k): | |
+ setattr(gm, k, v) | |
+ return gm | |
+ | |
+class MetadataGraphModule(torch.fx.GraphModule): | |
+ def __init__(self, root, graph : torch.fx.Graph, metadata : Any, class_name : str = 'GraphModule'): | |
+ super().__init__(root, graph, class_name) | |
+ self.metadata = metadata | |
+ | |
+ def __reduce__(self): | |
+ """ | |
+ Serialization of GraphModule. We serialize only the generated code, not | |
+ the underlying ``Graph``. This is because ``Graph`` does not have on-disk | |
+ backward-compatibility guarantees, whereas Python source code does. | |
+ On the deserialization side, we symbolically trace through the generated | |
+ code to regenerate the underlying ``Graph`` | |
+ """ | |
+ (_, (dict_without_graph, import_block)) = super().__reduce__() | |
+ return (reduce_graph_module, (dict_without_graph, import_block, self.metadata)) | |
def symbolic_trace( | |
model: PreTrainedModel, | |
@@ -589,6 +674,7 @@ def symbolic_trace( | |
# Tracing. | |
tracer = HFTracer() | |
traced_graph = tracer.trace(model, concrete_args=concrete_args) | |
- traced = torch.fx.GraphModule(model, traced_graph) | |
+ traced = MetadataGraphModule( | |
+ model, traced_graph, {'device' : model.device, 'concrete_args': concrete_args, 'input_vals' : tracer.input_vals}) | |
return traced | |
diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py | |
index b6ec0eae8..2ee7bdf80 100755 | |
--- a/tests/test_modeling_common.py | |
+++ b/tests/test_modeling_common.py | |
@@ -728,6 +728,24 @@ class ModelTesterMixin: | |
except RuntimeError: | |
self.fail("Couldn't trace module.") | |
+ # Test serialization | |
+ import pickle | |
+ | |
+ with tempfile.TemporaryDirectory() as tmp_dir_name: | |
+ pickle_file_name = os.path.join(tmp_dir_name, "traced_model.pkl") | |
+ | |
+ try: | |
+ with open(pickle_file_name, 'wb') as f: | |
+ pickle.dump(traced_model, f) | |
+ except Exception: | |
+ self.fail("Couldn't save module.") | |
+ | |
+ try: | |
+ with open(pickle_file_name, 'rb') as f: | |
+ loaded = pickle.load(f) | |
+ except Exception: | |
+ self.fail("Couldn't load module.") | |
+ | |
def flatten_output(output): | |
flatten = [] | |
for x in output: |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment