Todo sobre Gemma: novedades de Gemma 2

AGO 22, 2024
Ju-yeong Ji Gemma DevRel
Ravin Kumar Google Data Scientist Language Applications

En la publicación anterior de la serie Todo sobre Gemma, hablamos sobre la arquitectura de Gemma. Esta vez, explorarás el modelo más reciente, Gemma 2.


Gemma 2

Hace poco, lanzamos Gemma 2, nuestro nuevo y pionero grupo de modelos abiertos que establece un nuevo estándar de rendimiento y accesibilidad. Disponible en tamaños de parámetros 2B, 9B y 27B, Gemma 2 ya dejó huella. Nuestro modelo 27B ascendió rápidamente en la tabla de clasificación de LMSYS Chatbot Arena, en la cual superó incluso a modelos populares el doble de grandes por su capacidad de interacción y sus conversaciones del mundo real, con lo cual se convirtió en uno de los modelos abiertos mejor posicionados y más útiles. Por otro lado, el modelo Gemma 2 2B muestra su destreza excepcional para las conversaciones con IA, con lo cual superó a todos los modelos GPT-3.5 en Chatbot Arena y con un tamaño que le permite ejecutarse en dispositivos perimetrales.

Los desarrolladores pueden acceder a funciones de ajuste sólidas con Gemma 2 en diversas plataformas y herramientas. Ajuste de Gemma 2 se simplificó con soluciones basadas en la nube, como Google Cloud, y herramientas comunitarias populares, como Axolotl. La integración perfecta con socios como Hugging Face y NVIDIA TensorRT-LLM, así como JAX y Keras, permite la optimización del rendimiento y la implementación eficiente en diversas configuraciones de hardware.

Estos son los parámetros principales de los nuevos modelos:

Core parameters of new Gemma models, August 2024

Principales diferencias

Gemma 2 comparte una base arquitectónica similar con los modelos originales de Gemma, incluida la implementación de Rotary Positional Embedding (RoPE) y la no linealidad aproximada de GeGLU. Sin embargo, introduce innovaciones arquitectónicas que lo diferencian de sus predecesores.


Alternación entre atención local y global

En lugar de considerar todas las palabras de un texto a la vez, a veces pone el foco en un pequeño segmento (atención local) y, en otros casos, en el total (atención global). Esta combinación ayuda al modelo a comprender con eficacia tanto el contexto inmediato como el significado general del texto.


Logit soft-capping

Imagina que estás entrenando un modelo para predecir la siguiente palabra en una oración. A veces, el modelo puede estar demasiado seguro de una palabra en particular, incluso si no es la mejor opción. Para que eso no suceda, Logit soft-capping limita la confianza del modelo en sus predicciones, lo que lleva a un mejor rendimiento general.


RMSNorm para normalización previa y posterior

Piensa en esto como una forma de evitar que los cálculos del modelo se vuelvan demasiado grandes o demasiado pequeños durante el entrenamiento. Al igual que podríamos ajustar el volumen de un altavoz para evitar la distorsión, RMSNorm garantiza que la información que fluye a través del modelo se mantenga dentro de un rango razonable, lo que lleva a un entrenamiento más estable y efectivo.


Atención de consultas agrupadas (GQA)

Esta técnica ayuda al modelo a procesar información de manera más eficaz, especialmente, cuando se usan grandes cantidades de texto. Mejora el tradicional método de atención de hilos de ejecución múltiples (MHA) agrupando consultas, lo que habilita un procesamiento más rápido, sobre todo para los modelos grandes. Es como dividir una tarea grande en varias tareas más pequeñas y manejables, lo que le permite al modelo comprender más fácilmente las relaciones entre las palabras sin sacrificar la precisión.


Gemma 27B

Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 4608, padding_idx=0)
    (layers): ModuleList(
      (0-45): 46 x Gemma2DecoderLayer(
        (self_attn): Gemma2SdpaAttention(
          (q_proj): Linear(in_features=4608, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4608, out_features=2048, bias=False)
          (v_proj): Linear(in_features=4608, out_features=2048, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4608, bias=False)
          (rotary_emb): Gemma2RotaryEmbedding()
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=4608, out_features=36864, bias=False)
          (up_proj): Linear(in_features=4608, out_features=36864, bias=False)
          (down_proj): Linear(in_features=36864, out_features=4608, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm()
        (post_attention_layernorm): Gemma2RMSNorm()
        (pre_feedforward_layernorm): Gemma2RMSNorm()
        (post_feedforward_layernorm): Gemma2RMSNorm()
      )
    )
    (norm): Gemma2RMSNorm()
  )
  (lm_head): Linear(in_features=4608, out_features=256000, bias=False)
)
Gemma 27B architecture

self_attn

En el mecanismo de autoatención, Gemma 2 usa Atención de consultas agrupadas (GQA).

k_proj y v_proj comparten el mismo hilo de ejecución con un tamaño de 128 y 16 hilos de ejecución (128 x 16 = 2048). Por el contrario, q_proj y o_proj tienen 32 hilos de ejecución (128 x 32 = 4096) en paralelo.


Ten en cuenta que el modelo Gemma 9B usa la misma GQA, pero diferente número de hilos de ejecución (8 para k_proj y v_proj, 16 para q_proj y o_proj) y tamaño de hilo de ejecución (256)

(self_attn): Gemma2SdpaAttention(
          (q_proj): Linear(in_features=3584, out_features=4096, bias=False)
          (k_proj): Linear(in_features=3584, out_features=2048, bias=False)
          (v_proj): Linear(in_features=3584, out_features=2048, bias=False)
          (o_proj): Linear(in_features=4096, out_features=3584, bias=False)
          (rotary_emb): Gemma2RotaryEmbedding()
        )

El modelo 2B usa 4 para k_proj y v_proj, 8 para q_proj y o_proj y tamaño de hilo de ejecución (256)


pre_feedforward_layernorm y post_feedforward_layernorm

Otra diferencia significativa es la inclusión de RMSNorm adicional en Gemma 2, que mejora la estabilidad del proceso de entrenamiento.


Principales descubrimientos

Nuestro informe técnico proporciona amplios detalles, pero aquí hay un breve resumen de los principales hallazgos de Gemma 2:


Destilar vs. entrenar desde cero:

Entrenamos los modelos 2B y 9B con la destilación de conocimiento del modelo más grande (27B).

Destilar el conocimiento de un modelo más grande, incluso con un número igual de tokens de entrenamiento, conduce a mejoras significativas en el rendimiento.


Atención de consultas agrupadas vs. atención de hilos de ejecución múltiples:

Reemplazamos los resultados MHA con GQA en rendimiento comparable ofreciendo eficacia de parámetros y tiempos de inferencia más rápidos, por lo que GQA es la opción preferida.


Profundidad vs. amplitud del modelo:

Un modelo más profundo muestra un rendimiento ligeramente superior en comparación con un modelo más amplio con el mismo recuento de parámetros.


Lo que viene

En este artículo, aprendiste sobre Gemma 2, la nueva generación de modelos Gemma.

En nuestra próxima serie de publicaciones, examinarás RecurrentGemma, que es un modelo abierto basado en Griffin.

Si quieres conocer más sobre el fascinante mundo de la IA y ver estadísticas de los expertos que le están dando forma a su desarrollo, ve a goo.gle/ai-podcast o busca el programa People of AI Podcast en cualquier plataforma de podcasts.

¡No te pierdas las novedades y gracias por leer!



Referencias


Artículos


Ejemplos de código


📋 Serie completa sobre la arquitectura de Gemma