Gemma 徹底解説: RecurrentGemma のアーキテクチャ

8月 29, 2024
Ju-yeong Ji Gemma DevRel
Ravin Kumar Google Data Scientist Language Applications

Gemma 徹底解説シリーズの前回の投稿では、最新の Gemma 2 のアーキテクチャについて説明しました。今回の投稿では、RecurrentGemma のアーキテクチャについて説明します。さっそく始めましょう!


RecurrentGemma 2B、9B

RecurrentGemma は、ゲート付き線形再帰とローカル スライディング ウィンドウ アテンションを組み合わせたハイブリッド モデル、Griffin をベースとしています。この変更により、演算と記憶能力が向上し、長いコンテキスト プロンプトに適応できるようになります。

Griffin hybrid model architecture

ただし、Griffin アーキテクチャは内部状態のサイズが固定されているため、膨大な対象の中から小さなものを探す際のパフォーマンスが低下するという欠点があります。本の全文を入力することは可能ですが、このアプローチは最適ではない場合があります。再帰ニューラル ネットワーク(RNN)では、非常に長いシーケンスで遠く離れたもの同士の依存関係を学習するのが難しい場合があり、モデルのコンテキスト ウィンドウは限られます。つまり、予測を行うときには、一定数の先行トークンしか効果的に考慮できません。

さらに、推論時間の最適化という点で、再帰モデルはトランスフォーマーほど注目されていません。また、確立されているトランスフォーマー アーキテクチャに比べて、公開されている研究やコミュニティのサポートもかなり少なめです。

そのため、このモデルは、LLM のコンテキスト ウィンドウが足りなくなる恐れがある場合に非常に有効です。RecurrentGemma は、最新の情報を優先し、戦略的に古いデータを破棄することで、コンテキストが大きくなっても LLM のパフォーマンスを維持できるようになっています。

次に示すのは、RecurrentGemma 2B モデルのアーキテクチャ図です。

Recurrent Gemma 2B model architecture

Griffin は、他のトランスフォーマー ベースラインと同じように、残差パターンと MLP ブロックに従っています。しかし、MQA トランスフォーマー ベースラインとも Hawk モデルとも異なり、Griffin は、再帰ブロックと MQA ブロックをブレンドして使っています。

Layered structure of recurrent and MQA blocks

Griffin で使われているレイヤ構造では、2 つの残差ブロックと再帰ブロックが交互に配置され、続いてローカル MQA アテンション ブロックを組み込んだ残差ブロックが配置されています。

このアーキテクチャのコアパラメータを次の表にまとめます。

Core parameters of the architecture of 2B and 9B models

非埋め込みパラメータと埋め込みパラメータ

非埋め込みパラメータは、アテンション メカニズムやフィードフォワード ネットワークなどのコンポーネントで、モデルの隠れレイヤ全体に分散されます。

注: モデル「2B」の命名は、このパラメータに由来します

通常、埋め込みパラメータは、埋め込みレイヤと呼ばれる専用レイヤにあります。このレイヤは、離散トークン(単語や文字など)を連続ベクトル表現(埋め込み)にマッピングする役割を果たします。

注: 0.7B は、256k(語彙サイズ)x 2560(モデルの幅)の計算結果です


モデルの幅と RNN の幅

モデルの幅は、モデル内の隠れレイヤのサイズを指します。ベースの Gemma モデルと同じように、モデルがどこまで複雑なパターンを表現できるかが、これによって決まります。

再帰ニューラル ネットワーク(RNN)の幅は、リアルゲート付き線形再帰ユニット(RG-LRU)が保持する隠れ状態のサイズです。従来のトランスフォーマーとは異なり、再帰ブロックは、入力長に関係なく、内部状態は固定サイズになります。そのため、RecurrentGemma は少ないメモリで長いシーケンスを処理でき、長文やコードの生成といったタスクの効率が向上します。


MLP 拡張係数

ベースの Gemma モデルのフィードフォワード隠れ次元と同じです。RecurrentGemma モデルでは、簡潔さを優先し、拡張係数として 3 を適用し、MLP 次元を 7680(2560 x 3 の計算結果)にしました。


ローカル アテンション ウィンドウのサイズ

RecurrentGemma が保持する状態は有限サイズで、シーケンスが 2k トークンのローカル アテンション ウィンドウよりも長い場合、拡大されることはありません。つまり、Gemma が自己回帰的に生成するサンプルの最大長は、ホストシステムのメモリ容量が上限となりますが、RecurrentGemma はこの制約を克服できるので、任意の長さのシーケンスを生成できます。

