Skip to content

Instantly share code, notes, and snippets.

@qsh-zh
Created February 24, 2025 10:38
Show Gist options
  • Save qsh-zh/e927b1302f124e817bbe9c7cd1b2eaad to your computer and use it in GitHub Desktop.
Save qsh-zh/e927b1302f124e817bbe9c7cd1b2eaad to your computer and use it in GitHub Desktop.
torch_activation_checkpointing

Tutorial: PyTorch Gradient Checkpointing with Detailed Torch Background and Implementation

In this tutorial, we'll dive deep into gradient checkpointing in PyTorch, a memory optimization technique for neural network training. We'll cover its purpose, mechanics, and provide detailed insights into PyTorch's internal implementation, including advanced features like early stopping and selective checkpointing. Additionally, we'll address the user's specific requests for background on Torch operations, Torch dispatch, and detailed implementation analysis. Practical examples, flowcharts, and comprehensive explanations are included to ensure clarity.

Introduction to Gradient Checkpointing

Gradient checkpointing reduces memory consumption during the training of deep neural networks. Normally, during the forward pass, all intermediate activations are stored in memory for gradient computation in the backward pass. This can lead to high memory usage, especially for large models. Gradient checkpointing addresses this by:

  • Selectively Saving Activations: Only specific tensors (called checkpoints) are saved.
  • Recomputing on Demand: Unsaved activations are recomputed from the nearest checkpoint during the backward pass.

Why Use Gradient Checkpointing?

  • Memory Efficiency: Reduces peak memory usage, enabling training of larger models or batch sizes.
  • Flexibility: Allows fine-grained control over which tensors to save or recompute, balancing memory savings with computational cost.

Key Concepts

  • Checkpoints: Points in the computation graph where activations are saved.
  • Recompute: Process of recalculating unsaved activations during the backward pass.
  • Early Stopping: A feature to halt recomputation once all necessary tensors are available, avoiding extra work.
  • Computation Graph: A directed acyclic graph (DAG) representing operations and tensors in the forward pass.

PyTorch Gradient Checkpointing Basics

PyTorch provides the torch.utils.checkpoint module for gradient checkpointing. We'll focus on the non-reentrant version (use_reentrant=False), which offers greater flexibility, including early stopping and selective checkpointing.

Basic Usage Example

import torch
from torch.utils.checkpoint import checkpoint

# Define a function to checkpoint
def my_function(x):
    y = torch.sin(x)  # Intermediate computation
    z = torch.cos(y)  # Final output
    return z

# Input tensor with gradient tracking
x = torch.randn(100, 100, requires_grad=True)

# Apply checkpointing
output = checkpoint(my_function, x, use_reentrant=False)

# Compute loss and backpropagate
loss = output.sum()
loss.backward()

Torch Background: Understanding Key Concepts

To fully understand gradient checkpointing, we need to explore the underlying Torch mechanisms. This section addresses background on Torch operations, dispatch, and their roles in checkpointing.

Torch Operations (Torch Ops)

What Are Torch Ops?

  1. Definition: Torch operations (ops) are the fundamental building blocks of PyTorch computations. They represent low-level operations like:

    • Matrix multiplication (torch.matmul)
    • Activation functions (torch.sigmoid)
    • Element-wise operations (torch.mul)
  2. Implementation:

    • Each op is implemented in C++ or CUDA for efficiency
    • Registered with PyTorch's dispatcher
    • Python APIs (e.g., torch.sin) are thin wrappers around these ops
  3. Examples:

    • torch.ops.aten.mm.default: Matrix multiplication
    • torch.ops.aten.sigmoid.default: Sigmoid activation

    These ops are accessible via torch.ops but are typically invoked indirectly through Python APIs

