Keras での Gemma モデルの概要

2月 21, 2024
Martin Görner Product Manager Keras

Keras チームは、Gemini モデルの作成に使用したのと同じ研究と技術で構築された軽量で最先端のオープンモデルのファミリーである GemmaKerasNLP コレクションで利用できるようになったことをお知らせします。Keras 3 のおかげで、Gemma は JAX、PyTorch、TensorFlow で動作します。このリリースでは、Keras は新しい LoRA API(Low Rank Adaptation)と大規模なモデルの並行トレーニング機能など、大規模言語モデル向けに特別に設計されたいくつかの新しい機能も導入しています。

コードサンプルを直接活用する場合は、次をご確認ください。

スタートガイド

Gemma モデルには、ポータブルな 2B と 7B のパラメータ サイズがあり、同様のオープンモデル、さらにはいくつかの大きなモデルに大きな進歩をもたらします。以下に例を示します。

  • Gemma 7B は MMLU 言語理解ベンチマークで新たなクラス最高の 64.3% の正解率を記録(Mistral-7B では 62.5%、Llama2-13B では 54.8%)
  • Gemma では、GSM8K の小学生向け算数問題のベンチマーク スコアが 11% アップ(Gemma 7B は 46.4%、Mistral-7B は 35.4%、Llama2-13B は 28.7%)
  • コーディング チャレンジである HumanEval の正解率が 6.1% アップ(Gemma 7B では 32.3%、Mistral 7B では 26.2%、Llama2 13B では 18.3%)

Gemma モデルは、使い慣れた KerasNLP API と非常に読みやすい Keras の実装を備えています。モデルは 1 行のコードでインスタンス化できます。

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")

テキスト プロンプトで直接実行します。トークン化は組み込まれていますが、必要に応じて簡単に分割できます。方法については、Keras NLP ガイドをご覧ください。

gemma_lm.generate("Keras is a", max_length=32)
> "Keras is a popular deep learning framework for neural networks..."

試してみる: Gemma モデルを使ってみる

Keras 3 のおかげで、モデルを実行するバックエンドを選択できます。切り替える方法は次のとおりです。

os.environ["KERAS_BACKEND"] = "jax" # Or "tensorflow" or "torch".
import keras # バックエンドを選択してから keras をインポートします

Keras 3 には、大規模言語モデルに特化したいくつかの新機能が搭載されています。特に重要なのは、パラメータを効率的に微調整するための新しい LoRA API(Low Rank Adaptation)です。これを有効にする方法は次のとおりです。

gemma_lm.backbone.enable_lora(rank=4)
# 注: rank=4 は関連するレイヤの重み付けのマトリックスを
# rank 4 の 2 つの行列の製品 AxB に置き換え、
# トレーニング可能なパラメータの数を減らします。

この 1 行で、トレーニング可能なパラメータの数が 25 億個から 130 万個に減少します。

試してみる: LoRA で Gemma モデルを微調整する

複数の GPU / TPU で Gemma モデルを微調整する

Keras 3 は大規模なモデル トレーニングもサポートしており、Gemma はそれを試すのに最適なモデルです。新しい Keras Distribution API は、データ並列およびモデル並列の分散トレーニング オプションを提供します。この新しい API はマルチバックエンドであることを前提にしていますが、実証済みのスケーラビリティ(Gemma モデルは JAX でトレーニングされました)のため、当面は JAX バックエンドのみに実装されます。

より大きな Gemma 7B を微調整するには、分散設定が便利です。たとえば、Kaggle で無料で入手できる 8 つの TPU コアを搭載した TPUv3 や、Google Cloud の 8 つの GPU マシンなどがあります。モデル並列処理を使用して分散トレーニング用にモデルを構成する方法は次のとおりです。

device_mesh = keras.distribution.DeviceMesh(
   (1, 8), # Mesh topology
   ["batch", "model"], # named mesh axes
   devices=keras.distribution.list_devices() # actual accelerators
)
 
 
# モデル構成
layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = (None, "model")
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
   None, "model", None)
layout_map["decoder_block.*attention_output.*kernel"] = (
   None, None, "model")
layout_map["decoder_block.*ffw_gating.*kernel"] = ("model", None)
layout_map["decoder_block.*ffw_linear.*kernel"] = (None, "model")
 
 
# モデル構成を設定してモデルを読み込みます
model_parallel = keras.distribution.ModelParallel(
   device_mesh, layout_map, batch_dim_name="batch")
keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
# 準備: model.fit() を使用してトレーニング可能になりgenerate() でテキストを生成できるようになりました

このコード スニペットは、8 つのアクセラレータを 1 x 8 の行列に設定します。ここでは、2 つのディメンションを「バッチ」と「モデル」と呼びます。モデルの重みは「モデル」ディメンションでシャーディングされ、ここでは 8 つのアクセラレータ間で分割されますが、「バッチ」ディメンションが 1 であるため、データバッチは分割されません。

試してみる: 複数の GPU / TPU で Gemma モデルを微調整する

次のトピック

近日中に、Transformer モデルを正しくパーティショ二ングし、上述の 6 行のパーティション設定を記述する方法を示すガイドを公開します。それほど長文ではありませんが、この投稿には収まりません。

レイヤのパーティショニングは、レイヤ名の正規表現によって定義されていることに気づくでしょう。このコード スニペットでレイヤ名を確認できます。これは、上述の LayoutMap を構築するために実行されました。

# これは最初の Transformer ブロックのみのためですが
 # すべてが同じ構造です
tlayer = gemma_lm.backbone.get_layer('decoder_block_0')
for variable in tlayer.weights:
 print(f'{variable.path:<58}  {str(variable.shape):<16}')

Keras はこれらの設定をパワフルな XLA コンパイラに渡して分散コンピューティングの他のすべての詳細を把握するため、完全な GSPMD モデルの並列処理は、ここではわずかなパーティショニングのヒントで機能します。

Gemma モデルを試し、お楽しみいただければ幸いです。有用な指示チューニング チュートリアルもあります。ちなみに、微調整された重みをコミュニティと共有したい場合、Kaggle モデルハブはユーザーが調整した重みのアップロードをサポートするようになりました。Kaggle の Gemma モデルのモデルページに移動して、他のユーザーが作成したものをご確認ください。