RecurrentGemmaForCausalLM(
  (model): RecurrentGemmaModel(
    (embed_tokens): Embedding(256000, 2560, padding_idx=0)
    (layers): ModuleList(
      (0-1): 2 x RecurrentGemmaDecoderLayer(
        (temporal_pre_norm): RecurrentGemmaRMSNorm()
        (temporal_block): RecurrentGemmaRecurrentBlock(
          (linear_y): Linear(in_features=2560, out_features=2560, bias=True)
          (linear_x): Linear(in_features=2560, out_features=2560, bias=True)
          (linear_out): Linear(in_features=2560, out_features=2560, bias=True)
          (conv_1d): Conv1d(2560, 2560, kernel_size=(4,), stride=(1,), padding=(3,), groups=2560)
          (rg_lru): RecurrentGemmaRglru()
          (act_fn): PytorchGELUTanh()
        )
        (channel_pre_norm): RecurrentGemmaRMSNorm()
        (mlp_block): RecurrentGemmaMlp(
          (gate_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (up_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (down_proj): Linear(in_features=7680, out_features=2560, bias=True)
          (act_fn): PytorchGELUTanh()
        )
      )
      (2): RecurrentGemmaDecoderLayer(
        (temporal_pre_norm): RecurrentGemmaRMSNorm()
        (temporal_block): RecurrentGemmaSdpaAttention(
          (q_proj): Linear(in_features=2560, out_features=2560, bias=False)
          (k_proj): Linear(in_features=2560, out_features=256, bias=False)
          (v_proj): Linear(in_features=2560, out_features=256, bias=False)
          (o_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (rotary_emb): RecurrentGemmaRotaryEmbedding()
        )
        (channel_pre_norm): RecurrentGemmaRMSNorm()
        (mlp_block): RecurrentGemmaMlp(
          (gate_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (up_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (down_proj): Linear(in_features=7680, out_features=2560, bias=True)
          (act_fn): PytorchGELUTanh()
        )
      )
 
      :
 
      (23): RecurrentGemmaDecoderLayer(
        (temporal_pre_norm): RecurrentGemmaRMSNorm()
        (temporal_block): RecurrentGemmaSdpaAttention(
          (q_proj): Linear(in_features=2560, out_features=2560, bias=False)
          (k_proj): Linear(in_features=2560, out_features=256, bias=False)
          (v_proj): Linear(in_features=2560, out_features=256, bias=False)
          (o_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (rotary_emb): RecurrentGemmaRotaryEmbedding()
        )
        (channel_pre_norm): RecurrentGemmaRMSNorm()
        (mlp_block): RecurrentGemmaMlp(
          (gate_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (up_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (down_proj): Linear(in_features=7680, out_features=2560, bias=True)
          (act_fn): PytorchGELUTanh()
        )
      )
      (24-25): 2 x RecurrentGemmaDecoderLayer(
        (temporal_pre_norm): RecurrentGemmaRMSNorm()
        (temporal_block): RecurrentGemmaRecurrentBlock(
          (linear_y): Linear(in_features=2560, out_features=2560, bias=True)
          (linear_x): Linear(in_features=2560, out_features=2560, bias=True)
          (linear_out): Linear(in_features=2560, out_features=2560, bias=True)
          (conv_1d): Conv1d(2560, 2560, kernel_size=(4,), stride=(1,), padding=(3,), groups=2560)
          (rg_lru): RecurrentGemmaRglru()
          (act_fn): PytorchGELUTanh()
        )
        (channel_pre_norm): RecurrentGemmaRMSNorm()
        (mlp_block): RecurrentGemmaMlp(
          (gate_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (up_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (down_proj): Linear(in_features=7680, out_features=2560, bias=True)
          (act_fn): PytorchGELUTanh()
        )
      )
    )
    (final_norm): RecurrentGemmaRMSNorm()
  )
  (lm_head): Linear(in_features=2560, out_features=256000, bias=False)
)

embed_tokens(埋め込みレイヤ)

入力テキストをトークンのシーケンスとして受け取り、各トークンをサイズ 2560 の連続ベクトル表現にマッピングします。語彙サイズは 256000 で、ベースの Gemma モデルと同じです。


レイヤ

合計で 26 のデコーダレイヤがあり、パターンが繰り返されたグループに分かれています。

モデルの最初には、2 つの残差ブロックと 1 つの再帰ブロック(0-1)があります。このシーケンスの後には、残差ブロック(2)があり、最後のレイヤ(25)まで交互に繰り返される連続ブロックがあります。

Recurrent block architecture

残差ブロックと再帰ブロック

再帰ブロック(時間混合ブロック)では、2560 次元(モデルの幅)の入力を受け取り、2560 次元(RNN の幅)の出力がある 2 つの線形レイヤを並列に適用して、2 つのブランチを作成します。

最初のブランチ(右側)では、時間フィルタ次元が 4 である小さな個別の Conv1D レイヤを適用します。その後に、RG-LRU(リアルゲート付き線形再帰ユニット)レイヤが続きます。

2 番目のブランチ(左側)では、GeLU 非線形性を適用します。

次に、要素ごとに乗算してブランチをマージし、2560 次元(モデルの幅)の出力を持つ最終線形レイヤを適用します。

RecurrentGemma-Residual-block

RMSNorm の適用後に、MLP ブロックが続きます。


残差ブロックとローカル MQA

2 つの残差ブロックと 1 つの再帰ブロック(0-1)の後に、残差ブロックとローカル MQA(2)が続きます。グローバル アテンションを使う場合の主な欠点として、計算量がシーケンス長に対して 2 次関数的に増加することが挙げられます。RecurrentGemma では、これに対処するため、ローカル スライディング ウィンドウ アテンションを使っています。これにより、それぞれの位置が過去の一定数のトークンにのみ注目できるようになります。

ローカル MQA ブロック(時間混合ブロック)では、2560 次元(モデルの幅)の入力を受け取ります。クエリ、キー、値、出力表現は、それぞれ線形射影(q_projk_projv_projo_proj)を使って作成します。なお、k_projv_projout_features は 256 です。これは、サイズが 256 の同じヘッドを共有しているためです。q_projo_proj には、並列に動作する 10 個のヘッド(256 x 10 = 2560)があります。

ベースの Gemma モデルと同じく、ロータリー位置埋め込み(RoPE)に利用する rotary_emb(RecurrentGemmaRotaryEmbedding)が組み込まれています。

前の残差ブロックと同じく、RMSNorm と MLP ブロックを適用しています。


次のステップ

今回の記事では、RecurrentGemma について学びました。

次の投稿では、オープンで軽量な視覚言語モデル(VLM)、PaliGemma について説明します。

今後の投稿にご期待ください。お読みいただきありがとうございました。


参考文献

論文


コードサンプル


📋 Gemma アーキテクチャ徹底解説シリーズ