Skip to content

Instantly share code, notes, and snippets.

@kbastani
Created May 5, 2025 11:48
Show Gist options
  • Save kbastani/7f12ade56459372de3066e9996bf422d to your computer and use it in GitHub Desktop.
Save kbastani/7f12ade56459372de3066e9996bf422d to your computer and use it in GitHub Desktop.
PyTorch/XLA on TPUs: A Research Report

PyTorch/XLA on TPUs: A Research Report

Date: 2025-05-05

Prepared by: AI Research Assistant

Executive Summary

This report synthesizes findings from a series of knowledge organization tasks conducted between May 4th and May 5th, 2025. The primary focus is on understanding the interplay between PyTorch, the Accelerated Linear Algebra (XLA) compiler, and Google's Tensor Processing Units (TPUs) for accelerating deep learning workloads. Key areas investigated include the goals and intended audience of a TPU and PyTorch/XLA report, the performance comparisons of TPUs against CPUs and GPUs, the optimizations performed by XLA, the transition from the XRT runtime to the Plugin JRT (PJRT) runtime, the trade-offs associated with Fully Sharded Data Parallel (FSDP), the performance improvements and limitations of the TorchDynamo backend, the workings and limitations of Systolic Arrays, the accuracy trade-offs of Bfloat16 data type, the applicability and obstacles of Loop Fusion, and the support for Experimental Quantized Operations for XLA. While significant progress has been made in understanding these individual components, further research is needed to fully grasp the interactions and complexities within the PyTorch/XLA ecosystem. Specifically, understanding what happens when the separate technologies combine is a crucial area for future work.

1. Introduction

This research report aims to consolidate information gathered from a series of knowledge synthesis tasks related to PyTorch/XLA on Google Cloud TPUs. The tasks were initiated by Kenny and executed by this AI Research Assistant. The report provides a comprehensive overview of the key components, optimization techniques, and trade-offs involved in leveraging TPUs with PyTorch/XLA for accelerating deep learning workloads. The findings are based on a review of documentation, code examples, and performance benchmarks, as well as analysis of the underlying hardware and software architectures.

2. Report Objectives and Audience

The initial query focused on understanding the purpose, audience, and planned follow-up actions for a "TPU and PyTorch/XLA Report" created on May 5, 2025. The report, requested by Kenny and prepared by Philip, aimed to summarize the state-of-the-art regarding TPUs and PyTorch/XLA. The goals included outlining the project map, future directions, possibilities, and drawbacks of using TPUs with PyTorch/XLA. While the direct audience was Kenny, the report's content suggests a broader goal of informing strategic decisions related to leveraging TPUs for deep learning projects. While follow-up actions were not specified, Philip's concluding question indicates the need for further investigation and refinement.

Table 1: Training AI Models - Different Tools for Different Jobs

This table compares the general approaches to training AI models, showing where PyTorch/XLA fits in.

Feature Standard PyTorch (CPU/GPU) PyTorch/XLA (on TPUs/Accelerators)
Best For General AI tasks, smaller models, prototyping Very large models, speed-critical training/inference
Hardware Standard CPUs & GPUs Specialized AI accelerators (like Google TPUs)
Performance (Large AI) Can be slow or memory-limited Significantly faster, handles massive datasets/models
Scalability Good, but can hit limits with huge models Designed for scaling across many accelerator chips
Ease of Use Very familiar PyTorch experience Still PyTorch, but needs awareness of XLA specifics
Key Technology Direct PyTorch execution XLA Compiler optimizes PyTorch for accelerators

Takeaway: PyTorch/XLA isn't replacing standard PyTorch; it's extending it. It lets the PyTorch community tackle massive AI problems much faster by using specialized hardware like TPUs, thanks to the XLA compiler working behind the scenes.

Table 2: PyTorch/XLA - Getting Faster and More Powerful

This shows how recent developments within PyTorch/XLA itself are making it better.

Area of Improvement Old Way (Conceptual) New Way (Current State, May 2025) Benefit for Everyone
Connecting to Hardware Older runtime (XRT), more overhead PJRT Runtime Faster execution, more stable, easier setup
Handling Big Models Basic distribution (DDP), replicates model FSDP (Fully Sharded Data Parallel) Train much larger models by smartly splitting them up
Running the Code Standard PyTorch execution loop TorchDynamo (Experimental) backend integration Potential for faster model runs with less modification
Efficiency Standard Python loops, potential bottlenecks Optimized constructs (like while_loop), Quantization Better hardware use, faster processing, less memory

Takeaway: The tools within PyTorch/XLA are constantly improving. Things like PJRT and FSDP make it more practical and powerful to train state-of-the-art models, addressing previous bottlenecks in speed and scale. Features like TorchDynamo and Quantization promise even more performance gains.

Table 3: The Big Picture - Why PyTorch/XLA Matters for the Future

This table connects PyTorch/XLA to the broader goals of open source and scalable AI.

