Se você já se perguntou o que é preciso para criar um modelo de linguagem a partir do zero com o JAX, esta postagem é para você. Recentemente, realizei um workshop sobre esse tema, no Cloud Next 2025, e recebi ótimos feedbacks, por isso pensei em escrever um guia para todos que não puderam participar.
Neste artigo e no exemplo de código, você vai criar e pré-treinar um modelo GPT-2, mostrando como o JAX facilita o aproveitamento do poder das TPUs do Google. Você pode executar todo o projeto gratuitamente usando as TPUs no Colab ou no Kaggle e pode encontrar o notebook completo aqui.
Este é um tutorial prático, portanto, vou considerar que você tenha familiaridade com conceitos gerais de aprendizado de máquina. Se o JAX é uma novidade para você, o guia do desenvolvedor do PyTorch sobre os fundamentos do JAX é um ótimo lugar para começar.
Primeiro, vamos dar uma olhada rápida nas ferramentas que serão utilizadas.
Antes de começarmos a criar o modelo, vamos falar rapidamente sobre o ecossistema do JAX. O ecossistema do JAX adota uma abordagem modular, com o núcleo do JAX fornecendo os principais recursos de processamento numérico e uma vasta coleção de bibliotecas criadas com base nele para atender a diferentes necessidades específicas de aplicativos, por exemplo, o Flax para a criação de redes neurais, o Orbax para a definição de pontos de verificação e persistência de modelos e o Optax para otimização (usaremos todos os três neste artigo). Transformações de funções integradas, como autograd, vetorização e compilação JIT, além de um sólido desempenho e APIs fáceis de usar, tornam o JAX perfeito para treinar modelos de linguagem grandes.
Há algum tempo, a OpenAI lançou o código e os pesos do modelo GPT2, que são boas referências, e há muitos esforços da comunidade, como o nanoGPT, para replicar o modelo. Este é um diagrama geral da arquitetura de modelo para o GPT2:
Vamos usar o NNX (a nova interface do Flax) para criar o modelo GPT2. Para resumir, vamos nos concentrar no bloco do transformador, que é a chave para os modelos de linguagem grandes modernos. O bloco do transformador captura dependências de longo alcance em qualquer sequência e cria uma compreensão contextual abrangente. Um bloco de transformador do GPT2 consiste em duas camadas LayerNorm, uma camada atenção multicabeças (MHA, na sigla em inglês), duas camadas dropout, duas camadas de projeção linear e duas conexões residuais. Então, primeiro definimos essas camadas na função __init__
da classe TransformerBlock
:
class TransformerBlock(nnx.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
ff_dim: int,
dropout_rate: float,
rngs: nnx.Rngs,
):
self.layer_norm1 = nnx.LayerNorm(
epsilon=1e-6, num_features=embed_dim, rngs=rngs
)
self.mha = nnx.MultiHeadAttention(
num_heads=num_heads, in_features=embed_dim, rngs=rngs
)
self.dropout1 = nnx.Dropout(rate=dropout_rate)
self.layer_norm2 = nnx.LayerNorm(
epsilon=1e-6, num_features=embed_dim, rngs=rngs
)
self.linear1 = nnx.Linear(
in_features=embed_dim, out_features=ff_dim, rngs=rngs
)
self.linear2 = nnx.Linear(
in_features=ff_dim, out_features=embed_dim, rngs=rngs
)
self.dropout2 = nnx.Dropout(rate=dropout_rate)
Em seguida, montamos essas camadas na função __call__
:
class TransformerBlock(nnx.Module):
def __call__(self, inputs, training: bool = False):
input_shape = inputs.shape
bs, seq_len, emb_sz = input_shape
attention_output = self.mha(
inputs_q=self.layer_norm1(inputs),
mask=causal_attention_mask(seq_len),
decode=False,
)
x = inputs + self.dropout1(
attention_output, deterministic=not training
)
# MLP
mlp_output = self.linear1(self.layer_norm2(x))
mlp_output = nnx.gelu(mlp_output)
mlp_output = self.linear2(mlp_output)
mlp_output = self.dropout2(
mlp_output, deterministic=not training
)
return x + mlp_output
Esse código deverá parecer muito familiar se você já tiver usado qualquer outro framework de ML, como o PyTorch ou o TensorFlow, para treinar um modelo de linguagem. Mas uma das coisas de que eu realmente gosto no JAX é que ele tem a incrível capacidade de executar o código em paralelo automaticamente via SPMD (Single Program Multiple Data), o que é necessário porque vamos executar o código em vários aceleradores (vários núcleos de TPU). Vejamos como isso funciona.
Para executar o SPMD, primeiro precisamos ter certeza de que estamos usando TPUs. Escolha o ambiente de execução de TPU se estiver usando o Colab ou o Kaggle (você também pode usar uma VM do Cloud TPU).
import jax
jax.devices()
# Free-tier Colab offers TPU v2:
# [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
# TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
# TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
# TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
# TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
# TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
# TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
# TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
O Colab e o Kaggle oferecem uma TPU v2 ou v3, que tem oito núcleos de TPU separados. Veja como é uma bandeja do TPU v3:
Para treinar o modelo GPT2 de forma eficiente, executaremos todos os núcleos da TPU juntos via SPMD e utilizaremos o paralelismo de dados no JAX. Para isso, definimos uma malha de hardware:
mesh = jax.make_mesh((8, 1), ('batch', 'model'))
Pense na malha como uma matriz 2D de aceleradores. Neste caso, definimos dois eixos para a malha: o eixo batch
e o eixo model
. Então, no total, temos 8 x 1, que são oito núcleos. Esses eixos determinam como particionamos os dados e os parâmetros do modelo. Podemos mudar os eixos mais tarde se quisermos experimentar outros esquemas de paralelismo.
Agora, alteramos a função __init__
dizendo ao JAX como gostaríamos de particionar os parâmetros do modelo usando o eixo "model". Isso é feito pela adição de nnx.with_partitioning
durante a inicialização os tensores de peso: fragmentamos os tensores de peso 1D, como tensores de escala/viés do LayerNorm, diretamente no eixo "model" enquanto os tensores de peso 2D, como os tensores de kernel Linear e MHA, têm a 2ª dimensão fragmentada no eixo model
.
class TransformerBlock(nnx.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
ff_dim: int,
dropout_rate: float,
rngs: nnx.Rngs,
):
self.layer_norm1 = nnx.LayerNorm(
epsilon=1e-6, num_features=embed_dim,rngs=rngs, rngs=rngs,
scale_init=nnx.with_partitioning(
nnx.initializers.ones_init(),
("model"),
),
bias_init=nnx.with_partitioning(
nnx.initializers.zeros_init(),
("model"),
),
)
self.mha = nnx.MultiHeadAttention(
num_heads=num_heads, in_features=embed_dim,
kernel_init=nnx.with_partitioning(
nnx.initializers.xavier_uniform(),
(None, "model"),
),
bias_init=nnx.with_partitioning(
nnx.initializers.zeros_init(),
("model"),
),
)
# Other layers in the block are omitted for brevity
Precisamos particionar outras camadas como essa para que possamos ativar o paralelismo do tensor model para todo o modelo GPT2. Mesmo que não usemos o paralelismo do tensor model neste tutorial, ainda é uma boa ideia implementar isso porque o tamanho do modelo pode aumentar e talvez precisemos particionar nossos parâmetros do modelo no futuro. Ter isso já implementado nos permite alterar apenas uma linha de código e executar imediatamente os modelos maiores. Por exemplo,
mesh = jax.make_mesh((4, 2), ('batch', 'model'))
Em seguida, definimos as funções loss_fn
e train_step
de forma semelhante ao blog anterior. A função train_step()
calcula os gradientes da função de perda de entropia cruzada e atualiza os pesos por meio do otimizador, e ela será chamada em uma repetição para treinar o modelo. Para alcançar o melhor desempenho possível, estamos compilando em JIT as duas funções usando o decorador @nnx.jit
, uma vez que elas exigem muito da computação.
@nnx.jit
def loss_fn(model, batch):
logits = model(batch[0])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch[1]
).mean()
return loss, logits
@nnx.jit
def train_step(
model: nnx.Module,
optimizer: nnx.Optimizer,
metrics: nnx.MultiMetric,
batch,
):
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(model, batch)
metrics.update(loss=loss, logits=logits, lables=batch[1])
optimizer.update(grads)
Para o otimizador, estamos usando o AdamW da Optax com um schedule de redução de cosseno. Você também pode experimentar outros otimizadores ou schedules no Optax.
schedule = optax.cosine_decay_schedule(
init_value=init_learning_rate, decay_steps=max_steps
)
optax_chain = optax.chain(
optax.adamw(learning_rate=schedule, weight_decay=weight_decay)
)
optimizer = nnx.Optimizer(model, optax_chain)
Por fim, criamos uma repetição de treinamento simples.
while True:
input_batch, target_batch = get_batch("train")
train_step(
model,
optimizer,
train_metrics,
jax.device_put(
(input_batch, target_batch),
NamedSharding(mesh, P("batch", None)),
),
)
step += 1
if step > max_steps:
break
Observe como particionamos os dados de entrada no eixo batch usando a função jax.device_put
. Neste caso, o JAX ativará o paralelismo de dados e unirá tudo ao inserir coletivos de comunicação (AllReduce) automaticamente e fará a sobreposição de computação e comunicação o máximo possível. Para ver uma discussão mais aprofundada sobre computação paralela, consulte a documentação de introdução à programação paralela do JAX.
Neste ponto, o modelo deverá estar em treinamento, e nós poderemos observar a perda do treinamento se o Weights and Biases for utilizado para rastrear a execução. Esta é uma execução de teste para o treinamento do modelo GPT2 124M:
Ela leva aproximadamente sete horas na TPU v3 do Kaggle, que podemos usar por nove horas ininterruptas. Mas, se usarmos o Trillium, o tempo de treinamento cai para cerca de 1,5 hora (observe que o Trillium tem 32 G de memória HBM por chip, portanto, podemos dobrar o tamanho do lote e reduzir pela metade as etapas de treinamento).
As perdas finais são mais ou menos iguais às do nanoGPT, o que achei muito bom e analisei enquanto escrevia este exemplo de código.
Se usarmos Cloud TPUs, também podemos monitorar a utilização da TPU por meio do comando "tpu-info" (parte do pacote Cloud TPU Monitoring Debugging) ou do painel Weights and Biases. Nossas TPUs vão até tremer!
Depois que o modelo for treinado, podemos salvá-lo usando o Orbax:
checkpointer = orbax.PyTreeCheckpointer()
train_state = nnx.pure(nnx.state(model))
checkpointer.save(checkpoint_path, train_state)
E pronto. Isso é praticamente tudo de que precisamos para treinar um modelo GPT2. Você pode encontrar mais detalhes, como carregamento de dados, hiperparâmetros e métricas, no notebook completo.
É claro que o GPT2 é um modelo pequeno, hoje, e muitos laboratórios estão treinando modelos com centenas de bilhões de parâmetros. Mas, agora que você aprendeu a criar um modelo de linguagem pequeno com o JAX e uma TPU, já pode se aprofundar no artigo sobre como escalonar um modelo (em inglês).
Além disso, você pode usar o MaxText para treinar LLMs de ponta pré-criados ou aprender a criar os mais recentes modelos a partir do zero consultando os exemplos de LLM do JAX ou o modelo Stanford Marin.
Mal posso esperar para ver os modelos incríveis que você vai criar com o JAX e as TPUs!