Si alguna vez sentiste curiosidad por saber qué se necesita para crear un modelo de lenguaje desde cero con JAX, esta entrada es para ti. Hace muy poco, realicé un taller sobre este tema en Cloud Next 2025 y recibí excelentes comentarios, así que pensé en escribir una guía para todas las personas que no pudieron asistir.
En este artículo y ejemplo de código, construirás y preentrenarás un modelo GPT-2, con lo que demostrarás cómo JAX permite aprovechar de forma sencilla el poder de las TPU de Google. Puedes ejecutar todo el proyecto de forma gratuita con las TPU en Colab o Kaggle, y aquí encontrarás el notebook completo.
Este es un instructivo práctico, así que supondré que conoces los conceptos generales del aprendizaje automático. Si nunca usaste JAX, la guía para desarrolladores de PyTorch sobre los fundamentos de JAX es un excelente lugar para comenzar.
Primero, veamos rápidamente las herramientas que usaremos.
Antes de comenzar a crear el modelo, hablemos brevemente sobre el ecosistema de JAX. El ecosistema de JAX adopta un enfoque modular, donde JAX Core proporciona las capacidades básicas de procesamiento numérico, y una colección numerosa de bibliotecas para satisfacer diferentes necesidades específicas de la aplicación, por ejemplo, Flax para la construcción de redes neuronales, Orbax para la persistencia del punto de control y el modelo, y Optax para la optimización (vamos a utilizar las tres en este artículo). Las transformaciones de funciones integradas, como la autograduación, la vectorización y la compilación JIT, además del rendimiento sólido y las API fáciles de usar, hacen que JAX sea ideal para entrenar modelos de lenguaje grandes.
OpenAI lanzó anteriormente el código y los pesos del modelo GPT2, que son buenas referencias, y la comunidad hizo muchos aportes, como nanoGPT, para replicar el modelo. Este es un diagrama de la arquitectura del modelo de alto nivel para GPT2:
Vamos a usar NNX (la nueva interfaz de Flax) para crear el modelo GPT2. Para ir directo al grano, nos centraremos en el bloque transformador, que es la clave para los modelos modernos de lenguaje grandes. El bloque transformador captura dependencias de largo alcance en cualquier secuencia y crea una rica comprensión contextual de esta. Un bloque de transformador de GPT2 consta de dos capas LayerNorm, una capa de atención de múltiples hilos de ejecución (MHA), dos capas de abandono, dos capas de proyección lineal y dos conexiones residuales. Entonces, primero definimos estas capas en la función __init__
de la clase TransformerBlock
:
class TransformerBlock(nnx.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
ff_dim: int,
dropout_rate: float,
rngs: nnx.Rngs,
):
self.layer_norm1 = nnx.LayerNorm(
epsilon=1e-6, num_features=embed_dim, rngs=rngs
)
self.mha = nnx.MultiHeadAttention(
num_heads=num_heads, in_features=embed_dim, rngs=rngs
)
self.dropout1 = nnx.Dropout(rate=dropout_rate)
self.layer_norm2 = nnx.LayerNorm(
epsilon=1e-6, num_features=embed_dim, rngs=rngs
)
self.linear1 = nnx.Linear(
in_features=embed_dim, out_features=ff_dim, rngs=rngs
)
self.linear2 = nnx.Linear(
in_features=ff_dim, out_features=embed_dim, rngs=rngs
)
self.dropout2 = nnx.Dropout(rate=dropout_rate)
Luego, ensamblamos estas capas en la función __call__
:
class TransformerBlock(nnx.Module):
def __call__(self, inputs, training: bool = False):
input_shape = inputs.shape
bs, seq_len, emb_sz = input_shape
attention_output = self.mha(
inputs_q=self.layer_norm1(inputs),
mask=causal_attention_mask(seq_len),
decode=False,
)
x = inputs + self.dropout1(
attention_output, deterministic=not training
)
# MLP
mlp_output = self.linear1(self.layer_norm2(x))
mlp_output = nnx.gelu(mlp_output)
mlp_output = self.linear2(mlp_output)
mlp_output = self.dropout2(
mlp_output, deterministic=not training
)
return x + mlp_output
Este código te debería resultar muy familiar si ya usaste cualquier otro marco de aprendizaje automático, como PyTorch o TensorFlow, para entrenar un modelo de lenguaje. Pero algo que realmente me gusta de JAX es que tiene la increíble capacidad de ejecutar automáticamente el código en paralelo a través de SPMD (un programa, múltiples datos), lo cual es necesario porque ejecutaremos el código en varios aceleradores (varios núcleos de TPU). Veamos cómo funciona.
Para ejecutar SPMD, primero debemos asegurarnos de estar usando TPU. Elige el entorno de ejecución de TPU si estás usando Colab o Kaggle (también puedes usar una VM de Cloud TPU).
import jax
jax.devices()
# Free-tier Colab offers TPU v2:
# [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
# TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
# TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
# TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
# TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
# TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
# TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
# TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
Colab y Kaggle ofrecen una TPU v2 o v3, que tiene 8 núcleos de TPU separados. Así se ve una bandeja de TPU v3:
Para entrenar el modelo GPT2 de manera eficiente, ejecutaremos todos los núcleos de TPU juntos a través de SPMD y aprovecharemos el paralelismo de datos en JAX. Para lograrlo, definimos una malla de hardware:
mesh = jax.make_mesh((8, 1), ('batch', 'model'))
La malla sería como una matriz de aceleradores en 2D. En este caso, definimos dos ejes para la malla: el eje batch
y el eje model
. Entonces, en total, tenemos 8 por 1, que son 8 núcleos. Estos ejes determinan cómo dividimos los datos y los parámetros del modelo. Podemos cambiar los ejes más adelante si queremos probar otros esquemas de paralelismo.
Ahora cambiamos la función __init__
indicándole a JAX cómo nos gustaría dividir los parámetros del modelo utilizando el eje ‘model’. Para ello, debemos agregar nnx.with_partitioning
al inicializar los tensores de pesos. Los tensores de pesos 1D, como los tensores de escala/sesgo LayerNorm, se deben fragmentar directamente a lo largo del eje ‘model’; en cambio para los tensores 2D, como MHA y los tensores de núcleo lineal, fragmentamos la segunda dimensión a lo largo del eje model
.
class TransformerBlock(nnx.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
ff_dim: int,
dropout_rate: float,
rngs: nnx.Rngs,
):
self.layer_norm1 = nnx.LayerNorm(
epsilon=1e-6, num_features=embed_dim,rngs=rngs, rngs=rngs,
scale_init=nnx.with_partitioning(
nnx.initializers.ones_init(),
("model"),
),
bias_init=nnx.with_partitioning(
nnx.initializers.zeros_init(),
("model"),
),
)
self.mha = nnx.MultiHeadAttention(
num_heads=num_heads, in_features=embed_dim,
kernel_init=nnx.with_partitioning(
nnx.initializers.xavier_uniform(),
(None, "model"),
),
bias_init=nnx.with_partitioning(
nnx.initializers.zeros_init(),
("model"),
),
)
# Other layers in the block are omitted for brevity
Necesitamos dividir otras capas como esta para activar el paralelismo de tensores de modelo en todo el modelo GPT2. Aunque no usamos el paralelismo de tensores de modelo en este instructivo, recomiendo implementarlo, ya que el tamaño del modelo puede crecer y es posible que tengamos que dividir los parámetros de nuestro modelo en el futuro. Si implementamos esto, podremos cambiar solo una línea de código e inmediatamente ejecutar modelos más grandes. Por ejemplo:
mesh = jax.make_mesh((4, 2), ('batch', 'model'))
A continuación, definimos las funciones loss_fn
y train_step
como lo hicimos en el blog anterior. La función train_step()
calcula los gradientes de la función de pérdida de entropía cruzada y actualiza los pesos a través del optimizador, y se llamará en un bucle para entrenar el modelo. Para lograr el mejor rendimiento, vamos a compilar con JIT ambas funciones utilizando el decorador @nnx.jit
, ya que tienen mucha capacidad de cómputo.
@nnx.jit
def loss_fn(model, batch):
logits = model(batch[0])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch[1]
).mean()
return loss, logits
@nnx.jit
def train_step(
model: nnx.Module,
optimizer: nnx.Optimizer,
metrics: nnx.MultiMetric,
batch,
):
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(model, batch)
metrics.update(loss=loss, logits=logits, lables=batch[1])
optimizer.update(grads)
Para el optimizador, utilizaremos AdamW de Optax con un programa de decaimiento de coseno. También puedes probar otros optimizadores o programas en Optax.
schedule = optax.cosine_decay_schedule(
init_value=init_learning_rate, decay_steps=max_steps
)
optax_chain = optax.chain(
optax.adamw(learning_rate=schedule, weight_decay=weight_decay)
)
optimizer = nnx.Optimizer(model, optax_chain)
Por último, creamos un ciclo de entrenamiento simple.
while True:
input_batch, target_batch = get_batch("train")
train_step(
model,
optimizer,
train_metrics,
jax.device_put(
(input_batch, target_batch),
NamedSharding(mesh, P("batch", None)),
),
)
step += 1
if step > max_steps:
break
Observa cómo dividimos los datos de entrada a lo largo del eje por lotes utilizando la función jax.device_put
. En este caso, JAX habilitará el paralelismo de datos y unirá todo insertando colectivos de comunicación (AllReduce) automáticamente, además de superponer el cálculo y la comunicación tanto como sea posible. Para obtener información más detallada sobre la computación paralela, consulta el documento de JAX titulado Introducción a la programación paralela.
A esta altura, el modelo debe estar entrenando y podemos observar la pérdida de entrenamiento si se utilizan pesos y sesgos para hacer un seguimiento de la ejecución. Esta es una prueba para entrenar el modelo GPT2 124M:
Se tarda aproximadamente 7 horas en Kaggle TPU v3 (que podemos usar durante 9 horas sin interrupción); pero, si usamos Trillium, el tiempo de entrenamiento se reduce a aproximadamente 1.5 horas (ten en cuenta que Trillium tiene 32G HBM [memoria de alto ancho de banda] por chip, por lo que podemos duplicar el tamaño del lote y reducir a la mitad los pasos de entrenamiento).
Las pérdidas finales son más o menos similares a las de nanoGPT, lo que disfruté y analicé mientras escribía este ejemplo de código.
Si usamos Cloud TPU, también podemos supervisar la utilización de TPU a través del comando ‘tpu-info’ (parte del paquete de depuración de supervisión de Cloud TPU) o el panel de control de pesos y sesgos. ¡Nuestras TPU van a toda máquina!
Después de entrenar el modelo, podemos guardarlo con Orbax:
checkpointer = orbax.PyTreeCheckpointer()
train_state = nnx.pure(nnx.state(model))
checkpointer.save(checkpoint_path, train_state)
Eso es todo, o casi todo, lo que necesitamos para entrenar un modelo GPT2. Puedes encontrar detalles adicionales, como la carga, los hiperparámetros o las métricas de datos, en el notebook completo.
Desde luego, hoy GPT2 es un modelo pequeño y muchos labs de vanguardia están entrenando modelos con cientos de miles de millones de parámetros. Sin embargo, ahora que aprendiste a crear un modelo de lenguaje pequeño con JAX y TPU, ya puedes empezar a explorar Cómo escalar tu modelo.
Además, puedes usar MaxText para entrenar LLM de vanguardia creados previamente o aprender a crear los últimos modelos desde cero usando como referencia los ejemplos de JAX LLM o el modelo Stanford Marin.
¡Tengo muchas ganas de ver los increíbles modelos que crearás con JAX y TPU!