Explicação sobre o Gemma: arquitetura do PaliGemma

SET 05, 2024
Ju-yeong Ji Gemma DevRel
Ravin Kumar Google Data Scientist Language Applications

Na postagem anterior do Explicação sobre o Gemma, vimos a arquitetura do RecurrentGemma. Nesta postagem do blog, você conhecerá a arquitetura do PaliGemma. Vamos lá!


PaliGemma 3B

O PaliGemma é um modelo de visão-linguagem (VLM) leve, inspirado no PaLI-3 e baseado em componentes abertos, como o modelo de visão SigLIP e o modelo de linguagem Gemma. Pali significa Pathway Language e Image Model. Como o nome indica, esse modelo é capaz de receber entradas de imagem e texto e produzir uma resposta de texto, como você pode ver neste guia de ajuste.


Arquitetura do PaliGemma

O PaliGemma adiciona um modelo de visão extra ao modelo BaseGemma que consiste em um codificador de imagens. Esse codificador, juntamente com os tokens de texto, é transmitido para um modelo Gemma 2B especializado. O modelo de visão e o modelo Gemma são treinados em vários cenários, tanto de forma independente quanto juntos, para produzir a arquitetura conjunta final. Para saber todos os detalhes, consulte a Seção 3.2 do documento do Pali-3.

