A roboticist's journey with JAX: Finding efficiency in optimal control and simulation

2025년 7월 29일
Srikanth Kilaru Senior Product Manager Google ML Frameworks
Max Muchen Sun Robotics Researcher Northwestern University

JAX is increasingly being adopted by developers for a wide range of computational tasks, expanding its role beyond its original focus on large-scale AI. While it remains a popular framework for developing LLMs and foundation models, JAX is also gaining momentum in diverse scientific domains. One area generating particular excitement is robotics, where JAX is enabling powerful capabilities in simulation, control, and the integration of learning-based methods.

Recently, I had the pleasure of speaking with Max Muchen Sun, a Robotics Ph.D. candidate and researcher at Northwestern University advised by Prof. Todd Murphey. His experience clearly illustrates how JAX can address critical challenges in robotics research, particularly around computational efficiency for complex control algorithms and the seamless combination of model-based and learning-based approaches. Max's journey from grappling with traditional tools to leveraging JAX's unique features like vmap and scan is a story many in the field will find relatable and inspiring.


Max's journey: In his own words

My interest in JAX started from a computational efficiency perspective. My mentor at the time, Ian Abraham (now a professor at Yale University), was using autograd and later led me to JAX. We were working on research using ergodic control, which is a control framework for coverage problems. Compared to standard control formulations, the computational complexity of ergodic control is inherently higher. To achieve real-time ergodic control, I initially used standard NumPy and leveraged vectorization and broadcasting features.

The first JAX feature that caught my eye was JAX’s vmap. To me, it combines the vectorization and broadcasting mechanisms from standard NumPy, and generalizes them further through function transformation and compositional abstraction, making it much easier for me to reason about and implement parallelization for the problems I’m solving.

Then I learned about scan. It was less intuitive at first, but ultimately it became an efficient tool for simulating the trajectories of dynamic systems. In trajectory optimization, forward simulation of the system dynamics is a core operation that must be performed repeatedly and often becomes the computational bottleneck. With scan, trajectory simulation can be sped up to two orders of magnitude compared to standard NumPy-based implementations. The ease of use and substantial speed advantage fully pulled me into the JAX ecosystem.

On the other hand, a central focus of my PhD has been integrating model-based control with learning-based representations for autonomous exploration and multi-agent cooperation. I view model-based methods not as standalone solutions, but as structures to improve learning efficiency and robustness. JAX’s composability made it ideal for merging model-based and learning-based pipelines.

In one of my latest papers accepted to Robotics: Science and Systems (RSS), I combined flow matching from generative models with model-based optimal control for robot exploration, using flow gradients to map state-space flows to controls via an LQR-based update—analogous to backpropagation but on dynamic systems. I initially built the flow matching module in PyTorch and used C++ for LQR, but integration was slow. Switching to JAX, I reimplemented the flow matching part using vmap and grad, and leveraged JAX based tools like the OTT (Optimal Transport Toolbox). The remaining piece was a JAX-native LQR pipeline.

In another recent paper presented at the IEEE International Conference on Robotics and Automation (ICRA), I integrated a model-based game-theoretic control pipeline into a generative trajectory model to learn multi-agent cooperation from demonstrations. Rather than using game-theoretic control as a full solution—often computationally expensive and requiring manual loss specification—I embedded the game-theoretic computation as a structured layer inside a conditional variational autoencoder (CVAE). This improved data efficiency without sacrificing performance. Both components were implemented in JAX—CVAE with Flax and the control layer from scratch. JAX made it seamless: grad could differentiate through the equilibrium directly. I also built a JAX-based iLQGames solver for generating synthetic data.

After these projects, I realized I was reusing much of my JAX code for dynamic system calculations, especially LQR-based ones. Since I used LQR to integrate learning-based and model-based control in nonstandard ways, I packaged it into a standalone JAX-native solver—LQRax. It supports GPU acceleration, vmap, scan, and grad, enabling vectorized and differentiable LQR. I included examples like ergodic and game-theoretic control to highlight how model-based methods can complement learning.

I use JAX on both CPUs and GPUs, often differently than the ML community. For instance, in the flow matching project, LQR runs faster on CPUs, while flow matching gradients are faster on GPUs. I haven't used TPUs since I typically run all the computations locally. A few years ago, I tried JAX on an Nvidia Jetson, and installation was hard. I'm glad JAX is now supported on these embedded platforms, which is critical for robotics. I've been testing a crowd navigation algorithm on a quadruped robot using a Jetson with all computation done onboard, and I plan to integrate JAX into this project soon.

Looking ahead, I’ll continue using JAX for the same reasons I started. First, computational efficiency, especially GPU-based parallelization, is increasingly vital in robotics. Beyond training, it enables new model-based control possibilities like massive parallel simulations and real-time parameter updates, akin to embodied active learning. Second, JAX makes integrating model-based structures into learning pipelines intuitive—whether for dynamics, loss shaping, or differentiable solvers. That flexibility keeps me excited to push further.


Explore the JAX robotics ecosystem: From LQRax to MJX

Max's experience demonstrates several key advantages JAX offers the robotics community. The significant speedups achieved with vmap for parallel operations and scan for trajectory simulations are crucial for real-time control and complex planning. Furthermore, JAX's functional paradigm and automatic differentiation capabilities make it well-suited for integrating classical model-based techniques with modern learning-based components.

We believe stories like Max's are a sign of a rapidly growing and maturing ecosystem. His LQRax package is a great addition to a vibrant landscape of JAX native robotics tools, and we encourage you to explore the project on GitHub and try it out for yourself. In the world of simulation, JAX provides a powerful foundation with massively parallel engines like Brax and the new MuJoCo XLA (MJX), which brings the popular and standard MuJoCo physics engine directly to JAX. We're also seeing specialized tools from the community, such as the JaxSim library for control-focused multibody dynamics.

In the domain of trajectory optimization, where pioneers like Trajax first paved the way, LQRax arrives as a welcome and modern library for researchers building the next generation of control systems. It perfectly embodies the JAX spirit by providing a powerful, composable tool that bridges the gap between model-based control and deep learning.

Sincere thanks to Max for sharing his insightful journey with us. We're excited to see how he and other researchers continue to leverage JAX to build the next generation of intelligent robotic systems. The JAX team at Google is committed to supporting and growing this vibrant ecosystem.