解析 Gemma:Gemma 2 的新功能

八月 22, 2024
Ju-yeong Ji Gemma DevRel
Ravin Kumar Google Data Scientist Language Applications

在“解析 Gemma”系列的上一篇博文中,我们讨论了 Gemma 的架构。在本文中,我们将探讨最新发布的 Gemma 2 模型。现在就开始吧!


Gemma 2

最近,我们新发布了具有开创性的开放式模型套件 Gemma 2,这款套件在性能和无障碍功能方面确立了新的标准。Gemma 2 甫一亮相,就在业内产生了深远影响,目前可选的参数规模有 2B、9B 和 27B 三种。27B 模型在 LMSYS 聊天机器人竞技场排行榜上的排名迅速上升,在吸引人的现实对话中,其表现甚至超越了规模超过其两倍的流行模型,成为排名最高、最实用的开放式模型之一。与此同时,Gemma 2 2B 模型以可在边缘设备上运行的大小在对话式 AI 领域展现出卓越实力,性能表现超越聊天机器人竞技场上所有的 GPT-3.5 模型。

开发者可以跨平台和工具利用 Gemma 2 强大的调整功能。我们通过 Google Cloud 等基于云的解决方案和 Axolotl 等社区工具简化了 Gemma 2 的微调过程。与 Hugging Face 和 NVIDIA TensorRT-LLM 等合作伙伴的方案以及我们的 JAX 和 Keras 无缝集成,实现了跨多种硬件配置的性能优化和高效部署。

以下是这款新模型的核心参数:

Core parameters of new Gemma models, August 2024

关键特色

Gemma 2 与原始版本的 Gemma 模型具有相似的架构基础,包括旋转位置嵌入 (RoPE) 和近似 GeGLU 非线性实现。然而,与旧款相比,新颖的架构创新赋予了它独有的特色。


交替使用局部注意力和全局注意力

Gemma 2 不是一次性考虑文本中的所有单词,而是有时关注一个小的单词窗口(局部注意力),有时考虑所有单词(全局注意力)。这种组合有助于模型有效地理解文本的直接上下文和整体含义。


Logit 软上限

假设您正在训练一个模型来预测句子中的下一个单词。有时,模型可能会对某个词过于自信,即使其并非最佳选择亦然。Logit 软上限可通过限制模型对其预测的信心来防止这种情况,从而提高整体性能。


归一化前和归一化后使用 RMSNorm

我们可以将此视为在训练期间防止模型的计算变得过大或过小的一种方法。就像我们可以通过调整扬声器的音量来防止声音失真一样,RMSNorm 可确保流经模型的信息保持在合理的范围内,从而实现更稳定且有效的训练。


分组查询注意力 (GQA)

这种方法有助于模型更有效地处理信息,特别是在处理大量文本时。通过对查询进行分组组合来改进传统的多头注意力 (MHA),从而提高处理速度,特别是对于大型模型而言。这类似于将大型任务拆分为更小、更易于管理的块,使模型能够在不牺牲准确性的情况下更快地理解单词之间的关系。


Gemma 27B

Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 4608, padding_idx=0)
    (layers): ModuleList(
      (0-45): 46 x Gemma2DecoderLayer(
        (self_attn): Gemma2SdpaAttention(
          (q_proj): Linear(in_features=4608, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4608, out_features=2048, bias=False)
          (v_proj): Linear(in_features=4608, out_features=2048, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4608, bias=False)
          (rotary_emb): Gemma2RotaryEmbedding()
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=4608, out_features=36864, bias=False)
          (up_proj): Linear(in_features=4608, out_features=36864, bias=False)
          (down_proj): Linear(in_features=36864, out_features=4608, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm()
        (post_attention_layernorm): Gemma2RMSNorm()
        (pre_feedforward_layernorm): Gemma2RMSNorm()
        (post_feedforward_layernorm): Gemma2RMSNorm()
      )
    )
    (norm): Gemma2RMSNorm()
  )
  (lm_head): Linear(in_features=4608, out_features=256000, bias=False)
)
Gemma 27B architecture

self_attn

在自注意力 (self-attention) 机制中,Gemma 2 使用了分组查询注意力 (GQA)

k_projv_proj 共享相同的头部,大小为 128 个和 16 个头 (128 x 16 = 2048)。相比之下,q_projo_proj 有 32 个并行头 (128 x 32 = 4096)。


请注意,Gemma 9B 模型使用相同的 GQA,但头的数量(k_projv_proj 为 8 个,q_projo_proj 为 16 个)和大小 (256) 不同

(self_attn): Gemma2SdpaAttention(
          (q_proj): Linear(in_features=3584, out_features=4096, bias=False)
          (k_proj): Linear(in_features=3584, out_features=2048, bias=False)
          (v_proj): Linear(in_features=3584, out_features=2048, bias=False)
          (o_proj): Linear(in_features=4096, out_features=3584, bias=False)
          (rotary_emb): Gemma2RotaryEmbedding()
        )

2B 模型使用的是:k_projv_proj 为 4 个,q_projo_proj 为 8 个,头的大小为 256


pre_feedforward_layernorm 和 post_feedforward_layernorm

另一个重要的区别是 Gemma 2 中额外使用了 RMSNorm,可以增强训练过程的稳定性。


重要研究结果

我们在技术报告中提供了详细信息,关于 Gemma 2 的主要研究结果简要总结如下:


知识蒸馏与从零开始训练:

我们使用知识蒸馏的方法,通过从较大的模型 (27B) 提取知识来训练 2B 和 9B 模型。

从更大的模型中提取知识,即使使用相同数量的训练 token,也能显著提高性能。


分组查询注意力与多头注意力:

用 GQA 取代 MHA,可在实现同等性能的情况下,提高参数效率并缩短推理时间,这也使 GQA 成为首选。


模型深度与宽度:

在参数计数相同的情况下,较深模型的性能表现略优于较宽的模型。


未来计划

在本文中,我们介绍了下一代 Gemma 模型 Gemma 2。

在下一系列的帖子中,我们将探讨 RecurrentGemma,这是一个基于 Griffin 的开放式模型。

如果您想深入了解迷人的 AI 世界,并从该领域的前沿专家那里获得见解,请访问 goo.gle/ai-podcast,或在任何播客平台上搜索“People of AI Podcast”节目。

敬请关注,感谢您的阅读!



参考文献


论文


代码示例


📋 完整的 Gemma 架构系列