解析 Gemma:Gemma 模型系列架构概述

八月 15, 2024
Ju-yeong Ji Gemma DevRel
Ravin Kumar Google Data Scientist Language Applications

Gemma 是最前沿的轻量级开放式模型系列,基于与打造 Gemini 模型相同的研究和技术构建而成。

Gemma 的不同变体专为不同的用例和模态而设计,例如:

  • 单一模态(文本输入,文本输出)

  • 专为编码用例而设计

  • 多模态(文本和图像输入,文本输出)

  • 不同规模的模型适用于不同类型的硬件、推理需求和其他限制。

  • “新”架构

由于所有这些模型都共享相似的基础,Gemma 系列提供了一种独特的方式供人们了解现代 LLM 系统中可用的架构和设计选择。我们希望这将有助于建立丰富的开放模型生态系统,并促进对 LLM 系统工作原理的更深入理解。

本系列文章将介绍:

  • Gemma 1(2B 和 7B)- 基于 Transformer 的文本到文本模型。

  • CodeGemma(2B 和 7B)- Gemma 的微调版本,针对代码补全和生成进行了优化。

  • Gemma 2(2B、9B 和 27B)- 使用更新架构训练的文本到文本模型;2B 和 9B 版本是通过从更大规模的模型进行蒸馏训练得到的。

  • RecurrentGemma(2B 和 9B)- 建立在新型 Griffin 架构之上的模型。这种架构结合了局部注意力机制和线性递归,可在生成长序列时实现快速推理。

  • PaliGemma (3B) - 一种视觉语言模型,可以接收文本和图像输入,并提供文本输出。


如何使用本指南

在本系列文章中,我们将

  • 汇总各种模型的具体架构

  • 解释这些参数如何影响模型生成(例如,嵌入数量;多查询、多头查询与分组查询的对比)

  • 提供模型的代码示例以便进一步探索

为了提供模型信息,我们会使用 Hugging Face Transformer 的打印模块,如下方的简单代码所示。

from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("google/gemma-7b")
print(model)

您还可以使用 Keras 模型类 API 中的 torchinfosummary() 探索模型内部。


本指南非基础指南

本指南并非 AI 基础入门指南。读者应已对神经网络、Transformer 及相关术语(如令牌)有一定的理解和应用能力。如果您需要回顾这些概念,请使用以下资源:

一款可以在浏览器中使用的神经网络学习实操工具

Transformer 介绍


Gemma

Gemma 是开放权重的 LLM。该模型既提供指令微调变体也提供原始预训练变体,并且不同变体有不同的参数规模。Gemma 的基础架构是 Google 研究在《您只需集中注意力》(Attention Is All You Need) 论文中提出的 LLM 架构。其主要功能是根据用户的提示逐词(以 token 为单位)生成文本。在翻译等任务中,Gemma 接受一种语言的句子作为输入,并输出另一种语言中的等价句子。

继续阅读您将发现,Gemma 不仅自身是一个高质量的模型,而且非常适合自定义扩展以满足不同用户的需求。


Gemma 架构

首先,让我们了解一下 Gemma 模型所基于的 Transformer 解码器

Transformer decoder architecture

与在《您只需集中注意力》(Attention Is All You Need) 论文中提出的原始编码器-解码器 Transformer 模型架构不同,Gemma 是一个纯粹的“仅解码器”模型。

下表总结了架构的核心参数

Core parameters of the architecture

模型在长度为 8,192 个令牌的上下文上进行训练。这意味着模型一次最多可以处理约 6,144 个单词(按照 100 个令牌约当于 75 个单词的经验法则)。

值得注意的是,实际的输入限制可能会因任务和使用情况而异。这是因为文本生成会在上下文窗口内消耗令牌,实际上减少了新输入的空间。尽管技术上的输入限制保持不变,但生成的输出会成为后续输入的一部分,从而影响后续的生成过程。


d_model(2B:2,048,7B:3,072)

d_model 表示作为解码器输入的嵌入(单词或子单词的矢量表征;单词和子单词也称为令牌)的大小,并决定了解码器层内部表征的大小。

d_model x Num heads x Head size
“d_model x Num head x Head size”定义了 self_attn 中的参数数量

D_model 值越大,便意味着模型有更多的“空间”来表示不同单词及其关系的细微差别。这可以带来更好的性能,特别是对于复杂的语言任务。然而,增加 d_model 也会使模型变得更大,并使训练和使用时的计算成本变得更高。


层数(2B:18,7B:28)

Transformer 由多个堆叠层组成。更深层次的模型有更多的层,因此有更多的参数,可以学习更复杂的模式。然而,这些额外的参数意味着它们也更容易过度拟合,并且需要更多的计算资源。

这种增强的表征能力可能导致模型学习到噪声或特定的训练数据模式,后者缺乏泛化到新示例的能力。

此外,更深层次的模型通常需要更多的训练数据来避免过度拟合。在可用数据有限的情况下,模型可能缺乏足够的示例来学习可泛化的表征,从而导致模型只是记住了训练数据。


前馈隐藏层维度(2B:32,768,7B:49,152)

每个 Transformer 层在注意力机制之后都包含前馈网络。这个网络有自己的维度大小,通常比 d_model 的大小要大,以提高模型的表达能力。

前馈网络的实现方式是多层感知器 (MLP)。这是一种神经网络,用于进一步转换嵌入并提取更复杂的模式。

multi-layer perceptron (MLP) neural network achitecture

在 Gemma 中,标准的 ReLU 非线性激活函数由 GeGLU 激活函数(门控线性单元 (GLU) 的变体)所取代。GeGLU 将激活分为两部分:一个 S 形函数部分和一个线性投影。S 形函数部分的输出与线性投影逐元素相乘,形成非线性激活函数。

