隆重推出 Tunix:用于 LLM 后训练的 JAX 原生库

2025年9月30日
Srikanth Kilaru Senior Product Manager Google ML Frameworks
Tianshu Bao Senior Staff Software Engineer Google ML Frameworks
Tunix logo
download

对于 JAX 生态系统中的开发者和研究人员来说,从预训练模型到交付完全对齐、可用于生产环境的 LLM 比之前要轻松许多。

今天,我们很高兴推出 Tunix,这是一个专门为 LLM 后训练构建的全新开源 JAX 原生库。Tunix 提供用于大规模对齐模型,且便于开发者使用的综合性工具包,填补了一处关键空白。

Tunix 专为提升 TPU 上的性能而构建(特别是与 MaxText 相结合时效果倍增),可提供:

  • 完整的算法套件:在单个统一的库中获得可在生产环境使用的训练器,用于监督微调 (SFT)、偏好微调、知识蒸馏以及 PPO、GRPO 和 GSPO 等高级强化学习方法。
  • “白盒”设计:完全接管您的工作。Tunix 经过专门的设计来优化开发者体验,使您无需穿越多个抽象层即可轻松自定义训练循环和其他训练后代码。
  • 无缝 JAX 集成:作为 JAX 原生库,Tunix 是一个功能强大且易于使用的解决方案,可与您已使用的开源模型对齐。

此初始版本的功能

此初始版本为最常见的后训练工作流提供模块化且易于使用的 API,与 JAX 生态系统无缝集成:

  • 监督微调 (SFT)PeftTrainer 不受限于特定模型,支持全权重微调和流行的参数高效微调方法,如 LoRA 和 QLoRA(通过我们与 qwix 库的集成)。
  • 偏好微调:DPOTrainer 通过实施直接偏好优化 (DPO) 简化对齐过程。这种强大的技术使用包含偏好和不予采用的回复的简单数据集,使您无需训练和管理单独的奖励模型。
  • 强化学习 (RL):Tunix 提供一套强化学习训练器,使模型行为与人的偏好和说明对齐:
    • PPOLearner通过实施近端策略优化 (PPO),为 RLHF 提供黄金标准的执行者-批评者方法。这对于采用复杂连续的任务来训练模型至关重要,特别是对于涉及工具使用的新兴智能体工作流。
    • GRPOLearner提供高效、无批评者的强化学习算法。它采用组相对策略优化 (GRPO),后者将一组生成的回复中的奖励标准化,以此指导模型,免去了使用单独批评者模型带来的复杂性和成本。
    • 组序列策略优化 (GSPO-Token)提供 GRPO 算法的变体,为调整 Token 级别优势计算提供更好的灵活性,并可以提高多轮强化学习训练的稳定性。
  • 知识蒸馏:DistillationTrainer 通过训练更小、更有效的“学生”模型来复制更大“教师”模型的输出,从而实现模型压缩。这是在具有严格延迟或成本限制的生产环境中部署高性能模型的关键技术。Tunix 提供以下开箱即用的蒸馏算法:
    • 基于 Logit 的蒸馏:利用教师模型的最终输出概率作为“软目标”来指导学生模型。
    • 注意力转移:利用教师模型的注意力特征来指导学生模型。
  • PyPI 包:Tunix 已作为包发布在 PyPI 上,可通过以下命令安装:
    • pip install google-tunix
  • 示例:所有受支持算法的示例,在 Tunix 存储库中规范部署了一些领先的开源模型。
  • 智能体 AI:Tunix 支持对使用 LLM 进行推理并与外部环境交互的智能体进行训练。

量化结果

我们制作了一些 Python 笔记本,帮助用户开始使用 Tunix。以下结果证明了 Tunix GRPO 实现的有效性。在 GSM8K 数学推理基准上,使用 Tunix 对 Gemma 2 2B-IT 模型进行微调,使 Pass@1 答案准确率相对提高了约 12%。我们在所有指标中观察到正面结果,这展示出该库能够快速有效地对齐模型行为。

Tunix-table

为考虑文本生成的随机性,我们同时使用 Pass@1(贪婪搜索)和 Pass@5 (多样性采样)来评估性能,以衡量一次或五次尝试的正确性。我们的评估侧重于三个关键指标:

  • 答案准确率:含正确最终数字答案的预测所占的百分比。
  • 答案(部分)准确率:一个更灵活的指标,表示模型的答案在正确答案的 10% 误差以内(0.9 和 1.1 之间的比率)。
  • 格式准确率:模型正确使用所需推理和答案 Token 的样本所占的百分比。

