Created
April 19, 2022 18:16
-
-
Save jamesr66a/ef1e85bbc7e9db77245bd7032d7ca971 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py | |
index e2f033d72a..7b3a97991d 100644 | |
--- a/torch/fx/graph_module.py | |
+++ b/torch/fx/graph_module.py | |
@@ -222,6 +222,56 @@ def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str): | |
else: | |
setattr(to_module, field, from_obj) | |
+class _WrappedCall: | |
+ def __init__(self, cls, cls_call): | |
+ self.cls = cls | |
+ self.cls_call = cls_call | |
+ | |
+ # Previously, if an error occurred when valid | |
+ # symbolically-traced code was run with an invalid input, the | |
+ # user would see the source of the error as coming from | |
+ # `File "<eval_with_key_N">`, where N is some number. We use | |
+ # this function to generate a more informative error message. We | |
+ # return the traceback itself, a message explaining that the | |
+ # error occurred in a traced Module's generated forward | |
+ # function, and five lines of context surrounding the faulty | |
+ # line | |
+ @staticmethod | |
+ def _generate_error_message(frame_summary: traceback.FrameSummary) -> str: | |
+ # auxiliary variables (for readability) | |
+ err_lineno = frame_summary.lineno | |
+ err_line_len = len(frame_summary.line) | |
+ all_src_lines = linecache.getlines(frame_summary.filename) | |
+ | |
+ # constituent substrings of the error message | |
+ tb_repr = traceback.format_exc() | |
+ custom_msg = ("Call using an FX-traced Module, " | |
+ f"line {err_lineno} of the traced Module's " | |
+ "generated forward function:") | |
+ before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno]) | |
+ marker = "~" * err_line_len + "~~~ <--- HERE" | |
+ err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2]) | |
+ | |
+ # joined message | |
+ return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err]) | |
+ | |
+ def __call__(self, obj, *args, **kwargs): | |
+ try: | |
+ if self.cls_call is not None: | |
+ return self.cls_call(obj, *args, **kwargs) | |
+ else: | |
+ return super(self.cls, obj).__call__(*args, **kwargs) | |
+ except Exception as e: | |
+ assert e.__traceback__ | |
+ topmost_framesummary: traceback.FrameSummary = \ | |
+ traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type] | |
+ if "eval_with_key" in topmost_framesummary.filename: | |
+ print(_WrappedCall._generate_error_message(topmost_framesummary), | |
+ file=sys.stderr) | |
+ raise e.with_traceback(None) | |
+ else: | |
+ raise e | |
+ | |
@compatibility(is_backward_compatible=True) | |
class GraphModule(torch.nn.Module): | |
""" | |
@@ -587,51 +637,13 @@ class {module_name}(torch.nn.Module): | |
# bypass patching of torch.nn.Module.__call__ done while symbolic tracing. | |
cls_call = cls.__call__ if "__call__" in vars(cls) else None | |
- # Previously, if an error occurred when valid | |
- # symbolically-traced code was run with an invalid input, the | |
- # user would see the source of the error as coming from | |
- # `File "<eval_with_key_N">`, where N is some number. We use | |
- # this function to generate a more informative error message. We | |
- # return the traceback itself, a message explaining that the | |
- # error occurred in a traced Module's generated forward | |
- # function, and five lines of context surrounding the faulty | |
- # line | |
- def generate_error_message(frame_summary: traceback.FrameSummary) -> str: | |
- # auxiliary variables (for readability) | |
- err_lineno = frame_summary.lineno | |
- err_line_len = len(frame_summary.line) | |
- all_src_lines = linecache.getlines(frame_summary.filename) | |
- | |
- # constituent substrings of the error message | |
- tb_repr = traceback.format_exc() | |
- custom_msg = ("Call using an FX-traced Module, " | |
- f"line {err_lineno} of the traced Module's " | |
- "generated forward function:") | |
- before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno]) | |
- marker = "~" * err_line_len + "~~~ <--- HERE" | |
- err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2]) | |
- | |
- # joined message | |
- return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err]) | |
- | |
- def wrapped_call(self, *args, **kwargs): | |
- try: | |
- if cls_call is not None: | |
- return cls_call(self, *args, **kwargs) | |
- else: | |
- return super(cls, self).__call__(*args, **kwargs) | |
- except Exception as e: | |
- assert e.__traceback__ | |
- topmost_framesummary: traceback.FrameSummary = \ | |
- traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type] | |
- if "eval_with_key" in topmost_framesummary.filename: | |
- print(generate_error_message(topmost_framesummary), | |
- file=sys.stderr) | |
- raise e.with_traceback(None) | |
- else: | |
- raise e | |
- | |
- cls.__call__ = wrapped_call | |
+ if '_wrapped_call' not in vars(cls): | |
+ cls._wrapped_call = _WrappedCall(cls, cls_call) | |
+ | |
+ def call_wrapped(self, *args, **kwargs): | |
+ return self._wrapped_call(self, *args, **kwargs) | |
+ | |
+ cls.__call__ = call_wrapped | |
return python_code | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment