Using KerasHub for easy end-to-end machine learning workflows with Hugging Face

2025년 6월 24일
Yufeng Guo Developer Advocate
Divyashree Sreepathihalli Software Engineer
Monica Song Product Manager

How to load SafeTensors checkpoints across different frameworks


As the AI ecosystem continues to evolve, there are more and more ways to define machine learning models, and even more ways to save the model weights that result from training and fine-tuning. In this growing set of choices, KerasHub allows you to mix and match popular model architectures and their weights across different ML frameworks.

For example, a popular place to load checkpoints from is the Hugging Face Hub. Many of those model checkpoints were created with the Hugging Face transformers library in the SafeTensors format. Regardless of what ML framework was used to create the model checkpoint, those weights can be loaded into a KerasHub model, which allows you to use your choice of framework (JAX, PyTorch, or TensorFlow) to run the model.

Yes, that means you can run a checkpoint from Mistral or Llama on JAX, or even load Gemma with PyTorch – it doesn't get any more flexible than that.

Let's take a look at some of these terms in more detail, and talk about how this works in practice.


Model architecture vs. model weights

When loading models, there are two distinct parts that we need: the model architecture and the model weights (often called "checkpoints"). Let's define each of these in more detail.

When we say "model architecture", we are referring to how the layers of the model are arranged, and the operations that happen within them. Another way to describe this might be to call it the "structure" of the model. We use Python frameworks like PyTorch, JAX, or Keras to express model architectures.

When we talk about "model weights", we are referring to the "parameters" of a model, or numbers in a model that are changed over the course of training. The particular values of these weights are what give a trained model its characteristics.

"Checkpoints" are a snapshot of the values of the model weights at a particular point in the training. The typical checkpoint files that are shared and widely used are the ones where the model has reached a particularly good training outcome. As the same model architecture is further refined with fine-tuning and other techniques, additional new checkpoint files are created. For example, many developers have taken Google's gemma-2-2b-it model and fine-tuned it with their own datasets, and you can see over 600 examples. All of these fine-tuned models use the same architecture as the original gemma-2-2b-it model, but their checkpoints have differing weights.

So there we have it: the model architecture is described with code, while model weights are trained parameters, saved as checkpoint files. When we have a model architecture together with a set of model weights (in the form of a checkpoint file), we create a functioning model that produces useful outputs.

Different model weights can be loaded into the same model architecture. These different sets of weights are saved as checkpoints.

Tools like Hugging Face's transformers library and Google's KerasHub library provide model architectures and the APIs you need to experiment with them. Examples of checkpoint repositories include Hugging Face Hub and Kaggle Models.

You can mix and match model architecture libraries with your choice of checkpoint repositories. For example, you can load a checkpoint from Hugging Face Hub into a JAX model architecture and fine-tune it with KerasHub. For a different task, you might find a checkpoint on Kaggle Models that's suitable for your needs. This flexibility and separation means you are not boxed into one ecosystem.


What is KerasHub?

So we’ve mentioned KerasHub a few times– let’s go into it in more detail.

KerasHub is a Python library that helps make defining model architectures easier. It contains many of the most popular and commonly used machine learning models today, and more are being added all the time. Because it's based on Keras, KerasHub supports all three major Python machine learning libraries used today: PyTorch, JAX, and TensorFlow. This means you can have model architectures defined in whichever library you'd like.

Furthermore, since KerasHub supports the most common checkpoint formats, you can easily load checkpoints from many checkpoint repositories. For example, you can find hundreds of thousands of checkpoints on Hugging Face and Kaggle to load into these model architectures.


Comparisons to the Hugging Face transformers library

A common workflow by developers is to use the Hugging Face transformers library to fine-tune a model and upload it to the Hugging Face Hub. And if you’re a user of transformers, you’ll also find many familiar API patterns in KerasHub. Check out the KerasHub API documentation to learn more. An interesting aspect of KerasHub is that many of the checkpoints found on Hugging Face Hub are compatible with not only the transformers library, but also KerasHub. Let's take a look at how that works.


KerasHub is compatible with Hugging Face Hub

Hugging Face has a model checkpoint repository, called Hugging Face Hub. It's one of the many places where the machine learning community uploads their model checkpoints to share with the world. Especially popular on Hugging Face is the SafeTensors format, which is compatible with KerasHub.