Computation Graph and Torch Ops

  1. Computation Graph Structure:

    • PyTorch builds a dynamic computation graph during the forward pass
    • Nodes represent tensors or operations
    • Edges represent dependencies
    • Each operation corresponds to a Torch op in the graph

    Example: For y = torch.sin(x); z = torch.cos(y), the graph includes nodes for x, sin, y, cos, and z

  2. Role in Checkpointing:

    • Gradient checkpointing operates on this graph by deciding which intermediate tensors to save or recompute
    • Checkpoints are placed at strategic nodes to minimize memory usage

Selective Activation Checkpointing with Torch Ops

  1. Concept:

    • Define policies for which ops' outputs to:
      • Save (MUST_SAVE)
      • Recompute (PREFER_RECOMPUTE)
    • Implemented by intercepting ops during forward pass
  2. Underlying Mechanism:

    • Uses TorchDispatchMode to intercept ops at the dispatcher level
    • Policy function (policy_fn) examines each op and decides its fate

    Example: Save outputs of expensive ops (like torch.matmul) but recompute lightweight ops (like torch.sigmoid)

  3. Computation Graph Impact:

    • Saved tensors → stored as checkpoints
    • Unsaved tensors → replaced with placeholders and marked for recomputation

Torch Dispatch

What Is Torch Dispatch?

  1. Definition:

    • PyTorch's mechanism for routing operation calls to their implementations
    • A central dispatcher managing op execution, logging, and customization
  2. How It Works:

    graph LR
    A[Op Call] --> B[Python API]
    B --> C[Dispatcher]
    C --> D[Backend/Implementation]
    
    Loading
    • Op call (e.g., torch.sin) → Python API → Dispatcher → Backend
    • Dispatch can be intercepted using TorchDispatchMode
  3. Role in PyTorch:

    • Enables core features:
      • Autograd
      • Device placement
      • Custom extensions
    • Provides unified interface across backends

How Activation Checkpointing Uses Torch Dispatch

  1. Selective Checkpointing Implementation:

    from torch.utils.checkpoint import CheckpointPolicy
    
    def policy_fn(ctx, op, *args, **kwargs):
        if op == torch.ops.aten.mm.default:
            return CheckpointPolicy.MUST_SAVE
        return CheckpointPolicy.PREFER_RECOMPUTE
    
    output = checkpoint(fn, x, y, use_reentrant=False, policy=policy_fn)

    Note: TorchDispatchMode intercepts torch.matmul and marks its output for saving

  2. Implementation Details:

    • Dispatch mode wraps forward pass
    • MUST_SAVE → tensor stored as checkpoint
    • PREFER_RECOMPUTE → placeholder inserted
  3. Benefits:

    • Fine-grained memory control
    • Seamless PyTorch integration

Early Stopping and Avoiding Recomputation

Early Stopping Mechanism

  1. Concept:

    • Halts recomputation when all required tensors are available
    • Controlled by _CheckpointFrame class
  2. Implementation Components:

    • Recompute Phase:
      • Reruns function from nearest checkpoint for missing tensors
      • Stores results in _CheckpointFrame.recomputed
    • Early Stopping Trigger:
      • Tracks needed tensors
      • Raises _StopRecomputationError when complete
    • State Management:
      • recomp_counter: Tracks recomputation count
      • is_recomputed: Flags recomputation state

Pack/Unpack Hook Mechanism

Part of PyTorch's saved_tensors_hooks API

  1. pack_hook:

    • Triggered during forward pass
    • PREFER_RECOMPUTE → lightweight placeholder
    • MUST_SAVE → direct tensor save
  2. unpack_hook:

    • Triggered during backward pass
    • Placeholder → triggers recomputation
    • Saved tensor → direct retrieval

Implementation Flow Diagram

