Um elemento empolgante da atual era da IA é como modelos de base poderosos estão sendo compartilhados abertamente e ajudando a acelerar a inovação para todos. Esse progresso nos inspira a questionar: "Qual é o próximo passo da abertura?". O projeto Marin vê uma oportunidade de expandir a definição de "aberto" para abranger todo o processo científico por trás de um modelo.
O projeto Marin do CRFM (Center for Research on Foundation Models) da Stanford foi elaborado para ser um "laboratório aberto" no qual a meta não é apenas compartilhar o modelo, mas tornar toda a jornada acessível, incluindo código, conjunto de dados, metodologias de dados, experimentos, hiperparâmetros e registros de treinamento. Esse nível de transparência complementa o ecossistema existente ao fornecer um recurso exclusivo e totalmente reproduzível que capacita os pesquisadores a examinar, desenvolver e confiar nos modelos que estão sendo desenvolvidos. O projeto Marin da Stanford visa promover um futuro mais transparente e acessível para a pesquisa de modelos de base.
Os primeiros lançamentos desse laboratório aberto são os modelos Marin-8B-Base e Marin-8B-Instruct. De acordo com os princípios do projeto, os modelos, os dados, o código e o tokenizador são todos lançados sob a licença permissiva do Apache 2.0. Esse compromisso com a total reprodutibilidade é um problema de engenharia formidável que exige controle sobre todas as origens de variação em um sistema massivamente distribuído. O sucesso do projeto depende de uma pilha de tecnologia capaz de oferecer essa garantia de reprodutibilidade em escala e de maximizar a eficiência para treinar um modelo de base com uma relação ímpar de preço/desempenho.
Para que o projeto Marin tivesse sucesso na criação de modelos de base verdadeiramente abertos, escalonáveis e reproduzíveis, a equipe do CRFM teve que solucionar vários desafios de engenharia. A equipe escolheu o JAX como base porque seus princípios de design forneciam soluções diretas para esses problemas e criaram um novo framework, o Levanter (veja abaixo) para aproveitar o poder de JAX. Veja a seguir alguns exemplos de desafios e suas soluções.
Problema: o ciclo de treinamento principal é executado bilhões de vezes, portanto, o overhead de uma linguagem interpretada, como o Python, cria um gargalo de desempenho enorme. Se as operações forem despachadas passo a passo, o ciclo também poderá gerar um overhead excessivo de tráfego e memória, principalmente em hardware como as TPUs, nas quais a capacidade de processamento depende da execução eficiente de operações mescladas.
Nossa solução:
@jax.jit
. O compilador XLA da JAX transforma todo esse processo em um único kernel de código de máquina altamente otimizado, mesclando operações para maximizar a utilização do hardware em escala.jax.value_and_grad
para calcular a perda e seus gradientes em uma única passagem. O JAX também facilita o uso de técnicas avançadas, como o gradient checkpointing, economizando memória e nos permitindo usar lotes maiores com quase nenhum overhead.Pallas
do JAX, uma implementação altamente otimizada do Dot Product Attention, uma das operações mais críticas no centro de quase todos os grandes modelos de linguagem.Problema: o treinamento de modelos de última geração requer o escalonamento para milhares de chips aceleradores. Gerenciar manualmente a forma como o modelo e os dados são particionados e como os dispositivos se comunicam é imensamente complexo, e o código se torna rapidamente difícil de ler, depurar e adaptar.
Nossa solução:
@jax.jit
do JAX também dá suporte total ao carregamento em paralelo de Single-Program, Multiple-Data (SPMD), que automatiza a fragmentação e a comunicação de dados subjacentes. O compilador XLA agenda automaticamente a comunicação entre os aceleradores para minimizar o tempo gasto em espera na rede e maximizar o tempo gasto em computação.jit
ainda mais simples e seguro de usar, o Levanter desenvolveu o Haliax, uma biblioteca para tensores nomeados. Ao se referir a eixos de tensores com nomes legíveis por humanos (como "embed" ou "batch") em vez de índices de posição, o código se torna autodocumentado e robusto.Problema: o treinamento em larga escala requer acesso flexível a clusters de computação massivos. Dependemos muito de instâncias de TPU preemptivas para gerenciar custos, o que significa que precisamos de uma maneira simples de combinar muitas frações de TPU menores e díspares em um cluster lógico e ser resilientes a interrupções frequentes.
Nossa solução:
Problema: uma meta central do projeto Marin é permitir a ciência verificável. Isso requer atingir resultados reproduzíveis, mesmo quando o treinamento é pausado, reiniciado ou movimentado entre diferentes configurações de máquina, o que é um desafio técnico significativo.
Nossa solução:
Problema: embora o JAX forneça um mecanismo poderoso, nenhum framework de nível alto existente atendeu a nossos requisitos rigorosos e combinados de legibilidade, escalonabilidade massiva e determinismo em termos de bits. Precisávamos de um sistema completo e assertivo para orquestrar todo o processo de treinamento.
Nossa solução:
jit
) com controle de nível baixo (Pallas
).Os princípios, as ferramentas e as bibliotecas discutidos acima foram implementados e colocados em prática durante a execução do treinamento do Marin-8B. A arquitetura do modelo é um transformador no estilo Llama.
Em vez de ser uma execução estática e monolítica, o treinamento do Marin-8B foi uma jornada adaptativa, apelidada internamente de processo "Tootsie". Essa representação honesta de um fluxo de trabalho de pesquisa do mundo real é detalhada em público. O processo abrangeu mais de 12 trilhões de tokens e envolveu várias fases que se adaptaram a novos dados, técnicas e até mesmo diferentes configurações de hardware, migrando entre configurações de TPU com várias frações e em larga escala (pods 2x v5e-256 a 1x v4-2048) no fluxo intermediário. A equipe refinou continuamente a combinação de dados, incorporando origens de qualidade mais alta e hiperparâmetros ajustados, como taxa de aprendizado e tamanho do lote, para otimizar o desempenho. Essa realidade "confusa" é uma ferramenta educacional poderosa, e a capacidade da pilha do JAX e do Levanter de lidar com essas mudanças significativas, mantendo a reprodutibilidade bit a bit, é uma demonstração poderosa de sua robustez.
O projeto Marin é um convite aberto para participar do futuro do desenvolvimento de modelos de base e para contribuir com o ecossistema do JAX. A jornada do Marin representa a resposta à nossa pergunta: "Qual é o próximo passo da abertura?". Esse esforço para criar um "laboratório aberto" é possível graças às capacidades técnicas do ecossistema do JAX. Seu desempenho, sua portabilidade e seu design fundamental para a reprodutibilidade são os principais ingredientes que nos permitem tornar acessível a "jornada completa" da pesquisa.
Ao compartilhar tudo, desde metodologias de dados até registros de treinamento, pretendemos fornecer um recurso totalmente reproduzível, que capacite os pesquisadores a examinar, desenvolver e confiar profundamente no trabalho. Acreditamos que este é um passo colaborativo em direção a um futuro mais transparente para a IA. Convidamos você a se juntar a nós nesse "laboratório aberto", a usar o Marin, a contribuir com a pesquisa e a ajudar a construir a próxima onda de modelos de base inovadores e confiáveis.
O recurso central do projeto é o site oficial, marin.community. A partir daí, você pode encontrar os modelos lançados na Hugging Face, explorar o "laboratório aberto" no GitHub, ler a documentação do Marin e mergulhar no framework de treinamento do Levanter. Você também pode testar o Marin em um colab com um exemplo simples de inferência.
E discussões ativas estão ocorrendo no canal do Discord, onde você pode interagir diretamente com outros desenvolvedores. Para aqueles que são novos no ecossistema, a documentação oficial do JAX fornece excelentes recursos, incluindo um guia de início rápido.