虽然 JAX 作为一个大型 AI 模型开发的热门框架早已广为人知,但它也正在更广泛的科学领域中迅速被采纳。令人格外欣喜的是,我们发现它在诸如物理驱动的机器学习等计算密集型领域中的应用日益广泛。JAX 支持可组合的转换,即一组高阶函数。例如,grad 可以将一个函数作为输入,并返回另一个能计算该函数梯度的函数。更重要的是,您可以自由地嵌套(组合)这些转换。这一设计使得 JAX 在处理高阶导数和其他复杂转换时尤为简洁。
最近,我有幸与来自新加坡国立大学和 Sea AI Lab 的研究人员 Zekun Shi 和 Min Lin 进行交流。他们的经历清楚地展示了 JAX 如何应对科学研究中的根本性挑战,尤其是在求解复杂偏微分方程 (PDE) 时所面临的计算瓶颈。他们从传统框架的局限中不断摸索,最终利用 JAX 独特的泰勒模式自动微分,这一历程必将引起许多研究人员的共鸣。
我们的研究聚焦于科学计算中一个颇具挑战性的领域:利用神经网络求解高阶偏微分方程。神经网络是通用函数逼近器,这使其成为传统方法(如有限元法)的有力替代方案。然而,使用神经网络解决偏微分方程的一个主要障碍在于需要计算其高阶导数,有时甚至高达四阶或更高,包括混合偏导数。
标准深度学习框架主要针对通过反向传播训练模型进行优化,但并不适合这项任务,因为计算高阶导数的费用极其高昂。反复使用反向传播(反向模式 AD)计算高阶导数的费用会随着导数阶数 (k) 呈指数级增长,并随着域维度 (d) 呈多项式级增长。这种“维数灾难”和导数阶数的指数级增长使得用户几乎无法解决大型且复杂的现实问题。
虽然还有其他一些热门的深度学习库,但我们的研究需要一种更基础的能力:泰勒模式自动微分 (AD),而 JAX 对我们来说是一个巨大的变革性突破。
JAX 的关键架构特色在于其强大的函数表示和转换机制,该机制通过跟踪 Python 代码实现,并经过高性能编译。该系统的设计具有极高的通用性,支持从即时编译到计算标准导数的各种应用。正是这种底层灵活性使得其他框架难以实现的高级运算成为可能。对我们来说,关键的应用是对泰勒模式 AD 的支持,我们了解到这是这一独特架构的直接而强大的成果,使 JAX 能够完美地满足我们的科学工作需求。泰勒模式 AD 通过推进函数的泰勒级数展开,实现高阶导数的高效计算,并一次性高效地计算高阶导数,而无需通过重复且昂贵的反向传播。这使我们能够开发一种名为 Stochastic Taylor Derivative Estimator (STDE) 的算法,用于高效地随机化和估计任何微分算子。
在我们最近发表的论文“《随机泰勒导数估计器:适用于任意微分算子的高效摊销方法》”(Stochastic Taylor Derivative Estimator: Efficient amortization for arbitrary differential operators) 中,我们展示了如何使用这种方法,该论文获得了 NeurIPS 2024 最佳论文奖。我们演示了通过使用 JAX 的泰勒模式,可以构建一种算法,高效地提取这些高阶偏导数。核心思想是利用泰勒模式 AD 来高效计算 PDE 中出现的高阶导数张量的收缩。通过构造特殊的随机切向量(或称为“jet”),我们能够在一次高效的前向传播中,得到对任意复杂微分算子的无偏估计。
结果令人瞩目。在 JAX 中使用我们的 STDE 方法,与基线方法相比,我们实现了 >1,000 倍的速度提升和 >30 倍的内存节省。这一效率提升使我们能够在单个 NVIDIA A100 GPU 上仅用 8 分钟就解决 100 万维的 PDE,而这在以前是难以解决的任务。
对于仅面向标准机器学习工作任务的框架来说,这根本无法实现。其他框架虽然在反向传播上高度优化,但在端到端计算图的表达方面却不如 JAX 注重。这正是 JAX 的优势所在,使其在诸如函数转置或实现高阶泰勒模式微分等操作上表现尤为突出。
除了泰勒模式之外,JAX 的模块化设计,以及对通用数据类型和函数变换的支持,对我们的研究也至关重要。在另一篇论文“《JAX 中的自动函数微分》”(Automatic Functional Differentiation in JAX) 中,我们甚至将 JAX 普及到能够处理无限维向量(Hilbert 空间中的函数),方法是将其描述为自定义数组并在 JAX 中进行注册。这使我们能够复用现有的计算机制来求解泛函和算子的变分导数——而这种功能在其他框架中是完全无法实现的。
出于这些原因,我们不仅在本项目中采用了 JAX,还在量子化学等领域的广泛研究中采用了 JAX。作为一个通用、可扩展并且具有强大符号能力的系统,JAX 的核心设计使其成为推动科学计算前沿的理想选择。我们认为,让科学界了解这些能力至关重要。
Zekun 和 Min 的经历证明了 JAX 的强大功能与灵活性。他们使用 JAX 开发的 STDE 方法对基于物理的机器学习领域做出了重大贡献,使得解决一类以前难以解决的问题成为可能。我们鼓励您阅读他们那篇获奖论文,以更深入地了解技术细节,并在 GitHub 上探索他们开源的 STDE 库,您可借此对 JAX 原生科学工具领域产生更多了解。
此类故事凸显了一个日益明显的趋势:JAX 不仅仅是一个用于深度学习的工具;它还是一款用于可微分编程的基础库,正在推动新一代的科学发现。Google 的 JAX 团队致力于支持并壮大这一充满活力的生态系统,而这一切都从倾听您的声音开始。
我们非常高兴与您携手构建下一代科学计算工具。请联系我们的团队,分享您的工作成果,或探讨您对 JAX 的需求。
衷心感谢 Zekun 和 Min 与我们分享他们富有见地的历程。
参照
Shi, Z., Hu, Z., Lin, M., & Kawaguchi, K. (2025)。《随机泰勒导数估计器:适用于任意微分算子的高效摊销方法》(Stochastic Taylor Derivative Estimator: Efficient amortization for arbitrary differential operators)。 《神经信息处理系统的发展》(Advances in Neural Information Processing Systems), 37.
Lin, M. (2023)。《JAX 中的自动函数微分》(Automatic Functional Differentiation in JAX)。 第十二届国际学习表征会议.