解析 Gemma :RecurrentGemma 架构

八月 29, 2024
Ju-yeong Ji Gemma DevRel
Ravin Kumar Google Data Scientist Language Applications

在“解析 Gemma”系列的上一篇博文中,我们讨论了最新的 Gemma 2 架构。在本博文中,您将可以探讨 RecurrentGemma 架构。现在就开始吧!


RecurrentGemma 2B、9B

RecurrentGemma 基于 Griffin,该混合模型可将门控线性循环与局部滑动窗口注意力混合在一起。此更改不仅可以提高计算性能和内存性能,而且更适用于长上下文提示。

Griffin hybrid model architecture

然而,Griffin 架构为固定大小状态,因而会存在 Haystack 性能略微降低的缺点。虽然可以提供一本书中的整个文本作为输入,但可能并非最佳方法。循环神经网络 (RNN) 在学习超长序列的长距离依赖关系时会遇到困难,并且模型的上下文窗口有限。这意味着在进行预测时,它只能有效地考虑一定数量的先前令牌。

此外,与 Transformer 模型相比,循环模型在推理时间优化方面尚未得到如此多的关注。与成熟的 Transformer 架构相比,可用的研究资料和社区支持更少。

因此,如果您担心 LLM 的上下文窗口被耗尽,此模型将非常有价值。RecurrentGemma 会优先考虑最新信息并战略性地丢弃旧数据,以此确保 LLM 在上下文扩展的情况下仍保持强大性能。

以下是 Recurrent Gemma 2B 模型的架构图。

Recurrent Gemma 2B model architecture

Griffin 使用与其他 Transformer 基线相同的残差模式和 MLP 块。然而,与 MQA Transformer 基线和 Hawk 模型不同的是,Griffin 混合使用循环块和 MQA 块。

Layered structure of recurrent and MQA blocks

Griffin 使用分层的结构,先交替使用两个残差块和一个循环块,然后使用包含局部 MQA 注意力块的残差块。

下表总结了架构的核心参数

Core parameters of the architecture of 2B and 9B models

非嵌入参数和嵌入参数

非嵌入参数遍布模型的隐藏层,分布在注意力机制和前馈网络等组件中。

注意:模型“2B”的命名便源自此参数

嵌入参数通常位于被称为嵌入层的专用层中。该层负责将离散令牌(如单词或字符)映射到连续矢量表示(嵌入)中。

注意:0.7B 可以算作 256k(词汇量)x 2,560(模型宽度)


模型宽度和 RNN 宽度

模型宽度指模型中隐藏层的大小,可确定模型表示复杂模式的能力,就像基础 Gemma 模型一样。

循环神经网络 (RNN) 宽度是真实门控线性循环单元 (RG-LRU) 维护的隐藏状态的大小。与传统 Transformer 不同,无论输入长度如何,循环块都能保持固定大小的内部状态。因此,RecurrentGemma 能以更少的内存处理更长的序列,从而提高处理任务的效率,如生成长文章或代码。


MLP 膨胀系数

它与基础 Gemma 模型中的前馈隐藏维度相同。为简单起见,我们在 Recurrent Gemma 模型中应用了 3 的膨胀系数,导致 MLP 维度为 7,680(计算方式:2,560 x 3)。


局部注意力窗口大小

RecurrentGemma 可维护的状态大小有限,并且不会因序列长度超过 2k 令牌的局部注意力窗口而增长。这意味着,虽然 Gemma 以自回归方式生成的样本的最大长度受主机系统内存容量的限制,但 RecurrentGemma 可以生成任意长度的序列,从而克服这一限制。