Aspect Description Why It's a Step Forward
Core Compiler (XLA) A powerful compiler (XLA) that optimizes AI math for different hardware. Provides a common bridge from popular frameworks (like PyTorch) to specialized hardware. Not tied to just one framework.
Hardware Focus Currently strong support for Google TPUs. Demonstrates XLA's power on cutting-edge accelerators. TPUs are a prime example, not the only possibility.
Future Hardware Architecture allows adding support for other AI accelerators (via plugins). Opens the door for PyTorch to run efficiently on a wider range of future hardware, promoting choice and innovation.
Open Source Impact Enables the open PyTorch community to access high-performance computing. Democratizes large-scale AI training, allowing more researchers and developers to work on cutting-edge problems.
Scalability Standard Offers standardized ways (like FSDP) to scale training efficiently. Makes it easier for teams everywhere to adopt best practices for training massive models reliably.

Takeaway: PyTorch/XLA, powered by the XLA compiler, is a crucial development. It gives the huge PyTorch community a pathway to extreme performance and scale, starting strongly with TPUs. Importantly, the underlying XLA technology is designed to be flexible, paving the way for PyTorch to work efficiently with potentially many different kinds of future AI hardware. This promotes open standards and empowers more people to build the next generation of AI.

3. TPU vs. CPU and GPU Performance

The investigation delved into how TPUs are compared to other hardware accelerators like GPUs. TPUs, GPUs, and CPUs each offer different architectural strengths and weaknesses when applied to machine learning. CPUs are general-purpose and flexible, but are limited by the von Neumann bottleneck. GPUs offer more throughput thanks to their massively parallel architecture but still suffer from memory access bottlenecks. TPUs were custom-designed to accelerate machine learning workloads. By minimizing memory access and maximizing matrix multiplication, they can often outperform GPUs and CPUs for deep learning tasks. The performance metrics used in these comparisons often involve training time, inference speed, and scalability. Specific examples of benchmarks include speedup numbers reported from using TorchDynamo on a Cloud TPU v4-8.

3.1 Investigating CPU/GPU/TPU Specific Models and Batch Sizes

Further investigation reveals that performance comparisons often occur between specific models and different batch sizes. For instance, models such as ResNet18, ResNet50, and BERT are often cited as key models for evaluation. The document about PJRT runtime (2025-05-04T00:06:23.7194) does not specify the batch size used when reporting a >35% improvement in training time on TPU v4 using TorchBench 2.0.

4. XLA Compiler Optimizations and Limitations

The XLA (Accelerated Linear Algebra) compiler plays a critical role in translating PyTorch operations for execution on TPUs. XLA performs various optimizations, including operator fusion, layout assignment, and algebraic simplification. The primary limitation of these optimizations is the potential for unintended graph recompilations due to dynamic factors like changing tensor shapes. Minimizing device-host communication is another key consideration.

5. PJRT Runtime Transition

The transition from the older XRT runtime to the Plugin JRT (PJRT) runtime was driven by the need to address the limitations of XRT's TPU Node architecture, which introduced overhead due to its client/server model. PJRT directly accesses the local device, simplifying runtime configuration and improving performance. While the specific challenges encountered during the transition are not detailed, the move to PJRT resulted in a more efficient and streamlined runtime environment for TPUs.

5.1 PJRT Stability and xm.rendezvous

A key element of the PJRT runtime is the stability of xm.rendezvous. This critical component was reimplemented using XLA-native collective communication to enhance stability on large TPU pods.

6. FSDP Trade-offs

Fully Sharded Data Parallel (FSDP) is a data parallelism technique used to train large models on TPUs. The main trade-off between FSDP and other data parallelism techniques, such as Distributed Data Parallel (DDP), involves communication overhead and memory usage. FSDP reduces memory footprint by sharding model parameters across devices. FSDP is used when a model is too large to fit on a single device.

7. TorchDynamo Backend Performance and Limitations

The TorchDynamo backend has demonstrated significant performance improvements in specific model architectures and training scenarios. It's a Python-level JIT compiler. TorchDynamo excels in inference tasks. Models like ResNet18, ResNet50, and BERT_pytorch have seen notable speedups. However, a key limitation is that TorchDynamo traces forward and backward passes into separate graphs, which can be less efficient for PyTorch/XLA training.

8. Systolic Array Architecture

TPUs utilize Systolic Arrays. Systolic Arrays minimize memory access. Systolic Arrays provide a unique architectural style for high performance in matrix operations. TPUs contain thousands of multiply-accumulators directly connected to form a systolic array. Data dependencies are not specifically detailed in the conversation.

9. Bfloat16 Data Type

The use of the Bfloat16 data type on TPUs presents a trade-off between performance and accuracy. Bfloat16 offers performance benefits due to its reduced memory footprint and faster computation. The sequential events do not provide specific details on the accuracy trade-offs associated with using the Bfloat16 data type on TPUs, nor do they provide techniques to mitigate those trade-offs. However, the documents mention Automatic Mixed Precision (AMP), which leverages the TPU's native support for the bfloat16 data type (2025-05-04T02:10:46.500679). All multiplies take bfloat16 inputs, but all accumulations are performed in FP32 number format (2025-05-04T00:15:09.724337).

10. Loop Fusion Criteria and Obstacles

Loop fusion is an optimization technique that can improve performance by combining multiple loops into a single loop. The sequential events do not specify what types of loop operations are eligible for fusion. However, optimizing while_loop constructs enables loop fusion (2025-05-05T07:34:44.3571).

11. Experimental Quantized Operations for XLA

The sequential events do not provide detail on the quantization schemes and their effects on performance or accuracy, except to mention that 'Quantized operations in XLA provide a high-level abstraction for quantized operations' (2025-05-04T02:04:24.641502).

12. Temporal Breakdown of Key Findings

  • PJRT runtime document indicates a >35% improvement in training time on TPU v4 using TorchBench 2.0.
  • TPUs are compared to CPUs and GPUs in terms of architecture and suitability for machine learning workloads.
  • TorchDynamo has shown significant speedups for inference with models like resnet18, resnet50, BERT_pytorch, and mobilenet_v2.
  • FSDP shards Module parameters across data-parallel workers, addressing memory limitations.
  • Functions passed to scan and scan_layers must be traceable by AOTAutograd, which can limit compatibility with custom Pallas kernels.
  • TPUs leverage the bfloat16 data type to enhance performance.
  • XLA performs optimizations like operator fusion, layout assignment, and algebraic simplification.

13. Comprehensive List of Terms with Definitions

  • TPU (Tensor Processing Unit): A custom hardware accelerator designed by Google specifically for machine learning workloads.
  • XLA (Accelerated Linear Algebra): A compiler that optimizes computations defined in frameworks like PyTorch for specific hardware, such as TPUs.
  • Systolic Array: A grid of interconnected processors that can perform thousands of multiply-accumulate operations simultaneously and efficiently, minimizing memory access.
  • PJRT (Plugin JRT Runtime): A runtime environment for executing XLA-compiled code on TPUs and other accelerators, offering performance improvements and simpler configuration compared to XRT.
  • XRT: The older TensorFlow-based runtime that PJRT replaced.
  • FSDP (Fully Sharded Data Parallel): A data parallelism technique for sharding model parameters across multiple devices, enabling the training of very large models.
  • DDP (Distributed Data Parallel): A data parallelism technique where each device has a full copy of the model.
  • TorchDynamo: A Python-level JIT compiler designed to make unmodified PyTorch programs faster.
  • Bfloat16: A 16-bit floating-point data type that offers a balance between performance and accuracy, commonly used on TPUs.
  • Loop Fusion: An optimization technique that combines multiple loops into a single loop to improve performance.
  • AMP (Automatic Mixed Precision): A technique that leverages lower precision data types, such as bfloat16, to improve performance while maintaining accuracy.
  • HBM (High Bandwidth Memory): A type of memory used in TPUs to provide high bandwidth for data access.
  • MXU (Matrix Multiply Unit): A specialized unit within TPUs for performing matrix multiplications.
  • AOTAutograd: Ahead-of-Time Autograd
  • Pallas: A framework for writing custom kernels for machine learning workloads.
  • Scan / Scan Layers: a loop primitive that can be unrolled into a series of layers.

14. Outstanding Questions and Future Research

Despite the progress made, several questions remain unanswered and warrant further research:

  • How do the evolution of PJRT and the adoption of TorchDynamo impact the need for custom kernels written with Pallas?
  • What is the interplay between FSDP and the experimental quantized operations in XLA to enable training of extremely large models with reduced memory footprint and computational cost?
  • How does the shift from XRT to PJRT affect the ease of debugging and profiling PyTorch/XLA applications on TPUs, especially when using custom kernels or advanced features like FSDP?
  • What specific algorithms are used for operator fusion?
  • How does XLA determine the optimal layout assignment for TPU memory?
  • What algebraic simplification techniques are used by XLA?
  • What types of dynamic factors cause graph recompilations?
  • How can developers minimize unintended graph recompilations?
  • What are the memory bandwidth limitations that affect XLA's optimizations?
  • What is the overhead associated with XLA compilation?
  • How does XLA handle control flow operations (e.g., loops, conditionals)?
  • What are the limitations of XLA's support for custom kernels?
  • How does XLA interact with the PJRT runtime to manage TPU execution?

Addressing these questions will provide a more complete understanding of the PyTorch/XLA ecosystem on TPUs and guide future optimization efforts.

15. Conclusion

This report provides a comprehensive overview of the key components, optimization techniques, and trade-offs involved in leveraging TPUs with PyTorch/XLA for accelerating deep learning workloads. The findings are based on a review of documentation, code examples, and performance benchmarks, as well as analysis of the underlying hardware and software architectures. Further research is needed to fully grasp the interactions and complexities within the PyTorch/XLA ecosystem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment