使用 Keras 微调 Gemma 2 - 以及 Hugging Face 更新

六月 27, 2024
Martin Görner Product Manager Keras

最新 27B 参数 Keras 模型:Gemma 2

继 Gemma 1.1 (KaggleHugging Face)、CodeGemma (KaggleHugging Face) 和 PaliGemma 多模态模型 (KaggleHugging Face) 之后,我们很高兴宣布在 Keras 中发布 Gemma 2 模型。

Gemma 2 有两种规模(9B 和 27B 参数),并设有两种变体(标准变体和指令调整变体),您可以在此处找到它们:

Gemma 2 在 LLM 基准测试中的一流结果在其他地方也有介绍(请见 goo.gle/gemma2report)。在这篇博文中,我们想展示 Keras 和 JAX 的结合可如何帮助您使用这些大型模型。

JAX 是为扩展规模而构建的数字框架。它可利用 XLA 机器学习编译器,并被用于训练 Google 最大的模型。

Keras 是面向 ML 工程师的建模框架,该框架现在运行在 JAX、TensorFlow 或PyTorch 上。Keras 现可通过能够带来愉快使用体验的 Keras API 提供功率模型并行扩展。您可以在此处尝试 Keras 中的全新 Gemma 2 模型:


利用模型并行处理在 TPU/GPU 上进行分布式微调

由于模型规模较大,如要完全精确地加载和微调这些模型,就必须将其权重分割至多个加速器。JAX 和 XLA 可为权重分区提供广泛支持(SPMD 模型并行处理),Keras 还会提供 keras.distribution.ModelParallel API,帮助您以简单的方式逐层指定分片:

# 列出加速器
devices = keras.distribution.list_devices()


# 在具有命名轴的逻辑网格中布置加速器
device_mesh = keras.distribution.DeviceMesh((2, 8), ["batch", "model"], devices)


# 告诉 XLA 如何进行权重分区(Gemma 的默认值)
layout_map = gemma2_lm.backbone.get_layout_map()


# 定义 ModelParallel 分布
model_parallel = keras.distribution.ModelParallel(device_mesh, layout_map, batch_dim_name="batch")


# 设置为默认值并加载模型
keras.distribution.set_distribution(model_parallel)
gemma2_lm = keras_nlp.models.GemmaCausalLM.from_preset(...)

gemma2_lm.backbone.get_layout_map() 函数是为模型的所有权重传回逐层分片配置的帮助程序。它遵循相关 Gemma 论文 (goo.gle/gemma2report) 中的建议。以下是相关内容的摘录:

layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = ("model", "data")
layout_map["decoder_block.*attention.*(query|key|value).kernel"] =
("model", "data", None)
layout_map["decoder_block.*attention_output.kernel"] = ("model", None, "data")

简而言之,对于每一层,此配置可指定沿着哪个或哪些轴来分割每个权重块,以及在哪些加速器上放置这些经过分割的权重块。结合图片会更加容易理解。让我们以 Transformer 注意力架构中的“查询”权重为例,其形状为 (nb heads, embed size, head dim):

Weight partitioning example for the query (or key or value) weights in the Transformer attention architecture.
Transformer 注意力架构中查询(或键或值)权重的权重分区示例。

注意:没有分块的网格维度将收到副本。例如,如果上面的布局图为 ("model", None, None),就属于这种情况。

另请注意 ModelParallel 中的 batch_dim_name="batch" 参数。如果“batch”轴上有多行加速器(即本文所示情况),则也会使用数据并行处理。每行加速器将仅在每个数据批次的一部分上加载和训练,然后各行将合并其梯度。

模型加载完毕后,可使用以下两个易于使用的代码片段来显示实际应用的权重分片:

for variable in gemma2_lm.backbone.get_layer('decoder_block_1').weights:
    print(f'{variable.path:<58}  {str(variable.shape):<16} \
{str(variable.value.sharding.spec)}')
#…通过 gemma2_lm.compile() 设置优化器,然后:
gemma2_lm.optimizer.build(gemma2_lm.trainable_variables)
for variable in gemma2_lm.optimizer.variables:
    print(f'{variable.path:<73}  {str(variable.shape):<16} \
{str(variable.value.sharding.spec)}')

如果我们查看输出(如下所示),我们会注意到一些重要的事情:布局规范中的正则表达式不仅匹配层权重,还匹配优化器中相应的动量和速度变量,并对其进行了适当的分片。这是对模型进行分区时需要检查的重点。

# 对于层:
# 权重名称 . . . . . . . . . . 形状 . . . . . . 布局规范
decoder_block_1/attention/query/kernel (16, 3072, 256)
PartitionSpec('model', None, None)
decoder_block_1/ffw_gating/kernel (3072, 24576)
PartitionSpec(None, 'model')

# 对于优化器变量:
# 变量名称 . . . . . . . . . . . .形状 . . . . . . 布局规范
adamw/decoder_block_1_attention_query_kernel_momentum
(16, 3072, 256)   PartitionSpec('model', None, None)
adamw/decoder_block_1_attention_query_kernel_velocity
(16, 3072, 256)   PartitionSpec('model', None, None)

利用 LoRA 在有限的硬件上训练模型

LoRA 是一种可冻结模型权重并使用低秩(即小型)适配器替换这些权重的技术。

LoRA (Low Rank Adaptation)

Keras 还为此提供了简单易用的 API:

gemma2_lm.backbone.enable_lora(rank=4) # 秩是从实证检验中选取的

启用 LoRA 后,使用 model.summary() 显示模型详细信息。这时我们可以看到 LoRA 将 Gemma 9B 中的可训练参数数量从 90 亿减少到了 1,450 万。


Hugging Face 更新

上个月,我们宣布将在 Kaggle 和 Hugging Face 上提供 Keras 模型,以供用户下载和上传。如今,我们正在进一步推进 Hugging Face 集成:您现可为受支持的模型加载任何经过微调的权重,无论这些模型是否使用相应模型的 Keras 版本进行训练。这将通过实时转换权重来实现。这意味着您现在可以直接通过 KerasNLP 访问 Hugging Face 用户上传的数十种 Gemma 微调。最终,不仅仅是 Gemma,所有具有相应 KerasNLP 实现的 Hugging Face Transformer 模型都会如此。目前,这种情况只适用于 Gemma 和 Llama3。您可以使用以下 ColabHermes-2-Pro-Llama-3-8B 微调上尝试一番:

causal_lm = keras_nlp.models.Llama3CausalLM.from_preset(
   "hf://NousResearch/Hermes-2-Pro-Llama-3-8B"
)

利用 Keras 3 探索 PaliGemma

PaliGemma 是一款功能强大的开放式 VLM,灵感来自PaLI-3。PaliGemma 建立在包括 SigLIP 视觉模型和 Gemma 语言模型在内的开放组件之上,旨在为各种视觉语言任务提供领先水平的微调性能。这包括图像说明文字、可视化问答、理解图像中的文本、对象检测和对象分割。


您可以在 GitHubHugging Face 模型Kaggle 上找到 PaliGemma 的 Keras 实现。

我们希望您会喜欢使用 Keras 中的全新 Gemma 2 模型来进行实验或构建!