You can load these checkpoints from Hugging Face Hub directly into your KerasHub model, as long as the model architecture is available. Wondering if your favorite model is available? You can check https://keras.io/keras_hub/presets/ for a list of supported model architectures. And don't forget, all the community created fine-tuned checkpoints of these model architectures are also compatible! We recently created a new guide to help explain the process in more detail.

How does this all work? KerasHub has built-in converters that simplify the use of Hugging Face transformers models. These converters automatically handle the process of translating Hugging Face model checkpoints into a format that's compatible with the KerasHub. This means you can seamlessly load a wide variety of pretrained Hugging Face transformer models from the Hugging Face Hub directly into KerasHub with just a few lines of code.

If you notice a missing model architecture, you can add it by filing a pull request on GitHub.


How to load a Hugging Face Hub checkpoint into KerasHub

So how do we get checkpoints from Hugging Face Hub loaded into KerasHub? Let's take a look at some concrete examples.

We'll start by first choosing our machine learning library as our Keras "backend". We'll use JAX in the examples shown, but you can choose between JAX, PyTorch, or TensorFlow for any of them. All the examples below work regardless of which one you choose. Then we can proceed by importing keras, keras_hub, and huggingface_hub, and then login with our Hugging Face User Access token so we can access the model checkpoints.

import os
os.environ["KERAS_BACKEND"] = "jax" # or "torch" or "tensorflow"

import keras
from keras_hub import models
from huggingface_hub import login
login('HUGGINGFACE_TOKEN')
Python

Put a Mistral model on JAX

First up, perhaps we want to run a checkpoint from Mistral on JAX? Over on KerasHub, there are a handful of Mistral models available on KerasHub's list of available model architectures, let's try out mistral_0.2_instruct_7b_en. Clicking into it, we see that we should use the MistralCausalLM class to call from_preset. On the Hugging Face Hub side of things, we see that the corresponding model checkpoint is stored here, with over 900 fine-tuned versions. Browsing that list, there's a popular cybersecurity-focused fine-tuned model called Lily, with the pathname of segolilylabs/Lily-Cybersecurity-7B-v0.2. We'll also need to add "hf://" before that path to specify that KerasHub should look at Hugging Face Hub.

Putting it all together, we get the following code:

# Model checkpoint from Hugging Face Hub
gemma_lm = models.MistralCausalLM.from_preset("hf://segolilylabs/Lily-Cybersecurity-7B-v0.2")
gemma_lm.generate("Lily, how do evil twin wireless attacks work?", max_length=30)
Python

Running Llama 3.1 on JAX

Llama 3.1-8B-Instruct is a popular model, with over 5 million downloads last month. Let's put a fine-tuned version on JAX. With over 1400 fine-tuned checkpoints, there's no lack of choice. The xVerify fine-tuned checkpoint looks interesting, let's load that into JAX on KerasHub.

We'll use the Llama3CausalLM class to reflect the model architecture that we are using. As before, we'll need the appropriate path from Hugging Face Hub, prefixed with "hf://". It's pretty amazing that we can load and call a model with just two lines of code, right?

# Model checkpoint from Hugging Face Hub
gemma_lm = models.Llama3CausalLM.from_preset("hf://IAAR-Shanghai/xVerify-8B-I")
gemma_lm.generate("What is the tallest building in NYC?", max_length=100)
Python

Load Gemma on JAX

Finally, let's load a fine-tuned Gemma-3-4b-it checkpoint into JAX. We'll use the Gemma3CausalLM class, and select one of the fine-tuned checkpoints. How about EraX, a multilingual translator? As before, we'll use the pathname with the Hugging Face Hub prefix to create the full path of "hf://erax-ai/EraX-Translator-V1.0".

# Model checkpoint from Hugging Face Hub
gemma_lm = models.Gemma3CausalLM.from_preset("hf://erax-ai/EraX-Translator-V1.0")
gemma_lm.generate("Translate to German: ", max_length=30)
Python

Flexibility at your fingertips

As we've explored, a model's architecture does not need to be tied to its weights, which means you can combine architectures and weights from different libraries.

KerasHub bridges the gap between different frameworks and checkpoint repositories. You can take a model checkpoint from Hugging Face Hub — even one created using the PyTorch-based transformers library—and seamlessly load it into a Keras model running on your choice of backend: JAX, TensorFlow, or PyTorch. This allows you to leverage a vast collection of community fine-tuned models, while still having full choice over which backend framework to run on.

By simplifying the process of mixing and matching architectures, weights, and frameworks, KerasHub empowers you to experiment and innovate with simple, yet powerful flexibility.