RecurrentGemmaForCausalLM(
  (model): RecurrentGemmaModel(
    (embed_tokens): Embedding(256000, 2560, padding_idx=0)
    (layers): ModuleList(
      (0-1): 2 x RecurrentGemmaDecoderLayer(
        (temporal_pre_norm): RecurrentGemmaRMSNorm()
        (temporal_block): RecurrentGemmaRecurrentBlock(
          (linear_y): Linear(in_features=2560, out_features=2560, bias=True)
          (linear_x): Linear(in_features=2560, out_features=2560, bias=True)
          (linear_out): Linear(in_features=2560, out_features=2560, bias=True)
          (conv_1d): Conv1d(2560, 2560, kernel_size=(4,), stride=(1,), padding=(3,), groups=2560)
          (rg_lru): RecurrentGemmaRglru()
          (act_fn): PytorchGELUTanh()
        )
        (channel_pre_norm): RecurrentGemmaRMSNorm()
        (mlp_block): RecurrentGemmaMlp(
          (gate_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (up_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (down_proj): Linear(in_features=7680, out_features=2560, bias=True)
          (act_fn): PytorchGELUTanh()
        )
      )
      (2): RecurrentGemmaDecoderLayer(
        (temporal_pre_norm): RecurrentGemmaRMSNorm()
        (temporal_block): RecurrentGemmaSdpaAttention(
          (q_proj): Linear(in_features=2560, out_features=2560, bias=False)
          (k_proj): Linear(in_features=2560, out_features=256, bias=False)
          (v_proj): Linear(in_features=2560, out_features=256, bias=False)
          (o_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (rotary_emb): RecurrentGemmaRotaryEmbedding()
        )
        (channel_pre_norm): RecurrentGemmaRMSNorm()
        (mlp_block): RecurrentGemmaMlp(
          (gate_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (up_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (down_proj): Linear(in_features=7680, out_features=2560, bias=True)
          (act_fn): PytorchGELUTanh()
        )
      )
 
      :
 
      (23): RecurrentGemmaDecoderLayer(
        (temporal_pre_norm): RecurrentGemmaRMSNorm()
        (temporal_block): RecurrentGemmaSdpaAttention(
          (q_proj): Linear(in_features=2560, out_features=2560, bias=False)
          (k_proj): Linear(in_features=2560, out_features=256, bias=False)
          (v_proj): Linear(in_features=2560, out_features=256, bias=False)
          (o_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (rotary_emb): RecurrentGemmaRotaryEmbedding()
        )
        (channel_pre_norm): RecurrentGemmaRMSNorm()
        (mlp_block): RecurrentGemmaMlp(
          (gate_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (up_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (down_proj): Linear(in_features=7680, out_features=2560, bias=True)
          (act_fn): PytorchGELUTanh()
        )
      )
      (24-25): 2 x RecurrentGemmaDecoderLayer(
        (temporal_pre_norm): RecurrentGemmaRMSNorm()
        (temporal_block): RecurrentGemmaRecurrentBlock(
          (linear_y): Linear(in_features=2560, out_features=2560, bias=True)
          (linear_x): Linear(in_features=2560, out_features=2560, bias=True)
          (linear_out): Linear(in_features=2560, out_features=2560, bias=True)
          (conv_1d): Conv1d(2560, 2560, kernel_size=(4,), stride=(1,), padding=(3,), groups=2560)
          (rg_lru): RecurrentGemmaRglru()
          (act_fn): PytorchGELUTanh()
        )
        (channel_pre_norm): RecurrentGemmaRMSNorm()
        (mlp_block): RecurrentGemmaMlp(
          (gate_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (up_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (down_proj): Linear(in_features=7680, out_features=2560, bias=True)
          (act_fn): PytorchGELUTanh()
        )
      )
    )
    (final_norm): RecurrentGemmaRMSNorm()
  )
  (lm_head): Linear(in_features=2560, out_features=256000, bias=False)
)

embed_tokens(嵌入层)

将输入文本作为令牌序列,并将每个令牌映射到大小为 2,560 的连续矢量表示。它的词汇量大小为 256,000,与基础 Gemma 模型相同。


layers

解码器层总共有 26 个,按重复模式分组。

该模型从两个残差块和一个循环块 (0-1) 开始。然后,该序列之后是残差块 (2) 和一系列连续块交替出现,直到末端的层 (25)。

Recurrent block architecture

残差块与循环块

在循环块(时间混合块)中,模型采用维度(模型宽度)为 2,560 的输入,并应用输出维度(RNN 宽度)为 2,560 的两个线性层,以创建两个分支。

在第一个分支(右侧),模型应用一个可分离的 Conv1D 小层,时间滤波器维度为 4。接下来是真实门控线性循环单元 (RG-LRU) 层。

在第二个分支(左侧),模型应用的是 GeLU 非线性层。

然后通过元素乘法合并分支,应用输出维度(模型宽度)为 2,560 的最终线性层。

RecurrentGemma-Residual-block

应用 RMSNorm 后,接着应用 MLP 块。


残差块与局部 MQA

在应用两个残差块和一个循环块 (0-1) 后,接着应用一个残差块与一个局部 MQA (2)。使用全局注意力的主要缺点之一在于,其计算复杂度随序列长度呈平方增长。为了解决这个问题,RecurrentGemma 使用了局部滑动窗口注意力,允许每个位置只关注过去固定数量的标记。

在局部 MQA 块(时间混合块)中,模型接受维度(模型宽度)为 2,560 的输入。它使用线性投影(q_projk_projv_projo_proj)来创建查询、键、值和输出表示。请注意,k_projv_projout_features 为 256,因为它们共用大小为 256 的同一头部,而 q_projo_proj 同时使用 10 个头部(256 x 10 = 2,560)。

就像基础 Gemma 模型一样,该模型纳入了可用于旋转式位置编码 (RoPE) 的 rotary_emb (RecurrentGemmaRotaryEmbedding)。

应用 RMSNorm 和 MLP 块的方式与上一个残差块相同。


未来计划

在本文中,您了解了 RecurrentGemma。

在下一篇博文中,您可以探索 PaliGemma,这是一种轻量级的开放视觉语言模型 (VLM)。

敬请关注,感谢您的阅读!


参考文献

论文


代码示例


📋 完整的 Gemma 架构系列