如果您对如何使用 JAX 从头开始构建语言模型感到好奇,那么这篇帖子非常适合您。我最近在 2025 年 Google Cloud Next 大会上举办了一场关于此主题的研讨会,并获得了一些很好的反馈,所以我想为所有无法参会的人编写一份指南。
在本文和代码示例中,您将构建并预训练 GPT-2 模型,了解 JAX 如何直接利用 Google TPU 的强大功能。您可以使用 Colab 或 Kaggle 中的 TPU 免费运行整个项目,并在此处找到完整的笔记本。
这是一个实践教程,所以我会假定您已熟悉一般的机器学习概念。如果您不熟悉 JAX,可以从《PyTorch 开发者指南:JAX 基础知识》入手。
首先,让我们快速了解一下将要用到的工具。
在开始构建模型之前,让我们先简要介绍一下 JAX 生态系统。JAX 生态系统采用模块化方法,通过 JAX 核心提供核心数值处理功能,而一系列丰富的库则在此基础上构建而成,以满足不同应用的特定需求,如用于构建神经网络的 Flax、用于检查点和模型持久性的 Orbax 以及用于优化的 Optax(在本文中,这 3 个工具都将被用到)。内置函数转换,如 autograd、矢量化和 JIT 编译,加上强大的性能和易于使用的 API,使 JAX 非常适合训练大型语言模型。
OpenAI 此前发布了 GPT2 模型代码和权重,这些都是很好的参考资料,并且社区也付出了很多努力来复制该模型,例如 nanoGPT。以下是 GPT2 的高级模型架构图:
我们将使用 NNX(新的 Flax 接口)来构建 GPT2 模型。为了简洁起见,我们重点关注 Transformer 块,这是现代大型语言模型的关键所在。Transformer 块会捕获任何序列的长距离依赖关系,并构建对其丰富的上下文理解。GPT2 Transformer 块由 2 个 LayerNorm 层、1 个多头注意力 (MHA) 层、2 个 Dropout 层、2 个线性投影层和 2 个残差连接组成。因此,我们首先需要在 TransformerBlock
类的 __init__
函数中定义这些层:
class TransformerBlock(nnx.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
ff_dim: int,
dropout_rate: float,
rngs: nnx.Rngs,
):
self.layer_norm1 = nnx.LayerNorm(
epsilon=1e-6, num_features=embed_dim, rngs=rngs
)
self.mha = nnx.MultiHeadAttention(
num_heads=num_heads, in_features=embed_dim, rngs=rngs
)
self.dropout1 = nnx.Dropout(rate=dropout_rate)
self.layer_norm2 = nnx.LayerNorm(
epsilon=1e-6, num_features=embed_dim, rngs=rngs
)
self.linear1 = nnx.Linear(
in_features=embed_dim, out_features=ff_dim, rngs=rngs
)
self.linear2 = nnx.Linear(
in_features=ff_dim, out_features=embed_dim, rngs=rngs
)
self.dropout2 = nnx.Dropout(rate=dropout_rate)
接下来,我们需要在 __call__
函数中对这些层进行组合:
class TransformerBlock(nnx.Module):
def __call__(self, inputs, training: bool = False):
input_shape = inputs.shape
bs, seq_len, emb_sz = input_shape
attention_output = self.mha(
inputs_q=self.layer_norm1(inputs),
mask=causal_attention_mask(seq_len),
decode=False,
)
x = inputs + self.dropout1(
attention_output, deterministic=not training
)
# MLP
mlp_output = self.linear1(self.layer_norm2(x))
mlp_output = nnx.gelu(mlp_output)
mlp_output = self.linear2(mlp_output)
mlp_output = self.dropout2(
mlp_output, deterministic=not training
)
return x + mlp_output
如果您使用过任何其他 ML 框架(如 PyTorch 或 TensorFlow)来训练语言模型,那么您对这段代码应该非常熟悉。但我非常喜欢 JAX 的一点在于,它具有通过 SPMD(单程序多数据)自动并行运行代码的强大功能。这一点很有必要,因为我们将在多个加速器(多个 TPU 核心)上运行代码。让我们来看看它的工作原理。
要执行 SPMD,首先我们需要确保自己使用的是 TPU。如果您使用的是 Colab 或 Kaggle,请选择 TPU 运行时(您也可以使用 Cloud TPU 虚拟机)。
import jax
jax.devices()
# 免费版 Colab 会提供 TPU v2:
# [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
# TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
# TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
# TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
# TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
# TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
# TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
# TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
Colab 和 Kaggle 提供 TPU v2 或 v3,其中含有 8 个独立的 TPU 核心。TPU v3 托盘的外观如下所示:
为了高效训练 GPT2 模型,我们将通过 SPMD 让所有 TPU 核心协同运行,并利用 JAX 中的数据并行。为此,我们定义了一个硬件网格:
mesh = jax.make_mesh((8, 1), ('batch', 'model'))
我们可以将网格视为加速器的 2D 矩阵。在本例中,我们为网格定义了两个轴:batch
轴和 model
轴。因此,我们总共有 8 x 1 个核心,也就是 8 个核心。这些轴决定了我们如何划分数据和模型参数。如果之后想尝试其他并行方案,我们可以对这些轴进行调整。
现在,我们通过告诉 JAX 如何使用“model”轴划分模型参数来更改 __init__
函数。这是通过在初始化权重张量时添加 nnx.with_partitioning
来实现的:对于像 LayerNorm scale/bias 张量这样的 1D 权重张量,我们直接沿着“model”轴对它们进行分片;对于像 MHA 和线性内核张量这样的 2D 权重张量,我们沿着 model
轴对第二维度进行分片。
class TransformerBlock(nnx.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
ff_dim: int,
dropout_rate: float,
rngs: nnx.Rngs,
):
self.layer_norm1 = nnx.LayerNorm(
epsilon=1e-6, num_features=embed_dim,rngs=rngs, rngs=rngs,
scale_init=nnx.with_partitioning(
nnx.initializers.ones_init(),
("model"),
),
bias_init=nnx.with_partitioning(
nnx.initializers.zeros_init(),
("model"),
),
)
self.mha = nnx.MultiHeadAttention(
num_heads=num_heads, in_features=embed_dim,
kernel_init=nnx.with_partitioning(
nnx.initializers.xavier_uniform(),
(None, "model"),
),
bias_init=nnx.with_partitioning(
nnx.initializers.zeros_init(),
("model"),
),
)
# 为简洁起见,此处省略了块中的其他层
我们需要像这样划分其他层,以便为整个 GPT2 模型启用模型张量并行。即使我们在本教程中不会使用模型张量并行,实现这一点仍然是比较好的做法,因为模型大小可能会增加,我们将来可能需要对模型参数进行划分。实现后,我们只需更改一行代码即可立即运行更大的模型。例如,
mesh = jax.make_mesh((4, 2), ('batch', 'model'))
接下来,我们需要定义与上一篇博客类似的 loss_fn
和 train_step
函数。train_step()
函数会计算交叉熵损失函数的梯度,并通过优化器更新权重,然后在循环中被调用来训练模型。为了获得最佳性能,我们使用 @nnx.jit
装饰器对这两个函数进行 JIT 编译,因为它们属于计算密集型函数。
@nnx.jit
def loss_fn(model, batch):
logits = model(batch[0])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch[1]
).mean()
return loss, logits
@nnx.jit
def train_step(
model: nnx.Module,
optimizer: nnx.Optimizer,
metrics: nnx.MultiMetric,
batch,
):
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(model, batch)
metrics.update(loss=loss, logits=logits, lables=batch[1])
optimizer.update(grads)
schedule = optax.cosine_decay_schedule(
init_value=init_learning_rate, decay_steps=max_steps
)
optax_chain = optax.chain(
optax.adamw(learning_rate=schedule, weight_decay=weight_decay)
)
optimizer = nnx.Optimizer(model, optax_chain)
最后,我们需要创建一个简单的训练循环。
while True:
input_batch, target_batch = get_batch("train")
train_step(
model,
optimizer,
train_metrics,
jax.device_put(
(input_batch, target_batch),
NamedSharding(mesh, P("batch", None)),
),
)
step += 1
if step > max_steps:
break
请注意我们使用 jax.device_put
函数沿着 batch 轴对输入数据进行划分的方式。在这种情况下,JAX 将实现数据并行,并通过自动插入通信集合 (AllReduce) 将所有内容整合在一起,同时尽可能多地实现计算与通信的重叠。有关并行计算更深入的讨论,请参阅 JAX 的并行编程简介文档。
模型此时应处于训练状态,如果使用权重和偏差来跟踪运行情况,我们便可以观察训练损失。以下是训练 GPT2 124M 模型的测试运行结果:
如果使用 Kaggle TPU v3(我们可以连续使用 9 个小时),训练时间大约为 7 个小时;但如果使用 Trillium,训练时间将减少到大约 1.5 个小时(请注意,Trillium 的每个芯片配备 32G 高带宽内存 (HBM),因此我们可以将批次大小加倍,并将训练步骤减半)。
最终的损失情况与 nanoGPT 的损失情况大致相符。我真的很喜欢 nanoGPT,并在编写此代码示例时对其进行了研究。
如果使用 Cloud TPU,我们还可以通过“tpu-info”命令(Cloud TPU 监控调试软件包的一部分)或权重和偏差仪表板监控 TPU 利用率。我们的 TPU 正在全力运行!
完成模型训练后,我们可以使用 Orbax 保存模型:
checkpointer = orbax.PyTreeCheckpointer()
train_state = nnx.pure(nnx.state(model))
checkpointer.save(checkpoint_path, train_state)
就是这样。这就是我们训练 GPT2 模型所需了解的全部内容。您可以在完整的笔记本中找到其他详细信息,如数据加载、超参数、指标。
当然,GPT2 如今还是一个小模型,许多前沿实验室正在利用数十亿个参数训练模型。但是,现在您已经学习了如何使用 JAX 和 TPU 构建小型语言模型,您已经准备好深入了解如何扩展模型。
此外,您既可以使用 MaxText 来培训预构建的尖端 LLM,也可以学习如何通过参考 JAX LLM 示例或 Stanford Marin 模型来从头开始构建最新的模型。
我迫不及待想要看到您使用 JAX 和 TPU 构建的出色模型!