KerasHub の紹介: トレーニング済みモデルのワンストップ ショップ

10月 22, 2024
Divyashree Sreepathihalli Software Engineer
Luciano Martins Developer Advocate Google AI

ディープ ラーニングの世界は急速に進化しており、さまざまなタスクにとって、事前トレーニング済みモデルの重要性はますます高まっています。Keras は、ユーザー フレンドリーな API とアクセシビリティ重視で知られており、この動きの最前線にあるのが、テキストベースのモデルである KerasNLP や、コンピュータ ビジョン モデルである KerasCV などの専用ライブラリです。

一方で、モデルのモダリティの境界線はますます曖昧になっており(画像入力を受け取れる強力なチャット LLM や、テキスト エンコーダを活用するビジョンタスクを考えてみてください)、このように領域を分割し続けることは実用的とは言えません。NLP と CV を分離してしまえば、真のマルチモーダル モデルの開発や展開の妨げになったり、冗長な作業やユーザー エクスペリエンスの断片化などにつながったりする可能性があります。

keras-team/keras-hub, a unified, comprehensive library for pretrained models

この点に対処するため、Keras エコシステムが大きく進化することをお知らせします。KerasHub は事前トレーニング済みモデルを一か所に集めた総合ライブラリであり、最先端の NLP および CV の両方のアーキテクチャに効率的にアクセスすることが可能です。一元管理されたリポジトリである KerasHub では、テキスト分析用の BERT や画像分類用の EfficientNet などの最先端モデルを、すべて使い慣れた一貫性のある Keras フレームワークからシームレスに確認したり、利用したりできます。


一貫したデベロッパー エクスペリエンス

このように一元管理を行うことで、モデルを簡単に探して利用できるようになるだけでなく、協力関係に基づくエコシステムが醸成されます。KerasHub では、モデルを簡単に公開して共有する、LoRA ファイン チューニングによってリソース効率よく学習させる、量子化して最適なパフォーマンスを実現する、安定したマルチホスト トレーニングによって大規模なデータセットに対応することなど、高度な機能を活用できます。しかも、どの機能もモダリティを問わず利用できます。これは、あらゆる人が強力な AI ツールにアクセスし、革新的なマルチモーダル アプリケーションの開発に向かえるようにするための重要な一歩です。


KerasHub の最初のステップ

まずは、システムに KerasHub をインストールしましょう。すると、すぐに利用可能なモデルを集めた多様なコレクションや、一般的なアーキテクチャのさまざまな実装を確認できます。次に、簡単な手順で事前トレーニング済みモデルを読み込み、自分のプロジェクトに組み込んで、特定の要件に応じてファイン チューニングして最適なパフォーマンスを実現しましょう。


KerasHub のインストール

次のコマンドを実行するだけで、Keras 3 の最新の KerasHub リリースをインストールできます。

$ pip install --upgrade keras-hub

これで、利用できるモデルを調べられるようになります。Keras 3 で作業するための標準環境設定をまったく変えることなく、KerasHub を使用し始めることができます。

import os
 
# 利用する Keras 3 バックエンドを定義 - "jax"、"tensorflow"、または "torch"
os.environ["KERAS_BACKEND"] = "jax"
 
# Keras 3 および KerasHub モジュールをインポート
import keras
import keras_hub

KerasHub でコンピュータ ビジョンおよび自然言語のモデルを使う

以上で、KerasHub から Keras 3 エコシステムで公開されているモデルにアクセスして利用する準備が整いました。以下にいくつかの例を示します。


Gemma

Google が開発した Gemma は、最先端でありながらアクセシビリティの高いオープンモデルのコレクションです。Gemma のベースモデルには Gemini モデルと同じ研究技術が使われているため、質問への回答、情報の要約、論理的推論など、さまざまなテキスト生成タスクで高いパフォーマンスを発揮します。さらに、特定のニーズに対応するようにカスタマイズすることもできます。

この例では、Keras と KerasHub を使って Gemma 2 2B のパラメータを読み込み、コンテンツを生成します。Gemma のバリアントの詳細については、Kaggle のGemma モデルカードをご覧ください。

# Kaggle Models から Gemma 2 2B プリセットを読み込む 
gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("gemma_2b_en")
 
# Gemma 2 2B でコンテンツの生成を始める
gemma_lm.generate("Keras is a", max_length=32)

PaliGemma

PaliGemma は、画像とテキストの両方を理解できるコンパクトでオープンなモデルです。PaliGemma は、PaLI-3 に触発され、SigLIP ビジョンモデルGemma 言語モデルなどのオープンソース コンポーネントをベースに構築されており、画像に関する質問をすると、洞察に満ちた詳しい回答を返してくれます。ビジュアル コンテンツを深く理解できるので、画像や短い動画のキャプションを生成する、物体を識別する、画像内のテキストを読み取るなどといった機能を実現できます。

import os
 
# 利用する Keras 3 バックエンドを定義 - "jax"、"tensorflow"、または "torch"
os.environ["KERAS_BACKEND"] = "jax"
 
# Keras 3 および KerasHub モジュールをインポート
import keras
import keras_hub
from keras.utils import get_file, load_img, img_to_array
 
 
# 224x224 画像にファイン チューニングした PaliGemma 3B をインポート
pali_gemma_lm = keras_hub.models.PaliGemmaCausalLM.from_preset(
    "pali_gemma_3b_mix_224"
)
 
# テスト画像をダウンロードし、KerasHub で利用できるよう準備する
url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
img_path = get_file(origin=url)
img = img_to_array(load_img(image_path))
 
# 画像についての質問のプロンプトを作成
prompt = 'answer where is the cow standing?'
 
# Generate the contents with PaliGemma
output = pali_gemma_lm.generate(
    inputs={
        "images": img,
        "prompts": prompt,
    }
)

Keras 3 で利用できる事前トレーニング済みモデルの詳細については、Kaggle の Keras モデル一覧をご覧ください。


Stability.ai Stable Diffusion 3

コンピュータ ビジョンのモデルも利用できます。例として、KerasHub で stability.ai Stable Diffusion 3 を使ってみます。

from PIL import Image
from keras.utils import array_to_img
from keras_hub.models import StableDiffusion3TextToImage
 
text_to_image = StableDiffusion3TextToImage.from_preset(
    "stable_diffusion_3_medium",
    height=1024,
    width=1024,
    dtype="float16",
)
 
# SD3 で画像を生成
image = text_to_image.generate(
    "photograph of an astronaut riding a horse, detailed, 8k",
)
 
# 生成した画像を表示
img = array_to_img(image)
img

Keras 3 で利用できる事前トレーニング済みコンピュータ ビジョンのモデルの詳細については、Keras モデル一覧をご覧ください。


KerasNLP デベロッパーにとっての変更内容

KerasNLP から KerasHub への移行は簡単です。import 文を更新し、keras_nlpkeras_hub にするだけです。

例: 次のように、keras_nlp をインポートして BERT モデルを使っていた場合

import keras_nlp
 
# BERT モデルの読み込み 
classifier = keras_nlp.models.BertClassifier.from_preset(
    "bert_base_en_uncased", 
    num_classes=2,
)

インポートを調整すれば、KerasHub を使うことができます。

import keras_hub
 
# BERT モデルの読み込み 
classifier = keras_hub.models.BertClassifier.from_preset(
    "bert_base_en_uncased", 
    num_classes=2,
)

KerasCV デベロッパーにとっての変更内容

KerasCV をお使いの方は、KerasHub に更新することで、次のようなメリットを得られます。

  • 簡単なモデルの読み込み: KerasHub では、一貫した API でモデルを読み込むことができます。そのため、KerasCV と KerasNLP の両方を使っている場合、コードを簡略化できます。

  • フレームワークの柔軟性: JAX や PyTorch などのさまざまなフレームワークを検討したい場合は、KerasHub を使うと、KerasCV モデルと KerasNLP モデルを簡単に利用できます。

  • リポジトリの一元管理: KerasHub ではモデルのリポジトリを一元管理できるので、簡単にモデルを探してアクセスできます。また、新しいアーキテクチャも同じ場所に追加されます。


コードを KerasHub に対応させる方法

モデル

現在、KerasCV モデルを KerasHub に移植する作業が行われています。ほとんどはすでに利用できますが、一部はまだ進行中です。なお、Centerpillar モデルは移植されません。次のようにすると、KerasHub で任意のビジョンモデルを利用できるはずです。

import keras_hub
 
# プリセットを使ってモデルを読み込む
Model = keras_hub.models.<model_name>.from_preset('preset_name`)
 
# または、バックボーンとプリプロセッサを指定してカスタムモデルを読み込む
Model = keras_hub.models.<model_name>(backbone=backbone, preprocessor=preprocessor)

KerasHub では、KerasCV デベロッパー向けに期待の新機能が導入され、柔軟性が高まり、機能が拡張されます。以下のような機能が含まれます。


前処理の組み込み

それぞれのモデルに、サイズ変更、スケーリング変更などのルーチンタスクに対応する専用のプリプロセッサが付属しているので、ワークフローを効率化できます。

これまでの入力前処理は、モデルに入力する前に手動で行う必要がありました。

# 入力前処理の例
def preprocess_inputs(image, label):
    # 入力のサイズやスケールの変更など、前処理を行う
    return preprocessed_inputs
backbone = keras_cv.models.ResNet50V2Backbone.from_preset(
    "resnet50_v2_imagenet",
)
model = keras_cv.models.ImageClassifier(
    backbone=backbone,
    num_classes=4,
)
output = model(preprocessed_input)

現在のところ、タスクモデルの前処理は、既成のプリセットに組み込まれています。入力に前処理を適用すると、プリプロセッサがサンプル画像のサイズとスケールの変更を行います。プリプロセッサはタスクモデルの必須コンポーネントですが、カスタムのプリプロセッサを使うこともできます。

classifier = keras_hub.models.ImageClassifier.from_preset('resnet_18_imagenet')
classifier.predict(inputs)

損失関数

オーグメンテーション レイヤと同様に、これまで KerasCV で使っていた損失関数は、Keras の keras.losses.<loss_function> から利用できます。たとえば、次のようにして FocalLoss 関数を使っているとします。

import keras
import keras_cv
 
keras_cv.losses.FocalLoss(
    alpha=0.25, gamma=2, from_logits=False, label_smoothing=0, **kwargs
)

この場合、損失関数を定義するコードを調整し、keras_cv.losses ではなく keras.losses を使うようにするだけです。

import keras
 
keras.losses.FocalLoss(
    alpha=0.25, gamma=2, from_logits=False, label_smoothing=0, **kwargs
)

KerasHub を使ってみる

さっそく KerasHub の世界に飛び込んでみましょう。


Keras コミュニティに参加して、アクセシビリティの高い一元化された効率的なディープ ラーニング モデルの力を存分に活用しましょう。マルチモーダルこそ AI の未来であり、KerasHub はそこにつながる入り口です!