Para los desarrolladores e investigadores que trabajan en el ecosistema de JAX, el camino de un modelo preentrenado a un modelo de lenguaje grande (LLM) totalmente alineado y listo para la producción ahora es mucho más simple.
Hoy, nos complace presentar Tunix, una nueva biblioteca de código abierto nativa de JAX, creada específicamente para el entrenamiento posterior de LLM. Tunix cierra una brecha crítica, ya que proporciona un conjunto de herramientas integral y fácil de usar para alinear modelos a escala.
Tunix, que se diseñó para optimizar el rendimiento en TPU, especialmente cuando se combina con MaxText, ofrece lo siguiente:
Esta versión inicial incluye APIs modulares y fáciles de usar para los flujos de trabajo de posentrenamiento más comunes, que se integran a la perfección con el ecosistema de JAX:
PeftTrainer
se adapta a cualquier modelo y admite métodos de ajuste de peso completo y de ajuste de parámetros populares como LoRA y QLoRA (a través de nuestra integración con la biblioteca qwix).DPOTrainer
agiliza la alineación mediante la implementación de la optimización de preferencias directas (DPO). Esta técnica poderosa utiliza un conjunto simple de datos de respuestas preferidas y rechazadas, con lo que evita la necesidad de entrenar y administrar un modelo de recompensa separado.PPOLearner
: Proporciona el método estándar actor-crítico para RLHF mediante la implementación de la optimización de políticas proximales (PPO). Estas características son fundamentales para los modelos de entrenamiento en tareas complejas y secuenciales, especialmente para flujos de trabajo ágiles emergentes que involucran el uso de herramientas.GRPOLearner
: Ofrece un algoritmo de RL muy eficiente y sin críticos. Implementa la optimización de políticas relativas de grupo (GRPO), que normaliza las recompensas en un grupo de respuestas generadas para guiar el modelo sin la complejidad y el costo de un modelo crítico separado.Optimización de políticas de secuencia de grupo (GSPO-token)
: Ofrece una variante del algoritmo GRPO que proporciona una mejor flexibilidad para ajustar el cálculo de ventajas a nivel del token y puede mejorar la estabilidad para el entrenamiento de RL de varios turnos.DistillationTrainer
permite la compresión de modelos entrenando a un modelo “estudiante” más pequeño y eficiente para replicar los resultados de un modelo “maestro” más grande. Esta es una técnica fundamental para implementar modelos de alto rendimiento en entornos de producción con una latencia ajustada o restricciones de costos. Tunix proporciona los siguientes algoritmos de destilación listos para usar:Creamos varios notebooks de Python para ayudar a los usuarios a incorporar Tunix. Los resultados que se incluyen a continuación demuestran la efectividad de la implementación de GRPO de Tunix. En el punto de referencia de razonamiento matemático GSM8K, ajustar el modelo Gemma 2 2B-IT con Tunix dio como resultado una mejora relativa de ~ 12% en la precisión de la respuesta pass@1. Observamos avances prometedores en todas las métricas, mostrando la capacidad de la biblioteca para alinear de manera rápida y eficaz el comportamiento del modelo.
Para dar cuenta de la naturaleza estocástica de la generación de texto, evaluamos el rendimiento utilizando pass@1 (búsqueda voraz) y pass@5 (muestreo con diversidad) para medir la corrección en uno o cinco intentos. Nuestra evaluación se centró en tres métricas clave:
Para la validación, nuestra exactitud inicial de pass@1 de ~52% se alinea estrechamente con el ~51% informado por el arnés de evaluación de LM de Eleuther para el modelo base, lo que confirma la validez de nuestra configuración. Si bien la exactitud absoluta es sensible al formato de las indicaciones (p. ej., si se usa <START_answer> o <a answer>), el aumento significativo del rendimiento posterior al entrenamiento sigue siendo constante en diferentes entornos.
Link to Youtube Video (visible only when JS is disabled)
Tunix ya está potenciando la próxima ola de desarrollo de ML, desde laboratorios académicos líderes hasta nuevas empresas de IA. Estamos desarrollando Tunix en colaboración con nuestros socios para resolver los desafíos del mundo real en el campo de la alineación de modelos y la IA agéntica. Esto es lo que opinan nuestros colaboradores:
“Mi investigación se centra en el aprendizaje centrado en datos, que implica la preparación de datos de alta calidad para mejorar el rendimiento del modelo, especialmente en la fase posterior al entrenamiento de modelos de lenguaje grandes (LLM). Un desafío clave es iterar rápidamente en las muestras de datos para identificar cuáles son útiles y cuáles no. Para ello, Tunix es la biblioteca perfecta. Su diseño de tipo “white-box” le da a mi equipo un control total sobre el ciclo de entrenamiento, lo que nos permite modificar y adaptar fácilmente el código a nuestras necesidades de investigación específicas. Esta personalización es una importante ventaja sobre otros marcos y es fundamental para acelerar nuestro análisis de datos iterativo”.
— Hongfu Liu, profesor asistente de Ciencias de la Computación, Universidad Brandeis; catedrático de área senior para NeurIPS; catedrático de área para ICLR
“Uno de los principales cuellos de botella en el aprendizaje por refuerzo posterior al entrenamiento es la escasez de entornos con recompensas verificables. Los videojuegos proporcionan un entorno perfecto de varios turnos para resolver este problema y Tunix es el marco de trabajo ideal para esta investigación. Nos permite compilar directamente sobre JAX, aprovechando las TPU y la carga en paralelo sencilla. En comparación con otras alternativas, Tunix es una biblioteca ligera con una base de código limpia y manejable. Ofrece una personalización de alto nivel de modelos e hiperparámetros, sin las capas de abstracción excesivas de otros marcos de trabajo. Este enfoque simplificado es crucial para nuestro trabajo y descubrimos que la curva de aprendizaje es suave, ya que no es necesario tener experiencia en JAX para lograr eficacia”.
— Hao Zhang, profesor asistente, UC San Diego, cocreador de vLLM, Chatbot Arena (LMSys) e inventor de servicio desagregado
Precur AI es una startup que crea un compilador de agentes que transforma los flujos de trabajo en segundo plano en agentes basados en código confiables y eficientes. Hanjun Dai, cofundador y director de tecnología, comenta lo siguiente:
“Nuestra empresa se centra en los agentes en segundo plano que funcionan las 24 horas del día, los 7 días de la semana, sin supervisión. Un objetivo clave es la solidez de los agentes, por lo que posentrenamos los "núcleos de agentes": los modelos optimizados para tareas de largo plazo pero repetitivas. La amplitud de diseño de Tunix, que abarca SFT, RL y destilación, nos permite mantener unificada toda nuestra pila de desarrollo de agentes. La integración nativa con el ecosistema de JAX y TPU es una ventaja importante. Debido a la facilidad de personalización con Flax para el desarrollo y Qwix para el servicio, es un marco de trabajo limpio y potente que se adapta muy fácilmente a nuestro flujo de trabajo”.
— Hanjun Dai, cofundador y director de tecnología de PreCur AI
Estamos compilando Tunix con código abierto y te invitamos a unirte a nuestra comunidad, probarlo y colaborar.
Nos complace compartir Tunix con la comunidad de JAX y esperamos ver tus creaciones.