Presentamos los modelos de Gemma en Keras

FEB 21, 2024
Martin Görner Product Manager Keras

El equipo de Keras se complace en anunciar que Gemma, una familia de modelos abiertos ligeros y de última generación creados a partir de la misma investigación y tecnología que utilizamos para crear los modelos Gemini, ya está disponible en la colección KerasNLP. Gracias a Keras 3, Gemma funciona con JAX, PyTorch y TensorFlow. Con esta versión, Keras también presenta varias funciones nuevas diseñadas específicamente para modelos de lenguaje grandes: una nueva API LoRA (Low Rank Adaptation) y capacidades de entrenamiento paralelo de modelos a gran escala.

Si deseas sumergirte directamente en los ejemplos de código, dirígete aquí:

Primeros pasos

Los modelos de Gemma vienen en tamaños de parámetros portátiles de 2B y 7B, y ofrecen avances significativos en comparación con modelos abiertos similares, e incluso algunos más grandes. Por ejemplo:

  • Gemma 7B obtiene una nueva mejor puntuación en su clase: 64,3% de respuestas correctas en el punto de referencia de comprensión del lenguaje MMLU (en comparación con los 62,5% de Mistral-7B y 54,8% de Llama2-13B).
  • Gemma agrega más de 11 puntos porcentuales a la puntuación de referencia GSM8K relacionada con problemas matemáticos de la escuela primaria (46,4% de Gemma 7B en comparación con los 35,4% de Mistral-7B y los 28,7% de Llama2-13B).
  • Más de 6,1 puntos porcentuales de respuestas correctas en HumanEval, un desafío de codificación (32,3% de Gemma 7B en comparación con los 26,2% de Mistral 7B y los 18,3% de Llama2 13B).

Los modelos de Gemma se ofrecen con una API de KerasNLP familiar y una implementación de Keras superlegible. Es posible crear una instancia del modelo con una sola línea de código:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")

Y ejecútalo directamente en un mensaje de texto (Sí. La tokenización está integrada, aunque puedes dividirla fácilmente si fuera necesario). Consulta la guía de PNL de Keras para ver cómo hacerlo.

gemma_lm.generate("Keras is a", max_length=32)
> "Keras is a popular deep learning framework for neural networks..."

Pruébalo aquí: Comienza a usar los modelos de Gemma

Gracias a Keras 3, puedes elegir el backend en el que ejecutas el modelo. A continuación, te mostramos cómo cambiarlo:

os.environ["KERAS_BACKEND"] = "jax" # Or "tensorflow" or "torch".
import keras # import keras after having selected the backend

Keras 3 viene con varias características nuevas específicamente aplicables a modelos de lenguaje grandes. La principal es una nueva API LoRA (Low Rank Adaptation) disponible para lograr un ajuste fino para aumentar la eficacia de los parámetros. A continuación, te mostramos cómo activarla:

gemma_lm.backbone.enable_lora(rank=4)
# Note: rank=4 replaces the weights matrix of relevant layers with the
# product AxB of two matrices of rank 4, which reduces the number of
# trainable parameters.

¡Esta línea única reduce el número de parámetros entrenables de 2.500 millones a 1,3 millones!

Pruébalo aquí: Perfecciona los modelos de Gemma con LoRA.

Ajuste fino de los modelos de Gemma en varias GPU/TPU

Keras 3 también es compatible con el entrenamiento de modelos a gran escala y Gemma es el modelo perfecto para probarlo. La nueva API de distribución de Keras ofrece opciones de entrenamiento distribuido paralelo de datos y modelos. La nueva API está destinada a ser multibackend, pero, por el momento, se implementa solo para el backend de JAX, debido a su probada escalabilidad (se entrenaron los modelos Gemma con JAX).

Para afinar el Gemma 7B más grande, resulta útil una configuración distribuida, por ejemplo, un TPUv3 con 8 núcleos de TPU que puedes obtener gratis en Kaggle, o una máquina de 8 GPU de Google Cloud. A continuación, se muestra cómo configurar el modelo para el entrenamiento distribuido, para lo que se utiliza el paralelismo de modelos:

device_mesh = keras.distribution.DeviceMesh(
   (1, 8), # Mesh topology
   ["batch", "model"], # named mesh axes
   devices=keras.distribution.list_devices() # actual accelerators
)
 
 
# Model config
layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = (None, "model")
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
   None, "model", None)
layout_map["decoder_block.*attention_output.*kernel"] = (
   None, None, "model")
layout_map["decoder_block.*ffw_gating.*kernel"] = ("model", None)
layout_map["decoder_block.*ffw_linear.*kernel"] = (None, "model")
 
 
# Set the model config and load the model
model_parallel = keras.distribution.ModelParallel(
   device_mesh, layout_map, batch_dim_name="batch")
keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
# Ready: you can now train with model.fit() or generate text with generate()

Lo que hace este fragmento de código es configurar los 8 aceleradores en una matriz de 1 x 8, donde las dos dimensiones se llaman “batch” y “model”. Los pesos del modelo se dividen en la dimensión “model”, aquí dividida entre los 8 aceleradores, mientras que los lotes de datos no se dividen, ya que la dimensión “batch” es 1.

Perfecciona los modelos de Gemma en varias GPU/TPU

Lo que viene

Pronto publicaremos una guía en la que te mostraremos cómo dividir correctamente un modelo de transformador y escribir las 6 líneas de configuración de partición anteriores. No es muy larga, pero no cabría en este post.

Habrás notado que las particiones de capa se definen a través de regexes en los nombres de capa. Puedes verificar los nombres de las capas con este fragmento de código. Ejecutamos esto para construir el LayoutMap anterior.

# This is for the first Transformer block only,
# but they all have the same structure
tlayer = gemma_lm.backbone.get_layer('decoder_block_0')
for variable in tlayer.weights:
 print(f'{variable.path:<58}  {str(variable.shape):<16}')

El paralelismo completo de modelos de GSPMD funciona aquí con solo unas pocas sugerencias de partición, porque Keras pasa estas configuraciones al potente compilador XLA, que resuelve todos los demás detalles del cálculo distribuido.

Esperamos que disfrutes jugando con los modelos de Gemma. Aquí también hay un instructivo con instrucciones de ajuste, que puede resultarte útil. Y, por cierto, si quieres compartir tus pesos ajustados con la comunidad, el Kaggle Model Hub ahora admite cargas de pesos ajustadas por el usuario. ¡Visita la página de modelos de Gemma en Kaggle y ve lo que otras personas ya crearon!