Stanford’s Marin foundation model: The first fully open model developed using JAX

JULY 16, 2025
Srikanth Kilaru Senior Product Manager Google ML Frameworks
David Hall Research Engineering Lead Stanford HAI

An exciting element of the current AI era is how powerful foundation models are being shared in the open and helping to accelerate innovation for everyone. This progress inspires us to ask, 'What's next for openness?' The Marin project sees an opportunity to expand the definition of 'open' to encompass the entire scientific process behind a model.

Stanford CRFM’s (Center for Research on Foundation Models) Marin project is designed as an 'open lab,' where the goal is not only to share the model but to make the complete journey accessible — including the code, data set, data methodologies, experiments, hyperparameters and training logs. This level of transparency complements the existing ecosystem by providing a unique, fully reproducible resource that empowers researchers to scrutinize, build upon, and trust the models being developed. Stanford’s Marin project seeks to foster a more transparent and accessible future for foundation model research.


The spectrum of AI model openness

The Spectrum of AI Model Openness

The first releases from this open lab are the Marin-8B-Base and Marin-8B-Instruct models. In keeping with the project's principles, the models, data, code, and tokenizer are all released under the permissive Apache 2.0 license. This commitment to complete reproducibility is a formidable engineering problem, requiring control over every source of variance in a massively distributed system. The project's success hinges on a technology stack that can deliver this guarantee of reproducibility at scale, and maximize efficiency to train a foundation model with leading price / performance.


Core challenges of building open foundation models

For the Marin project to succeed in creating truly open, scalable, and reproducible foundation models, the CRFM team had to solve several engineering challenges. The team chose JAX as the foundation because its design principles provided direct solutions to these problems, and built a new framework, Levanter (see below), to harness JAX's power. Here are a few examples of challenges and their solutions


Achieving maximum speed on a single accelerator

Problem: The core training loop is executed billions of times, so the overhead from an interpreted language like Python creates a massive performance bottleneck. If operations are dispatched step by step, the loop can also incur excessive memory traffic and overhead—especially on hardware like TPUs, where throughput depends on executing fused operations efficiently.

Our solution:

  • To eliminate interpreter overhead, Levanter encapsulates the entire multi-stage training step (forward pass, loss, backpropagation, and update) into a single function and uses the @jax.jit decorator. JAX's XLA compiler transforms this entire process into a single, highly-optimized machine code kernel, fusing operations to maximize hardware utilization at scale.

  • To avoid redundant computation, we use jax.value_and_grad to compute both the loss and its gradients in a single pass. JAX also makes it easy to use advanced techniques like gradient checkpointing, saving memory and enabling us to use larger batch sizes with almost no overhead.

  • Levanter also uses JAX’s powerful Pallas-based Splash Attention kernel, a highly optimized implementation of Dot Product Attention, one of the most critical operations at the heart of nearly all large language models.


Managing the complexity of large-scale parallelism

Problem: Training state-of-the-art models requires scaling out to thousands of accelerator chips. Manually managing how the model and data are partitioned and how the devices communicate is immensely complex, and the code quickly becomes difficult to read, debug, and adapt.

Our solution:

  • JAX’s @jax.jit decorator also seamlessly supports Single-Program, Multiple-Data (SPMD) parallelization that automates the underlying data sharding and communication. The XLA compiler automatically schedules communication between accelerators to minimize time spent waiting on the network and maximizing time spent on computation.

  • To make jit’s power even easier and safer to use, Levanter developed Haliax, a library for named tensors. By referring to tensor axes with human-readable names (like "embed" or "batch") instead of positional indices, the code becomes self-documenting and robust.

  • This abstraction allows us to define and modify sophisticated sharding strategies like Fully Sharded Data Parallelism (FSDP) and Tensor Parallelism simply by changing a few lines in a configuration file, without ever touching the model code.


Building and managing resilient, cost-effective compute clusters

Problem: Large-scale training requires flexible access to massive compute clusters. We rely heavily on preemptible TPU instances to manage costs, which means we need a way to easily combine many smaller, disparate TPU slices into one logical cluster and be resilient to frequent interruptions.

