Skip to content

Instantly share code, notes, and snippets.

@soulitzer
soulitzer / get_source_partition.py
Created June 27, 2023 15:33
get_source_partitions produces different results in export aten_graph=True
import torch
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
import pprint
m = torch.nn.Linear(10, 10)
def fn(x):
return m(x)
@soulitzer
soulitzer / gist:4da95730e40d814aa0c64cdff5c48571
Last active June 19, 2023 20:17
fx graph before inlining modules
(pytorch1) jw3468-mbp:pytorch1 jw3468$ python test/dynamo/test_modules.py -k TestTemp -v
test_dynamo_inline_module_nn_AdaptiveAvgPool1d_cpu_float32 (__main__.TestTempCPU) ...
opcode name target args kwargs
----------- --------- --------- ------------ --------
placeholder l_args_0_ L_args_0_ () {}
call_module m m (l_args_0_,) {}
output output output ((m,),) {}
stats [('calls_captured', 1), ('unique_graphs', 1)]
ok
@soulitzer
soulitzer / gist:c747d3e9cb1e241f6c7b9b57c0f84a9b
Created June 19, 2023 20:12
fx graph after inlining module call
test_dynamo_inline_module_nn_AdaptiveAvgPool1d_cpu_float32 (__main__.TestTempCPU) ...
opcode name target args kwargs
------------- ------------------- ------------------------------------------------------------------- ------------------------- --------
placeholder l_args_0_ L_args_0_ () {}
call_function adaptive_avg_pool1d <built-in method adaptive_avg_pool1d of type object at 0x103ccc7e8> (l_args_0_, 3) {}
output output output ((adaptive_avg_pool1d,),) {}
inline_call []
stats [('calls_captured', 1), ('unique_graphs', 1)]
ok
@soulitzer
soulitzer / gist:ff8fd8c3ccbb2c90dfe3df6d7713b167
Created June 16, 2023 15:25
activation checkpoint error when debug=True with python and cpp stacktrace
======================================================================
ERROR: test_checkpoint_detects_non_determinism (__main__.TestAutograd)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/home/jw3468/local/a/pytorch/torch/utils/checkpoint.py", line 1071, in unpack_hook_with_error_cb
return unpack_hook(holder)
File "/home/jw3468/local/a/pytorch/torch/utils/checkpoint.py", line 1053, in unpack_hook
frame.check_recomputed_tensors_match(gid)
File "/home/jw3468/local/a/pytorch/torch/utils/checkpoint.py", line 826, in check_recomputed_tensors_match
raise CheckpointError(
@soulitzer
soulitzer / gist:3d5e19c7cceae8e22f9bdd625ec39dd4
Created June 16, 2023 15:24
activation checkpoint error when debug=True with python stacktrace only
torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass.
tensor at position 1:
saved metadata: {'shape': torch.Size([1]), 'dtype': torch.float32, 'device': device(type='cpu')}
recomputed metadata: {'shape': torch.Size([2]), 'dtype': torch.float32, 'device': device(type='cpu')}
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/local/pytorch1/test/test_autograd.py", line 5692, in test_checkpoint_detects_non_determinism
import torch
from typing import Callable, Any
import contextlib
from torch.utils._python_dispatch import TorchDispatchMode
from typing import Dict, Tuple, Optional, Set
import weakref
_cloned: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
_original: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
_tid_to_weakhandle: Dict[Tuple[int, int], weakref.ReferenceType] = dict()
@soulitzer
soulitzer / global_fwd_tb.py
Last active August 16, 2022 20:50
How to expose anomaly mode traceback as a global
# Implements Alban's idea of making available the forward traceback
# corresponding to the execution of the current backwared node as a global
import torch
from torch import autograd
from torch.utils._python_dispatch import TorchDispatchMode
current_metadata = None
# Set up hooks so that during backward the global is properly set/unset
import syft as sy
import numpy as np
from uuid import uuid1
sy.logger.remove()
_syf_owner = None
_syf_user = None
def syf_login():
global _syf_owner
[... slack complaining message is too long]
File ~/Users/jw3468/local/install/miniconda3/envs/pysyft/lib/python3.9/site-packages/urllib3/connectionpool.py:340, in HTTPConnectionPool._raise_timeout(self, err, url, timeout_value)
339 if isinstance(err, SocketTimeout):
--> 340 raise ReadTimeoutError(
341 self, url, "Read timed out. (read timeout=%s)" % timeout_value
342 )
344 # See the above comment about EAGAIN in Python 3. In Python 2 we have
345 # to specifically catch it and throw the timeout error
from functools import wraps
import torch
from torch._decomp import decomposition_table
import torch.nn.functional as F
from torch.utils._pytree import tree_map
from torch.testing._internal.logging_tensor import LoggingTensor, capture_logs
# Goals:
# - we want something reusable that can compose with any subclass