Si bien JAX es conocido como un framework popular para el desarrollo de modelos de IA a gran escala, también está logrando una rápida adopción en un conjunto más amplio de dominios científicos. Nos entusiasma ver su creciente uso en campos computacionalmente intensivos, como el aprendizaje automático basado en la física. JAX admite transformaciones componibles, un conjunto de funciones de orden superior. Por ejemplo, grad toma una función como entrada y devuelve otra función que calcula su gradiente y, lo que es más importante, puedes anidar (componer) estas transformaciones de manera libre. Este diseño es lo que hace que JAX sea especialmente elegante para derivadas de orden superior y otras transformaciones complejas.
Hace poco, tuve el placer de hablar con Zekun Shi y Min Lin, investigadores de la Universidad Nacional de Singapur y Sea AI Lab. Su experiencia ilustra claramente cómo JAX puede abordar los desafíos fundamentales de la investigación científica, particularmente en torno al colapso computacional que se enfrenta al resolver ecuaciones diferenciales parciales (EDP) complejas. Su recorrido desde lidiar con las limitaciones de los frameworks tradicionales hasta aprovechar la diferenciación automática del modo Taylor de JAX es una historia que resonará entre muchos investigadores.
Nuestro trabajo se centra en un área desafiante de la computación científica: el uso de redes neuronales para resolver PDE de alto orden. Las redes neuronales son aproximadores de funciones universales, lo que las convierte en una alternativa prometedora a los métodos tradicionales, como los elementos finitos. Sin embargo, un obstáculo importante para resolver las PDE con una red neuronal es que se deben evaluar sus derivadas de alto orden, a veces hasta del cuarto orden o más, incluidas las derivadas parciales mixtas.
Los frameworks de aprendizaje profundo estándar, que están optimizados principalmente para modelos de capacitación a través de la propagación hacia atrás, no son adecuados para esta tarea, ya que el cálculo de derivadas de alto orden es increíblemente costoso. El costo de aplicar la propagación hacia atrás (modo hacia atrás AD) repetidamente para derivadas de orden superior escala exponencialmente con el orden de la derivada (k) y polinomialmente con la dimensión de dominio (d). Esta "maldición de la dimensionalidad" y el escalado exponencial en el orden de las derivadas hacen prácticamente imposible abordar problemas grandes, complejos y del mundo real.
Si bien hay otras bibliotecas populares para el aprendizaje profundo, nuestra investigación requirió una capacidad más fundamental: la diferenciación automática (AD) en modo Taylor. JAX representó un cambio radical para nosotros.
La distinción arquitectónica clave de JAX es su potente mecanismo de representación y transformación de funciones, implementado mediante el rastreo de código Python y compilado para un alto rendimiento. Este sistema está diseñado con tal generalidad que permite una gama versátil de aplicaciones, desde la compilación justo a tiempo hasta la computación de derivadas estándar. Es esta flexibilidad subyacente la que permite realizar operaciones avanzadas que no se pueden lograr fácilmente en otros frameworks. Para nosotros, la aplicación crucial fue el soporte para la diferenciación automática del modo Taylor, que aprendimos que es un resultado directo y poderoso de esta arquitectura única, lo que hace que JAX esté perfectamente equipado para nuestro trabajo científico. La diferenciación automática del modo Taylor permite el cálculo eficiente de derivadas de alto orden al impulsar la expansión de la serie de Taylor de una función y calcular eficientemente las derivadas de alto orden en una sola pasada en lugar de a través de la repetida y costosa propagación hacia atrás. Esto nos permitió desarrollar un algoritmo, el Estimador de Derivadas de Taylor Estocástico (STDE), para aleatorizar y estimar eficientemente cualquier operador diferencial.
En nuestro reciente artículo, "Estimador de Derivadas de Taylor Estocástico: amortización eficiente para operadores diferenciales arbitrarios", que recibió el premio al mejor artículo en NeurIPS 2024, demostramos cómo se podría utilizar este enfoque. Demostramos que, al usar el modo Taylor de JAX, podríamos crear un algoritmo para extraer estas derivadas parciales de alto orden de manera eficiente. La idea principal era aprovechar la diferenciación automática del modo Taylor para calcular de manera eficiente las contracciones de los tensores de derivadas de orden superior que aparecen en las PDE. Al construir vectores tangentes aleatorios especiales (o "chorros"), podríamos obtener una estimación imparcial de un operador diferencial arbitrariamente complejo en un único y eficiente paso hacia adelante.
Los resultados fueron espectaculares. Con nuestro método STDE en JAX, logramos una aceleración de >1000 veces y una reducción de memoria de >30 veces en comparación con los métodos de referencia. Esta ganancia de eficiencia nos permitió resolver una PDE de 1 millón de dimensiones en solo 8 minutos en una sola GPU NVIDIA A100, una tarea que antes era casi imposible.
Esto simplemente no habría sido posible con un framework orientado solo a las cargas de trabajo estándar de aprendizaje automático. Otros marcos están altamente optimizados para la propagación hacia atrás, pero se centran menos en la representación de gráficos computacionales de extremo a extremo que JAX. Eso ayuda a JAX a brillar con operaciones como la transposición de una función o la implementación de la diferenciación del modo Taylor de orden superior.
Más allá del modo Taylor, el diseño modular y el soporte de JAX para tipos de datos generales y transformaciones de funciones son fundamentales para nuestra investigación. En otro trabajo, "Diferenciación funcional automática en JAX", incluso generalizamos JAX para manejar vectores de dimensión infinita (funciones en el espacio de Hilbert) describiéndolos como una matriz personalizada y registrándolos con JAX. Esto nos permite reutilizar la maquinaria existente para calcular derivadas variacionales para funcionales y operadores, una funcionalidad que está completamente fuera del alcance de otros frameworks.
Por estas razones, adoptamos JAX no solo para este proyecto, sino también para una amplia gama de investigaciones en áreas como la química cuántica. Su diseño fundamental como un sistema general, extensible y simbólicamente poderoso lo convierte en la opción ideal para ampliar las fronteras de la computación científica. Creemos que es importante que la comunidad científica conozca estas capacidades.
La experiencia de Zekun y Min demuestra el poder y la flexibilidad de JAX. Su método STDE desarrollado utilizando JAX es una contribución significativa al campo del aprendizaje automático basado en la física, lo que permite abordar una clase de problemas que antes eran imposibles. Te recomendamos leer su galardonado artículo para profundizar en los detalles técnicos y explorar la biblioteca STDE de código abierto en GitHub, que es una fantástica adición al panorama de las herramientas científicas nativas de JAX.
Historias como esta resaltan una tendencia creciente: JAX es mucho más que una herramienta para el aprendizaje profundo; es una biblioteca fundamental para la programación diferenciable que está impulsando una nueva generación de descubrimientos científicos. El equipo de JAX en Google se compromete a apoyar y hacer crecer este vibrante ecosistema, y eso comienza con escucharte directamente a ti.
Nos entusiasma asociarnos contigo para crear la próxima generación de herramientas informáticas científicas. Comunícate con el equipo para compartir tu trabajo o contarnos qué necesitas de JAX.
Un sincero agradecimiento a Zekun y Min por compartir su perspicaz recorrido con nosotros.
Referencia
Shi, Z.; Hu, Z.; Lin, M.; y Kawaguchi, K. (2025). Estimador de Derivadas de Taylor Estocástico: amortización eficiente para operadores diferenciales arbitrarios. Avances en sistemas de procesamiento de información neuronal, 37.
Lin, M. (2023). Diferenciación funcional automática en JAX. Duodécima Conferencia Internacional sobre Representaciones de Aprendizaje.