当前 AI 时代令人振奋的一个特点是:强大的基础模型正以开放的方式共享,并推动所有人的创新步伐。这一进展促使我们思考:“开放的下一步是什么?”Marin 项目正是一次拓展“开放”定义的尝试,它希望将“开放”的范畴延伸至模型背后完整的科研过程。
斯坦福大学基础模型研究中心 (CRFM) 推出了 Marin 项目,其设计理念是“开放实验室”。其目标不仅是分享模型本身,更是让整个开发过程完全公开,包括代码、数据集、数据处理方法、实验过程、超参数设置以及训练日志。这种程度的透明性,为现有生态系统增添了独特的价值,提供了一个完全可复现的资源,使研究人员能够深入检验、持续构建并真正信任所开发的模型。斯坦福大学的 Marin 项目致力于推动基础模型研究迈向一个更加透明、更加开放的未来。
开放实验室的首批成果是 Marin-8B-Base 和 Marin-8B-Instruct 模型。秉承项目的核心理念,模型、数据、代码和分词器均采用宽松的 Apache 2.0 许可证发布。这种对完全可复现性的承诺是一项极具挑战的工程难题,要求在大规模分布式系统中对每一个可能引入差异的变量进行精确控制。该项目的成功依赖于一个技术栈,它能够实现大规模可复现性,并能最大化效率,以训练出具有领先性价比的基础模型。
为了让 Marin 项目在创建真正开放、可伸缩且可复现的基础模型方面取得成功,CRFM 团队必须攻克一系列工程上的挑战。团队选择 JAX 作为基础框架,因为其设计理念能够直接应对这些问题,并在此基础上构建了一个新框架 Levanter(见下文),以充分发挥 JAX 的强大能力。以下是一些挑战及其解决方案的示例
问题:核心训练循环会被执行数十亿次,因此由 Python 等解释型语言带来的开销会成为严重的性能瓶颈。如果操作是逐步分发执行的,循环还会造成过多的内存流量和额外开销,特别是在 TPU 这类硬件上,其吞吐量依赖于融合操作的高效执行。
解决方案:
@jax.jit
装饰器。JAX 的 XLA 编译器将整个过程转换为单个高度优化的机器代码内核,通过操作融合在大规模场景下最大限度地提高硬件利用率。jax.value_and_grad
在单次传递中计算损失及其梯度。JAX 还使得高级技术(如梯度检查点)的使用变得简单,节省了内存,并使我们能够在几乎没有额外开销的情况下使用更大的批量大小。Pallas
的强大 Splash Attention 内核,这是对点积注意力的高度优化实现。而点积注意力正是几乎所有大语言模型中最核心的操作之一。问题:训练最先进的模型需要扩展到数千个加速器芯片。手动管理模型和数据的分区方式以及设备的通信方式非常复杂,代码很快就变得难以阅读、调试和适配。
解决方案:
@jax.jit
装饰器还无缝支持单程序、多数据 (SPMD) 并行化,实现底层数据分片和通信的自动化。XLA 编译器会自动调度加速器之间的通信,以最大限度地减少在网络上等待的时间,并最大限度地增加用于计算的时间。jit
的强大功能更容易、更安全地使用,Levanter 开发了 Haliax,一个用于命名张量的库。通过使用人类可读的名称(如“embed”或“batch”)来引用张量的轴,而非依赖位置索引,代码变得自解释且更稳健。问题:大规模训练需要灵活访问庞大的计算集群。我们高度依赖可抢占的 TPU 实例来控制成本,这意味着我们需要一种方法来轻松地将多个小型、异构的 TPU 切片组合成一个逻辑集群,并且从容应对频繁的中断。
解决方案:
问题:Marin 项目的核心目标是实现可验证的科学研究。这便要求结果可复现,即使训练过程中发生暂停、重启,或在不同机器配置之间迁移,也必须保持结果一致。这是一个重大的技术挑战。
解决方案:
问题:虽然 JAX 提供了强大的引擎,但现有的高级框架无法同时满足我们在可读性、大规模可伸缩性以及逐位确定性方面的严格要求。我们需要一个完整且理念一致的系统来协调整个培训过程。
解决方案:
jit
) 与低层控制 (Pallas
) 之间的无缝集成,构建了 Levanter。上述所讨论的原则、工具与库在 Marin-8B 的训练过程中得到了实际应用与验证。该模型的架构采用的是 Llama 风格的 Transformer 架构。
Marin-8B 的训练并非一次静态、单一的过程,而是一次适应性的旅程,团队内部戏称为“Tootsie”过程。这一真实世界研究工作流程的坦诚展示已在公开资料中详细描述。整个训练过程涵盖了超过 12 万亿个 token,分为多个阶段,不断适应新的数据、技术,甚至中途更换了不同的硬件配置,例如在训练中途从 2x v5e-256 迁移到 1x v4-2048 的大规模多切片 TPU 配置。与此同时,团队持续优化数据混合策略,引入更高质量的数据源,并调整诸如学习率和批次大小等超参数,以提升模型表现。这种“混乱”的真实研发过程是一份宝贵的教学资源,而 JAX 与 Levanter 堆栈能够在经历如此重大变动的同时,仍保持逐位精确的可复现性,也充分展示了其强大的稳健性。
Marin 项目是一次开放的邀请,旨在邀请更多人参与基础模型开发的未来,并为 JAX 生态系统贡献力量。Marin 的旅程代表了我们对“开放的下一步是什么?”这一问题的回答。JAX 生态系统的技术能力使创建“开放实验室”的构想成为可能。其性能、可移植性,以及为可复现性而设计的基础架构,是让我们能够将研究的“完整历程”开放呈现的关键要素。
我们分享了从数据处理方法到训练日志的一切内容,旨在提供一个完全可复现的资源。这一资源将使研究人员能够深入检验、持续构建并信赖我们的工作。我们相信,这是迈向更加透明的 AI 未来的重要一步。我们诚邀您加入我们的“开放实验室”,使用 Marin,参与研究,共同打造新一代创新且值得信赖的基础模型。
该项目的核心资源均位于官方网站 marin.community。通过该网站,您可以找到在 Hugging Face 上发布的模型,探索 GitHub 上的“开放实验室”,阅读 Marin 文档,并深入了解 Levanter 训练框架。您还可以通过一个简单的推理示例在 Colab 中试用 Marin。
此外,Discord 频道正在进行活跃讨论,您可以在那里与其他开发者直接交流。对于刚接触这一生态系统的新手而言,JAX 官方文档提供了丰富的资源,其中包括一份实用的 快速入门指南。