作为验证,我们的基准 Pass@1 准确率为约 52%,与 Eleuther 的 LM Eval Harness 报告的基础模型约 51% 的准确率非常接近,这证实了我们此设置的有效性。虽然绝对准确性对提示的格式(例如,使用 <start_answer> 与 <answer>)很敏感,但在不同的设置下,训练后仍能获得一致的显著性能提升。

Link to Youtube Video (visible only when JS is disabled)

深受研究人员和创新人士的信赖

从领先的学术实验室到 AI 初创公司,Tunix 已经在推动新一轮机器学习发展。我们正在与合作伙伴携手开发 Tunix,以解决模型对齐和智能体 AI 方面的真实世界挑战。以下是他们的感触:

“我的研究重点是以数据为中心的学习,这涉及准备高质量的数据来提高模型性能,特别是在大型语言模型 (LLM) 的训练后阶段。其中一个关键的挑战是快速迭代数据样本,确定哪些有用,哪些没用。在这方面,Tunix 是完美的库。其“白盒”设计使我的团队能够完全控制训练循环,让我们能轻松修改和调整代码,满足特定研究需求。与其他框架相比,这种可定制性具有显著优势,对于加速迭代数据分析至关重要。”

——刘洪甫,布兰迪斯大学计算机科学助理教授;NeurIPS 资深领域主席;ICLR 领域主席
“后训练强化学习的一个主要瓶颈是缺乏可验证奖励的环境,而游戏拥有完美的多轮交互环境来解决这个问题,其中 Tunix 是这项研究的理想框架。它使我们能够直接在 JAX 上构建,利用 TPU 和轻松执行并行化。与其他替代方案相比,Tunix 是一个轻量级的库,具有整洁且易于管理的代码库。我们可以利用它对模型和超参数展开高级自定义,免于穿越其他框架的过多抽象层。这种精简的方法对我们的工作至关重要,而且我们发现这种学习曲线很温和,因为并非要成为 JAX 专家才能取得成效。”

——张昊,加州大学圣地亚哥分校助理教授,vLLM、聊天机器人竞技场 (LMSys) 联合创始人以及分离服务的发明者

Precur AI 是一家构建智能体编译器的初创公司,这种编译器可将后台工作流转换为代码驱动、可靠且高效的智能体。该公司联合创始人兼首席技术官 Hanjun Dai 表示:

“我们公司专注于构建在没有监督的情况下全天候运行的后台智能体。要实现的一个关键目标是智能体稳健性,因此我们对“智能体内核”进行了后训练。智能体内核是指针对长期但重复的任务进行优化的模型。Tunix 的适用范围广泛,涵盖 SFT、强化学习和蒸馏,使我们能够保持整个智能体开发堆栈的统一。它与 JAX 和 TPU 生态系统的原生集成是一项重大优势。我们可以轻松定制:使用 Flax 进行开发,使用 Qwix 进行量化服务。因此,Tunix 这个整洁而强大的框架非常适合我们的工作流。”

——PreCur AI 联合创始人兼首席技术官 Hanjun Dai

社群与协作:欢迎参与

我们正在公开构建 Tunix,欢迎您加入我们的社区,试用该库并做出贡献。

  • 为 Tunix 做出贡献:我们正在积极寻找合作者,并很乐意为您的贡献提供支持。无论您对开发新的智能体功能或环境、增强算法还是建立研究合作伙伴关系感兴趣,都可以使用此表格告诉我们您希望如何参与进来。
  • GitHub 存储库和文档:您可以在我们的 GitHub 存储库tunix.readthedocs.io 上找到源代码、问题跟踪器、深度文档并加入讨论。
  • 实操示例:最好的入门方法是运行代码。我们准备了一套 Python 笔记本,您可以运行这些笔记本,开始使用我们的每个核心训练器。
  • MaxTextMaxText 是一个仅用 Python/JAX 编写的高性能、高度可扩展的开源 LLM 库和参考实现,针对 Google Cloud TPU 和 GPU 进行训练。

我们很高兴向 JAX 社区分享 Tunix,并期待看到您所构建的内容。