Modelo de base Marin da Stanford: o primeiro modelo totalmente aberto desenvolvido usando JAX

16 DE JULHO DE 2025
Srikanth Kilaru Senior Product Manager Google ML Frameworks
David Hall Research Engineering Lead Stanford HAI

Um elemento empolgante da atual era da IA é como modelos de base poderosos estão sendo compartilhados abertamente e ajudando a acelerar a inovação para todos. Esse progresso nos inspira a questionar: "Qual é o próximo passo da abertura?". O projeto Marin vê uma oportunidade de expandir a definição de "aberto" para abranger todo o processo científico por trás de um modelo.

O projeto Marin do CRFM (Center for Research on Foundation Models) da Stanford foi elaborado para ser um "laboratório aberto" no qual a meta não é apenas compartilhar o modelo, mas tornar toda a jornada acessível, incluindo código, conjunto de dados, metodologias de dados, experimentos, hiperparâmetros e registros de treinamento. Esse nível de transparência complementa o ecossistema existente ao fornecer um recurso exclusivo e totalmente reproduzível que capacita os pesquisadores a examinar, desenvolver e confiar nos modelos que estão sendo desenvolvidos. O projeto Marin da Stanford visa promover um futuro mais transparente e acessível para a pesquisa de modelos de base.


O espectro de abertura de modelos de IA

The Spectrum of AI Model Openness

Os primeiros lançamentos desse laboratório aberto são os modelos Marin-8B-Base e Marin-8B-Instruct. De acordo com os princípios do projeto, os modelos, os dados, o código e o tokenizador são todos lançados sob a licença permissiva do Apache 2.0. Esse compromisso com a total reprodutibilidade é um problema de engenharia formidável que exige controle sobre todas as origens de variação em um sistema massivamente distribuído. O sucesso do projeto depende de uma pilha de tecnologia capaz de oferecer essa garantia de reprodutibilidade em escala e de maximizar a eficiência para treinar um modelo de base com uma relação ímpar de preço/desempenho.


Principais desafios da criação de modelos de base abertos

Para que o projeto Marin tivesse sucesso na criação de modelos de base verdadeiramente abertos, escalonáveis e reproduzíveis, a equipe do CRFM teve que solucionar vários desafios de engenharia. A equipe escolheu o JAX como base porque seus princípios de design forneciam soluções diretas para esses problemas e criaram um novo framework, o Levanter (veja abaixo) para aproveitar o poder de JAX. Veja a seguir alguns exemplos de desafios e suas soluções.


Atingir a velocidade máxima em um único acelerador

Problema: o ciclo de treinamento principal é executado bilhões de vezes, portanto, o overhead de uma linguagem interpretada, como o Python, cria um gargalo de desempenho enorme. Se as operações forem despachadas passo a passo, o ciclo também poderá gerar um overhead excessivo de tráfego e memória, principalmente em hardware como as TPUs, nas quais a capacidade de processamento depende da execução eficiente de operações mescladas.

Nossa solução:

  • Para eliminar o overhead do interpretador, o Levanter encapsula toda a etapa de treinamento em várias fases (passagem progressiva, perda, retropropagação e atualização) em uma única função e usa o decorator @jax.jit. O compilador XLA da JAX transforma todo esse processo em um único kernel de código de máquina altamente otimizado, mesclando operações para maximizar a utilização do hardware em escala.

  • Para evitar cálculos redundantes, usamos jax.value_and_grad para calcular a perda e seus gradientes em uma única passagem. O JAX também facilita o uso de técnicas avançadas, como o gradient checkpointing, economizando memória e nos permitindo usar lotes maiores com quase nenhum overhead.

  • O Levanter também usa o poderoso kernel Splash Attention baseado em Pallas do JAX, uma implementação altamente otimizada do Dot Product Attention, uma das operações mais críticas no centro de quase todos os grandes modelos de linguagem.


Gerenciamento da complexidade do paralelismo em larga escala

