Para desenvolvedores e pesquisadores do ecossistema do JAX, o caminho de um modelo pré-treinado para um LLM totalmente alinhado e pronto para produção acaba de ficar mais simples.
Hoje, temos o prazer de apresentar o Tunix, uma nova biblioteca de código aberto nativa do JAX criada especificamente para o pós-treinamento de LLMs. O Tunix preenche uma lacuna crítica ao fornecer um kit de ferramentas abrangente e otimizado para desenvolvedores para o alinhamento de modelos em escala.
Criado para melhorar o desempenho em TPUs, especialmente quando combinado com o MaxText, o Tunix oferece:
Esta versão inicial traz APIs modulares e fáceis de usar para os fluxos de trabalho de pós-treinamento mais comuns, com integração total ao ecossistema do JAX:
PeftTrainer
é independente de modelo e dá suporte tanto ao ajuste de peso total quanto a métodos populares de ajuste com eficiência de parâmetros, como LoRA e QLoRA (por meio de nossa integração à biblioteca qwix).DPOTrainer
simplifica o alinhamento ao implementar a otimização de preferências diretas (DPO). Essa técnica poderosa usa um conjunto de dados simples de respostas preferidas e rejeitadas, eliminando a necessidade de treinar e gerenciar um modelo de recompensa separado.PPOLearner
: fornece o método ator-crítico definitivo para RLHF ao implementar a otimização de políticas proximais (PPO). Isso é essencial para treinar modelos em tarefas complexas e sequenciais, especialmente para fluxos de trabalho agênticos emergentes que envolvem o uso de ferramentas.GRPOLearner
: oferece um algoritmo de RL altamente eficiente e livre de críticas. Ele implementa a otimização de políticas relativas de grupos (GRPO), que normaliza as recompensas em um grupo de respostas geradas para orientar o modelo sem a complexidade e o custo de um modelo crítico separado.Otimização de políticas de sequência de grupos (GSPO-token)
: oferece uma variante do algoritmo GRPO que fornece mais flexibilidade para ajustar a computação de vantagens no nível do token e pode aumentar a estabilidade para o treinamento RL com várias voltas.DistillationTrainer
habilita a compactação do modelo treinando um modelo "aluno" menor e mais eficiente para replicar os resultados de um modelo "professor" maior. Essa é uma técnica fundamental para implantar modelos de alto desempenho em ambientes de produção com restrições de latência ou custo. O Tunix fornece os seguintes algoritmos de destilação prontos para uso:Criamos vários notebooks em python para ajudar os usuários a embarcar no Tunix. Os resultados abaixo demonstram a eficácia da implementação do GRPO do Tunix. No comparativo de mercado de raciocínio matemático GSM8K, o ajuste do modelo Gemma 2 2B-IT com o Tunix resultou em uma melhoria relativa de aproximadamente 12% na acurácia da resposta pass@1. Observamos ganhos promissores em todas as métricas, o que demonstra a capacidade da biblioteca de alinhar o comportamento do modelo de forma rápida e eficaz.
Para levar em conta a natureza estocástica da geração de texto, avaliamos o desempenho usando pass@1 (busca gulosa) e pass@5 (amostragem com diversidade) para medir a exatidão em uma ou cinco tentativas. Nossa avaliação se concentrou em três métricas principais:
Para validação, nossa acurácia de linha de base de pass@1 de cerca de 52% se aproxima muito dos cerca de 51% reportados pelo LM Eval Harness do Eleuther para o modelo de base, confirmando a validade de nossa configuração. Embora a acurácia absoluta dependa da formatação do prompt (por exemplo, uso de <start_answer> versus <answer>), o aumento significativo do desempenho pós-treinamento permanece consistente em diferentes configurações.
Link to Youtube Video (visible only when JS is disabled)
Dos mais importantes laboratórios acadêmicos às startups de IA, o Tunix já está capacitando a próxima onda de desenvolvimento de ML. Estamos desenvolvendo o Tunix em colaboração com nossos parceiros para resolver desafios de alinhamento de modelos e IA agêntica do mundo real. Veja o que eles têm a dizer:
"Minha pesquisa se concentra no aprendizado centrado em dados, que envolve a preparação de dados de alta qualidade para melhorar o desempenho do modelo, especialmente na fase de pós-treinamento de modelos de linguagem grandes (LLMs). Um dos principais desafios é iterar rapidamente as amostras de dados para identificar quais são úteis e quais não são. Para isso, o Tunix é a biblioteca perfeita. Seu design de "caixa branca" dá à minha equipe controle total sobre o ciclo de treinamento e nos permite modificar e adaptar facilmente o código para nossas necessidades específicas de pesquisa. Essa personalização é uma vantagem significativa em relação a outros frameworks e é crucial para acelerar nossa análise iterativa de dados."
— Hongfu Liu, professor assistente de ciência da computação da Brandeis University; diretor sênior de área para NeurIPS; diretor de área para ICLR
"Um dos grandes gargalos do aprendizado por reforço pós-treinamento é a escassez de ambientes com recompensas verificáveis. Os jogos fornecem um ambiente de várias voltas perfeito para resolver isso, e o Tunix é o framework ideal para essa pesquisa. Ele nos permite criar diretamente no JAX, aproveitando as TPUs e a facilidade de carregamento em paralelo. Em comparação com outras alternativas, o Tunix é uma biblioteca leve com uma base de código clara e gerenciável. Ele oferece personalização de alto nível de modelos e hiperparâmetros sem as camadas de abstração excessivas de outros frameworks. Essa abordagem simplificada é crucial para nosso trabalho, e a curva de aprendizado é suave, porque você não precisa ser um especialista em JAX para ter eficácia."
— Hao Zhang, professor assistente, UC San Diego, cocriador do vLLM, Chatbot Arena (LMSys), e inventor da inferência desagregada
A Precur AI é uma startup que está criando um compilador de agentes que transforma fluxos de trabalho em segundo plano em agentes orientados por código confiáveis e eficientes. Hanjun Dai, cofundador e CTO, diz:
"Nossa empresa se concentra em agentes em segundo plano que funcionam 24 horas por dia, 7 dias por semana, sem supervisão. Uma meta fundamental é a robustez dos agentes, por isso fazemos o pós-treinamento de "kernels de agentes" — os modelos otimizados para tarefas de horizonte longo, mas repetitivas. A amplitude do design do Tunix, que abrange SFT, RL e destilação, nos permite manter toda a nossa pilha de desenvolvimento de agentes unificada. Sua integração nativa com o ecossistema do JAX e de TPUs é uma vantagem significativa. A facilidade de personalização com o Flax para o desenvolvimento e o Qwix para a inferência quantizada faz dele um framework claro e poderoso que se encaixa muito facilmente em nosso fluxo de trabalho."
— Hanjun Dai, cofundador e CTO, PreCur AI
Estamos criando o Tunix abertamente e convidamos você a participar de nossa comunidade, experimentar e contribuir.
É um prazer poder compartilhar o Tunix com a comunidade JAX. Mal podemos esperar para ver o que você vai criar.