对于 JAX 生态系统中的开发者和研究人员来说,从预训练模型到交付完全对齐、可用于生产环境的 LLM 比之前要轻松许多。
今天,我们很高兴推出 Tunix,这是一个专门为 LLM 后训练构建的全新开源 JAX 原生库。Tunix 提供用于大规模对齐模型,且便于开发者使用的综合性工具包,填补了一处关键空白。
Tunix 专为提升 TPU 上的性能而构建(特别是与 MaxText 相结合时效果倍增),可提供:
此初始版本为最常见的后训练工作流提供模块化且易于使用的 API,与 JAX 生态系统无缝集成:
PeftTrainer 不受限于特定模型,支持全权重微调和流行的参数高效微调方法,如 LoRA 和 QLoRA(通过我们与 qwix 库的集成)。DPOTrainer 通过实施直接偏好优化 (DPO) 简化对齐过程。这种强大的技术使用包含偏好和不予采用的回复的简单数据集,使您无需训练和管理单独的奖励模型。PPOLearner:通过实施近端策略优化 (PPO),为 RLHF 提供黄金标准的执行者-批评者方法。这对于采用复杂连续的任务来训练模型至关重要,特别是对于涉及工具使用的新兴智能体工作流。GRPOLearner:提供高效、无批评者的强化学习算法。它采用组相对策略优化 (GRPO),后者将一组生成的回复中的奖励标准化,以此指导模型,免去了使用单独批评者模型带来的复杂性和成本。组序列策略优化 (GSPO-Token):提供 GRPO 算法的变体,为调整 Token 级别优势计算提供更好的灵活性,并可以提高多轮强化学习训练的稳定性。DistillationTrainer 通过训练更小、更有效的“学生”模型来复制更大“教师”模型的输出,从而实现模型压缩。这是在具有严格延迟或成本限制的生产环境中部署高性能模型的关键技术。Tunix 提供以下开箱即用的蒸馏算法:我们制作了一些 Python 笔记本,帮助用户开始使用 Tunix。以下结果证明了 Tunix GRPO 实现的有效性。在 GSM8K 数学推理基准上,使用 Tunix 对 Gemma 2 2B-IT 模型进行微调,使 Pass@1 答案准确率相对提高了约 12%。我们在所有指标中观察到正面结果,这展示出该库能够快速有效地对齐模型行为。
为考虑文本生成的随机性,我们同时使用 Pass@1(贪婪搜索)和 Pass@5 (多样性采样)来评估性能,以衡量一次或五次尝试的正确性。我们的评估侧重于三个关键指标:
作为验证,我们的基准 Pass@1 准确率为约 52%,与 Eleuther 的 LM Eval Harness 报告的基础模型约 51% 的准确率非常接近,这证实了我们此设置的有效性。虽然绝对准确性对提示的格式(例如,使用 <start_answer> 与 <answer>)很敏感,但在不同的设置下,训练后仍能获得一致的显著性能提升。
Link to Youtube Video (visible only when JS is disabled)
从领先的学术实验室到 AI 初创公司,Tunix 已经在推动新一轮机器学习发展。我们正在与合作伙伴携手开发 Tunix,以解决模型对齐和智能体 AI 方面的真实世界挑战。以下是他们的感触:
“我的研究重点是以数据为中心的学习,这涉及准备高质量的数据来提高模型性能,特别是在大型语言模型 (LLM) 的训练后阶段。其中一个关键的挑战是快速迭代数据样本,确定哪些有用,哪些没用。在这方面,Tunix 是完美的库。其“白盒”设计使我的团队能够完全控制训练循环,让我们能轻松修改和调整代码,满足特定研究需求。与其他框架相比,这种可定制性具有显著优势,对于加速迭代数据分析至关重要。”
——刘洪甫,布兰迪斯大学计算机科学助理教授;NeurIPS 资深领域主席;ICLR 领域主席
“后训练强化学习的一个主要瓶颈是缺乏可验证奖励的环境,而游戏拥有完美的多轮交互环境来解决这个问题,其中 Tunix 是这项研究的理想框架。它使我们能够直接在 JAX 上构建,利用 TPU 和轻松执行并行化。与其他替代方案相比,Tunix 是一个轻量级的库,具有整洁且易于管理的代码库。我们可以利用它对模型和超参数展开高级自定义,免于穿越其他框架的过多抽象层。这种精简的方法对我们的工作至关重要,而且我们发现这种学习曲线很温和,因为并非要成为 JAX 专家才能取得成效。”
——张昊,加州大学圣地亚哥分校助理教授,vLLM、聊天机器人竞技场 (LMSys) 联合创始人以及分离服务的发明者
Precur AI 是一家构建智能体编译器的初创公司,这种编译器可将后台工作流转换为代码驱动、可靠且高效的智能体。该公司联合创始人兼首席技术官 Hanjun Dai 表示:
“我们公司专注于构建在没有监督的情况下全天候运行的后台智能体。要实现的一个关键目标是智能体稳健性,因此我们对“智能体内核”进行了后训练。智能体内核是指针对长期但重复的任务进行优化的模型。Tunix 的适用范围广泛,涵盖 SFT、强化学习和蒸馏,使我们能够保持整个智能体开发堆栈的统一。它与 JAX 和 TPU 生态系统的原生集成是一项重大优势。我们可以轻松定制:使用 Flax 进行开发,使用 Qwix 进行量化服务。因此,Tunix 这个整洁而强大的框架非常适合我们的工作流。”
——PreCur AI 联合创始人兼首席技术官 Hanjun Dai
我们正在公开构建 Tunix,欢迎您加入我们的社区,试用该库并做出贡献。
我们很高兴向 JAX 社区分享 Tunix,并期待看到您所构建的内容。