GeGLU activation function example

头数(2B:8,7B:16)

每个 Transformer 层包含多个并行工作的注意力机制。这些“头”使模型能够同时关注输入序列的不同方面。增加头的数量可以增强模型捕获数据中不同关系的能力。


KV 头数(2B:1,7B:16)

7B 模型使用多头注意力 (MHA),2B 模型则使用多查询注意力 (MQA)。在 MQA 中,所有头都使用相同的键和值投影,这意味着每个头都关注相同的底层表征,但使用不同的查询投影。

原始的 MHA 具有更丰富的表征学习,但计算成本更高。MQA 则提供一种已被证明有效的高效替代方案。


头的大小(2B:256,7B:256)

头的大小指的是多头注意力机制中每个注意力头的维度。该值的计算方法是将嵌入维度除以头的数量。例如,如果嵌入维度为 2,048,并且有 8 个头,则每个头的大小为 256。


词汇量(2B:256,128,7B:256,128)

该参数定义了模型理解且可以处理的唯一令牌(单词、子单词或字符)的数量。Gemma 的分词器基于 SentencePiece。词汇量在训练前就已经确定。然后,SentencePiece 会根据选定的词汇量和训练数据学习最优的子单词分割。Gemma 拥有庞大的 256,000 词汇量,可以处理多样化的文本输入,并有可能提高在各种任务上的性能,例如处理多语言文本输入。


Gemma 7B

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 3072, padding_idx=0)
    (layers): ModuleList(
      (0-27): 28 x GemmaDecoderLayer(
        (self_attn): GemmaSdpaAttention(
          (q_proj): Linear(in_features=3072, out_features=4096, bias=False)
          (k_proj): Linear(in_features=3072, out_features=4096, bias=False)
          (v_proj): Linear(in_features=3072, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=3072, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=3072, out_features=24576, bias=False)
          (up_proj): Linear(in_features=3072, out_features=24576, bias=False)
          (down_proj): Linear(in_features=24576, out_features=3072, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm()
        (post_attention_layernorm): GemmaRMSNorm()
      )
    )
    (norm): GemmaRMSNorm()
  )
  (lm_head): Linear(in_features=3072, out_features=256000, bias=False)
)
Gemma 7B architecture

embed_tokens(嵌入层)

该层将输入令牌(单词或子单词)转换为模型可以处理的密集型数字表征(嵌入)。其词汇量为 256,000,并可创建维度为 3,072 的嵌入。


layers

这是模型的核心部分,由 28 个堆叠的 GemmaDecoderLayer 组件构成。每一层都优化了令牌嵌入,以捕捉单词及其上下文之间的复杂关系。


self_attn

在自注意力 (self-attention) 机制中,模型在生成下一个单词时为输入中的单词分配不同的权重。该模型利用缩放的点积注意机制,并使用线性投影(q_projk_projv_projo_proj)来生成查询、键、值和输出表征。

q_projk_projv_proj 的所有 out_features 值都是相同的 4,096,因为该模型使用了多头注意力 (MHA)。该机制有 16 个并行的头,每个头的尺寸是 256,总共为 4,096 (256 x 16)。

此外,该模型通过使用 rotary_emb (GemmaRotaryEmbedding) 进行位置编码(又名 RoPE),来更有效地利用位置信息。

最后,o_proj 层将注意力输出投影回原始维度 (3072)。


请注意,Gemma 2B 模型使用多查询注意力 (MQA)

Multi-Query Attention (MQA) architecture used in Gemma 2B model

k_projv_proj 共享相同的头,大小为 256,因此 out_features 为 256。相比之下,q_projo_proj 有 8 个并行的头 (256 x 8 = 2048)。

(self_attn): GemmaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )


mlp

此参数利用 gate_projup_proj 作为门控机制,然后通过 down_proj 将维度降回到 3,072。


input_layernorm、post_attention_layernorm 和 norm

这些归一化层稳定了训练过程,并提高了模型的有效学习能力。


lm_head

此最后一层将经优化的嵌入 (3,072) 映射回整个词汇空间 (256,000) 的概率分布上,以预测下一个令牌。


CodeGemma(2B 和 7B)

CodeGemma 模型是经过微调的 Gemma 模型,针对代码补全和编程聊天辅助进行了优化。CodeGemma 模型基于超过 5,000 亿个主要为代码的令牌进行训练。CodeGemma 还添加了中间填充功能,可在两段现有文本之间进行补全。

CodeGemma 突出了 Gemma 检查点的微调能力。通过额外的训练,模型在特定任务上变得更加专业,能够学习比纯后缀补全更为复杂的补全。


CodeGemma 的使用

您可以使用 4 个用户定义的令牌:3 个令牌用于 FIM;1 个 "<|file_separator|>" 令牌用于支持多文件上下文。

BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"

假设您正试图像下面的屏幕截图那样补全代码。

Code snippet example - CodeGemma (2B and 7B)

输入提示应如下所示

<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n    sys.exit(0)<|fim_middle|>

模型将提供 "sys" 作为建议补全的代码。

您可以在 CodeGemma/快速入门上探索更多关于 CodeGemma 的信息。


未来计划

本文介绍了 Gemma 架构。

在接下来的一系列帖子中,我们将探索最新的模型:Gemma 2。该模型在安全措施方面有显著提升,并在推理期间的性能和效率上超越了前一版本。

敬请关注,感谢您的阅读!



参考文献


论文

代码示例

Gemma

CodeGemma


📋 The complete Gemma architecture series