使用 Keras 和 JAX,在 10 分钟内即可完成 Recommender 系统的构建和训练

2025年5月13日
Yufeng Guo Developer Advocate
Monica Song Product Manager

今天,我们很高兴地宣布推出 Keras Recommenders,这个新库让最先进的推荐技术触手可及。


借助推荐系统提升数字体验

推荐系统为您在当下与技术的许多交互提供支持。在手机上打开任何应用,您可能会立即发现自己正在与推荐模型进行交互,例如常用社交媒体平台上的主页动态、YouTube 上的视频推荐,甚至是您最喜爱的游戏中弹出的广告。随着 AI 世界的不断发展,打造个性化体验比以往任何时候都更加重要。大型语言模型并非万能,而 Recommender 系统现如今在打造众多顶级数字体验方面发挥着关键作用。

为帮助开发者创建高性能和准确的 Recommender 系统,Keras Recommenders (KerasRS) 提供了一组 API,其中包含专为排名和检索等任务而设计的基本模块。例如,Google 使用 KerasRS 来帮助优化 Google Play 中的动态消息。


使用 JAX、TensorFlow 或 PyTorch 安装 KerasRS

首先,通过 pip 安装 keras-rs 软件包。然后,将后端设置为 JAX、TensorFlow 或 PyTorch。随后,您便可以制作最先进的专属 Recommender 系统。

import os
os.environ["KERAS_BACKEND"] = "jax"
 
import keras
import keras_rs
 
class SequentialRetrievalModel(keras.Model):
    def __init__(self):
        self.query_model = keras.Sequential([
            keras.layers.Embedding(query_count, embed_dim),
            keras.layers.GRU(embed_dim),
        ])
        self.candidate_model = keras.layers.Embedding(candidate_count, embed_dim)
        self.retrieval = keras_rs.layers.BruteForceRetrieval(k=10)
        self.loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)
 
    def call(self, inputs):
        query_embeddings = self.query_model(inputs)
        predictions = self.retrieval(query_embeddings)
        return {"query_embeddings": query_embeddings, "predictions": predictions}
Python

在此示例中,我们展示了一个热门检索架构,其中我们确定了一组候选推荐。KerasRS 提供实施此架构所需的一切,包括专为 Recommender 任务设计的专业层、损失和指标。您也可以按照 Colab 笔记本中的说明执行这些操作。

当然,所有这些基本模块都能与 model.compile 的标准 Keras API 一起使用,以构建您的模型和 model.fit,从而轻松配置训练循环。

model.compile(
    loss=keras_rs.losses.PairwiseHingeLoss(),
    metrics=[keras_rs.metrics.NDCG(k=8, name="ndcg")],
    optimizer=keras.optimizers.Adagrad(learning_rate=3e-4),
)
model.fit(train_ds, validation_data=val_ds, epochs=5)
Python

在接下来的几个月里,我们计划发布 keras_rs.layers.DistributedEmbedding 类,用于利用 TPU 上的 SparseCore 芯片执行跨计算机分布的大型嵌入式查找。此外,我们将持续在库中添加热门模型实现,以便更轻松地构建先进的 Recommender 系统。


探索 KerasRS 文档和示例

我们还想在最近经过重新设计的 keras.io 网站上重点介绍有关 Keras Recommenders 的所有文档。在 keras.io/keras_rs 上,您可以找到涉及经典深度交叉网络 (DCN)双塔嵌入模型的入门示例,其中显示了编写和培训首个 Recommender 系统的分步流程。我们还提供更高级的教程(如 SASRec),以展示训练 Transformer 模型的端到端示例。

开始使用

立即访问我们的网站,获取更多示例、文档和指南,以构建您自己的推荐系统。您也可以在 https://github.com/keras-team/keras-rs 上浏览代码并提供想法(浏览代码时,欢迎点亮一颗星 ⭐!)。

我们期待看到使用 Keras Recommenders 构建的所有优秀推荐系统。



致谢

特别感谢 Fabien Hertschuh 和 Abheesht Sharma 构建 Keras Recommenders,以及 Keras 团队、ML 框架团队和所有协作者的鼎力支持。