斯坦福大学推出 Marin 基础模型:首个使用 JAX 开发的完全开放模型

2025年7月16日
Srikanth Kilaru Senior Product Manager Google ML Frameworks
David Hall Research Engineering Lead Stanford HAI

当前 AI 时代令人振奋的一个特点是:强大的基础模型正以开放的方式共享,并推动所有人的创新步伐。这一进展促使我们思考:“开放的下一步是什么?”Marin 项目正是一次拓展“开放”定义的尝试,它希望将“开放”的范畴延伸至模型背后完整的科研过程。

斯坦福大学基础模型研究中心 (CRFM) 推出了 Marin 项目,其设计理念是“开放实验室”。其目标不仅是分享模型本身,更是让整个开发过程完全公开,包括代码、数据集、数据处理方法、实验过程、超参数设置以及训练日志。这种程度的透明性,为现有生态系统增添了独特的价值,提供了一个完全可复现的资源,使研究人员能够深入检验、持续构建并真正信任所开发的模型。斯坦福大学的 Marin 项目致力于推动基础模型研究迈向一个更加透明、更加开放的未来。


AI 模型开放性的范围

The Spectrum of AI Model Openness

开放实验室的首批成果是 Marin-8B-Base 和 Marin-8B-Instruct 模型。秉承项目的核心理念,模型、数据、代码和分词器均采用宽松的 Apache 2.0 许可证发布。这种对完全可复现性的承诺是一项极具挑战的工程难题,要求在大规模分布式系统中对每一个可能引入差异的变量进行精确控制。该项目的成功依赖于一个技术栈,它能够实现大规模可复现性,并能最大化效率,以训练出具有领先性价比的基础模型。


构建开放基础模型的核心挑战

为了让 Marin 项目在创建真正开放、可伸缩且可复现的基础模型方面取得成功,CRFM 团队必须攻克一系列工程上的挑战。团队选择 JAX 作为基础框架,因为其设计理念能够直接应对这些问题,并在此基础上构建了一个新框架 Levanter(见下文),以充分发挥 JAX 的强大能力。以下是一些挑战及其解决方案的示例


在单一加速器上实现最大速度

问题:核心训练循环会被执行数十亿次,因此由 Python 等解释型语言带来的开销会成为严重的性能瓶颈。如果操作是逐步分发执行的,循环还会造成过多的内存流量和额外开销,特别是在 TPU 这类硬件上,其吞吐量依赖于融合操作的高效执行。

解决方案:

  • 为消除解释器开销,Levanter 将整个多阶段训练步骤(前向传递、损失计算、反向传播和参数更新)封装到一个单一函数中,并使用 @jax.jit 装饰器。JAX 的 XLA 编译器将整个过程转换为单个高度优化的机器代码内核,通过操作融合在大规模场景下最大限度地提高硬件利用率。

  • 为了避免冗余计算,我们使用 jax.value_and_grad 在单次传递中计算损失及其梯度。JAX 还使得高级技术(如梯度检查点)的使用变得简单,节省了内存,并使我们能够在几乎没有额外开销的情况下使用更大的批量大小。

  • Levanter 还使用了 JAX 基于 Pallas 的强大 Splash Attention 内核,这是对点积注意力的高度优化实现。而点积注意力正是几乎所有大语言模型中最核心的操作之一。


管理大规模并行的复杂性

问题:训练最先进的模型需要扩展到数千个加速器芯片。手动管理模型和数据的分区方式以及设备的通信方式非常复杂,代码很快就变得难以阅读、调试和适配。

解决方案:

  • JAX 的 @jax.jit 装饰器还无缝支持单程序、多数据 (SPMD) 并行化,实现底层数据分片和通信的自动化。XLA 编译器会自动调度加速器之间的通信,以最大限度地减少在网络上等待的时间,并最大限度地增加用于计算的时间。

  • 为了让 jit 的强大功能更容易、更安全地使用,Levanter 开发了 Haliax,一个用于命名张量的库。通过使用人类可读的名称(如“embed”或“batch”)来引用张量的轴,而非依赖位置索引,代码变得自解释且更稳健。

  • 这种抽象使我们能够通过仅修改配置文件中的几行代码,就定义和调整复杂的分片策略,例如完全分片数据并行 (FSDP) 和张量并行,而无需改动模型代码本身。


构建和管理弹性、经济高效的计算群集

问题:大规模训练需要灵活访问庞大的计算集群。我们高度依赖可抢占的 TPU 实例来控制成本,这意味着我们需要一种方法来轻松地将多个小型、异构的 TPU 切片组合成一个逻辑集群,并且从容应对频繁的中断。

解决方案:

  • 我们利用了 Google Cloud TPU Multislice,该技术允许训练作业将多个 TPU 切片当作一个大型系统来使用。这使得我们可以轻松地将多个小型、可抢占的 TPU 切片拼接成单个强大的计算集群,以用于模型训练。

  • Levanter 使用 Ray 来协调这一过程,从而在训练过程中无缝地动态扩展或缩减 TPU 切片数量,更重要的是,当某个切片被抢占时,作业仍能稳健运行。

  • 得益于 JAX 和 XLA 的强大能力,Levanter 和 Marin 也能在 GPU 上实现类似的高性能结果。


