Todo sobre Gemma: arquitectura RecurrentGemma

AGO 29, 2024
Ju-yeong Ji Gemma DevRel
Ravin Kumar Google Data Scientist Language Applications

En la publicación anterior de la serie Todo sobre Gemma, hablamos sobre la arquitectura de Gemma 2. Esta vez, explorarás la arquitectura RecurrentGemma.


RecurrentGemma 2B, 9B

RecurrentGemma se basa en Griffin, un modelo híbrido que combina recurrencias lineales cerradas con atención de ventana deslizante local. Este cambio mejora el cálculo y la memoria, y es más adecuado para indicaciones contextuales largas.

Griffin hybrid model architecture

Sin embargo, viene con la desventaja de menor rendimiento de «aguja en el pajar» debido al estado de tamaño fijo de la arquitectura Griffin. Si bien es posible proporcionar el texto completo de un libro como entrada, este enfoque podría no ser óptimo. Las redes neuronales recurrentes (RNN) pueden encontrar dificultades para aprender dependencias de largo alcance en secuencias extremadamente largas, y el modelo tiene una ventana contextual limitada. Esto significa que solo puede considerar de manera efectiva un cierto número de tokens anteriores al hacer predicciones.

Además, los modelos recurrentes aún no recibieron tanta atención en términos de optimizaciones de tiempo de inferencia en comparación con sus contrapartes transformadoras. Y hay menos investigación y apoyo de la comunidad disponible en comparación con la arquitectura de transformación bien establecida.

Por lo tanto, este modelo será muy valioso en escenarios en los que te preocupe agotar la ventana contextual de tu LLM. Como prioriza la información más reciente y descarta de manera estratégica los datos más antiguos, RecurrentGemma garantiza que el rendimiento del LLM se mantenga sólido a medida que se expanda el contexto.

A continuación, se muestra el diagrama de arquitectura para el modelo RecurrentGemma 2B.

Recurrent Gemma 2B model architecture

Griffin sigue el mismo patrón residual y bloque MLP que otras líneas de base transformadoras. Sin embargo, a diferencia de la línea de base transformadora MQA y el modelo Hawk, Griffin utiliza una combinación de bloques recurrentes y MQA.

Layered structure of recurrent and MQA blocks

Griffin utiliza una estructura en capas que alterna dos bloques residuales con un bloque recurrente, seguido de un bloque residual que incorpora el bloque de atención MQA local.

Los parámetros principales de la arquitectura se resumen en la siguiente tabla.

Core parameters of the architecture of 2B and 9B models

Parámetros incrustados y no incrustados

Los parámetros no incrustados se distribuyen por las capas ocultas del modelo, en componentes como mecanismos de atención y redes de prealimentación.

Nota: El nombre del modelo "2B" proviene de este parámetro.

Los parámetros incrustados suelen encontrarse en la capa dedicada denominada «capa de incrustación». Esta capa es responsable de asignar tokens discretos (como palabras o caracteres) en representaciones vectoriales continuas (incrustaciones).

Nota: 0.7B se puede calcular como 256,000 (tamaño de vocabulario) x 2,560 (ancho del modelo)


Ancho del modelo y ancho de RNN

El ancho del modelo hace referencia al tamaño de sus capas ocultas, lo que determina la capacidad del modelo para representar patrones complejos, como los modelos Gema de base.

El ancho de la red neuronal recurrente (RNN) es el tamaño del estado oculto mantenido por la RG-LRU (Real-Gated Linear Recurrent Unit). A diferencia de los transformadores tradicionales, el bloque recurrente mantiene un estado interno de tamaño fijo, independientemente de la longitud de entrada. Esto le permite a RecurrentGemma procesar secuencias más largas con menos memoria, lo que lo hace más eficiente para tareas como generar artículos o código largos.


Factor de expansión de MLP

Es igual a las dimensiones ocultas de prealimentación del modelo Gemma de base. Por razones de simplicidad, aplicamos un factor de expansión de 3 en el modelo RecurrentGemma, lo que resultó en una dimensión de MLP de 7,680 (que se calcula como 2,560 x 3).


Tamaño de la ventana de atención local

El estado mantenido por RecurrentGemma tiene un tamaño finito y no crece con secuencias más largas que la ventana de atención local de 2,000 tokens. Esto significa que, si bien la longitud máxima de las muestras generadas de forma autorregresiva por Gemma está limitada por la capacidad de memoria del sistema host, RecurrentGemma puede generar secuencias de longitud arbitraria y superar esta restricción.