Joint architecture of the Vision Model and Gemma 2B model
PaliGemmaForConditionalGeneration(
  (vision_tower): SiglipVisionModel(
    (vision_model): SiglipVisionTransformer(
      (embeddings): SiglipVisionEmbeddings(
        (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
        (position_embedding): Embedding(256, 1152)
      )
      (encoder): SiglipEncoder(
        (layers): ModuleList(
          (0-26): 27 x SiglipEncoderLayer(
            (self_attn): SiglipAttention(
              (k_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (v_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (q_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (out_proj): Linear(in_features=1152, out_features=1152, bias=True)
            )
            (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
            (mlp): SiglipMLP(
              (activation_fn): PytorchGELUTanh()
              (fc1): Linear(in_features=1152, out_features=4304, bias=True)
              (fc2): Linear(in_features=4304, out_features=1152, bias=True)
            )
            (layer_norm2): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
          )
        )
      )
      (post_layernorm): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
    )
  )
  (multi_modal_projector): PaliGemmaMultiModalProjector(
    (linear): Linear(in_features=1152, out_features=2048, bias=True)
  )
  (language_model): GemmaForCausalLM(
    (model): GemmaModel(
      (embed_tokens): Embedding(257216, 2048, padding_idx=0)
      (layers): ModuleList(
        (0-17): 18 x GemmaDecoderLayer(
          (self_attn): GemmaSdpaAttention(
            (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (k_proj): Linear(in_features=2048, out_features=256, bias=False)
            (v_proj): Linear(in_features=2048, out_features=256, bias=False)
            (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (rotary_emb): GemmaRotaryEmbedding()
          )
          (mlp): GemmaMLP(
            (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
            (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
            (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
            (act_fn): PytorchGELUTanh()
          )
          (input_layernorm): GemmaRMSNorm()
          (post_attention_layernorm): GemmaRMSNorm()
        )
      )
      (norm): GemmaRMSNorm()
    )
    (lm_head): Linear(in_features=2048, out_features=257216, bias=False)
  )
)

vision_tower (SiglipVisionModel)

Esse componente é responsável pelo processamento da imagem de entrada.

Ele usa o SiglipVisionTransformer, que é um tipo de arquitetura de transformador projetada para tarefas de visão.


embeddings (SiglipVisionEmbeddings)

O PaliGemma usa como entrada uma ou mais imagens, que são transformadas em "soft tokens" pelo codificador do SigLIP.

Ele divide a imagem em pedaços menores, semelhante à forma como um modelo de texto processa as palavras em uma frase. O modelo, então, aprende a capturar relacionamentos entre esses pedaços, entendendo efetivamente o conteúdo visual da imagem.


patch_embedding

Ele usa uma camada convolucional (Conv2d) com os parâmetros a seguir.

  • 3: a entrada tem 3 canais (para imagens RGB).

  • 1152: a saída tem 1.152 canais, que é a dimensão de incorporação de cada pedaço.

  • kernel_size=(14, 14): cada pedaço é um quadrado de 14 x 14 pixels.

  • stride=(14, 14): os pedaços são utilizados sem sobreposições (o filtro convolucional move 14 pixels por vez).

  • padding=’valid’: nenhum preenchimento é aplicado, portanto, o tamanho da saída será menor que o da entrada.


position_embedding

Incorporações de posições são adicionadas a cada incorporação de pedaços da imagem para codificar as informações espaciais (ou seja, a localização de cada pedaço na imagem original).

Isso é feito por meio de uma camada de incorporação aprendida (Embedding), que toma como entrada a posição de cada pedaço (até 256 posições) e gera como saída um vetor de tamanho 1152 (o mesmo que a dimensão de incorporação de patches).


encoder (SiglipEncoder)

As incorporações passam por uma série de SiglipEncoderLayers, cada uma consistindo em redes neurais de autoatenção e lineares (feed-forward). Isso ajuda o modelo a capturar relacionamentos entre diferentes partes da imagem.


multi_modal_projector (PaliGemmaMultiModalProjector)

Esse componente projeta a saída da torre de visão em um espaço multimodal. Isso é feito utilizando uma camada linear simples e permite que a visão e as representações de linguagem sejam combinadas de forma eficaz.


language_model (GemmaForCausalLM)

Esse componente é um modelo de linguagem baseado no modelo Gemma 2B.

Ele toma como entrada a representação multimodal do projetor e gera saída de texto.

Para a entrada de texto, cada ponto de verificação foi treinado com vários comprimentos de sequência. Por exemplo, paligemma-3b-mix-224 foi treinado com o comprimento de sequência de 256 (texto de entrada + texto de saída tokenizado pelo tokenizador do Gemma).

O PaliGemma usa o tokenizador do Gemma com 256.000 tokens, mas estende seu vocabulário com 1.024 entradas que representam coordenadas em espaço de imagem normalizado (<loc0000>...<loc1023>) e mais 128 entradas (<seg000>...<seg127>) que são palavras de código utilizadas por um autocodificador variacional quantizado por vetor (VQ-VAE) de segmentação de expressão de referência leve. (256.000 + 1.024 + 128 = 257.216)


Exemplo de segmentação de objetos

Soft tokens adicionais codificam a detecção de objetos e a segmentação de imagens. Veja a seguir um exemplo de saída de paligemma-3b-mix-224. Você pode experimentá-lo na demonstração ao vivo da Hugging Face.

Image of a child and cat on a snowy roof top

Saída do PaliGemma com o prompt "segment floor;cat;person;"

image of output from the PaliGemma with the prompt “segment floor;cat;person;”

A decodificação das saídas do modelo não é intuitiva quando não se tem familiaridade com tarefas de ML e visão computacional.

Os quatro tokens de localização iniciais representam a coordenada da caixa delimitadora, variando de 0 a 1023. Essas coordenadas são independentes da proporção, pois presume-se que a imagem seja redimensionada para 1024 x 1024.

Por exemplo, a saída exibe a localização do gato dentro das coordenadas (382, 637) e (696, 784). Nesse sistema de coordenadas, o canto superior esquerdo é denotado como (0,0), e a coordenada vertical é listada antes da coordenada horizontal.

image showing the output displaying the cat's location within coordinates (382, 637) and (696, 784)

A máscara é codificada com os 16 tokens de segmentação seguintes. Um modelo de rede neural (VQ-VAE) pode reconstruir máscaras a partir de representações quantizadas (índices de codebook) decodificando esses valores. Você pode explorar o código real aqui.

Por fim, você pode obter este lindo resultado na saída do PaliGemma.

image showing object segmentation result, where the floor is shaded blue, the child is shaded red, and the cat is shaded yellow

Resumo

Neste artigo, você conheceu o PaliGemma.

A família Gemma apresenta uma oportunidade única de entender os sistemas modernos de modelos de linguagem grandes ao oferecer uma coleção de modelos de peso abertos com arquiteturas centrais similares, mas projetados para diferentes casos de uso. Esses modelos, lançados pelo Google para pesquisadores, desenvolvedores e usuários finais, abrangem várias funcionalidades e complexidades.

Esperamos que esta visão geral forneça uma compreensão concisa da família de modelos Gemma, destacando sua versatilidade e adequação a uma ampla gama de tarefas.

O servidor Discord da comunidade de desenvolvedores do Google é uma excelente plataforma para demonstrar projetos, estabelecer conexões com outros desenvolvedores e participar de discussões interativas. Considere a participação no servidor para explorar essas empolgantes oportunidades.

Agradecemos a leitura!


Referências


Artigos


Exemplos de código


📋 A série completa sobre a arquitetura do Gemma