graph TD
    A[Start Forward Pass] --> B[Apply _CachingTorchDispatchMode with policy_fn]
    B --> C[Run Function with _checkpoint_hook]
    C --> D[For Each Op in Function]
    D --> E{policy_fn Decision}
    E -->|MUST_SAVE| F[Save Output Tensor as Checkpoint]
    E -->|PREFER_RECOMPUTE| G[Execute Op, pack_hook: Save Placeholder]
    F --> H[Save Inputs and Checkpoints]
    G --> H
    H --> I[Complete Forward Pass]

    J[Start Backward Pass] --> K[unpack_hook: Need a Tensor?]
    K --> L{Is Tensor Saved?}
    L -->|Yes, MUST_SAVE| M[Fetch Saved Tensor from Checkpoint]
    L -->|No, Placeholder| N[Run recompute_fn with _CachedTorchDispatchMode]
    N --> O[For Each Op in recompute_fn]
    O --> P{policy_fn Decision}
    P -->|MUST_SAVE| Q[Fetch Saved Tensor, Skip Op Recomputation]
    P -->|PREFER_RECOMPUTE| R[Recompute Op, Save Tensor in _CheckpointFrame.recomputed]
    Q --> S[Next Op or Finish]
    R --> S
    S --> T{Check Early Stop?}
    T -->|Yes, All Tensors Ready| U[Raise _StopRecomputationError]
    T -->|No| V[Continue Recomputation]
    U --> W[Stop Recomputation]
    V --> W
    M --> X[Compute Gradients]
    W --> X
    X --> Y[End Backward Pass]
Loading

Advanced Example with Selective Checkpointing

def fn(x, y):
    a = torch.matmul(x, y)  # Matrix multiplication
    b = torch.sigmoid(a)    # Activation
    c = b * y               # Element-wise multiplication
    return c

def policy_fn(ctx, op, *args, **kwargs):
    if op == torch.ops.aten.mm.default:
        return CheckpointPolicy.MUST_SAVE
    return CheckpointPolicy.PREFER_RECOMPUTE

x = torch.randn(10, 20, requires_grad=True)
y = torch.randn(20, 30, requires_grad=True)
output = checkpoint(fn, x, y, use_reentrant=False, policy=policy_fn)
loss = output.sum()
loss.backward()

Forward Pass Analysis

  1. Initial State:

    • Inputs (x and y) are saved
  2. Policy Application:

    • torch.matmul: MUST_SAVEa is saved
    • torch.sigmoid: PREFER_RECOMPUTEb gets placeholder
    • torch.mul: PREFER_RECOMPUTEc gets placeholder
  3. Final State:

    • Checkpoints: x, y, and a
    • Placeholders: b and c

Backward Pass Analysis

  1. Gradient on c:

    • Requires b and y
    • y is saved
    • b needs recomputation
  2. Recompute Process:

    • Reruns fn from x and y
    • Uses saved a (MUST_SAVE)
    • Only b = torch.sigmoid(a) needs recomputation
    • Early stopping after b
  3. Gradient Computation:

    • Uses recomputed b and saved y

Conclusion

PyTorch's gradient checkpointing combines several sophisticated mechanisms:

  1. Torch Operations and Dispatch:

    • Low-level ops as building blocks
    • Dispatcher for routing operations
    • TorchDispatchMode for policy enforcement
  2. Memory Management:

    • Selective saving of activations
    • Pack/unpack hooks for tensor management
    • Placeholder system for deferred computation
  3. Optimization Features:

    • Early stopping
    • Operation skipping for saved tensors
    • Policy-based selective checkpointing
  4. Integration:

    • Seamless autograd compatibility
    • Flexible policy customization
    • Efficient memory-computation trade-off

This comprehensive system enables training of large neural networks with limited memory resources while maintaining computational efficiency. Understanding these implementation details helps in effectively utilizing gradient checkpointing and potentially customizing it for specific needs.

Source Code References

  1. TorchDispatchMode: Found in torch/_dispatch/python.py
  2. _CachingTorchDispatchMode: Implemented in torch/utils/checkpoint.py
  3. Pack/Unpack Hooks: Via torch.autograd.graph.saved_tensors_hooks
  4. Recomputation Logic: In torch/utils/checkpoint.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment