机器人专家的 JAX 探索历程:通过最优控制和模拟获得效率

2025年7月29日
Srikanth Kilaru Senior Product Manager Google ML Frameworks
Max Muchen Sun Robotics Researcher Northwestern University

JAX 越来越多地被开发者用于广泛的计算任务,它的用途得到拓展,不再像从前那样仅集中应用于大规模 AI。虽然 JAX 仍然是开发 LLM 和基础模型的热门框架,但也在不同的科学领域得到推广,其中一个特别激动人心的领域是机器人技术——JAX 在模拟、控制和学习驱动型方法的集成方面展现了强大的功能。

最近,我有幸与美国西北大学 Todd Murphey 教授指导的机器人学博士候选人兼研究员 Max Muchen Sun 进行了交流。他的经验清晰展现了 JAX 如何解决机器人研究中的关键挑战,特别是在复杂控制算法的计算效能提升,以及模型驱动型与学习驱动型方法的流畅融合方面。从使用传统工具到利用 JAX 的独特功能(如 vmap scan),Max 的探索经历想必能引起业内人士的共鸣,并对大家有所启发。


以下是 Max 对探索历程的自述:

我对 JAX 的兴趣始于计算效率。当时我的导师 Ian Abraham(现任耶鲁大学教授)正在使用 autograd,后来引导我接触了 JAX。我们当时在使用遍历控制进行研究,这是用于覆盖问题的控制框架。与标准控制表达式相比,遍历控制的计算复杂性本质上更高。为了实现实时遍历控制,我最初使用了标准的 NumPy,并利用了矢量化及广播功能。

JAX 首先引起我注意的功能是 vmap。对我来说,这一功能结合了标准 NumPy 的矢量化和广播机制,并通过函数转换和组合抽象进一步泛化,使我更容易实现对目标问题的推理和并行处理。

然后我了解了 scan。这种功能起初不太直观,但我逐渐发现它作为模拟动态系统轨迹工具很有效。在轨迹优化中,系统动力学的前向仿真是必须重复执行的核心操作,经常成为计算瓶颈。与基于 NumPy 的标准实现相比,使用 scan 可以让轨迹模拟加速两个数量级。这种易用性和实质性的速度优势,使我完全转向了 JAX 生态系统。

另一方面,我博士论文的一个关注点是将模型驱动型的控制与学习驱动型的表征相结合,以实现自主探索和多智能体合作。我认为模型驱动型方法不是独立的解决方案,而是提高学习效率和稳健性的结构。JAX 的可组合性使其成为融合模型驱动型和学习驱动型流水线的理想选择。

近期的一篇论文中(该论文已被机器人:科学与系统 (RSS) 会议接收),我将来自生成模型的流匹配结合模型驱动型最优控制,用于机器人探索功能。我使用流梯度,通过基于 LQR 的更新将状态空间流映射到控制指令。其原理与反向传播类似,只不过是应用在动态系统上。我最初在 PyTorch 中构建流匹配模块,并将 C++ 应用于 LQR,但集成速度很慢。切换到 JAX 后,我使用 vmapgrad 重新实现了流匹配部分,并利用了基于 JAX 的工具,如 OTT(最优传输工具箱)。最终仅剩 LQR 流水线部分需转为 JAX 原生实现。

在另一篇近期的论文中(此论文发表于 IEEE 机器人与自动化国际会议 (ICRA)),我将模型驱动型博弈论控制流水线集成到生成式轨迹模型中,以从演示中学习多智能体合作。我没有使用博弈论控制作为完整的解决方案,这种方式通常计算成本昂贵,且需要手动确定损失规范;相反,我用博弈论计算作为结构化层,嵌入到条件变分自编码器 (CVAE) 中。这在不牺牲性能的情况下提高了数据效率。这两个组件都是在 JAX 中实现的,CVAE 使用 Flax 实现,而控制层则是零开始实现的。借助 JAX,上述过程非常流畅:grad 可直接对平衡点求导。我还构建了一个基于 JAX 的 iLQGames 求解器,用于生成合成数据。

完成这些项目后,我发现多数 JAX 代码可复用于动态系统计算(尤其是基于 LQR 的计算)。由于我使用 LQR 以非标准方式集成了学习驱动型和模型驱动型的控制,因此我将其打包到一个独立的 JAX 原生求解器 LQRax 中。这个工具支持 GPU 加速、vmapscan grad,从而实现矢量化和可微分 LQR。我加入了遍历和博弈论控制等例子,以突出模型驱动型方法如何补充学习驱动型的方法。

我在 CPU 和 GPU 上应用 JAX 的方式通常与 ML 社区不同。例如,在流匹配项目中,LQR 在 CPU 上运行更快,而流匹配梯度在 GPU 上运行更快。我没有使用 TPU,因为我通常在本地运行所有计算。几年前,我在 Nvidia Jetson 上尝试了 JAX,但安装非常困难。很高兴看到这些嵌入式平台现在已经能支持 JAX,这对机器人技术至关重要。我一直在使用 Jetson 在四足机器人上测试人群导航算法,所有计算都在机器人本体上完成,我计划很快将 JAX 集成到这个项目中。

展望未来,我将继续使用 JAX,因为我选择它的初衷始终没变。首先是 JAX 的计算效率,而这点在机器人技术中越来越重要,特别是基于 GPU 的并行处理。除了训练之外,它还为模型驱动型控制创造了新的可能性,如大规模并行模拟和实时参数更新,类似于具身主动学习。其次,JAX 使模型驱动型结构集成到学习流水线中变得直观——无论是用于动态机制、损失调整还是可微分求解器。这种灵活性让我对 JAX 进一步的发展和潜力充满期待。


探索 JAX 机器人生态系统:从 LQRax 到 MJX

Max 的经验表明,JAX 为机器人社区提供了几个关键优势。使用 vmap 能实现并行操作显著加速,通过 scan 则能实现轨迹模拟的显著加速,这两点对于实时控制和复杂规划至关重要。此外,JAX 的功能范式自动微分功能,使其非常适合用来集成经典的模型驱动型技术与现代的学习驱动型组件。

我们相信,类似 Max 这样的经历意味着,JAX 生态系统正在快速发展且日趋成熟。他的 LQRax 软件包为充满活力的 JAX 原生机器人工具领域注入了新力量。我们鼓励您在 GitHub 上探索该项目并亲自尝试。在模拟世界中,JAX 通过 Brax 和新的 MuJoCo XLA (MJX) 等大规模并行引擎,提供了强大的基础;后者 (MJX) 将流行的标准 MuJoCo 物理引擎直接引入了 JAX。我们还看到了来自社区的专用工具,例如用于以控件为中心的多体动态的 JaxSim 库。

在轨迹优化领域,像 Trajax 这样的先驱技术率先铺平了道路。LQRax 作为一个受欢迎的现代库,为研究人员构建下一代控制系统提供有力支持。它完美地体现了 JAX 的精髓,提供了一个强大、可组合的工具,弥合了模型驱动型控制和深度学习之间的差距。

衷心感谢 Max 与我们分享他发人深思的历程。我们很高兴看到他和其他研究人员继续利用 JAX 构建下一代智能机器人系统。Google 的 JAX 团队也将继续支持这个充满活力的生态系统,帮助其不断发展壮大。