The Keras team is happy to announce that Gemma, a family of lightweight, state-of-the art open models built from the same research and technology that we used to create the Gemini models, is now available in the KerasNLP collection. Thanks to Keras 3, Gemma runs on JAX, PyTorch and TensorFlow. With this release, Keras is also introducing several new features specifically designed for large language models: a new LoRA API (Low Rank Adaptation) and large scale model-parallel training capabilities.
If you want to dive directly into code samples, head here:
Gemma models come in portable 2B and 7B parameter sizes, and deliver significant advances against similar open models, and even some larger ones. For example:
Gemma models are offered with a familiar KerasNLP API and a super-readable Keras implementation. You can instantiate the model with a single line of code:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
And run it directly on a text prompt – yes, tokenization is built-in, although you can easily split it out if needed - read the Keras NLP guide to see how.
gemma_lm.generate("Keras is a", max_length=32)
> "Keras is a popular deep learning framework for neural networks..."
Try it out here: Get started with Gemma models
Thanks to Keras 3, you can choose the backend on which you run the model. Here is how to switch:
os.environ["KERAS_BACKEND"] = "jax" # Or "tensorflow" or "torch".
import keras # import keras after having selected the backend
Keras 3 comes with several new features specifically for large language models. Chief among them is a new LoRA API (Low Rank Adaptation) for parameter-efficient fine-tuning. Here is how to activate it:
gemma_lm.backbone.enable_lora(rank=4)
# Note: rank=4 replaces the weights matrix of relevant layers with the
# product AxB of two matrices of rank 4, which reduces the number of
# trainable parameters.
This single line drops the number of trainable parameters from 2.5 billion to 1.3 million!
Try it out here: Fine-tune Gemma models with LoRA.
Keras 3 also supports large-scale model training and Gemma is the perfect model to try it out. The new Keras distribution API offers data-parallel and model-parallel distributed training options. The new API is meant to be multi-backend but for the time being, it is implemented for the JAX backend only, because of its proven scalability (Gemma models were trained with JAX).
To fine-tune the larger Gemma 7B, a distributed setup is useful, for example a TPUv3 with 8 TPU cores that you can get for free on Kaggle, or an 8-GPU machine from Google Cloud. Here is how to configure the model for distributed training, using model parallelism:
device_mesh = keras.distribution.DeviceMesh(
(1, 8), # Mesh topology
["batch", "model"], # named mesh axes
devices=keras.distribution.list_devices() # actual accelerators
)
# Model config
layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = (None, "model")
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
None, "model", None)
layout_map["decoder_block.*attention_output.*kernel"] = (
None, None, "model")
layout_map["decoder_block.*ffw_gating.*kernel"] = ("model", None)
layout_map["decoder_block.*ffw_linear.*kernel"] = (None, "model")
# Set the model config and load the model
model_parallel = keras.distribution.ModelParallel(
device_mesh, layout_map, batch_dim_name="batch")
keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
# Ready: you can now train with model.fit() or generate text with generate()
What this code snippet does is set up the 8 accelerators into a 1 x 8 matrix where the two dimensions are called “batch” and “model”. Model weights are sharded on the “model” dimension, here split between the 8 accelerators, while data batches are not partitioned since the “batch” dimension is 1.
Try it out here: Fine-tune Gemma models on multiple GPUs/TPUs.
We will soon be publishing a guide showing you how to correctly partition a Transformer model and write the 6 lines of partitioning setup above. It is not very long but it would not fit in this post.
You will have noticed that layer partitionings are defined through regexes on layer names. You can check layer names with this code snippet. We ran this to construct the LayoutMap above.
# This is for the first Transformer block only,
# but they all have the same structure
tlayer = gemma_lm.backbone.get_layer('decoder_block_0')
for variable in tlayer.weights:
print(f'{variable.path:<58} {str(variable.shape):<16}')
Full GSPMD model parallelism works here with just a few partitioning hints because Keras passes these settings to the powerful XLA compiler which figures out all the other details of the distributed computation.
We hope you will enjoy playing with Gemma models. Here is also an instruction-tuning tutorial that you might find useful. And by the way, if you want to share your fine-tuned weights with the community, the Kaggle Model Hub now supports user-tuned weights uploads. Head to the model page for Gemma models on Kaggle and see what others have already created!