Introducing Keras Hub: Your one-stop shop for pretrained models

OCT 22, 2024
Divyashree Sreepathihalli Software Engineer
Luciano Martins Developer Advocate Google AI

The world of deep learning is rapidly evolving, with pretrained models becoming increasingly crucial for a wide range of tasks. Keras, known for its user-friendly API and focus on accessibility, has been at the forefront of this movement with specialized libraries like KerasNLP for text-based models and KerasCV for computer vision models.

However, as models increasingly blur the lines between modalities – think of powerful chat LLMs with image inputs or vision tasks leveraging text encoders – maintaining these separate domains is less practical. The division between NLP and CV can hinder the development and deployment of truly multimodal models, leading to redundant efforts and a fragmented user experience.

keras-team/keras-hub, a unified, comprehensive library for pretrained models

To address this, we're excited to announce a major evolution in the Keras ecosystem: KerasHub, a unified, comprehensive library for pretrained models, streamlining access to both cutting-edge NLP and CV architectures. KerasHub is a central repository where you can seamlessly explore and utilize state-of-the-art models like BERT for text analysis alongside EfficientNet for image classification, all within a consistent and familiar Keras framework.


A unified developer experience

This unification not only simplifies model discovery and usage but also fosters a more cohesive ecosystem. With KerasHub, you can leverage advanced features like effortless model publishing and sharing, LoRA fine-tuning for resource-efficient adaptation, quantization for optimized performance, and robust multi-host training for tackling large-scale datasets, all applicable across diverse modalities. This marks a significant step towards democratizing access to powerful AI tools and accelerating the development of innovative multimodal applications.


First steps with KerasHub

Let's get started by installing KerasHub on your system. From there, you can explore the extensive collection of readily available models and different implementations of popular architectures. You'll then be ready to easily load and incorporate these pre-trained models into your own projects and fine-tune them for optimal performance according to your specific requirements.


Installing KerasHub

To install the latest KerasHub release with Keras 3, simply run:

$ pip install --upgrade keras-hub

Now you can start exploring the available models. The standard environment setup to start working with Keras 3 doesn't change at all to start using KerasHub:

import os

# Define the Keras 3 backend you want to use - "jax", "tensorflow" or "torch"
os.environ["KERAS_BACKEND"] = "jax"

# Import Keras 3 and KerasHub modules
import keras
import keras_hub

Using computer vision and natural language models with KerasHub

Now you are ready to start with KerasHub to access and use the models available at Keras 3 ecosystem. Some examples below:


Gemma

Gemma is a collection of cutting-edge, yet accessible, open models developed by Google. Leveraging the same research and technology behind the Gemini models, Gemma's base models excel at various text generation tasks. These include answering questions, summarizing information, and engaging in logical reasoning. Furthermore, they can be customized to address specific needs.

In this example you use Keras and KerasHub to load and start generating contents using Gemma2 2B parameters. For more details about Gemma variants, take a look at the Gemma model card at Kaggle.

# Load Gemma 2 2B preset from Kaggle models 
gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("gemma_2b_en")

# Start generating contents with Gemma 2 2B
gemma_lm.generate("Keras is a", max_length=32)

PaliGemma

PaliGemma is a compact, open model that understands both images and text. Drawing inspiration from PaLI-3 and built on open-source components like the SigLIP vision model and the Gemma language model, PaliGemma can provide detailed and insightful answers to questions about images. This allows for a deeper understanding of visual content, enabling capabilities such as generating captions for images and short videos, identifying objects, and even reading text within images.

import os

# Define the Keras 3 backend you want to use - "jax", "tensorflow" or "torch"
os.environ["KERAS_BACKEND"] = "jax"

# Import Keras 3 and KerasHub modules
import keras
import keras_hub
from keras.utils import get_file, load_img, img_to_array


# Import PaliGemma 3B fine tuned with 224x224 images
pali_gemma_lm = keras_hub.models.PaliGemmaCausalLM.from_preset(
    "pali_gemma_3b_mix_224"
)

# Download a test image and prepare it for usage with KerasHub
url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
img_path = get_file(origin=url)
img = img_to_array(load_img(image_path))

# Create the prompt with the question about the image
prompt = 'answer where is the cow standing?'

# Generate the contents with PaliGemma
output = pali_gemma_lm.generate(
    inputs={
        "images": img,
        "prompts": prompt,
    }
)