Problema: o treinamento de modelos de última geração requer o escalonamento para milhares de chips aceleradores. Gerenciar manualmente a forma como o modelo e os dados são particionados e como os dispositivos se comunicam é imensamente complexo, e o código se torna rapidamente difícil de ler, depurar e adaptar.

Nossa solução:

  • O decorador @jax.jit do JAX também dá suporte total ao carregamento em paralelo de Single-Program, Multiple-Data (SPMD), que automatiza a fragmentação e a comunicação de dados subjacentes. O compilador XLA agenda automaticamente a comunicação entre os aceleradores para minimizar o tempo gasto em espera na rede e maximizar o tempo gasto em computação.

  • Para tornar o poder do jit ainda mais simples e seguro de usar, o Levanter desenvolveu o Haliax, uma biblioteca para tensores nomeados. Ao se referir a eixos de tensores com nomes legíveis por humanos (como "embed" ou "batch") em vez de índices de posição, o código se torna autodocumentado e robusto.

  • Essa abstração nos permite definir e modificar estratégias sofisticadas de fragmentação, como Fully Sharded Data Parallelism (FSDP) e paralelismo de tensores, simplesmente alterando algumas linhas em um arquivo de configuração, sem jamais tocar no código do modelo.


Criação e gerenciamento de clusters de computação resilientes e econômicos

Problema: o treinamento em larga escala requer acesso flexível a clusters de computação massivos. Dependemos muito de instâncias de TPU preemptivas para gerenciar custos, o que significa que precisamos de uma maneira simples de combinar muitas frações de TPU menores e díspares em um cluster lógico e ser resilientes a interrupções frequentes.

Nossa solução:

  • Utilizamos o Google Cloud TPU Multislice, uma tecnologia que permite que um job de treinamento use várias frações de TPU como se fossem um grande sistema. Isso facilita a junção de muitas frações pequenas de TPU preemptivas em um único e poderoso cluster de computação para treinamento.

  • O Levanter usa o Ray para orquestrar esse processo, escalonando de maneira otimizada o número de frações de TPU para cima ou para baixo durante um job de treinamento e, o que é mais importante, garantindo que o job permaneça resiliente se qualquer fração sofrer uma interrupção forçada.

  • Graças ao JAX e ao XLA, o Levanter e o Marin também conseguiram obter resultados de alto desempenho semelhantes em GPUs.


Promoção da confiança científica com total reprodutibilidade

Problema: uma meta central do projeto Marin é permitir a ciência verificável. Isso requer atingir resultados reproduzíveis, mesmo quando o treinamento é pausado, reiniciado ou movimentado entre diferentes configurações de máquina, o que é um desafio técnico significativo.

Nossa solução:

  • Esse era um requisito fundamental para o projeto do Levanter. Escolhemos o JAX especificamente por suas sólidas garantias de reprodutibilidade, como o uso padrão de geradores de números pseudoaleatórios (PRNGs, na sigla em inglês) determinísticos.

  • Essa escolha foi validada durante o treinamento do Marin-8B, que envolveu a migração entre diferentes frações de TPU e tipos de hardware, mantendo a reprodutibilidade bit por bit entre as preempções.

  • O Levanter também inclui um sistema robusto de carregamento de dados criado com base na biblioteca Tensorstore do Google. O repositório de dados do Levanter oferece acesso determinístico e aleatório a qualquer lote de dados de treinamento, independentemente de reinicializações de jobs ou mudanças de origens de dados, que são essenciais para o suporte a estratégias avançadas de treinamento, como o treinamento intermediário. O determinismo do JAX e o repositório de dados do Levanter também facilitam para os pesquisadores de interpretabilidade a compreensão de como dados específicos impactam o modelo durante o treinamento.


Criação de um framework coeso

