Keras と JAX を使って 10 分でレコメンデーション システムを作ってトレーニングする

2025年5月13日
Yufeng Guo Developer Advocate
Monica Song Product Manager

本日は、Keras Recommenders のリリースをお知らせします。この新しいライブラリを使うと、最先端のレコメンデーション テクニックを自在に利用できます。


レコメンデーション システムでデジタル エクスペリエンスを強化する

現在、レコメンデーション システムは、テクノロジーとの接点として広く使われています。スマートフォンでアプリを開けば、人気のソーシャル メディア プラットフォームのホームフィード、YouTube のおすすめ動画、お気に入りのゲームのポップアップ広告など、すぐにレコメンデーション モデルに触れることになります。AI の世界が進化し続ける今、パーソナライズされたエクスペリエンスを提供することがこれまで以上に重要になっています。すべてを大規模言語モデルで行うことはできません。現在のトップクラス デジタル エクスペリエンスの多くは、レコメンデーション システムが生み出しています。

Keras Recommenders(KerasRS)API には、ランキングや検索などのタスク向けに設計されたビルディング ブロックが含まれており、デベロッパーが正確で効率的なレコメンデーション システムを作成できるようになっています。たとえば Google では、Google Play のフィードに KerasRS を活用しています。


JAX、TensorFlow、PyTorch で KerasRS をインストールする

最初に、pip で keras-rs パッケージをインストールします。次に、バックエンドを JAX(または TensorFlow か PyTorch)に設定します。これで最先端の専用レコメンデーション システムを作成できます。

import os
os.environ["KERAS_BACKEND"] = "jax"
 
import keras
import keras_rs
 
class SequentialRetrievalModel(keras.Model):
    def __init__(self):
        self.query_model = keras.Sequential([
            keras.layers.Embedding(query_count, embed_dim),
            keras.layers.GRU(embed_dim),
        ])
        self.candidate_model = keras.layers.Embedding(candidate_count, embed_dim)
        self.retrieval = keras_rs.layers.BruteForceRetrieval(k=10)
        self.loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
 
    def call(self, inputs):
        query_embeddings = self.query_model(inputs)
        predictions = self.retrieval(query_embeddings)
        return {"query_embeddings": query_embeddings, "predictions": predictions}
Python

次の例では、おすすめの候補を探す一般的な検索アーキテクチャを示しています。KerasRS には、レコメンデーション タスクに特化したレイヤ、損失、指標など、このアーキテクチャを実装するために必要なものがすべて含まれています。こちらの colab ノートブックも参考になります。

そしてもちろん、これらのビルディング ブロックは、モデルをビルドする model.compile やトレーニング ループを簡単に設定できる model.fit など、標準の Keras API と完全に互換性があります。

model.compile(
    loss=keras_rs.losses.PairwiseHingeLoss(),
    metrics=[keras_rs.metrics.NDCG(k=8, name="ndcg")],
    optimizer=keras.optimizers.Adagrad(learning_rate=3e-4),
)
model.fit(train_ds, validation_data=val_ds, epochs=5)
Python

今後数か月のうちに、keras_rs.layers.DistributedEmbedding クラスをリリースする予定です。これにより、TPU の SparseCore チップを使って、複数のマシンに分散した大量のエンベディングを検索できるようになります。また、このライブラリに継続的に人気のモデル実装を追加し、最先端のレコメンデーション システムをさらに簡単に開発できるようにします。


KerasRS のドキュメントとサンプルを確認する

合わせて注目したいのが、最近リニューアルした keras.io ウェブサイトの Keras Recommenders ドキュメントです。keras.io/keras_rs では、最初のサンプルとして、古典的な Deep and Cross Network(DCN)Two-Tower エンベディング モデルを紹介しています。初めてレコメンダーを作成してトレーニングする過程が順を追って示されています。また、SASRec などのさらに高度なチュートリアルでは、トランスフォーマー モデルをトレーニングするエンドツーエンドの例を紹介します。

使ってみる

専用のレコメンデーション システムの作成に役立つその他の例やドキュメント、ガイドを確認したい方は、今すぐウェブサイトをご覧ください。https://github.com/keras-team/keras-rs でコードを確認したり、貢献したりすることもできます(その際には、ぜひ星 ⭐ をつけてください!)。

皆さんが Keras Recommenders ですばらしいレコメンデーション システムを作ることを楽しみにしています。



謝辞

Keras Recommenders を開発した Fabien Hertschuh と Abheesht Sharma に感謝します。また、この仕組みの実現に尽力しれくれた Keras および ML フレームワーク チーム、ならびにすべての協力者やリーダーの皆さんにも謝意を捧げます。