Modelo fundacional de Marin de Stanford: el primer modelo totalmente abierto desarrollado con JAX

16 DE JULIO DE 2025
Srikanth Kilaru Senior Product Manager Google ML Frameworks
David Hall Research Engineering Lead Stanford HAI

Un elemento emocionante de la era actual de la IA es cómo los potentes modelos fundacionales se comparten abiertamente y ayudan a acelerar la innovación para todos. Este progreso nos inspira a preguntarnos: "¿qué sigue en cuanto a la apertura?". El proyecto Marin ve una oportunidad para ampliar la definición de "abierto" a fin de abarcar todo el proceso científico detrás de un modelo.

El proyecto Marin de Stanford CRFM (Centro de Investigación sobre Modelos Fundacionales) está diseñado como un " laboratorio abierto", donde el objetivo no es solo compartir el modelo, sino hacer que el recorrido completo sea accesible, incluido el código, el conjunto de datos, las metodologías de datos, los experimentos, los hiperparámetros y los registros de entrenamiento. Este nivel de transparencia complementa el ecosistema existente al proporcionar un recurso único y totalmente reproducible que permite a los investigadores analizar, desarrollar y confiar en los modelos que se están desarrollando. El proyecto Marin de Stanford busca fomentar un futuro más transparente y accesible para la investigación de modelos fundacionales.


El espectro de la apertura del modelo de IA

The Spectrum of AI Model Openness

Los primeros lanzamientos de este laboratorio abierto son los modelos Marin-8B-Base y Marin-8B-Instruct. De acuerdo con los principios del proyecto, los modelos, los datos, el código y el tokenizador se publican bajo la licencia permisiva Apache 2.0. Este compromiso con la reproducibilidad completa es un problema de ingeniería formidable, que requiere control sobre todas las fuentes de variación en un sistema distribuido masivamente. El éxito del proyecto depende de una pila de tecnología que pueda ofrecer esta garantía de reproducibilidad a escala y maximizar la eficiencia para entrenar un modelo fundacional con una relación precio/rendimiento líder.


Desafíos centrales de la creación de modelos fundacionales abiertos

Para que el proyecto Marin tuviera éxito en la creación de modelos fundacionales verdaderamente abiertos, escalables y reproducibles, el equipo de CRFM tuvo que resolver varios desafíos de ingeniería. Eligió a JAX como base porque sus principios de diseño proporcionaban soluciones directas a estos problemas, y creó un nuevo framework, Levanter (ver más abajo), para aprovechar el poder de JAX. Estos son algunos ejemplos de desafíos y sus soluciones:


Lograr la velocidad máxima en un solo acelerador

Problema: el ciclo de entrenamiento central se ejecuta miles de millones de veces, por lo que la sobrecarga de un lenguaje interpretado como Python crea un cuello de botella de rendimiento masivo. Si las operaciones se envían paso a paso, el bucle también puede incurrir en tráfico de memoria excesivo y sobrecarga, especialmente en hardware como TPU, donde el rendimiento depende de la ejecución eficiente de las operaciones fusionadas.

Nuestra solución:

  • Para eliminar la sobrecarga del intérprete, Levanter encapsula todo el paso de capacitación de múltiples etapas (pase hacia adelante, pérdida, retropropagación y actualización) en una sola función y utiliza el decorador @ jax.jit. El compilador XLA de JAX transforma todo este proceso en un núcleo de código máquina único y altamente optimizado, fusionando las operaciones para maximizar la utilización del hardware a escala.

  • Para evitar cálculos redundantes, utilizamos parajax.value_and_grad para calcular tanto la pérdida como sus gradientes en una sola pasada. JAX también facilita el uso de técnicas avanzadas como el punto de control de gradiente, lo que ahorra memoria y nos permite usar lotes más grandes casi sin gastos generales.

  • Levanter también utiliza el potente núcleo Splash Attention basado en Pallas de JAX, una implementación altamente optimizada de Dot Product Attention, una de las operaciones más críticas en el corazón de casi todos los modelos de lenguaje grandes.


