JAX has become a key framework for developing state-of-the-art foundation models across the AI landscape, and not just at Google. Leading LLM providers such as Anthropic, xAI, and Apple are utilizing the open-source JAX framework as one of the tools to build their foundation models.
Today, we are excited to share an overview of the JAX AI Stack — a robust, end-to-end platform based on JAX, the core numerical library, into an industrial-grade solution for machine learning at any scale.
To showcase the power and design of this ecosystem, we have published a detailed technical report explaining every component. We urge developers, researchers, and infrastructure engineers to read the full report to understand how these tools can be leveraged for your specific needs.
Below, we outline the architectural philosophy and key components that form a robust and flexible platform for modern AI.
The JAX AI Stack is built on a philosophy of modular, loosely coupled components, where each library is designed to excel at a single task. This approach empowers users to build a bespoke ML stack, selecting and combining the best libraries for optimization, data loading, or checkpointing to precisely fit their requirements. Crucially, this modularity is vital in the rapidly evolving field of AI. It allows for rapid innovation, as new libraries and techniques can be developed and integrated without the risk and overhead of modifying a large, monolithic framework.
A modern ML stack must provide a continuum of abstraction: automated high-level optimizations for speed of development, and fine-grained, manual control for when every microsecond counts. The JAX AI Stack is designed to offer this continuum.
At the heart of the JAX ecosystem is the “JAX AI Stack” consisting of four key libraries that provide the foundation for model development, all built on the compiler-first design of JAX and XLA.
The jax-ai-stack is a metapackage that can be installed with the following command:
pip install jax-ai-stack
Building on this stable core, a rich ecosystem of specialized libraries provides the end-to-end capabilities needed for the entire ML lifecycle.
Beneath the user-facing libraries lies the infrastructure that enables JAX to scale from a single TPU/GPU to thousands of GPUs/TPUs seamlessly.
To achieve the highest levels of hardware utilization, the ecosystem provides specialized tools that offer deeper control and higher efficiency.
Other modules that augment the JAX AI Stack offer a mature, end-to-end application layer that bridges the gap from research to widespread deployment.
The JAX AI Stack is more than just a collection of libraries; it is a modular, production-ready platform, co-designed with Cloud TPUs to tackle the next generation of AI challenges. This deep integration of software and hardware delivers a compelling advantage in both performance and total cost of ownership, as seen across a diverse range of applications. For large-scale production models, Kakao leveraged the stack to overcome infrastructure limits, achieving a 2.7x throughput increase for their LLMs while optimizing for cost-performance. For cutting-edge generative video models, Lightricks broke through a critical scaling wall with their 13-billion-parameter video model, unlocking linear scalability and accelerating research in ways their previous framework could not. And for pioneering scientific research, Escalante harnesses JAX’s unique composability to combine a dozen models into a single optimization, achieving 3.65x better performance per dollar for their AI-driven protein design. These examples show how the co-designed JAX and TPU stack provides a powerful, efficient, and flexible foundation for building the future of AI, from production-scale LLMs to the frontiers of scientific discovery.
We invite you to explore the ecosystem deeply, read the technical report to see how these components can work for you, and visit our new central hub to get started at https://jaxstack.ai
There, you will find everything you need to start building with the JAX AI Stack: