Skip to content

Instantly share code, notes, and snippets.

@sublee
Last active April 1, 2019 06:27
Show Gist options
  • Save sublee/e370f3ecadb570eac18a9a0f0cc3b095 to your computer and use it in GitHub Desktop.
Save sublee/e370f3ecadb570eac18a9a0f0cc3b095 to your computer and use it in GitHub Desktop.
A demonstration of the problems from parallel checkpoints

This gist demonstrates the problems (a deadlock and a recursion limit error) due to pytorch#18566. Both code are tested with PyTorch 1.0.1.post2.

Each code has a docstring to explain itself.

"""It demonstrates a deadlock.
(Tested with PyTorch 1.0.1.post2)
There's a custom autograd function called RaceCondition. It performs
a nested backward() during its backward pass. Its backward pass also
acquires a global lock.
We will make two autograd branch through RaceCondition. These two
branches are independent so there should be no reason to lead to a
deadlock but it happens.
When you execute it, you will see only "forward done!". But you will
never see "backward done!" beacuse one of the branches cannot acquire
the lock.
"""
import threading
import torch
lock = threading.Lock()
class RaceCondition(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return input
@staticmethod
def backward(ctx, grad_output):
# A daedlock occurs.
with lock:
grad_output.requires_grad_().backward()
return grad_output
a = torch.ones(1, requires_grad=True)
b = torch.ones(1, requires_grad=True)
a = RaceCondition.apply(a)
b = RaceCondition.apply(b)
c = a + b
print('forward done!')
c.backward()
print('backward done!') # Here is never reached.
"""It demonstrates a recursion limit error.
(Tested with PyTorch 1.0.1.post2)
We have one input tensor and a small model. We will stack model output
tensors with checkpointing several times.
In the backward pass, the nth checkpoint recursively calls the (n-1)th
checkpoint until the first one. If there are stacked checkpoints more
than the recursion limit, RecursionError occurs.
"""
import sys
import torch
from torch.utils.checkpoint import checkpoint
sys.setrecursionlimit(100)
a = torch.ones(1, requires_grad=True)
m = torch.nn.Linear(1, 1)
tensors = []
for _ in range(sys.getrecursionlimit()):
tensors.append(checkpoint(m, a))
c = torch.stack(tensors).sum()
# RecursionError occurs.
c.backward()
'''
Traceback (most recent call last):
File "pytorch-checkpoint-recursionlimit.py", line 16, in <module>
c.backward()
File ".../lib/python3.6/site-packages/torch/tensor.py", line 102, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File ".../lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward
allow_unreachable=True) # allow_unreachable flag
File ".../lib/python3.6/site-packages/torch/autograd/function.py", line 76, in apply
return self._forward_cls.backward(self, *args)
File ".../lib/python3.6/site-packages/torch/utils/checkpoint.py", line 70, in backward
with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state):
File ".../lib/python3.6/contextlib.py", line 159, in helper
return _GeneratorContextManager(func, args, kwds)
File ".../lib/python3.6/contextlib.py", line 60, in __init__
self.gen = func(*args, **kwds)
RecursionError: maximum recursion depth exceeded while calling a Python object
'''
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment