Gemma 徹底解説: Gemma モデルファミリーのアーキテクチャの概要

8月 15, 2024
Ju-yeong Ji Gemma DevRel
Ravin Kumar Google Data Scientist Language Applications

Gemma は最先端かつ軽量のオープンモデル ファミリーで、Gemini モデルと同じ研究技術で構築されています。

Gemma にはさまざまなバリエーションがあり、次のようなさまざまなユースケースやモダリティに対応できるように設計されています。

  • 単一モダリティ(テキスト入力、テキスト出力)

  • コーディングのユースケースに特化したもの

  • マルチ モダリティ(テキストおよび画像入力、テキスト出力)

  • ハードウェアの種類や推論ニーズ、その他の制約に応じたサイズがあります。

  • 「斬新な」アーキテクチャ

どのモデルも同じ DNA を共有しているため、Gemma ファミリーは、現在の LLM システムで使われるさまざまなアーキテクチャやデザインについて学ぶうえで、格好の題材となります。この投稿が、オープンモデル エコシステムのさらなる活性化と、LLM システムの動作の深い理解につながれば幸いです。

このシリーズでは、以下について説明します。

  • Gemma 1(2B、7B) - トランスフォーマー ベースのテキストツーテキスト モデル。

  • CodeGemma(2B および 7B) - コード補完とコード生成向けに最適化した Gemma のファイン チューニング版。

  • Gemma 2 (2B、9B、27B) - 新しいアーキテクチャでトレーニングしたテキストツーテキスト モデルの更新版。2B および 9B バージョンは、大きなモデルを蒸留することでトレーニングを実施しています。

  • RecurrentGemma(2B、9B) - 新しい Griffin アーキテクチャを使って構築したモデル。ローカル アテンションと線形回帰を組み合わせることで、長いシーケンスを高速推論で生成できるようにしたアーキテクチャを用いています。

  • PaliGemma(3B) - テキストと画像を受け取り、テキストを出力する視覚言語モデル。


このガイドの使い方

このシリーズでは、以下の内容を扱います。

  • さまざまなモデル固有のアーキテクチャを比較する

  • パラメータがモデルの生成にどのように影響するかを説明する(例: 埋め込みの数、マルチクエリ対マルチヘッド対グループ化クエリ)

  • モデルのコード例を提供し、詳しく調べられるようにする

モデルに関する情報を提供する際は、Hugging Face Transformers のモデルを print で出力します。次に簡単なコードを示します。

from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("google/gemma-7b")
print(model)

モデルの内部を調べたい場合は、Keras の Model クラスの API である torchinfosummary() を使うこともできます。


このガイドで扱わない内容

このガイドは AI を紹介するものではありません。ニューラル ネットワークやトランスフォーマー、そしてそれに関連するトークンなどの用語について、実用的な知識を持っている方を対象としています。これらの概念についての復習が必要な方は、以下のリソースをご覧ください。

ブラウザで動作する実践的なニューラル ネットワーク学習ツール

トランスフォーマー入門


Gemma

Gemma は重みが公開されている LLM です。インストラクション チューニング版と基本モデルのそれぞれに、異なるパラメータ サイズのトレーニング済みバリアントがあります。ベースとなっているのは、Google Research が Attention Is All You Need という論文で紹介した LLM アーキテクチャです。主な機能として、ユーザーが提供するプロンプトに基づくトークン単位でのテキスト生成が挙げられます。翻訳のようなタスクでは、Gemma はある言語の文を入力として受け取り、同等の文を別の言語で出力します。

すぐにおわかりいただけると思いますが、Gemma はそれ自体が優れたモデルであるだけでなく、さまざまなユーザーのニーズを満たすカスタム拡張機能も持ち合わせています。


Gemma のアーキテクチャ

まず、Gemma モデルのベースとなっているトランスフォーマー デコーダーに注目しましょう。

Transformer decoder architecture

もともと「Attention Is All You Need」で触れられていたのは、エンコーダ - デコーダ トランスフォーマー モデル アーキテクチャでしたが、Gemma は「デコーダのみ」のモデルです。

このアーキテクチャのコアパラメータを次の表にまとめます。

Core parameters of the architecture

モデルのトレーニングは、8192 トークンのコンテキスト長で行います。つまり、最大で一度に約 6144 ワード(経験則によれば、100 トークンは約 75 ワード)を処理できることになります。

実際の入力の上限値は、タスクや用法によって異なる可能性があることに注意しましょう。つまり、生成されたテキストもコンテキスト ウィンドウのトークンを消費するので、実質的に新しい入力に利用できる領域が減ることになります。厳密な意味での入力の上限値は一定ですが、生成された出力が後続の入力の一部となるので、以降の生成に影響することになります。


d_model(2B: 2048、7B: 3072)

d_model は、デコーダへの入力に使う埋め込み(単語またはサブワード、すなわちトークンのベクトル表現)のサイズを表します。デコーダレイヤの内部表現のサイズも、これによって決まります。

d_model x Num heads x Head size
「d_model x ヘッド数 x ヘッドサイズ」で self_attn のパラメータ数が決まる

d_model 値が大きくなるほど、モデルがさまざまな単語のニュアンスとその関係を表現できる「余地」が増えることになり、とりわけ複雑な言語タスクでパフォーマンスが向上する可能性があります。しかしながら、d_model が大きくなると、モデルも大きくなり、トレーニングや利用に必要な計算量も増えることになります。


レイヤ(2B: 18、7B: 28)

トランスフォーマーは、複数のレイヤを積み重ねた構造になっています。深いモデルにはそれだけ多くのレイヤがあるので、パラメータの数も多くなり、複雑なパターンを学べるようになります。ただし、パラメータ数が増えると過学習が起こりやすくなり、必要な計算リソースも増加します。

モデルの表現力が向上すると、ノイズやトレーニング データの特定のパターンを学習し、汎用化して新しい例に対応する能力が低下する場合もあります。

さらに、多くの場合は、モデルが深くなるほど、過学習を回避するためにトレーニング データを増やさなければならなくなります。利用できるデータが限られている場合、十分なサンプルがないために汎用表現を学習することができず、トレーニング データを覚えるだけになってしまうこともあります。


フィードフォワード隠れ次元(2B: 32768、7B: 49152)

トランスフォーマーの各レイヤには、アテンション メカニズムの後にフィードフォワード ネットワークが含まれています。このネットワークには独自の次元があります。モデルの表現力を高めるため、ほとんどの場合は d_model よりも大きなサイズになっています。

これは、ニューラル ネットワークの一種である多層パーセプトロン(MLP)として実装されています。多層パーセプトロンは、埋め込みを再変換し、さらに複雑なパターンを抽出します。

multi-layer perceptron (MLP) neural network achitecture

Gemma では、標準的な ReLU 非線形関数の代わりに、GeGLU 活性化関数が使われています。GeGLU は GLU(ゲート リニア ユニット)の一種で、活性化をシグモイド部分と線形射影という 2 つの部分に分割しています。シグモイド部分の出力を要素ごとに線形射影と掛け合せることで、非線形活性化関数が得られます。

GeGLU activation function example

ヘッド数(2B: 8、7B: 16)

トランスフォーマーの各レイヤには、並列に動作する複数のアテンション メカニズムが含まれています。これを「ヘッド」と言い、これによってモデルが入力シーケンスの異なる特徴に同時に注目することができます。ヘッドの数を増やせば、モデルがデータの多様な関係を理解できるようになります。


KV ヘッド数(2B: 1、7B: 16)

7B モデルはマルチヘッド アテンション(MHA)を、2B モデルはマルチクエリ アテンション(MQA)を利用します。MQA では、同じキーと値の射影が共有されます。つまり、各ヘッドは同じ基礎表現に注目しますが、クエリの射影は異なります。

オリジナルの MHA の方が高度な特徴表現学習が可能ですが、計算コストは高くなります。MQA は効率的な代替手段を提供します。この手法の効率性は証明されています


ヘッドサイズ(2B: 256、7B: 256)

マルチヘッド アテンション機構の各アテンション ヘッドの次元を表します。埋め込みの次元数をヘッド数で割って計算します。たとえば、埋め込み次元数が 2048、ヘッド数が 8 の場合、各ヘッドのサイズは 256 になります。


語彙サイズ(2B: 256128、7B: 256128)

モデルが理解および処理できる固有のトークン(単語、サブワード、文字のいずれか)数を定義します。Gemma のトークナイザーは、SentencePiece がベースになっています。語彙のサイズは、トレーニング前にあらかじめ決められています。SentencePiece は、選択された語彙サイズとトレーニング データに基づいて、最適なサブワード分割を学習します。Gemma の語彙は大きく、256k のサイズなので、多様なテキスト入力を処理でき、多言語テキスト入力の処理などのさまざまなタスクのパフォーマンスが向上する可能性があります。


