AI エコシステムが進化し続けるにつれて、機械学習モデルを定義する方法がますます多様化し、トレーニングとファインチューニングによって得られるモデルの重みを保存する方法もさらに増えています。このように選択肢が拡大する中で、KerasHub は、さまざまな機械学習フレームワーク間で一般的なモデル アーキテクチャとその重みを組み合わせて活用できる機能を提供します。
たとえば、チェックポイントの読み込み元としてよく使われるのに Hugging Face Hub があります。これらのモデル チェックポイントの多くは、SafeTensors 形式の Hugging Face transformers
ライブラリを使用して作成されています。モデル チェックポイントの作成にどの機械学習フレームワークを使用したかにかかわらず、これらの重みを KerasHub モデルに読み込むことが可能です。これにより、任意のフレームワーク(JAX、PyTorch、TensorFlow)を使用してモデルを実行できます。
そうです。つまり、JAX で Mistral や Llama からチェックポイントを実行したり、Gemma に PyTorch を読み込んだりすることもできるのです。これ以上にない柔軟性を提供します。
これらの用語のいくつかを詳しく説明し、実際の動作についても見ていきましょう。
モデルを読み込む際には、モデル アーキテクチャとモデルの重み(多くの場合「チェックポイント」と呼ばれる)という 2 つの異なる部分が必要です。それぞれについて詳しく定義していきます。
「モデル アーキテクチャ」とは、モデルのレイヤーがどのように配置され、その中でどのような操作が行われるかを指します。言い換えれば、モデルの「構造」とも呼べます。PyTorch、JAX、Keras などの Python フレームワークを使用してモデル アーキテクチャを表現します。
「モデルの重み」とはモデルの「パラメータ」、つまりトレーニングの過程で変更されるモデル内の数値を指します。これらの重みの具体的な値がトレーニング済みモデルの特性を左右します。
「チェックポイント」は、トレーニングの特定の時点におけるモデルの重み値のスナップショットです。共有され広く利用されている典型的なチェックポイント ファイルは、モデルが特に優れたトレーニング結果に達した時点のものです。同じモデル アーキテクチャがファインチューニングなどの手法でさらに改良されるにつれて、新たなチェックポイント ファイルが追加で作成されます。たとえば、多くのデベロッパーが Google の gemma-2-2b-it モデルを使用して独自のデータセットでファインチューニングしており、600 以上の例を確認できます。これらのファインチューニング済みモデルはすべて、元の gemma-2-2b-it モデルと同じアーキテクチャを使用していますが、チェックポイントの重みは異なります。
モデル アーキテクチャはコードで記述されますが、モデルの重みはトレーニング済みのパラメータであり、チェックポイント ファイルとして保存されます。モデル アーキテクチャとモデルの重み(チェックポイント ファイル形式)のセットがそろうと、有用な出力を生成する機能的なモデルを作成できます。
Hugging Face の Transformers ライブラリや Google の KerasHub ライブラリといったツールは、モデル アーキテクチャとそれらを試してみるのに必要な API を提供します。チェックポイント リポジトリの例には、Hugging Face Hub や Kaggle Models などがあります。
モデル アーキテクチャ ライブラリとチェックポイント リポジトリを自由に組み合わせて使用することが可能です。たとえば、Hugging Face Hub から JAX モデル アーキテクチャにチェックポイントを読み込み、KerasHub でファインチューニングできます。別のタスクでは、ニーズに適した Kaggle モデルのチェックポイントを検索することもできます。この柔軟性と分離性のおかげで、特定のエコシステムに縛られる必要がなくなります。
KerasHub については何度か触れてきましたが、ここでさらに詳しく説明していきます。
KerasHub はモデル アーキテクチャの定義を容易にする Python ライブラリです。現在最も人気があり広く利用されている機械学習モデルを多く含んでおり、さらに多くのモデルが随時追加されています。KerasHub は Keras をベースとしているため、現在使用されている 3 つの主要な Python 機械学習ライブラリ(PyTorch、JAX、TensorFlow)をすべてサポートしています。つまり、どのライブラリでもモデル アーキテクチャを定義することが可能です。
さらに、KerasHub は最も一般的なチェックポイント形式をサポートしているため、多くのチェックポイント リポジトリからチェックポイントを簡単に読み込めます。たとえば、Hugging Face と Kaggle については、これらのモデル アーキテクチャに読み込める何十万ものチェックポイントが用意されています。
transformers
ライブラリとの比較デベロッパーの一般的なワークフローは、Hugging Face transformers
ライブラリを使用してモデルをファインチューニングし、Hugging Face Hub にアップロードするというものです。また、transformers
を使用している場合は、KerasHub にも馴染みのある API パターンが多く用意されています。詳細については、KerasHub API ドキュメントをご覧ください。KerasHub の興味深い点は、Hugging Face Hub にあるチェックポイントの多くが、transformers
ライブラリだけでなく KerasHub とも互換性があることです。その仕組みを見ていきましょう。
Hugging Face には、Hugging Face Hub と呼ばれるモデル チェックポイント リポジトリがあります。これは、機械学習コミュニティが世界と共有するためにモデル チェックポイントをアップロードする多くの場所の一つです。Hugging Face で特に人気があるのは、KerasHub と互換性のある SafeTensors 形式です。
モデル アーキテクチャが利用可能であれば、これらのチェックポイントを Hugging Face Hub から KerasHub モデルに直接読み込むことが可能です。お気に入りのモデルが利用可能か確認したい場合は、https://keras.io/keras_hub/presets/ を参照してサポートされているモデル アーキテクチャのリストをご確認ください。また、これらのモデル アーキテクチャのコミュニティで作成されたファインチューニング済みのチェックポイントにもすべて互換性があります!先日、このプロセスをより詳しく説明した新しいガイドを作成しました。
どのような仕組みなのかというと、KerasHub には Hugging Face transformers
モデルの使用を簡素化するコンバータが組み込まれています。これらのコンバータが、Hugging Face モデルのチェックポイントを KerasHub と互換性のある形式に変換するプロセスを自動的に処理します。つまり、わずか数行のコードで、さまざまな事前トレーニング済みの Hugging Face Transformer モデルを Hugging Face Hub から KerasHub に直接シームレスに読み込めます。
モデル アーキテクチャが見つからない場合は、GitHub で pull リクエストを送信して追加できます。
それでは、Hugging Face Hub のチェックポイントを KerasHub に読み込むにはどうすればよいでしょうか?具体的な例を見てみましょう。
まず、Keras の「バックエンド」として機械学習ライブラリを選択します。ここでは JAX を使用しますが、JAX、PyTorch、TensorFlow のいずれかを選択できます。以下の例はすべて、どのライブラリを選択しても機能します。次に、keras
、keras _hub
、huggingface_hub
をインポートし、Hugging Face のユーザー アクセス トークンでログインしてモデルのチェックポイントにアクセスできるようにします。
import os
os.environ["KERAS_BACKEND"] = "jax" # or "torch" or "tensorflow"
import keras
from keras_hub import models
from huggingface_hub import login
login('HUGGINGFACE_TOKEN')
まずは、JAX 上で Mistral のチェックポイントを実行してみましょうか。KerasHub の利用可能なモデル アーキテクチャのリストには、いくつかの Mistral モデルが記載されています。ここでは mistral_0.2_instruct_7b_en
を試してみましょう。クリックすると、MistralCausalLM
クラスを使用して from_preset
を呼び出す必要があることがわかります。Hugging Face Hub 側では、対応するモデル チェックポイントがこちらに保存されており、900 以上のファインチューニング済みバージョンが存在しています。そのリストを見てみると、segolilylabs/Lily-Cybersecurity-7B-v0.2
というパス名を持つサイバーセキュリティに特化した人気のファインチューニング済みモデル「Lily」があります。このパスの前に「hf://
」を追加して、KerasHub が Hugging Face Hub を参照するよう指定する必要があります。
すべてをまとめると、次のようなコードになります。
# Model checkpoint from Hugging Face Hub
gemma_lm = models.MistralCausalLM.from_preset("hf://segolilylabs/Lily-Cybersecurity-7B-v0.2")
gemma_lm.generate("Lily, how do evil twin wireless attacks work?", max_length=30)
Llama 3.1-8B-Instruct は人気のモデルで、先月だけで 500 万回以上ダウンロードされています。JAX にファインチューニング済みのバージョンを追加してみましょう。1400 以上のファインチューニング済みチェックポイントがあるので、選択肢は十分です。xVerify のファインチューニング済みチェックポイントが面白そうなので、KerasHub の JAX に読み込んでみましょう。
使用しているモデル アーキテクチャを反映するために、Llama3CausalLM クラスを使用します。先ほどと同様に、「hf://
」を接頭辞にした Hugging Face Hub からの適切なパスが必要です。わずか 2 行のコードでモデルを読み込んで呼び出せるなんて、本当にすごいと思いませんか?
# Model checkpoint from Hugging Face Hub
gemma_lm = models.Llama3CausalLM.from_preset("hf://IAAR-Shanghai/xVerify-8B-I")
gemma_lm.generate("What is the tallest building in NYC?", max_length=100)
最後に、ファインチューニング済みの Gemma-3-4b-it チェックポイントを JAX に読み込みましょう。Gemma3CausalLM クラスを使用して、ファインチューニング済みのチェックポイントから 1 つを選択します。多言語翻訳ツールの EraX はどうでしょうか?今回も同様に、Hugging Face Hub の接頭辞付きのパス名を使用して、フルパス「hf://erax-ai/EraX-Translator-V1.0
」を作成します。
# Model checkpoint from Hugging Face Hub
gemma_lm = models.Gemma3CausalLM.from_preset("hf://erax-ai/EraX-Translator-V1.0")
gemma_lm.generate("Translate to German: ", max_length=30)
これまで説明してきたように、モデル アーキテクチャとその重みを結びつける必要はありません。つまり、さまざまなライブラリのアーキテクチャと重みを組み合わせて使用できます。
KerasHub が異なるフレームワークとチェックポイント リポジトリ間のギャップを埋めてくれます。Hugging Face Hub からモデル チェックポイント(PyTorch ベースの Transformers ライブラリを使用して作成されたものも含む)を取得し、JAX、TensorFlow、PyTorch のいずれかのバックエンドで実行される Keras モデルにシームレスに読み込むことが可能です。これにより、コミュニティによってファインチューニングされたモデルの膨大なコレクションを活用しながら、どのバックエンド フレームワークを実行するかを自由に選択できます。
アーキテクチャ、重み、フレームワークを組み合わせるプロセスを簡素化することで、KerasHub はシンプルでありながら強力な柔軟性を備えた試行とイノベーションを可能にします。