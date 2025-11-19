JAX has become a key framework for developing state-of-the-art foundation models across the AI landscape, and not just at Google. Leading LLM providers such as Anthropic, xAI, and Apple are utilizing the open-source JAX framework as one of the tools to build their foundation models.

Today, we are excited to share an overview of the JAX AI Stack — a robust, end-to-end platform based on JAX, the core numerical library, into an industrial-grade solution for machine learning at any scale.

To showcase the power and design of this ecosystem, we have published a detailed technical report explaining every component. We urge developers, researchers, and infrastructure engineers to read the full report to understand how these tools can be leveraged for your specific needs.

Below, we outline the architectural philosophy and key components that form a robust and flexible platform for modern AI.

The Architectural Imperative: Modularity and Performance

The JAX AI Stack is built on a philosophy of modular, loosely coupled components, where each library is designed to excel at a single task. This approach empowers users to build a bespoke ML stack, selecting and combining the best libraries for optimization, data loading, or checkpointing to precisely fit their requirements. Crucially, this modularity is vital in the rapidly evolving field of AI. It allows for rapid innovation, as new libraries and techniques can be developed and integrated without the risk and overhead of modifying a large, monolithic framework.

A modern ML stack must provide a continuum of abstraction: automated high-level optimizations for speed of development, and fine-grained, manual control for when every microsecond counts. The JAX AI Stack is designed to offer this continuum.

The Core “JAX AI Stack”

At the heart of the JAX ecosystem is the “JAX AI Stack” consisting of four key libraries that provide the foundation for model development, all built on the compiler-first design of JAX and XLA.

JAX : The foundation for accelerator-oriented array computation. Its pure functional programming model makes transformations composable, allowing workloads to scale effectively across hardware types and cluster sizes.

The foundation for accelerator-oriented array computation. Its pure functional programming model makes transformations composable, allowing workloads to scale effectively across hardware types and cluster sizes. Flax : While JAX provides the functional core, many developers prefer an object-oriented approach for neural networks. Flax bridges this gap, offering a flexible, intuitive API for model authoring and "surgery," familiar to users coming from other frameworks, without sacrificing JAX's performance.

While JAX provides the functional core, many developers prefer an object-oriented approach for neural networks. Flax bridges this gap, offering a flexible, intuitive API for model authoring and "surgery," familiar to users coming from other frameworks, without sacrificing JAX's performance. Optax : Optimization is critical, and one size does not fit all. Optax provides a library of composable gradient processing and optimization transformations. It allows researchers to declaratively chain standard optimizers (like Adam) with complex techniques like gradient clipping or accumulation in just a few lines of code, rather than manually managing state in a training loop.

Optimization is critical, and one size does not fit all. Optax provides a library of composable gradient processing and optimization transformations. It allows researchers to declaratively chain standard optimizers (like Adam) with complex techniques like gradient clipping or accumulation in just a few lines of code, rather than manually managing state in a training loop. Orbax: Resilience at scale is critical. Orbax is our "any-scale" checkpointing library. It supports asynchronous distributed checkpointing, ensuring that expensive training runs can withstand hardware failures without losing significant progress. It is designed for resilience at extreme scales, used currently in training runs spanning tens of thousands of nodes.