Our solution:

  • We leverage Google Cloud TPU Multislice, a technology that allows a training job to use multiple TPU slices as if they were one large system. This makes it easy to stitch many small, preemptible TPU slices together into a single, powerful compute cluster for training.

  • Levanter uses Ray to orchestrate this process, seamlessly scaling the number of TPU slices up or down during a training job and, crucially, ensuring the job remains resilient if any single slice is preempted.

  • Thanks to JAX and XLA, Levanter and Marin were able to get similar high performance results on GPUs as well.


Fostering scientific trust with perfect reproducibility

Problem: A core goal of the Marin project is to enable verifiable science. This requires achieving reproducible results, even when training is paused, restarted, or moved between different machine configurations—a significant technical hurdle.

Our solution:

  • This was a fundamental requirement that drove Levanter's design. We chose JAX specifically for its strong reproducibility guarantees, such as its default use of deterministic pseudo-random number generators (PRNGs).

  • This choice was validated during the training of Marin-8B, which involved migrating between different TPU slices and hardware types while successfully maintaining bit-for-bit reproducibility across preemptions.

  • Levanter also includes a robust data loading system built on Google’s Tensorstore library. Levanter’s data store offers deterministic, random access to any batch of training data, regardless of job restarts or data source changes—critical for supporting advanced training strategies like mid-training. JAX’s determinism and Levanter’s data store also make it easy for interpretability researchers to understand how specific data impacts the model during training.


Creating a cohesive framework

Problem: While JAX provides a powerful engine, no existing high-level framework met our stringent, combined requirements for legibility, massive scalability, and bitwise determinism. We needed a complete, opinionated system to orchestrate the entire training process.

Our solution:

  • We built Levanter, a JAX-native framework, from the ground up to be the system we needed: bitwise deterministic, scalable with advanced distribution strategies, and resilient.

  • We could do this because JAX is more than just a library; it's a "meta-framework" for building new tools. We built upon its mature, high-performance support for TPUs and its seamless integration of high-level abstractions (jit) with low-level control (Pallas).

  • This approach is common in the JAX community, which has produced a vibrant ecosystem of libraries like Flax, Equinox, Orbax and Optax that work together, allowing teams like ours to build powerful solutions.


A look under the hood: The voyage of Marin-8B

The principles, tools and libraries discussed above were implemented and put to work during the Marin-8B training run. The model architecture is a Llama-style transformer.


Marin-8B-Base: Model architecture at a glance

Marin 8B-Base model architecture at a glance

Rather than a static, monolithic run, the training of Marin-8B was an adaptive journey, internally dubbed the "Tootsie" process. This honest portrayal of a real-world research workflow is detailed in the public. The process spanned over 12 trillion tokens and involved multiple phases that adapted to new data, techniques, and even different hardware configurations — migrating between large-scale, multi-slice TPU configurations (2x v5e-256 to 1x v4-2048 pods) mid-stream. The team continuously refined the data mixture, incorporating higher-quality sources, and adjusted hyperparameters like learning rate and batch size to optimize performance. This "messy" reality is a powerful educational tool, and the ability of the JAX and Levanter stack to handle these significant shifts while maintaining bit-for-bit reproducibility is a powerful demonstration of its robustness.


Join the Marin community

The Marin project is an open invitation to participate in the future of foundation model development and contribute to the JAX ecosystem. The journey of Marin represents the answer to our question, "What's next for openness?" This effort to create an 'open lab' is made possible by the technical capabilities of the JAX ecosystem. Its performance, portability, and foundational design for reproducibility are the key ingredients that allow us to make the 'complete journey' of research accessible.

By sharing everything from data methodologies to training logs, we aim to provide a fully reproducible resource—one that empowers researchers to deeply scrutinize, build upon, and trust the work. We believe this is a collaborative step toward a more transparent future for AI. We invite you to join us in this 'open lab'—to use Marin, to contribute to the research, and to help build the next wave of innovative and trustworthy foundation models.

The central resource for the project is the official website, marin.community. From there, you can find the released models on Hugging Face, explore the "open lab" on GitHub, read Marin documentation, and dive into the Levanter training framework. You can also test drive Marin in a colab with a simple inference example.

And active discussions are taking place in the Discord channel where you can interact directly with other developers. For those new to the ecosystem, the official JAX documentation provides excellent resources, including a Quickstart guide.