Gemma 7B

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 3072, padding_idx=0)
    (layers): ModuleList(
      (0-27): 28 x GemmaDecoderLayer(
        (self_attn): GemmaSdpaAttention(
          (q_proj): Linear(in_features=3072, out_features=4096, bias=False)
          (k_proj): Linear(in_features=3072, out_features=4096, bias=False)
          (v_proj): Linear(in_features=3072, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=3072, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=3072, out_features=24576, bias=False)
          (up_proj): Linear(in_features=3072, out_features=24576, bias=False)
          (down_proj): Linear(in_features=24576, out_features=3072, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm()
        (post_attention_layernorm): GemmaRMSNorm()
      )
    )
    (norm): GemmaRMSNorm()
  )
  (lm_head): Linear(in_features=3072, out_features=256000, bias=False)
)
Gemma 7B architecture

embed_tokens(埋め込みレイヤ)

このレイヤは、入力トークン(単語またはサブワード)を、モデルが処理できる高密度数値表現(埋め込み)に変換します。語彙サイズは 256,000 で、3072 次元の埋め込みが作成されます。


レイヤ

これはモデルの中核で、28 層の GemmaDecoderLayer ブロックを積み重ねたものです。各レイヤでトークン埋め込みを処理し、単語とそのコンテキストとの間の複雑な関係を理解します。


self_attn

セルフアテンション メカニズムでは、入力ワードに異なる重みを割り当てることで、次の単語を生成します。モデルは、スケーリング ドット積アテンション メカニズムと線形射影(q_projk_projv_projo_proj)を使って、クエリ、キー、値、出力表現を生成します。

このモデルはマルチヘッド アテンション(MHA)を使うため、q_projk_projv_proj のすべての out_features 値は同じ 4096 です。サイズ 256 のヘッド 16 個が並列に動作するので、合計の数は 4096 個(256 x 16)になります。

さらにモデルは、位置エンコーディング(別名 RoPE)に rotary_emb(GemmaRotaryEmbedding)を使うことで、位置情報を有効活用しています。

最後に o_proj レイヤで、アテンション出力を元の次元(3072)に射影します。


Gemma 2B モデルは、マルチクエリ アテンション(MQA)を使っていることに注意してください。

Multi-Query Attention (MQA) architecture used in Gemma 2B model

k_projv_proj は、サイズ 256 の同じヘッドを共有するため、out_features は 256 になります。対照的に、q_projo_proj には、並列に動作する 8 つのヘッドがあります(256 x 8 = 2048)。

(self_attn): GemmaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )


mlp

ゲート メカニズムに gate_projup_proj を利用し、その後、down_proj を使って次元を減らし、3072 に戻します。


input_layernorm、post_attention_layernorm および norm

これらの正規化レイヤにより、トレーニングが安定し、モデルが効果的に学習できるようになります。


lm_head

この最終レイヤで、処理済みの埋め込み(3072)を、語彙空間(256000)における次のトークンの確率分布にマッピングします。


CodeGemma(2B および 7B)

CodeGemma モデルは、コード補完とチャットによるコーディング支援向けに最適化したファイン チューニング済み Gemma モデルです。CodeGemma モデルは、5,000 億以上の主要コードのトークンでトレーニングしています。さらに、CodeGemma にはフィルインザミドル(FIM)機能が追加されており、2 つの既存テキスト間で補完を行うことができます。

CodeGemma からわかるのは、Gemma チェックポイントのファイン チューニング能力です。追加のトレーニングを行うことで、モデルを特定のタスクに特化させて、最後尾の補完にとどまらない複雑な補完を学習させることができます。


CodeGemma の利用

4 つのユーザー定義トークンを利用できます。FIM 用に 3 つのトークン、マルチファイル コンテキスト サポート用に "<|file_separator|>" トークンを利用できます。

BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"

次の画面のようなコード補完について考えてみましょう。

Code snippet example - CodeGemma (2B and 7B)

入力プロンプトは次のようになります。

<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n    sys.exit(0)<|fim_middle|>

モデルは、コード補完の候補として「sys」を提示します。

CodeGemma の詳細については、CodeGemma / クイックスタートをご覧ください。


次のステップ

この記事では、Gemma のアーキテクチャについて説明しました。

次回以降の投稿では、最新モデルである Gemma 2 について説明します。安全対策が大幅に強化されており、推論時のパフォーマンスと効率の面で前のモデルを凌駕しています。

今後の投稿にご期待ください。お読みいただきありがとうございました。



参考文献


論文

コードサンプル

Gemma

CodeGemma


📋 The complete Gemma architecture series