随着 AI 生态系统的持续发展,定义机器学习模型的方式越来越多,保存训练和微调所得模型权重的方法更是层出不穷。面对越来越多的选择,KerasHub 让您可以在不同机器学习框架间混合搭配流行的模型架构及其权重。
例如,Hugging Face Hub 是加载检查点的热门来源。其中许多模型检查点都是通过 Hugging Face transformers
库以 SafeTensors 格式创建的。无论使用哪种机器学习框架创建模型检查点,这些权重都可以加载到 KerasHub 模型中,让您能自由选择运行模型所用的框架(JAX、PyTorch 或 TensorFlow)。
没错,这意味着您可以在 JAX 上运行来自 Mistral 或 Llama 的检查点,甚至用 PyTorch 加载 Gemma,灵活性无出其右。
我们来深入了解其中一些术语,并探讨其实际应用原理。
加载模型时,我们需要两个独立部分:模型架构和模型权重(通常称为“检查点”)。我们分别来具体解释一下。
“模型架构”指的是模型各层的排列方式以及每层中发生的操作。换句话说,就是模型的“结构”。我们使用 PyTorch、JAX 或 Keras 等 Python 框架来表示模型结构。
而“模型权重”指的是模型的“参数”,即训练过程中变化的模型内部数值。这些权重的具体数值决定了训练后模型的特征。
“检查点”就是某一训练时刻模型权重值的快照。被广泛分享和使用的是那些达到优异训练效果的检查点文件。随着模型架构通过微调等方式被进一步优化,会产生新的检查点文件。例如,许多开发者基于 Google 的 gemma-2-2b-it 模型使用自己的数据集进行微调,目前已经有 600 多个此类模型。所有这些微调模型都使用与原始 gemma-2-2b-it 模型相同的架构,但其检查点的权重各不相同。
总结一下:模型架构通过代码定义,而模型权重是训练得到的参数,以检查点文件形式保存。当我们将模型架构与一组模型权重(以检查点文件形式)结合时,就得到了一个可以输出有用结果的实用模型。
Hugging Face 的 transformers 库和 Google 的 KerasHub 库等工具提供了模型架构及实验所需的 API。检查点存储库包括 Hugging Face Hub 和 Kaggle Models 等。
您可以自由组合模型架构库和所选的检查点存储库。例如,您可以将 Hugging Face Hub 的检查点加载到 JAX 模型架构中,然后使用 KerasHub 进行微调。对于其他任务,您可以在 Kaggle Models 找到合适的检查点。这种灵活性和分离性意味着您不会被束缚在单一生态系统中。
前文我们提到过几次 KerasHub,现在我们来详细介绍一下。
KerasHub 是一个 Python 库,能简化模型架构的定义过程。它收录了当今最流行和常用的许多机器学习模型,而且还在不断增加。由于 KerasHub 基于 Keras,所以它支持当前三大主流 Python 机器学习库:PyTorch、JAX 和 TensorFlow。这意味着您可以选择自己喜欢的库来定义模型架构。
此外,由于 KerasHub 支持最常见的检查点格式,您可以轻松从许多检查点存储库加载检查点。例如,Hugging Face 和 Kaggle 上就有数十万个检查点可加载到这些模型架构中。
transformers
库的对比开发者的常见工作流是使用 Hugging Face transformers
库微调模型并上传到 Hugging Face Hub。如果您是 transformers
的用户,也会在 KerasHub 中发现许多熟悉的 API 模式。详情请查阅 KerasHub API 文档。有趣的是,Hugging Face Hub 上许多检查点不仅兼容 transformers
库,也兼容 KerasHub。我们接下来看看它是如何做到的。
Hugging Face 有一个名为 Hugging Face Hub 的模型检查点存储库。这是机器学习社区上传以分享模型检查点的众多平台之一。Hugging Face 上特别流行的是与 KerasHub 兼容的 SafeTensors 格式。
只要模型架构可用,您就可以将这些检查点从 Hugging Face Hub 直接加载到 KerasHub 模型中。想知道您喜欢的模型是否可用?您可以查看 https://keras.io/keras_hub/presets/,获取支持的模型架构列表。别忘了,社区中基于这些架构微调后创建的所有检查点同样兼容!我们最近还发布了一份新指南来详细说明这一过程。
这一切是如何实现的?KerasHub 内置转换器,可以简化使用 Hugging Face transformers
模型的流程。这些转换器会自动将 Hugging Face 模型的检查点转换为与 KerasHub 兼容的格式。这意味着您只需几行代码,就能将各种预训练的 Hugging Face transformer 模型从 Hugging Face Hub 无缝加载到 KerasHub 中。
如果您发现模型架构缺失,可以在 GitHub 上提交拉取请求来添加。
那么如何将 Hugging Face Hub 的检查点加载到 KerasHub 中呢?我们来看几个具体示例。
我们首先需要选择一个机器学习库作为 Keras 的“后端”。以下示例使用的是 JAX,但您也可以选择 JAX、PyTorch 或 TensorFlow 中的任意一个。以下所有示例无论选择哪个后端都能工作。接着,我们导入 keras
、keras_hub
和 huggingface_hub
,并使用 Hugging Face 的用户访问令牌登录,以便访问模型检查点。
import os
os.environ["KERAS_BACKEND"] = "jax" # or "torch" or "tensorflow"
import keras
from keras_hub import models
from huggingface_hub import login
login('HUGGINGFACE_TOKEN')
首先,也许我们想在 JAX 上运行 Mistral 的检查点。在 KerasHub 的可用模型架构列表中,有几个 Mistral 模型可供选择。让我们试试 mistral_0.2_instruct_7b_en
。点击进入后可以看到,应使用 MistralCausalLM
类来调用 from_preset
。在 Hugging Face Hub 方面,对应的模型检查点存储在此处,有 900 多个微调版本。浏览列表会发现一个专注于网络安全的流行微调模型 Lily,路径名为 segolilylabs/Lily-Cybersecurity-7B-v0.2
。我们还需要在该路径之前添加“hf://
”,来指定 KerasHub 应从 Hugging Face Hub 查找。
将以上步骤整合后,代码如下:
# Model checkpoint from Hugging Face Hub
gemma_lm = models.MistralCausalLM.from_preset("hf://segolilylabs/Lily-Cybersecurity-7B-v0.2")
gemma_lm.generate("Lily, how do evil twin wireless attacks work?", max_length=30)
Llama 3.1-8B-Instruct 是一个非常流行的模型,单上个月下载量就超过 500 万次。我们来在 JAX 上运行它的一个微调版本。目前它有超过 1400 个微调检查点,选择非常丰富。其中 xVerify 的微调检查点看起来很有趣,让我们将其加载到 KerasHub 上的 JAX 中。
我们将使用 Llama3CausalLM 类来反映所使用的模型架构。与之前一样,需要将 Hugging Face Hub 中的相应路径添加 hf://
前缀。只需两行代码就能加载和调用一个模型,确实很方便,对吧?
# Model checkpoint from Hugging Face Hub
gemma_lm = models.Llama3CausalLM.from_preset("hf://IAAR-Shanghai/xVerify-8B-I")
gemma_lm.generate("What is the tallest building in NYC?", max_length=100)
最后,我们来将一个微调过的 Gemma-3-4b-it 检查点加载到 JAX 中。我们将使用 Gemma3CausalLM 类,并选择其中一个微调检查点,比如多语言翻译器 EraX。和之前一样,我们将使用带有 Hugging Face Hub 前缀的路径名构建完整路径 hf://erax-ai/EraX-Translator-V1.0
。
# Model checkpoint from Hugging Face Hub
gemma_lm = models.Gemma3CausalLM.from_preset("hf://erax-ai/EraX-Translator-V1.0")
gemma_lm.generate("Translate to German: ", max_length=30)
正如我们所展示的,模型架构不需要和权重绑定在一起,这意味着您可以自由组合来自不同库的架构和权重。
KerasHub 打通了不同框架与检查点存储库之间的壁垒。您可以将 Hugging Face Hub 中的模型检查点(即使是使用基于 PyTorch 的 transformers 库创建的)无缝加载到运行在您所选后端(JAX、TensorFlow 或 PyTorch)上的 Keras 模型中。这让您能利用大量社区微调模型,同时依然可以完全自由地选择运行所用的后端框架。
通过简化架构、权重和框架的混合搭配过程,KerasHub 赋予您简单而强大的灵活性来进行实验和创新。