Problema: embora o JAX forneça um mecanismo poderoso, nenhum framework de nível alto existente atendeu a nossos requisitos rigorosos e combinados de legibilidade, escalonabilidade massiva e determinismo em termos de bits. Precisávamos de um sistema completo e assertivo para orquestrar todo o processo de treinamento.

Nossa solução:

  • Criamos o Levanter, um framework nativo do JAX, a partir do zero para ser o sistema de que precisávamos: determinístico em termos de bits, escalonável com estratégias avançadas de distribuição e resiliente.

  • Pudemos fazer isso porque o JAX é mais do que apenas uma biblioteca; é um "metaframework" para a criação de novas ferramentas. A criação foi feita com base em seu suporte maduro e de alto desempenho a TPUs e sua integração total de abstrações de nível alto (jit) com controle de nível baixo (Pallas).

  • Essa abordagem é comum na comunidade JAX, que produziu um ecossistema vibrante de bibliotecas como Flax, Equinox, Orbax e Optax, que trabalham juntas, permitindo que equipes como a nossa criem soluções poderosas.


Nos bastidores: a jornada do Marin-8B

Os princípios, as ferramentas e as bibliotecas discutidos acima foram implementados e colocados em prática durante a execução do treinamento do Marin-8B. A arquitetura do modelo é um transformador no estilo Llama.


Marin-8B-Base: visão geral da arquitetura do modelo

Marin 8B-Base model architecture at a glance

Em vez de ser uma execução estática e monolítica, o treinamento do Marin-8B foi uma jornada adaptativa, apelidada internamente de processo "Tootsie". Essa representação honesta de um fluxo de trabalho de pesquisa do mundo real é detalhada em público. O processo abrangeu mais de 12 trilhões de tokens e envolveu várias fases que se adaptaram a novos dados, técnicas e até mesmo diferentes configurações de hardware, migrando entre configurações de TPU com várias frações e em larga escala (pods 2x v5e-256 a 1x v4-2048) no fluxo intermediário. A equipe refinou continuamente a combinação de dados, incorporando origens de qualidade mais alta e hiperparâmetros ajustados, como taxa de aprendizado e tamanho do lote, para otimizar o desempenho. Essa realidade "confusa" é uma ferramenta educacional poderosa, e a capacidade da pilha do JAX e do Levanter de lidar com essas mudanças significativas, mantendo a reprodutibilidade bit a bit, é uma demonstração poderosa de sua robustez.


Participe da comunidade do Marin

O projeto Marin é um convite aberto para participar do futuro do desenvolvimento de modelos de base e para contribuir com o ecossistema do JAX. A jornada do Marin representa a resposta à nossa pergunta: "Qual é o próximo passo da abertura?". Esse esforço para criar um "laboratório aberto" é possível graças às capacidades técnicas do ecossistema do JAX. Seu desempenho, sua portabilidade e seu design fundamental para a reprodutibilidade são os principais ingredientes que nos permitem tornar acessível a "jornada completa" da pesquisa.

Ao compartilhar tudo, desde metodologias de dados até registros de treinamento, pretendemos fornecer um recurso totalmente reproduzível, que capacite os pesquisadores a examinar, desenvolver e confiar profundamente no trabalho. Acreditamos que este é um passo colaborativo em direção a um futuro mais transparente para a IA. Convidamos você a se juntar a nós nesse "laboratório aberto", a usar o Marin, a contribuir com a pesquisa e a ajudar a construir a próxima onda de modelos de base inovadores e confiáveis.

O recurso central do projeto é o site oficial, marin.community. A partir daí, você pode encontrar os modelos lançados na Hugging Face, explorar o "laboratório aberto" no GitHub, ler a documentação do Marin e mergulhar no framework de treinamento do Levanter. Você também pode testar o Marin em um colab com um exemplo simples de inferência.

E discussões ativas estão ocorrendo no canal do Discord, onde você pode interagir diretamente com outros desenvolvedores. Para aqueles que são novos no ecossistema, a documentação oficial do JAX fornece excelentes recursos, incluindo um guia de início rápido.