隆重推出 KerasHub:适用于预训练模型的一站式商店

十月 22, 2024
Divyashree Sreepathihalli Software Engineer
Luciano Martins Developer Advocate Google AI

深度学习行业正在迅速发展,预训练模型对于处理广泛的任务变得越来越重要。Keras 以其人性化的 API 和注重可访问性而闻名,不仅一直处于这一运动的最前沿,而且拥有专门的库,如用于文本类模型的 KerasNLP 和用于计算机视觉模型的 KerasCV。

然而,由于模型导致模态之间的界限变得越来越模糊(想一想利用文本编码器处理图像输入或视觉任务的强大聊天 LLM),维护这些单独的域变得不太实际。NLP 和 CV 之间的分歧可能会阻碍真正多模态模型的开发和部署,从而导致工作变得冗余以及用户体验不完整。

keras-team/keras-hub, a unified, comprehensive library for pretrained models

为了解决这个问题,我们很高兴地宣布 Keras 生态系统的重大变革:KerasHub 是一个统一、全面的预训练模型库,可简化对尖端 NLP 和 CV 架构的访问。KerasHub 是一个中央存储区,您可以其中无缝探索并利用 BERT 等最先进的模型进行文本分析,以及利用 EfficientNet 进行图像分类,一切操作都在一致且熟悉的 Keras 框架内完成。


统一的开发者体验

这种统一不仅可以简化模型的发现和使用,而且还有助于打造更具凝聚力的生态系统。借助 KerasHub,您可以利用各项高级功能,如轻松发布和共享模型、通过 LoRA 微调实现高效采用、进行量化以优化性能以及通过强大的多主机培训来处理大规模数据集,所有这些都适用于各种模态。这标志着向实现强大 AI 工具民主化和加速创新多模态应用开发迈出了重要一步。


使用 KerasHub 的第一步

让我们开始在您的系统上安装 KerasHub。在 KerasHub 中,您可以探索大量可用的模型和各种不同的热门架构部署方式。然后,您就可以轻松加载这些预训练模型并将其整合到您自己的项目中,随后再根据您的特定需求进行微调,以获得最佳性能。


安装 KerasHub

要使用 Keras 3 安装最新的 KerasHub 版本,只需运行以下内容:

$ pip install --upgrade keras-hub

现在,您可以开始探索可用的模型。为开始使用 Keras 3 设置的标准环境同样完全适用于开始使用 KerasHub:

import os
 
# 定义要使用的 Keras 3 后端 - "jax", "tensorflow" or "torch"
os.environ["KERAS_BACKEND"] = "jax"
 
# 导入 Keras 3 和 KerasHub 模块
import keras
import keras_hub

通过 KerasHub 使用计算机视觉和自然语言模型

现在,您可以从 KerasHub 开始访问和使用 Keras 3 生态系统中的模型。以下是一些示例:


Gemma

Gemma 是由 Google 开发的一系列尖端且易于访问的开放模型。Gemma 的基础模型利用 Gemini 模型所用的研究成果和技术,擅长处理各种文本生成任务,其中包括回答问题、总结信息以及进行逻辑推理。此外,您还可以根据特定需求自定义模型。

在此示例中,您可以使用 Keras 和 KerasHub 加载内容,并开始使用 Gemma2 2B 参数生成内容。有关 Gemma 变体的更多详细信息,请查看 Kaggle 上的 Gemma 模型卡

# 从 Kaggle 模型加载 Gemma 2 2B 预设 
gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("gemma_2b_en")
 
# 开始利用 Gemma 2 2B 生成内容
gemma_lm.generate("Keras is a", max_length=32)

PaliGemma

PaliGemma 是一款紧凑型的开放模型,可以理解图像和文本。PaliGemma 由我们从 PaLI-3 中汲取灵感开发而来,并以 SigLIP 视觉模型Gemma 语言模型等开源组件为基础,可以为有关图像的问题提供详细而有见地的答案。这样一来,该模型可以更深入地了解视觉内容,并具备为图像和短视频生成字幕、识别对象甚至阅读图像中的文本等功能。

import os
 
# 定义您要使用的 Keras 3 后端 - "jax", "tensorflow" or "torch"
os.environ["KERAS_BACKEND"] = "jax"
 
# 导入 Keras 3 和 KerasHub 模块
import keras
import keras_hub
from keras.utils import get_file, load_img, img_to_array
 
 
# 导入利用 224x224 图像微调的 PaliGemma 3B
pali_gemma_lm = keras_hub.models.PaliGemmaCausalLM.from_preset(
    "pali_gemma_3b_mix_224"
)
 
# 下载一张测试图像并准备将其用于 KerasHub
url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
img_path = get_file(origin=url)
img = img_to_array(load_img(image_path))
 
# 利用关于图像的问题来创建提示。
prompt = 'answer where is the cow standing?'
 
