Todo sobre Gemma: arquitectura PaliGemma

SEP 05, 2024
Ju-yeong Ji Gemma DevRel
Ravin Kumar Google Data Scientist Language Applications

En la entrada anterior de Todo sobre Gemma, analizamos la arquitectura RecurrentGemma. En esta entrada del blog, exploraremos la arquitectura PaliGemma. Veamos de qué se trata.


PaliGemma 3B

PaliGemma es un modelo de lenguaje-visión (VLM) ligero y de código abierto, inspirado en PaLI-3 y basado en componentes abiertos como el modelo de visión SigLIP y el modelo de lenguaje Gemma. Pali hace referencia a la sigla inglesa correspondiente a Pathway Language and Image Model (lenguaje de ruta y modelo de imágenes). Como su nombre lo indica, este modelo puede tomar entradas de imágenes y texto y producir una respuesta de texto, como se puede ver en esta guía de ajuste.


Arquitectura PaliGemma

PaliGemma agrega un modelo de visión adicional al modelo BaseGemma, que consiste en un codificador de imágenes. Este codificador junto con los tokens de texto se pasa a un modelo Gemma 2B especializado. El modelo de visión y el modelo Gemma se entrenan en varias etapas, tanto de forma independiente como en conjunto, para producir la arquitectura conjunta final. Si quieres obtener más detalles, consulta la sección 3.2 del artículo Pali-3.