Gestionar la complejidad del paralelismo a gran escala

Problema: entrenar modelos de vanguardia requiere escalar a miles de chips aceleradores. Administrar manualmente cómo se particionan el modelo y los datos y cómo se comunican los dispositivos es inmensamente complejo, y el código se vuelve rápidamente difícil de leer, depurar y adaptar.

Nuestra solución:

  • El decorador @ jax.jitde JAX también admite sin problemas la paralelización de un solo programa y datos múltiples (SPMD) que automatiza la fragmentación y comunicación de datos subyacentes. El compilador XLA programa automáticamente la comunicación entre aceleradores para minimizar el tiempo de espera en la red y maximizar el tiempo dedicado al cálculo.

  • Para hacer que el poder de JIT fuera aún más fácil y seguro de usar, Levanter desarrolló Haliax, una biblioteca para tensores con nombre. Al referirse a ejes tensores con nombres legibles por humanos (como "incrustar" o "lote") en lugar de índices posicionales, el código se vuelve autodocumentado y sólido.

  • Esta abstracción nos permite definir y modificar estrategias sofisticadas de fragmentación, como el paralelismo de datos completamente fragmentados (FSDP) y el paralelismo de tensores, simplemente cambiando unas pocas líneas en un archivo de configuración, sin tocar nunca el código del modelo.


Creación y administración de clústeres informáticos resilientes y rentables

Problema: la capacitación a gran escala requiere un acceso flexible a clústeres informáticos masivos. Dependemos en gran medida de las instancias de TPU interrumpibles para administrar los costos, lo que significa que necesitamos una manera de combinar fácilmente muchas porciones de TPU más pequeñas y dispares en un grupo lógico y ser resistentes a las interrupciones frecuentes.

Nuestra solución:

  • Aprovechamos Google Cloud TPU Multislice, una tecnología que permite a una tarea de entrenamiento usar múltiples cortes de TPU como si fueran un sistema grande. Esto hace que sea fácil coser muchas rebanadas de TPU interrumpibles y pequeñas en un solo y potente grupo de cómputo para el entrenamiento.

  • Levanter utiliza Ray para orquestar este proceso, escalando sin problemas el número de segmentos de TPU hacia arriba o hacia abajo durante una tarea de entrenamiento y, lo que es más importante, garantizando que la tarea siga siendo resistente si se interrumpe un segmento.

  • Gracias a JAX y XLA, Levanter y Marin también pudieron obtener resultados similares de alto rendimiento en las GPU.


Fomentar la confianza científica con una reproducibilidad perfecta

Problema: un objetivo central del proyecto Marin es permitir la ciencia verificable. Esto requiere lograr resultados reproducibles, incluso cuando el entrenamiento se pausa, se reinicia o se mueve entre diferentes configuraciones de la máquina, un obstáculo técnico significativo.

Nuestra solución:

  • Esta elección se validó durante el entrenamiento de Marin-8B, que implicó la migración entre diferentes segmentos de TPU y tipos de hardware mientras se mantenía con éxito la reproducibilidad bit a bit en todas las interrupciones temporales.

  • Levanter también incluye un sólido sistema de carga de datos basado en la biblioteca Tensorstore de Google. El almacén de datos de Levanter ofrece acceso determinista y aleatorio a cualquier lote de datos de entrenamiento, independientemente de los reinicios laborales o los cambios en la fuente de datos, lo cual es fundamental para respaldar estrategias de capacitación avanzadas, como la capacitación intermedia. El determinismo de JAX y el almacén de datos de Levanter también facilitan a los investigadores de interpretabilidad comprender cómo los datos específicos afectan el modelo durante la capacitación.


Crear un framework cohesivo

Problema: si bien JAX proporciona un motor potente, ningún framework de alto nivel existente cumplía con nuestros estrictos requisitos combinados de legibilidad, escalabilidad masiva y determinismo bit a bit. Necesitábamos un sistema completo y basado en convenciones para orquestar todo el proceso de entrenamiento.

Nuestra solución:

  • Creamos Levanter, un framework nativo de JAX, desde cero para que fuera el sistema que necesitábamos: determinista a nivel de bits, escalable con estrategias de distribución avanzadas y resistente.

  • Podríamos hacer esto porque JAX es más que una biblioteca; es un "meta-framework" para crear nuevas herramientas. Nos basamos en su soporte maduro y de alto rendimiento para TPU y su integración perfecta de abstracciones de alto nivel (JIT) con control de bajo nivel (Pallas).

  • Este enfoque es común en la comunidad JAX, que ha producido un ecosistema vibrante de bibliotecas como Flax, Equinox, Orbax y Optax que trabajan juntas, lo que permite a equipos como el nuestro crear soluciones potentes.


Una mirada bajo el capó: el viaje del Marin-8B

Los principios, herramientas y bibliotecas discutidos anteriormente se implementaron y pusieron en práctica durante la capacitación de Marin-8B. La arquitectura del modelo es un transformador estilo Llama.


Marin-8B-Base: cómo modelar la arquitectura de un vistazo

Marin 8B-Base model architecture at a glance

En lugar de una carrera estática y monolítica, el entrenamiento de Marin-8B fue un recorrido adaptativo, internamente denominado el proceso "Tootsie". Esta representación honesta de un flujo de trabajo de investigación del mundo real se detalla en el público. El proceso abarcó más de 12 billones de tokens e involucró múltiples fases que se adaptaron a nuevos datos, técnicas e incluso diferentes configuraciones de hardware, migrando entre configuraciones de TPU de múltiples segmentos a gran escala (pods de 2x v5e-256 a 1x v4-2048) durante la ejecución. El equipo ajustó continuamente la mezcla de datos, incorporando fuentes de mayor calidad y ajustando hiperparámetros como la tasa de aprendizaje y el tamaño del lote para optimizar el rendimiento. Esta realidad "desordenada" es una poderosa herramienta educativa, y la capacidad de la pila de JAX y Levanter para manejar estos cambios significativos mientras mantiene la reproducibilidad bit a bit es una poderosa demostración de su solidez.


Únete a la comunidad de Marin

El proyecto Marin es una invitación abierta a participar en el futuro del desarrollo de modelos fundacionales y contribuir al ecosistema JAX. El recorrido de Marín representa la respuesta a nuestra pregunta: "¿qué sigue en cuanto a la apertura?". Esta iniciativa por crear un "laboratorio abierto" es posible gracias a las capacidades técnicas del ecosistema JAX. Su rendimiento, portabilidad y diseño fundacional para la reproducibilidad son los ingredientes clave que nos permiten hacer accesible el 'viaje completo' de la investigación.

Al compartir todo, desde metodologías de datos hasta registros de entrenamiento, nuestro objetivo es proporcionar un recurso totalmente reproducible, que permita a los investigadores analizar, desarrollar y confiar profundamente en el trabajo. Creemos que este es un paso colaborativo hacia un futuro más transparente para la IA. Te invitamos a unirte a nosotros en este "laboratorio abierto": para usar Marin, para contribuir a la investigación y para ayudar a crear la próxima ola de modelos fundacionales innovadores y confiables.

El recurso central del proyecto es el sitio web oficial, marin.community. A partir de ahí, puedes encontrar los modelos lanzados en Hugging Face, explorar el "laboratorio abierto" en GitHub, leer la documentación de Marin y sumergirte en el framework de entrenamiento de Levanter. También puedes probar la conducción de Marin en un Colab con un simple ejemplo de inferencia.

Además, se están produciendo debates activos en el canal de Discord, donde puedes interactuar directamente con otros desarrolladores. Para aquellos que son nuevos en el ecosistema, la documentación oficial de JAX proporciona excelentes recursos, incluida una guía de inicio rápido.