# 利用 PaliGemma 生成内容
output = pali_gemma_lm.generate(
    inputs={
        "images": img,
        "prompts": prompt,
    }
)

有关 Keras 3 上可用的预训练模型的更多详细信息,请在 Kaggle 上查看 Keras 中的模型列表


Stability.ai Stable Diffusion 3

您也可以使用计算机视觉模型。例如,您可以结合使用 stability.ai Stable Diffusion 3 与 KerasHub:

from PIL import Image
from keras.utils import array_to_img
from keras_hub.models import StableDiffusion3TextToImage
 
text_to_image = StableDiffusion3TextToImage.from_preset(
    "stable_diffusion_3_medium",
    height=1024,
    width=1024,
    dtype="float16",
)
 
# 利用 SD3 生成图像
image = text_to_image.generate(
    "photograph of an astronaut riding a horse, detailed, 8k",
)
 
# Display the generated image
img = array_to_img(image)
img

有关 Keras 3 上可用的预训练计算机视觉模型的更多详细信息,请查看 Keras 中的模型列表


对于 KerasNLP 开发者而言,有哪些变化?

从 KerasNLP 过渡到 KerasHub 是一个简单的过程,只需将 import 语句从 keras_nlp 更新为 keras_hub 即可。

示例:以前,如果您需要导入 keras_nlp 才能使用 BERT 模型,如下所示

import keras_nlp
 
# 加载 BERT 模型 
classifier = keras_nlp.models.BertClassifier.from_preset(
    "bert_base_en_uncased", 
    num_classes=2,
)

现在,您只需调整 import 即可使用 KerasHub:

import keras_hub
 
# 加载 BERT 模型 
classifier = keras_hub.models.BertClassifier.from_preset(
    "bert_base_en_uncased", 
    num_classes=2,
)

对于 KerasCV 开发者而言,有哪些变化?

如果您当前是 KerasCV 用户,更新到 KerasHub 将为您带来以下好处:

  • 简化模型加载:KerasHub 为加载模型提供一致的 API,如果您同时使用 KerasCV 和 KerasNLP,则可以简化代码。

  • 框架灵活性:如果您有兴趣探索 JAX 或 PyTorch 等不同框架,KerasHub 可以让您更轻松地使用 KerasCV 和 KerasNLP 模型。

  • 集中式存储区:利用 KerasHub 的统一模型存储区,您可以更轻松地查找和访问模型,并在未来在其中添加新的架构。


如何使我的代码适应 KerasHub?

模型

目前,我们正在将 KerasCV 模型迁移到 KerasHub。虽然大多数模型已经可用,但有些仍在迁移中。请注意,Centerpillar 模型不会被迁移。在 KerasHub 中,您可以通过输入以下内容来使用视觉模型:

import keras_hub
 
# 使用预设加载模型
Model = keras_hub.models.<model_name>.from_preset('preset_name`)
 
# or load a custom model by specifying the backbone and preprocessor
Model = keras_hub.models.<model_name>(backbone=backbone, preprocessor=preprocessor)

KerasHub 为 KerasCV 开发者引入了令人兴奋的新功能,从而提供了更大的灵活性和扩展功能。其中包括:


内置预处理

每个模型都附带一个定制的预处理器,可解决常规任务,包括调整大小、重新缩放等,从而简化工作流。

在此之前,输入预处理是在向模型提供输入之前手动执行的。

# 预处理输入示例
def preprocess_inputs(image, label):
    # 重新调整大小或对输入进行更多预处理
    return preprocessed_inputs
backbone = keras_cv.models.ResNet50V2Backbone.from_preset(
    "resnet50_v2_imagenet",
)
model = keras_cv.models.ImageClassifier(
    backbone=backbone,
    num_classes=4,
)
output = model(preprocessed_input)

目前,任务模型的预处理已集成到确定的预设中。预处理器会对输入进行预处理并对样本图像进行大小调整和重新缩放。虽然预处理器是任务模型的固有组件,但开发者仍然可以选择使用个性化的预处理器。

classifier = keras_hub.models.ImageClassifier.from_preset('resnet_18_imagenet')
classifier.predict(inputs)

损失函数

与增强层类似,以前在 KerasCV 中使用的损失函数现在可在 Keras 中通过 keras.losses.<loss_function> 进行使用。例如,如果您当前正在使用 FocalLoss 函数

import keras
import keras_cv
 
keras_cv.losses.FocalLoss(
    alpha=0.25, gamma=2, from_logits=False, label_smoothing=0, **kwargs
)

您只需要调整损失函数定义代码即可使用 keras.losses 而不是 keras_cv.losses

import keras
 
keras.losses.FocalLoss(
    alpha=0.25, gamma=2, from_logits=False, label_smoothing=0, **kwargs
)

开始使用 KerasHub

立即深入了解 KerasHub:


加入 Keras 社区,释放统一、可访问和高效深度学习模型的力量。AI 未来的发展方向是多模态 AI,KerasHub 便是通往多模态 AI 的门户!