Cada vez más desarrolladores adoptan JAX para una amplia gama de tareas computacionales, expandiendo su papel más allá de su enfoque original en la IA a gran escala. Si bien sigue siendo un framework popular para el desarrollo de LLM y modelos de fundación, JAX también está ganando impulso en diversos dominios científicos. Un área que genera especial entusiasmo es la robótica, donde JAX está habilitando potentes capacidades en simulación, control e integración de métodos basados en el aprendizaje.
Recientemente, tuve el placer de hablar con Max Muchen Sun, un candidato al doctorado en Robótica e investigador de la Universidad de Northwestern asesorado por el profesor Todd Murphey. Su experiencia ilustra de manera clara cómo JAX puede abordar desafíos fundamentales de la investigación en robótica, particularmente en torno a la eficiencia computacional para algoritmos de control complejos y la combinación perfecta de enfoques basados en modelos y en el aprendizaje. El recorrido de Max, desde lidiar con herramientas tradicionales hasta aprovechar las características únicas de JAX, como vmap y scan, es una historia con la que muchos en el campo se sentirán identificados y se inspirarán.
Mi interés en JAX comenzó desde una perspectiva de eficiencia computacional. Mi mentor en ese momento, Ian Abraham (ahora profesor en la Universidad de Yale), estaba usando autograd y luego me guio hacia JAX. Estábamos trabajando en la investigación utilizando el control ergódico, que es un framework de control para problemas de cobertura. En comparación con las formulaciones de control estándares, la complejidad computacional del control ergódico es inherentemente mayor. Para lograr el control ergódico en tiempo real, inicialmente utilicé el estándar NumPy y aproveché las funciones de vectorización y transmisión.
La primera función de JAX que me llamó la atención fue vmap de JAX. Para mí, combina los mecanismos de vectorización y difusión del estándar NumPy, y los generaliza aún más a través de la transformación de funciones y la abstracción compositiva, lo que me facilita mucho razonar e implementar la paralelización para los problemas que estoy resolviendo.
Luego, aprendí sobre el escaneo. Al principio, era menos intuitivo, pero finalmente se convirtió en una herramienta eficiente para simular las trayectorias de los sistemas dinámicos. En la optimización de la trayectoria, la simulación hacia adelante de la dinámica del sistema es una operación central que debe realizarse repetidamente y, a menudo, se convierte en el cuello de botella computacional. Con escaneo, la simulación de trayectoria se puede acelerar hasta dos órdenes de magnitud en comparación con las implementaciones estándares basadas en NumPy. La facilidad de uso y la ventaja sustancial de velocidad me atrajeron completamente al ecosistema de JAX.
Por otro lado, un enfoque central de mi doctorado ha sido integrar el control basado en modelos con representaciones basadas en el aprendizaje para la exploración autónoma y la cooperación multiagente. Considero que los métodos basados en modelos no son soluciones independientes, sino estructuras para mejorar la eficiencia y la solidez del aprendizaje. La componibilidad de JAX lo hizo ideal para fusionar flujos de procesamiento basados en modelos y en aprendizaje.
En uno de mis últimos artículos aceptados en Robotics: Science and Systems (RSS), combiné la coincidencia de flujos de modelos generativos con un control óptimo basado en modelos para la exploración de robots, utilizando gradientes de flujo para asignar flujos de estado-espacio a los controles a través de una actualización basada en LQR, análoga a la retropropagación, pero en sistemas dinámicos. Inicialmente, construí el módulo de coincidencia de flujo en PyTorch y usé C++ para LQR, pero la integración fue lenta. Cambiando a JAX, volví a implementar la parte de coincidencia de flujo utilizando vmap y grad, y aproveché herramientas basadas en JAX como OTT (Optimal Transport Toolbox). La pieza restante era un flujo de procesamiento LQR nativo de JAX.
En otro artículo reciente presentado en la Conferencia Internacional sobre Robótica y Automatización (ICRA) del IEEE, integré un flujo de procesamiento de control teórico de juegos basado en modelos en un modelo de trayectoria generativa para aprender sobre la cooperación multiagente a partir de demostraciones. En lugar de utilizar el control de la teoría de juegos como una solución completa, a menudo computacionalmente costosa y que requiere una especificación de pérdida manual, incorporé el cálculo de la teoría de juegos como una capa estructurada dentro de un autoencoder variacional condicional (CVAE). Esto mejoró la eficiencia de los datos sin sacrificar el rendimiento. Ambos componentes se implementaron en JAX—CVAE con Flax y la capa de control desde cero. JAX lo hizo sin problemas: grad podía diferenciarse directamente a través del equilibrio. También compilé un solucionador iLQGames basado en JAX para generar datos sintéticos.
Después de estos proyectos, me di cuenta de que estaba reutilizando gran parte de mi código JAX para cálculos dinámicos del sistema, especialmente los basados en LQR. Como usé LQR para integrar el control basado en el aprendizaje y en modelos de manera no estándar, lo empaqueté en un solucionador nativo de JAX independiente: LQRax. Es compatible con la aceleración de GPU, vmap, scan y grad, lo que permite contar con un LQR vectorizado y diferenciable. Incluí ejemplos, como el control ergódico y la teoría de juegos, para resaltar cómo los métodos basados en modelos pueden complementar el aprendizaje.
Utilizo JAX tanto en CPU como en GPU, a menudo de manera diferente a la comunidad de AA. Por ejemplo, en el proyecto de coincidencia de flujo, LQR se ejecuta más rápido en las CPU, mientras que los gradientes de coincidencia de flujo son más rápidos en las GPU. No usé TPU, ya que normalmente ejecuto todos los cálculos de manera local. Hace unos años, probé JAX en un Nvidia Jetson y la instalación fue difícil. Me alegro de que JAX ahora sea compatible con estas plataformas integradas, lo cual es fundamental para la robótica. Estuve probando un algoritmo de navegación de multitudes en un robot cuadrúpedo usando un Jetson con todos los cálculos realizados a bordo y planeo integrar JAX en este proyecto pronto.
De cara al futuro, seguiré usando JAX por las mismas razones por las que empecé a usarlo. En primer lugar, la eficiencia computacional, especialmente la paralelización basada en GPU, es cada vez más vital en robótica. Más allá de la capacitación, permite nuevas posibilidades de control basadas en modelos, como simulaciones paralelas masivas y actualizaciones de parámetros en tiempo real, similares al aprendizaje activo incorporado. En segundo lugar, JAX hace que la integración de estructuras basadas en modelos en flujos de procesamiento de aprendizaje sea intuitiva, ya sea para la dinámica, la configuración de pérdidas o los solucionadores diferenciables. Esa flexibilidad me mantiene con entusiasmo para seguir adelante.
La experiencia de Max demuestra varias ventajas clave que JAX ofrece a la comunidad robótica. Los aceleramientos significativos logrados con vmap para operaciones paralelas y el escaneo para simulaciones de trayectoria son cruciales para el control en tiempo real y la planificación compleja. Además, el paradigma funcional y las funciones de diferenciación automática de JAX lo hacen adecuado para integrar técnicas clásicas basadas en modelos con componentes modernos basados en el aprendizaje.
Creemos que historias como la de Max son un signo de un ecosistema en rápido crecimiento y maduración. Su paquete de LQRax es una gran incorporación a un panorama vibrante de herramientas de robótica nativas de JAX, y te animamos a explorar el proyecto en GitHub y probarlo por ti mismo. En el mundo de la simulación, JAX proporciona una base poderosa con motores masivamente paralelos como Brax y el nuevo MuJoCo XLA (MJX), que lleva el popular y estándar motor de física MuJoCo directamente a JAX. También estamos viendo herramientas especializadas de la comunidad, como la biblioteca JaxSim para dinámicas multicuerpo centradas en el control.
En el ámbito de la optimización de la trayectoria, donde pioneros como Trajax allanaron el camino por primera vez, LQRax llega como una biblioteca bienvenida y moderna para los investigadores que están creando la próxima generación de sistemas de control. Combina perfectamente el espíritu de JAX con una herramienta potente y componible que cierra la brecha entre el control basado en modelos y el aprendizaje profundo.
Un sincero agradecimiento a Max por compartir su interesante recorrido con nosotros. Nos entusiasma ver cómo él y otros investigadores continúan aprovechando JAX para crear la próxima generación de sistemas robóticos inteligentes. El equipo de JAX en Google se compromete a apoyar y hacer crecer este vibrante ecosistema.