以完美可复现性构建科学信任

问题:Marin 项目的核心目标是实现可验证的科学研究。这便要求结果可复现,即使训练过程中发生暂停、重启,或在不同机器配置之间迁移,也必须保持结果一致。这是一个重大的技术挑战。

解决方案:

  • 这是推动 Levanter 设计的基本要求。我们之所以选择 JAX,是因为它提供强大的可复现性保障,例如其默认使用确定性伪随机数生成器 (PRNG)

  • 这一选择在 Marin-8B 的训练过程中得到了验证:即便在不同 TPU 切片和硬件类型之间进行迁移,系统仍成功保持了在多次抢占下的逐位精确可复现性。

  • Levanter 还包括基于 Google Tensorstore 库构建的强大数据加载系统。Levanter 的数据存储提供对任何批次训练数据的确定性随机访问,无论作业是否重启、数据源是否变化。这对于实现诸如训练中期干预等高级训练策略至关重要。JAX 的确定性与 Levanter 的数据存储区相结合,也极大地方便了可解释性研究,使研究人员能够清楚理解特定数据在培训期间对模型的影响。


构建统一的框架

问题:虽然 JAX 提供了强大的引擎,但现有的高级框架无法同时满足我们在可读性、大规模可伸缩性以及逐位确定性方面的严格要求。我们需要一个完整且理念一致的系统来协调整个培训过程。

解决方案:

  • 我们从零开始构建了 Levanter,一个原生于 JAX 的框架,旨在满足我们的需求:具备逐位确定性、可伸缩性、先进的分布式策略,以及弹性。

  • 我们之所以能够做到这一点,是因为 JAX 并不仅仅是一个库,它更像是一个用于构建新工具的“元框架”。我们基于其成熟的高性能 TPU 支持,以及其在高级抽象 (jit) 与低层控制 (Pallas) 之间的无缝集成,构建了 Levanter。

  • 这种构建方式在 JAX 社区中十分常见。该社区已孕育出一个充满活力的生态系统,例如 FlaxEquinoxOrbaxOptax 等库,它们彼此兼容,使得我们这样的团队能够构建强大的解决方案。


探索内核:Marin-8B 的旅程

上述所讨论的原则、工具与库在 Marin-8B 的训练过程中得到了实际应用与验证。该模型的架构采用的是 Llama 风格的 Transformer 架构。


Marin-8B-Base:模型架构概览

Marin 8B-Base model architecture at a glance

Marin-8B 的训练并非一次静态、单一的过程,而是一次适应性的旅程,团队内部戏称为“Tootsie”过程。这一真实世界研究工作流程的坦诚展示已在公开资料中详细描述。整个训练过程涵盖了超过 12 万亿个 token,分为多个阶段,不断适应新的数据、技术,甚至中途更换了不同的硬件配置,例如在训练中途从 2x v5e-256 迁移到 1x v4-2048 的大规模多切片 TPU 配置。与此同时,团队持续优化数据混合策略,引入更高质量的数据源,并调整诸如学习率和批次大小等超参数,以提升模型表现。这种“混乱”的真实研发过程是一份宝贵的教学资源,而 JAX 与 Levanter 堆栈能够在经历如此重大变动的同时,仍保持逐位精确的可复现性,也充分展示了其强大的稳健性。


加入 Marin 社区

Marin 项目是一次开放的邀请,旨在邀请更多人参与基础模型开发的未来,并为 JAX 生态系统贡献力量。Marin 的旅程代表了我们对“开放的下一步是什么?”这一问题的回答。JAX 生态系统的技术能力使创建“开放实验室”的构想成为可能。其性能、可移植性,以及为可复现性而设计的基础架构,是让我们能够将研究的“完整历程”开放呈现的关键要素。

我们分享了从数据处理方法到训练日志的一切内容,旨在提供一个完全可复现的资源。这一资源将使研究人员能够深入检验、持续构建并信赖我们的工作。我们相信,这是迈向更加透明的 AI 未来的重要一步。我们诚邀您加入我们的“开放实验室”,使用 Marin,参与研究,共同打造新一代创新且值得信赖的基础模型。

该项目的核心资源均位于官方网站 marin.community。通过该网站,您可以找到在 Hugging Face 上发布的模型,探索 GitHub 上的“开放实验室”,阅读 Marin 文档,并深入了解 Levanter 训练框架。您还可以通过一个简单的推理示例在 Colab 中试用 Marin

此外,Discord 频道正在进行活跃讨论,您可以在那里与其他开发者直接交流。对于刚接触这一生态系统的新手而言,JAX 官方文档提供了丰富的资源,其中包括一份实用的 快速入门指南