For more details about the available pre-trained models on Keras 3, check out the list of models in Keras on Kaggle.


Stability.ai Stable Diffusion 3

You have the computer vision models available for usage too. As an example you can use stability.ai Stable Diffusion 3 with KerasHub:

from PIL import Image
from keras.utils import array_to_img
from keras_hub.models import StableDiffusion3TextToImage

text_to_image = StableDiffusion3TextToImage.from_preset(
    "stable_diffusion_3_medium",
    height=1024,
    width=1024,
    dtype="float16",
)

# Generate images with SD3
image = text_to_image.generate(
    "photograph of an astronaut riding a horse, detailed, 8k",
)

# Display the generated image
img = array_to_img(image)
img

For more details about the available pre-trained computer vision models on Keras 3, check the list of models in Keras.


What changes for KerasNLP developers?

The transition from KerasNLP to KerasHub is a straightforward process. It solely requires updating the import statements from keras_nlp to keras_hub.

Example: Previously if you were importing keras_nlp to use a BERT model like below

import keras_nlp

# Load a BERT model 
classifier = keras_nlp.models.BertClassifier.from_preset(
    "bert_base_en_uncased", 
    num_classes=2,
)

Adjust the import, and you are ready to go with KerasHub:

import keras_hub

# Load a BERT model 
classifier = keras_hub.models.BertClassifier.from_preset(
    "bert_base_en_uncased", 
    num_classes=2,
)

What changes for KerasCV developers?

If you are a current KerasCV user, updating to KerasHub gives you these benefits:

  • Simplified Model Loading: KerasHub offers a consistent API for loading models, which can simplify your code if you're working with both KerasCV and KerasNLP.

  • Framework Flexibility: If you're interested in exploring different frameworks like JAX or PyTorch, KerasHub makes it easier to use KerasCV and KerasNLP models with them.

  • Centralized Repository: Finding and accessing models is easier with KerasHub's unified model repository and is where new architectures will be added in the future.


How to adapt my code to KerasHub?

Models

KerasCV models are currently being ported to KerasHub. While most are already available, a few are still a work in progress. Please note that the Centerpillar model will not be ported. You should be able to use any vision model in KerasHub with:

import keras_hub

# Load a model using preset
Model = keras_hub.models.<model_name>.from_preset('preset_name`)

# or load a custom model by specifying the backbone and preprocessor
Model = keras_hub.models.<model_name>(backbone=backbone, preprocessor=preprocessor)

KerasHub introduces exciting new features for KerasCV developers, offering greater flexibility and expanded capabilities. It includes:


Built in preprocessing

Each model is accompanied by a bespoke preprocessor that addresses routine tasks including resizing, rescaling, and more, streamlining your workflow.

Prior to this, the input preprocessing was performed manually prior to providing the inputs to the model.

# Preprocess inputs for example
def preprocess_inputs(image, label):
    # Resize rescale or do more preprocessing on inputs
    return preprocessed_inputs
backbone = keras_cv.models.ResNet50V2Backbone.from_preset(
    "resnet50_v2_imagenet",
)
model = keras_cv.models.ImageClassifier(
    backbone=backbone,
    num_classes=4,
)
output = model(preprocessed_input)

Currently, the task models' preprocessing is integrated within the established presets. The inputs undergo preprocessing, where sample images undergo resizing and rescaling within the preprocessor. The preprocessor is an intrinsic component of the task model. Notwithstanding, one has the option to utilize a personalized preprocessor.

classifier = keras_hub.models.ImageClassifier.from_preset('resnet_18_imagenet')
classifier.predict(inputs)

Loss functions

Similar to augmentation layers, loss functions previously in KerasCV are now available in Keras through keras.losses.<loss_function>. For example, if you are currently using FocalLoss function:

import keras
import keras_cv

keras_cv.losses.FocalLoss(
    alpha=0.25, gamma=2, from_logits=False, label_smoothing=0, **kwargs
)

You just need to adjust your loss function definition code to use keras.losses instead of keras_cv.losses:

import keras

keras.losses.FocalLoss(
    alpha=0.25, gamma=2, from_logits=False, label_smoothing=0, **kwargs
)

Get started with KerasHub

Dive into the world of KerasHub today:


Join the Keras community and unlock the power of unified, accessible, and efficient deep learning models. The future of AI is multimodal, and KerasHub is your gateway to it!