Siguiendo los pasos de Gemma 1.1 (Kaggle, Hugging Face), CodeGemma (Kaggle, Hugging Face) y el modelo multimodal PaliGemma (Kaggle, Hugging Face), nos complace anunciar el lanzamiento del modelo Gemma 2 en Keras.
Gemma 2 está disponible en dos tamaños, parámetros 9B y 27B, con variantes estándar y personalizadas. Puedes encontrarlos aquí:
Los increíbles resultados de Gemma 2 en los puntos de referencia de LLM ya se describen en otro lugar (consulta goo.gle/gemma2report). En esta entrada, nos gustaría mostrar cómo la combinación de Keras y JAX puede ayudarte a trabajar con modelos grandes.
JAX es un marco de trabajo numérico creado para escalar. Aprovecha el compilador de aprendizaje automático XLA y capacita a los modelos más grandes en Google.
Keras es el marco de trabajo de modelado para ingenieros de ML, que ahora se ejecuta en JAX, TensorFlow o PyTorch. Keras ahora lleva el modelo de potencia a escala paralela a través de su API. Puedes probar los nuevos modelos Gemma 2 en Keras aquí:
Debido a su tamaño, estos modelos solo se pueden cargar y ajustar con total precisión dividiendo sus pesos en múltiples aceleradores. JAX y XLA tienen un amplio soporte para la partición de pesos (paralelismo del modelo SPMD) y Keras agrega la API keras.distribution.ModelParallel
para ayudar a especificar fragmentos capa por capa de una manera simple:
# List accelerators
devices = keras.distribution.list_devices()
# Arrange accelerators in a logical grid with named axes
device_mesh = keras.distribution.DeviceMesh((2, 8), ["batch", "model"], devices)
# Tell XLA how to partition weights (defaults for Gemma)
layout_map = gemma2_lm.backbone.get_layout_map()
# Define a ModelParallel distribution
model_parallel = keras.distribution.ModelParallel(device_mesh, layout_map, batch_dim_name="batch")
# Set is as the default and load the model
keras.distribution.set_distribution(model_parallel)
gemma2_lm = keras_nlp.models.GemmaCausalLM.from_preset(...)
La función gemma2_lm.backbone.get_layout_map()
ayuda a mostrar una configuración de fragmentos capa por capa para todos los pesos del modelo. Sigue las recomendaciones de Gemma (goo.gle/gemma2report). Este es un extracto:
layout_map = keras.distribution.LayoutMap (device_mesh)
layout_map["token_embedding/embeddings"] = ("model", "data")
layout_map["decoder_block.*attention.*(query|key|value).kernel"] =
("model", "data", None)
layout_map["decoder_block.* attention_output.kernel"] = ("model", None, "data")
...
En pocas palabras, para cada capa, esta configuración especifica a lo largo de qué eje o ejes se debe dividir cada bloque de pesos y sobre qué aceleradores se deben colocar las piezas. Es más fácil de entender con una imagen. Tomemos como ejemplo los pesos de "consulta" en la arquitectura de atención del transformador, que tienen la forma (nb heads, embed size, head dim
):
Nota: Las dimensiones de malla para las que no hay divisiones recibirán copias. Este sería el caso, por ejemplo, si el mapa del diseño anterior fuera (“model”, None, None
).
Observa también el parámetro batch_dim_name="batch"
en ModelParallel
. Si el eje "batch" tiene varias filas de aceleradores, como es el caso aquí, también se utilizará el paralelismo de datos. Cada fila de aceleradores cargará y entrenará solo una parte de cada lote de datos, y luego las filas combinarán sus gradientes.
Luego de que se cargue el modelo, aquí hay dos extensiones de código útiles para mostrar los fragmentos de peso que se aplicaron realmente:
for variable in gemma2_lm.backbone.get_layer('decoder_block_1').weights:
print(f'{variable.path:<58} {str(variable.shape):<16} \
{str(variable.value.sharding.spec)}')
#... set an optimizer through gemma2_lm.compile() and then:
gemma2_lm.optimizer.build(gemma2_lm.trainable_variables)
for variable in gemma2_lm.optimizer.variables:
print(f'{variable.path:<73} {str(variable.shape):<16} \
{str(variable.value.sharding.spec)}')
Y si observamos el resultado (a continuación), notamos algo importante: las expresiones regulares en la especificación de diseño coincidieron no solo con los pesos de las capas, sino también con sus correspondientes variables de impulso y velocidad en el optimizador, y además las fragmentaron adecuadamente. Este es un punto importante que se debe verificar al dividir un modelo.
# for layers:
# weight name . . . . . . . . . . shape . . . . . . layout spec
decoder_block_1/attention/query/kernel (16, 3072, 256)
PartitionSpec('model', None, None)
decoder_block_1/ffw_gating/kernel (3072, 24576)
PartitionSpec(None, 'model')
...
# for optimizer vars:
# var name . . . . . . . . . . . .shape . . . . . . layout spec
adamw/decoder_block_1_attention_query_kernel_momentum
(16, 3072, 256) PartitionSpec('model', None, None)
adamw/decoder_block_1_attention_query_kernel_velocity
(16, 3072, 256) PartitionSpec('model', None, None)
...
LoRA es una técnica que congela los pesos del modelo y los reemplaza con adaptadores de baja clasificación, es decir, pequeños.
Keras también tiene APIs sencillas para esto:
gemma2_lm.backbone.enable_lora(rank=4) # Rank picked from empirical testing
Al mostrar los detalles del modelo con model.summary () después de habilitar LoRA, podemos ver que LoRA reduce el número de parámetros entrenables en Gemma 9B de 9 mil millones a 14.5 millones.
El mes pasado, anunciamos que los modelos de Keras estarían disponibles para que los carguen y descarguen los usuarios, tanto en Kaggle como en Hugging Face. Hoy, impulsamos aun más la integración de Hugging Face: ahora puedes cargar cualquier peso personalizado para los modelos compatibles, ya sea que hayan sido entrenados con una versión de Keras del modelo o no. Los pesos se convertirán sobre la marcha para que esto funcione. Esto significa que ahora tienes acceso a las docenas de ajustes de Gemma subidos por los usuarios de Hugging Face, directamente desde KerasNLP. Con el tiempo, esto funcionará con cualquier modelo de Hugging Face Transformers que tenga una implementación KerasNLP correspondiente. Por ahora, Gemma y Llama3 funcionan. Puedes probarlos en el ajuste personalizado Hermes-2-Pro-Llama-3-8B con este Colab:
causal_lm = keras_nlp.models.Llama3CausalLM.from_preset(
"hf://NousResearch/Hermes-2-Pro-Llama-3-8B"
)
PaliGemma es un potente VLM abierto inspirado en PaLI-3. Basado en componentes abiertos que incluyen el modelo de visión SigLIP y el modelo de lenguaje Gemma, PaliGemma está diseñado para lograr un rendimiento de ajuste líder en su clase en una amplia gama de tareas de lenguaje de visión. Esto incluye subtítulos de imágenes y videos, respuestas a preguntas visuales, comprensión del texto en las imágenes, detección de objetos y segmentación de objetos.
Puedes encontrar la implementación de Keras de PaliGemma en GitHub, modelos de Hugging Face y Kaggle.
Esperamos que disfrutes experimentando o compilando con los nuevos modelos Gemma 2 en Keras.