Joint architecture of the Vision Model and Gemma 2B model
PaliGemmaForConditionalGeneration(
  (vision_tower): SiglipVisionModel(
    (vision_model): SiglipVisionTransformer(
      (embeddings): SiglipVisionEmbeddings(
        (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
        (position_embedding): Embedding(256, 1152)
      )
      (encoder): SiglipEncoder(
        (layers): ModuleList(
          (0-26): 27 x SiglipEncoderLayer(
            (self_attn): SiglipAttention(
              (k_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (v_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (q_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (out_proj): Linear(in_features=1152, out_features=1152, bias=True)
            )
            (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
            (mlp): SiglipMLP(
              (activation_fn): PytorchGELUTanh()
              (fc1): Linear(in_features=1152, out_features=4304, bias=True)
              (fc2): Linear(in_features=4304, out_features=1152, bias=True)
            )
            (layer_norm2): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
          )
        )
      )
      (post_layernorm): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
    )
  )
  (multi_modal_projector): PaliGemmaMultiModalProjector(
    (linear): Linear(in_features=1152, out_features=2048, bias=True)
  )
  (language_model): GemmaForCausalLM(
    (model): GemmaModel(
      (embed_tokens): Embedding(257216, 2048, padding_idx=0)
      (layers): ModuleList(
        (0-17): 18 x GemmaDecoderLayer(
          (self_attn): GemmaSdpaAttention(
            (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (k_proj): Linear(in_features=2048, out_features=256, bias=False)
            (v_proj): Linear(in_features=2048, out_features=256, bias=False)
            (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (rotary_emb): GemmaRotaryEmbedding()
          )
          (mlp): GemmaMLP(
            (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
            (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
            (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
            (act_fn): PytorchGELUTanh()
          )
          (input_layernorm): GemmaRMSNorm()
          (post_attention_layernorm): GemmaRMSNorm()
        )
      )
      (norm): GemmaRMSNorm()
    )
    (lm_head): Linear(in_features=2048, out_features=257216, bias=False)
  )
)

vision_tower (SiglipVisionModel)

Este componente es responsable de procesar la imagen de entrada.

Utiliza SiglipVisionTransformer, que es un tipo de arquitectura de transformador diseñada para tareas de visión.


embeddings (SiglipVisionEmbeddings)

PaliGemma toma como entrada una o más imágenes, que el codificador SigLIP convierte en “tokens blandos”.

Divide la imagen en parches más pequeños, al igual que un modelo de texto procesa las palabras de una oración. Luego, el modelo aprende a capturar las relaciones entre estos parches, y comprende con eficacia el contenido visual de la imagen.


patch_embedding

Utiliza una capa convolucional (Conv2d) con los siguientes parámetros.

  • 3: la entrada tiene 3 canales (para imágenes RGB)

  • 1152: el resultado tiene 1152 canales, que es la dimensión de incorporación de cada parche

  • kernel_size=(14, 14): cada parche es un cuadrado de 14 x 14 píxeles

  • stride=(14, 14): los parches se toman sin superposición (el filtro convolucional se mueve 14 píxeles por vez)

  • padding=’valid’: no se aplica relleno, por lo que el tamaño del resultado será menor que el tamaño de la entrada


position_embedding

Las incorporaciones de posición se agregan a cada incorporación de parche para codificar la información espacial (es decir, dónde se ubicaba cada parche en la imagen original).

Esto se logra mediante una capa de incorporación aprendida (Embedding) que toma como entrada la posición de cada parche (hasta 256 posiciones) y genera un vector de tamaño 1152 (el mismo que la dimensión de incorporación del parche).


encoder (SiglipEncoder)

Las incorporaciones pasan a través de una serie de componentes SiglipEncoderLayer, que constan de redes neuronales de autoatención y prealimentadas. Esto ayuda a que el modelo capture las relaciones entre las diferentes partes de la imagen.


multi_modal_projector (PaliGemmaMultiModalProjector)

Este componente proyecta el resultado de la torre de visión en un espacio multimodal. Esto se logra utilizando una capa lineal simple y permite que las representaciones de la visión y el lenguaje se combinen de manera efectiva.


language_model (GemmaForCausalLM)

Este componente es un modelo de lenguaje basado en el modelo Gemma 2B.

Toma como entrada la representación multimodal del proyector y genera una salida de texto.

Para la entrada de texto, cada punto de control se entrenó con varias longitudes de secuencia. Por ejemplo, paligemma-3b-mix-224 se entrenó con una longitud de secuencia de 256 (texto de entrada + texto de salida tokenizado por el tokenizador de Gemma).

PaliGemma utiliza el tokenizador de Gemma con 256,000 tokens, pero amplía su vocabulario con 1024 entradas que representan coordenadas en el espacio-imagen normalizado (<loc0000>...<loc1023>), y otro con 128 entradas (<seg000>...<seg127>), que son palabras de código utilizadas por un autocodificador variacional con cuantificación vectorial y segmentación por expresión referencia (VQ-VAE) de peso ligero. (256,000 + 1024 + 128 = 257,216)


Ejemplo de segmentación de objetos

Los tokens blandos adicionales codifican la detección de objetos y la segmentación de imágenes. A continuación, se muestra un ejemplo de resultado de paligemma-3b-mix-224. Puedes probarlo desde la demostración en vivo de HuggingFace.

Image of a child and cat on a snowy roof top

Resultado de PaliGemma con la instrucción “segment floor;cat;person;

image of output from the PaliGemma with the prompt “segment floor;cat;person;”

La decodificación de los resultados del modelo no es intuitiva si no conoces las tareas de AA y visión artificial.

Los cuatro tokens de ubicación iniciales representan las coordenadas del cuadro de límite, que van de 0 a 1023. Estas coordenadas son independientes de la relación de aspecto, ya que se supone que la imagen se redimensiona a 1024 x 1024.

Por ejemplo, el resultado muestra la ubicación del gato dentro de las coordenadas (382, 637) y (696, 784). En este sistema de coordenadas, la esquina superior izquierda se indica como (0,0) y la coordenada vertical se enumera antes de la coordenada horizontal.

image showing the output displaying the cat's location within coordinates (382, 637) and (696, 784)

La máscara se codifica con los siguientes 16 tokens de segmentación. Un modelo de red neuronal (VQ-VAE) puede reconstruir máscaras a partir de representaciones cuantificadas (índices de libro de códigos) decodificando esos valores. Puedes explorar el código real aquí.

Por último, puedes obtener este hermoso resultado de PaliGemma.

image showing object segmentation result, where the floor is shaded blue, the child is shaded red, and the cat is shaded yellow

Resumen

En este artículo, aprendiste sobre PaliGemma.

La familia Gemma presenta una oportunidad única para comprender los sistemas modernos de modelos de lenguaje grande, ya que ofrece una colección de modelos de pesos abiertos con arquitecturas centrales similares pero diseñados para diferentes casos de uso. Estos modelos, que Google lanzó para investigadores, desarrolladores y usuarios finales, abarcan diversas funcionalidades y complejidades.

Esperamos que este resumen te permita comprender de forma concisa cómo funciona la familia de modelos Gemma, que se destaca por su versatilidad e idoneidad para una amplia gama de tareas.

El servidor de Discord de la comunidad de desarrolladores de Google es una excelente plataforma para mostrar tus proyectos, establecer conexiones con otros desarrolladores y participar en debates interactivos. Puedes unirte al servidor para explorar estas emocionantes oportunidades.

¡Gracias por leer!


Referencias


Artículos


Ejemplos de código


📋 Serie completa sobre la arquitectura de Gemma