Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save vanbasten23/94b33839fa37469f30a03487d1fac746 to your computer and use it in GitHub Desktop.
Save vanbasten23/94b33839fa37469f30a03487d1fac746 to your computer and use it in GitHub Desktop.
What is PyTorch's gradient checkpointing?
What is PyTorch's gradient checkpointing?
PyTorch's gradient checkpointing is a technique used to reduce the memory footprint during the training of deep neural networks, especially those with very deep architectures. This is particularly useful for training large models that would otherwise require more GPU memory than is available.
### How Gradient Checkpointing Works
1. **Standard Training Process**:
- During the forward pass, activations (outputs of layers) are computed and stored for each layer.
- During the backward pass, these stored activations are used to compute gradients.
2. **Memory Bottleneck**:
- Storing all activations can consume a significant amount of memory, especially in very deep networks.
- This limits the maximum batch size and model size that can be trained on a given GPU.
3. **Gradient Checkpointing**:
- Instead of storing all activations during the forward pass, only a subset of them (checkpoints) are stored.
- During the backward pass, activations that were not stored are recomputed as needed.
- This trades increased computation time (due to recomputation) for reduced memory usage.
### Benefits
- **Reduced Memory Usage**: By storing fewer activations, memory usage during training is significantly reduced. This allows for training larger models or using larger batch sizes.
- **Feasibility of Training Deep Models**: Enables the training of extremely deep networks that would otherwise be infeasible due to memory constraints.
### Trade-offs
- **Increased Computation Time**: Recomputing activations during the backward pass increases the overall computation time.
- **Complexity in Implementation**: While PyTorch provides utilities to facilitate gradient checkpointing, it adds some complexity to the training pipeline.
### Implementation in PyTorch
PyTorch provides a straightforward API for gradient checkpointing through the `torch.utils.checkpoint` module. Here's how you can implement gradient checkpointing in a model:
1. **Import Required Module**:
```python
import torch
from torch.utils.checkpoint import checkpoint
```
2. **Define a Function for Checkpointing**:
- This function should encapsulate the part of the model you want to checkpoint.
```python
def checkpointed_forward(module, *inputs):
return checkpoint(module, *inputs)
```
3. **Apply Checkpointing in the Model**:
- Modify the model's forward method to use the checkpointed function for specific layers.
```python
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.layer1 = torch.nn.Linear(10, 10)
self.layer2 = torch.nn.Linear(10, 10)
self.layer3 = torch.nn.Linear(10, 10)
def forward(self, x):
x = self.layer1(x)
x = checkpointed_forward(self.layer2, x)
x = self.layer3(x)
return x
```
4. **Training the Model**:
- Train the model as usual. The checkpointed layers will use less memory but may take longer to compute due to recomputation.
```python
model = MyModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.MSELoss()
# Example training loop
for epoch in range(num_epochs):
for inputs, targets in dataloader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
```
### Example
Here is a more complete example of how gradient checkpointing might be used in a simple neural network:
```python
import torch
from torch.utils.checkpoint import checkpoint
class CheckpointedModel(torch.nn.Module):
def __init__(self):
super(CheckpointedModel, self).__init__()
self.layer1 = torch.nn.Linear(1000, 1000)
self.layer2 = torch.nn.Linear(1000, 1000)
self.layer3 = torch.nn.Linear(1000, 1000)
self.layer4 = torch.nn.Linear(1000, 1000)
def forward(self, x):
x = self.layer1(x)
x = checkpoint(self.layer2, x)
x = checkpoint(self.layer3, x)
x = self.layer4(x)
return x
# Example usage
model = CheckpointedModel()
inputs = torch.randn(10, 1000)
outputs = model(inputs)
```
### Conclusion
Gradient checkpointing in PyTorch is a powerful technique to manage memory consumption, making it possible to train larger and deeper neural networks. While it does introduce some additional computation overhead, the trade-off is often worth it for the ability to train models that would otherwise be limited by memory constraints. PyTorch's `torch.utils.checkpoint` module provides a user-friendly way to implement this technique in your models.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment