Building production AI on Google Cloud TPUs with JAX

2025年11月19日
Rakesh Iyer Senior Software Engineering Manager Google ML Frameworks
Srikanth Kilaru Senior Product Manager Google ML Frameworks
JAX logo

JAX has become a key framework for developing state-of-the-art foundation models across the AI landscape, and not just at Google. Leading LLM providers such as Anthropic, xAI, and Apple are utilizing the open-source JAX framework as one of the tools to build their foundation models.

Today, we are excited to share an overview of the JAX AI Stack — a robust, end-to-end platform based on JAX, the core numerical library, into an industrial-grade solution for machine learning at any scale.

To showcase the power and design of this ecosystem, we have published a detailed technical report explaining every component. We urge developers, researchers, and infrastructure engineers to read the full report to understand how these tools can be leveraged for your specific needs.

Below, we outline the architectural philosophy and key components that form a robust and flexible platform for modern AI.

The Architectural Imperative: Modularity and Performance

The JAX AI Stack is built on a philosophy of modular, loosely coupled components, where each library is designed to excel at a single task. This approach empowers users to build a bespoke ML stack, selecting and combining the best libraries for optimization, data loading, or checkpointing to precisely fit their requirements. Crucially, this modularity is vital in the rapidly evolving field of AI. It allows for rapid innovation, as new libraries and techniques can be developed and integrated without the risk and overhead of modifying a large, monolithic framework.

A modern ML stack must provide a continuum of abstraction: automated high-level optimizations for speed of development, and fine-grained, manual control for when every microsecond counts. The JAX AI Stack is designed to offer this continuum.

The Core “JAX AI Stack”

At the heart of the JAX ecosystem is the “JAX AI Stack” consisting of four key libraries that provide the foundation for model development, all built on the compiler-first design of JAX and XLA.

  • JAX: The foundation for accelerator-oriented array computation. Its pure functional programming model makes transformations composable, allowing workloads to scale effectively across hardware types and cluster sizes.
  • Flax: While JAX provides the functional core, many developers prefer an object-oriented approach for neural networks. Flax bridges this gap, offering a flexible, intuitive API for model authoring and "surgery," familiar to users coming from other frameworks, without sacrificing JAX's performance.
  • Optax: Optimization is critical, and one size does not fit all. Optax provides a library of composable gradient processing and optimization transformations. It allows researchers to declaratively chain standard optimizers (like Adam) with complex techniques like gradient clipping or accumulation in just a few lines of code, rather than manually managing state in a training loop.
  • Orbax: Resilience at scale is critical. Orbax is our "any-scale" checkpointing library. It supports asynchronous distributed checkpointing, ensuring that expensive training runs can withstand hardware failures without losing significant progress. It is designed for resilience at extreme scales, used currently in training runs spanning tens of thousands of nodes.

The jax-ai-stack is a metapackage that can be installed with the following command:
pip install jax-ai-stack

JAX_ecosystem
The JAX AI Stack and Ecosystem Components

The Extended JAX AI Stack

Building on this stable core, a rich ecosystem of specialized libraries provides the end-to-end capabilities needed for the entire ML lifecycle.

Industrial-Scale Infrastructure

Beneath the user-facing libraries lies the infrastructure that enables JAX to scale from a single TPU/GPU to thousands of GPUs/TPUs seamlessly.

  • XLA (Accelerated Linear Algebra): Our domain-specific, hardware-agnostic compiler. Unlike kernel-centric approaches that wait for hand-optimized libraries to catch up to new research, XLA aims to deliver strong out-of-the-box performance by using whole-program analysis to fuse operators and optimize memory layouts. This compiler-centric approach can often provide a high-performance path for new model architectures without the need for hand-written kernels.
  • Pathways: This is the unified runtime for massive-scale distributed computation. It allows researchers to code as if they are using a single powerful machine, while Pathways orchestrates the computation across tens of thousands of chips, handling fault tolerance and automatic recovery with built-in automation.

Advanced Development for Peak Efficiency

To achieve the highest levels of hardware utilization, the ecosystem provides specialized tools that offer deeper control and higher efficiency.

  • Pallas & Tokamax: When you need to surpass automated compilers, Pallas offers an extension to JAX for writing custom kernels for TPUs and GPUs with precise control over memory hierarchy and parallelism. Tokamax complements this as a curated library of state-of-the-art kernels (like FlashAttention), giving you plug-and-play access to peak performance.
  • Qwix: As models grow, quantization becomes essential. Qwix is our comprehensive, non-intrusive quantization library. It allows you to apply techniques like QLoRA or PTQ by intercepting JAX functions, meaning models can be quantized with minimal or no changes to the original model code.
  • Grain: Data pipelines can often become a bottleneck. Grain is a performant, deterministic data loading library. Crucially, it integrates with Orbax to allow the exact state of the data pipeline to be checkpointed alongside the model, guaranteeing bit-for-bit reproducibility even after restarting a massive training job.

The Full Path to Production

Other modules that augment the JAX AI Stack offer a mature, end-to-end application layer that bridges the gap from research to widespread deployment.

  • MaxText & MaxDiffusion: These are our flagship, scalable frameworks for LLM and diffusion model training. They serve as well-established and reliable starting points for builders, highly optimized for goodput and Model Flops Utilization (MFU) out of the box.
  • Tunix: Once pre-trained, models need alignment. Tunix is our JAX-native library for post-training, offering state-of-the-art algorithms like SFT with LoRA / Q-LoRA, GRPO, GSPO, DPO, and PPO in a streamlined package. MaxText integration with Tunix provides the most performant and scalable post-training for Google Cloud customers.
  • Inference Solutions: We offer a dual path for deployment. For maximum compatibility, we provide the popular vLLM serving framework for any model.

Read the Report, Explore the Stack

The JAX AI Stack is more than just a collection of libraries; it is a modular, production-ready platform, co-designed with Cloud TPUs to tackle the next generation of AI challenges. This deep integration of software and hardware delivers a compelling advantage in both performance and total cost of ownership, as seen across a diverse range of applications. For large-scale production models, Kakao leveraged the stack to overcome infrastructure limits, achieving a 2.7x throughput increase for their LLMs while optimizing for cost-performance. For cutting-edge generative video models, Lightricks broke through a critical scaling wall with their 13-billion-parameter video model, unlocking linear scalability and accelerating research in ways their previous framework could not. And for pioneering scientific research, Escalante harnesses JAX’s unique composability to combine a dozen models into a single optimization, achieving 3.65x better performance per dollar for their AI-driven protein design. These examples show how the co-designed JAX and TPU stack provides a powerful, efficient, and flexible foundation for building the future of AI, from production-scale LLMs to the frontiers of scientific discovery.

We invite you to explore the ecosystem deeply, read the technical report to see how these components can work for you, and visit our new central hub to get started at https://jaxstack.ai

There, you will find everything you need to start building with the JAX AI Stack: