Un elemento emocionante de la era actual de la IA es cómo los potentes modelos fundacionales se comparten abiertamente y ayudan a acelerar la innovación para todos. Este progreso nos inspira a preguntarnos: "¿qué sigue en cuanto a la apertura?". El proyecto Marin ve una oportunidad para ampliar la definición de "abierto" a fin de abarcar todo el proceso científico detrás de un modelo.
El proyecto Marin de Stanford CRFM (Centro de Investigación sobre Modelos Fundacionales) está diseñado como un " laboratorio abierto", donde el objetivo no es solo compartir el modelo, sino hacer que el recorrido completo sea accesible, incluido el código, el conjunto de datos, las metodologías de datos, los experimentos, los hiperparámetros y los registros de entrenamiento. Este nivel de transparencia complementa el ecosistema existente al proporcionar un recurso único y totalmente reproducible que permite a los investigadores analizar, desarrollar y confiar en los modelos que se están desarrollando. El proyecto Marin de Stanford busca fomentar un futuro más transparente y accesible para la investigación de modelos fundacionales.
Los primeros lanzamientos de este laboratorio abierto son los modelos Marin-8B-Base y Marin-8B-Instruct. De acuerdo con los principios del proyecto, los modelos, los datos, el código y el tokenizador se publican bajo la licencia permisiva Apache 2.0. Este compromiso con la reproducibilidad completa es un problema de ingeniería formidable, que requiere control sobre todas las fuentes de variación en un sistema distribuido masivamente. El éxito del proyecto depende de una pila de tecnología que pueda ofrecer esta garantía de reproducibilidad a escala y maximizar la eficiencia para entrenar un modelo fundacional con una relación precio/rendimiento líder.
Para que el proyecto Marin tuviera éxito en la creación de modelos fundacionales verdaderamente abiertos, escalables y reproducibles, el equipo de CRFM tuvo que resolver varios desafíos de ingeniería. Eligió a JAX como base porque sus principios de diseño proporcionaban soluciones directas a estos problemas, y creó un nuevo framework, Levanter (ver más abajo), para aprovechar el poder de JAX. Estos son algunos ejemplos de desafíos y sus soluciones:
Problema: el ciclo de entrenamiento central se ejecuta miles de millones de veces, por lo que la sobrecarga de un lenguaje interpretado como Python crea un cuello de botella de rendimiento masivo. Si las operaciones se envían paso a paso, el bucle también puede incurrir en tráfico de memoria excesivo y sobrecarga, especialmente en hardware como TPU, donde el rendimiento depende de la ejecución eficiente de las operaciones fusionadas.
Nuestra solución:
@
jax.jit. El compilador XLA de JAX transforma todo este proceso en un núcleo de código máquina único y altamente optimizado, fusionando las operaciones para maximizar la utilización del hardware a escala.parajax.value_and_grad
para calcular tanto la pérdida como sus gradientes en una sola pasada. JAX también facilita el uso de técnicas avanzadas como el punto de control de gradiente, lo que ahorra memoria y nos permite usar lotes más grandes casi sin gastos generales.Pallas
de JAX, una implementación altamente optimizada de Dot Product Attention, una de las operaciones más críticas en el corazón de casi todos los modelos de lenguaje grandes.Problema: entrenar modelos de vanguardia requiere escalar a miles de chips aceleradores. Administrar manualmente cómo se particionan el modelo y los datos y cómo se comunican los dispositivos es inmensamente complejo, y el código se vuelve rápidamente difícil de leer, depurar y adaptar.
Nuestra solución:
@ jax.jitde
JAX también admite sin problemas la paralelización de un solo programa y datos múltiples (SPMD) que automatiza la fragmentación y comunicación de datos subyacentes. El compilador XLA programa automáticamente la comunicación entre aceleradores para minimizar el tiempo de espera en la red y maximizar el tiempo dedicado al cálculo.JIT
fuera aún más fácil y seguro de usar, Levanter desarrolló Haliax, una biblioteca para tensores con nombre. Al referirse a ejes tensores con nombres legibles por humanos (como "incrustar" o "lote") en lugar de índices posicionales, el código se vuelve autodocumentado y sólido.Problema: la capacitación a gran escala requiere un acceso flexible a clústeres informáticos masivos. Dependemos en gran medida de las instancias de TPU interrumpibles para administrar los costos, lo que significa que necesitamos una manera de combinar fácilmente muchas porciones de TPU más pequeñas y dispares en un grupo lógico y ser resistentes a las interrupciones frecuentes.
Nuestra solución:
Problema: un objetivo central del proyecto Marin es permitir la ciencia verificable. Esto requiere lograr resultados reproducibles, incluso cuando el entrenamiento se pausa, se reinicia o se mueve entre diferentes configuraciones de la máquina, un obstáculo técnico significativo.
Nuestra solución:
Problema: si bien JAX proporciona un motor potente, ningún framework de alto nivel existente cumplía con nuestros estrictos requisitos combinados de legibilidad, escalabilidad masiva y determinismo bit a bit. Necesitábamos un sistema completo y basado en convenciones para orquestar todo el proceso de entrenamiento.
Nuestra solución:
JIT
) con control de bajo nivel (Pallas
).Los principios, herramientas y bibliotecas discutidos anteriormente se implementaron y pusieron en práctica durante la capacitación de Marin-8B. La arquitectura del modelo es un transformador estilo Llama.
En lugar de una carrera estática y monolítica, el entrenamiento de Marin-8B fue un recorrido adaptativo, internamente denominado el proceso "Tootsie". Esta representación honesta de un flujo de trabajo de investigación del mundo real se detalla en el público. El proceso abarcó más de 12 billones de tokens e involucró múltiples fases que se adaptaron a nuevos datos, técnicas e incluso diferentes configuraciones de hardware, migrando entre configuraciones de TPU de múltiples segmentos a gran escala (pods de 2x v5e-256 a 1x v4-2048) durante la ejecución. El equipo ajustó continuamente la mezcla de datos, incorporando fuentes de mayor calidad y ajustando hiperparámetros como la tasa de aprendizaje y el tamaño del lote para optimizar el rendimiento. Esta realidad "desordenada" es una poderosa herramienta educativa, y la capacidad de la pila de JAX y Levanter para manejar estos cambios significativos mientras mantiene la reproducibilidad bit a bit es una poderosa demostración de su solidez.
El proyecto Marin es una invitación abierta a participar en el futuro del desarrollo de modelos fundacionales y contribuir al ecosistema JAX. El recorrido de Marín representa la respuesta a nuestra pregunta: "¿qué sigue en cuanto a la apertura?". Esta iniciativa por crear un "laboratorio abierto" es posible gracias a las capacidades técnicas del ecosistema JAX. Su rendimiento, portabilidad y diseño fundacional para la reproducibilidad son los ingredientes clave que nos permiten hacer accesible el 'viaje completo' de la investigación.
Al compartir todo, desde metodologías de datos hasta registros de entrenamiento, nuestro objetivo es proporcionar un recurso totalmente reproducible, que permita a los investigadores analizar, desarrollar y confiar profundamente en el trabajo. Creemos que este es un paso colaborativo hacia un futuro más transparente para la IA. Te invitamos a unirte a nosotros en este "laboratorio abierto": para usar Marin, para contribuir a la investigación y para ayudar a crear la próxima ola de modelos fundacionales innovadores y confiables.
El recurso central del proyecto es el sitio web oficial, marin.community. A partir de ahí, puedes encontrar los modelos lanzados en Hugging Face, explorar el "laboratorio abierto" en GitHub, leer la documentación de Marin y sumergirte en el framework de entrenamiento de Levanter. También puedes probar la conducción de Marin en un Colab con un simple ejemplo de inferencia.
Además, se están produciendo debates activos en el canal de Discord, donde puedes interactuar directamente con otros desarrolladores. Para aquellos que son nuevos en el ecosistema, la documentación oficial de JAX proporciona excelentes recursos, incluida una guía de inicio rápido.