RecurrentGemmaForCausalLM(
  (model): RecurrentGemmaModel(
    (embed_tokens): Embedding(256000, 2560, padding_idx=0)
    (layers): ModuleList(
      (0-1): 2 x RecurrentGemmaDecoderLayer(
        (temporal_pre_norm): RecurrentGemmaRMSNorm()
        (temporal_block): RecurrentGemmaRecurrentBlock(
          (linear_y): Linear(in_features=2560, out_features=2560, bias=True)
          (linear_x): Linear(in_features=2560, out_features=2560, bias=True)
          (linear_out): Linear(in_features=2560, out_features=2560, bias=True)
          (conv_1d): Conv1d(2560, 2560, kernel_size=(4,), stride=(1,), padding=(3,), groups=2560)
          (rg_lru): RecurrentGemmaRglru()
          (act_fn): PytorchGELUTanh()
        )
        (channel_pre_norm): RecurrentGemmaRMSNorm()
        (mlp_block): RecurrentGemmaMlp(
          (gate_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (up_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (down_proj): Linear(in_features=7680, out_features=2560, bias=True)
          (act_fn): PytorchGELUTanh()
        )
      )
      (2): RecurrentGemmaDecoderLayer(
        (temporal_pre_norm): RecurrentGemmaRMSNorm()
        (temporal_block): RecurrentGemmaSdpaAttention(
          (q_proj): Linear(in_features=2560, out_features=2560, bias=False)
          (k_proj): Linear(in_features=2560, out_features=256, bias=False)
          (v_proj): Linear(in_features=2560, out_features=256, bias=False)
          (o_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (rotary_emb): RecurrentGemmaRotaryEmbedding()
        )
        (channel_pre_norm): RecurrentGemmaRMSNorm()
        (mlp_block): RecurrentGemmaMlp(
          (gate_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (up_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (down_proj): Linear(in_features=7680, out_features=2560, bias=True)
          (act_fn): PytorchGELUTanh()
        )
      )
 
      :
 
      (23): RecurrentGemmaDecoderLayer(
        (temporal_pre_norm): RecurrentGemmaRMSNorm()
        (temporal_block): RecurrentGemmaSdpaAttention(
          (q_proj): Linear(in_features=2560, out_features=2560, bias=False)
          (k_proj): Linear(in_features=2560, out_features=256, bias=False)
          (v_proj): Linear(in_features=2560, out_features=256, bias=False)
          (o_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (rotary_emb): RecurrentGemmaRotaryEmbedding()
        )
        (channel_pre_norm): RecurrentGemmaRMSNorm()
        (mlp_block): RecurrentGemmaMlp(
          (gate_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (up_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (down_proj): Linear(in_features=7680, out_features=2560, bias=True)
          (act_fn): PytorchGELUTanh()
        )
      )
      (24-25): 2 x RecurrentGemmaDecoderLayer(
        (temporal_pre_norm): RecurrentGemmaRMSNorm()
        (temporal_block): RecurrentGemmaRecurrentBlock(
          (linear_y): Linear(in_features=2560, out_features=2560, bias=True)
          (linear_x): Linear(in_features=2560, out_features=2560, bias=True)
          (linear_out): Linear(in_features=2560, out_features=2560, bias=True)
          (conv_1d): Conv1d(2560, 2560, kernel_size=(4,), stride=(1,), padding=(3,), groups=2560)
          (rg_lru): RecurrentGemmaRglru()
          (act_fn): PytorchGELUTanh()
        )
        (channel_pre_norm): RecurrentGemmaRMSNorm()
        (mlp_block): RecurrentGemmaMlp(
          (gate_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (up_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (down_proj): Linear(in_features=7680, out_features=2560, bias=True)
          (act_fn): PytorchGELUTanh()
        )
      )
    )
    (final_norm): RecurrentGemmaRMSNorm()
  )
  (lm_head): Linear(in_features=2560, out_features=256000, bias=False)
)

embed_tokens (capa de incrustación)

Toma el texto de entrada como una secuencia de tokens y asigna cada token a una representación vectorial continua con un tamaño de 2,560. Tiene un tamaño de vocabulario de 256,000, que es el mismo que los modelos Gemma de base.


layers

Hay 26 capas del decodificador en total, que se agrupan en patrones repetidos.

El modelo comienza con dos bloques residuales con un bloque recurrente (0-1). A esta secuencia le sigue un bloque residual (2) y una serie de bloques continuos que se alternan hasta el final de la capa (25).

Recurrent block architecture

Bloque residual con un bloque recurrente

En el bloque recurrente (bloque de combinación temporal), el modelo toma la entrada de dimensión (ancho del modelo) 2,560 y aplica dos capas lineales con dimensión de salida (ancho RNN) 2,560 en paralelo, lo que crea dos ramas.

En la primera rama (lado derecho), se aplica una capa pequeña separable Conv1D con una dimensión de filtro temporal de 4. Y la sigue la RG-LRU (Real-Gated Linear Recurrent Unit).

En la segunda rama (lado izquierdo), se aplica una no linealidad GeLU.

Luego, se combinan las ramas por multiplicación de elementos y se aplica una capa lineal final con dimensión de salida (ancho del modelo) 2,560.

RecurrentGemma-Residual-block

Después de implementar RMSNorm, sigue el bloque MLP.


Bloque residual con un MQA local

Después de tener dos bloques residuales con un bloque recurrente (0-1), sigue un bloque residual con un MQA local (2). Una de las desventajas clave de usar la atención global es que su complejidad computacional crece cuadráticamente en la longitud de la secuencia. Para resolver ese problema, RecurrentGemma utiliza una ventana deslizante de atención local. Permite que cada posición aborde solo una cantidad fija de tokens anteriores.

En el bloque MQA (bloque de combinación temporal) local, el modelo toma la entrada de la dimensión (ancho del modelo) 2,560. Usa proyecciones lineales (q_proj, k_proj, v_proj, o_proj) para crear representaciones de consultas, claves, valores y salida. Ten en cuenta que out_features para k_proj y v_proj es 256, ya que comparten el mismo hilo de ejecución con un tamaño de 256, mientras que q_proj y o_proj tiene 10 hilos de ejecución (256 x 10 = 2,560) en paralelo.

Incorpora rotary_emb (RecurrentGemmaRotaryEmbedding) para incrustaciones posicionales giratorias (RoPE) al igual que los modelos Gemma de base.

La aplicación de RMSNorm y del bloque MLP es la misma que la del bloque residual anterior.


Lo que viene

En este artículo, aprendiste sobre RecurrentGemma.

En la próxima publicación, explorarás PaliGemma, que es un modelo de lenguaje de visión (VLM) ligero abierto.

¡No te pierdas las novedades y gracias por leer!


Referencias

Artículos


Ejemplos de código


📋 Serie completa sobre la arquitectura de Gemma