This gist demonstrates the problems (a deadlock and a recursion limit error) due to pytorch#18566. Both code are tested with PyTorch 1.0.1.post2.
Each code has a docstring to explain itself.
This gist demonstrates the problems (a deadlock and a recursion limit error) due to pytorch#18566. Both code are tested with PyTorch 1.0.1.post2.
Each code has a docstring to explain itself.
"""It demonstrates a deadlock. | |
(Tested with PyTorch 1.0.1.post2) | |
There's a custom autograd function called RaceCondition. It performs | |
a nested backward() during its backward pass. Its backward pass also | |
acquires a global lock. | |
We will make two autograd branch through RaceCondition. These two | |
branches are independent so there should be no reason to lead to a | |
deadlock but it happens. | |
When you execute it, you will see only "forward done!". But you will | |
never see "backward done!" beacuse one of the branches cannot acquire | |
the lock. | |
""" | |
import threading | |
import torch | |
lock = threading.Lock() | |
class RaceCondition(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, input): | |
return input | |
@staticmethod | |
def backward(ctx, grad_output): | |
# A daedlock occurs. | |
with lock: | |
grad_output.requires_grad_().backward() | |
return grad_output | |
a = torch.ones(1, requires_grad=True) | |
b = torch.ones(1, requires_grad=True) | |
a = RaceCondition.apply(a) | |
b = RaceCondition.apply(b) | |
c = a + b | |
print('forward done!') | |
c.backward() | |
print('backward done!') # Here is never reached. |
"""It demonstrates a recursion limit error. | |
(Tested with PyTorch 1.0.1.post2) | |
We have one input tensor and a small model. We will stack model output | |
tensors with checkpointing several times. | |
In the backward pass, the nth checkpoint recursively calls the (n-1)th | |
checkpoint until the first one. If there are stacked checkpoints more | |
than the recursion limit, RecursionError occurs. | |
""" | |
import sys | |
import torch | |
from torch.utils.checkpoint import checkpoint | |
sys.setrecursionlimit(100) | |
a = torch.ones(1, requires_grad=True) | |
m = torch.nn.Linear(1, 1) | |
tensors = [] | |
for _ in range(sys.getrecursionlimit()): | |
tensors.append(checkpoint(m, a)) | |
c = torch.stack(tensors).sum() | |
# RecursionError occurs. | |
c.backward() | |
''' | |
Traceback (most recent call last): | |
File "pytorch-checkpoint-recursionlimit.py", line 16, in <module> | |
c.backward() | |
File ".../lib/python3.6/site-packages/torch/tensor.py", line 102, in backward | |
torch.autograd.backward(self, gradient, retain_graph, create_graph) | |
File ".../lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward | |
allow_unreachable=True) # allow_unreachable flag | |
File ".../lib/python3.6/site-packages/torch/autograd/function.py", line 76, in apply | |
return self._forward_cls.backward(self, *args) | |
File ".../lib/python3.6/site-packages/torch/utils/checkpoint.py", line 70, in backward | |
with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state): | |
File ".../lib/python3.6/contextlib.py", line 159, in helper | |
return _GeneratorContextManager(func, args, kwds) | |
File ".../lib/python3.6/contextlib.py", line 60, in __init__ | |
self.gen = func(*args, **kwds) | |
RecursionError: maximum recursion depth exceeded while calling a Python object | |
''' |