While JAX is well known as a popular framework for large scale AI model development, it is also gaining rapid adoption in a wider set of scientific domains. We are particularly excited to see its growing use in computationally intensive fields like physics-informed machine learning. JAX supports composable transformations, a set of higher-order functions. For example, grad takes a function as input and returns another function which computes its gradient—and crucially, you can nest (compose) these transformations freely. This design is what makes JAX especially elegant for higher-order derivatives and other complex transformations. Recently, I had the pleasure of speaking with Zekun Shi and Min Lin, researchers from the National University of Singapore and Sea AI Lab. Their experience clearly illustrates how JAX can address fundamental challenges in scientific research, particularly around the computational cliff faced when solving complex Partial Differential Equations (PDEs). Their journey from grappling with the limitations of traditional frameworks to harnessing JAX's unique Taylor mode automatic differentiation is a story that will resonate with many researchers.

A new approach to solving PDEs: In the researchers' own words Our work focuses on a challenging area of scientific computing: using neural networks to solve high-order PDEs. Neural networks are universal function approximators, which makes them a promising alternative to traditional methods like finite elements. However, a major hurdle in solving PDEs with a neural network is that you need to evaluate its high-order derivatives, sometimes up to the fourth order or even higher, including mixed partial derivatives. Standard deep learning frameworks, which are primarily optimized for training models via backpropagation, are not well-suited for this task as computing high-order derivatives is incredibly expensive. The cost of applying back-propagation (backward mode AD) repeatedly for high-order derivatives scales exponentially with the derivative order (k) and polynomially with the domain dimension (d). This "curse of dimensionality" and exponential scaling in derivative order make it practically impossible to tackle large, complex, real-world problems.

Compute graph scales exponentially in derivative order k

While there are other popular libraries for Deep Learning, our research required a more fundamental capability: Taylor mode automatic differentiation (AD). JAX was a game-changer for us. The key architectural distinction of JAX is its powerful function representation and transformation mechanism, implemented by tracing Python code and compiled for high performance. This system is designed with such generality that it enables a versatile range of applications, from just-in-time compilation to computing standard derivatives. It is this underlying flexibility that allows for advanced operations not easily achievable in other frameworks. For us, the crucial application was the support for Taylor mode AD, which we learned is a direct and powerful result of this unique architecture, making JAX perfectly equipped for our scientific work. Taylor mode AD enables the efficient computation of high-order derivatives by pushing forward a function's Taylor series expansion and efficiently computing high-order derivatives in a single pass rather than through repeated, costly back-propagation. This enabled us to develop an algorithm, the Stochastic Taylor Derivative Estimator (STDE), to efficiently randomize and estimate any differential operator.

Taylor-mode for second-order derivative - No exponential scaling.