Explicação sobre o Gemma: uma visão geral das arquiteturas da família de modelos Gemma

AGO 15, 2024
Ju-yeong Ji Gemma DevRel
Ravin Kumar Google Data Scientist Language Applications

O Gemma é uma família de modelos abertos leves e de última geração, desenvolvidos a partir da mesma pesquisa e tecnologia usadas para criar os modelos Gemini.

Diferentes variações do Gemma são projetadas para diferentes casos de uso e modalidades, como:

  • Modalidade única (entrada de texto, saída de texto).

  • Especialização para casos de uso de codificação.

  • Modalidade múltipla (entrada de texto e imagem, saída de texto).

  • Tamanhos variados para diferentes tipos de hardware, necessidades de inferência e outras restrições.

  • Novas arquiteturas.

Como todos esses modelos compartilham um DNA semelhante, a família Gemma apresenta uma maneira única de aprender sobre as arquiteturas e opções de design disponíveis nos sistemas modernos de LLM. Esperamos que isso contribua para um ecossistema rico de modelos abertos e promova uma compreensão maior do funcionamento dos sistemas de LLM.

Esta série cobrirá:

  • Gemma 1 (2B e 7B) – modelos de texto para texto baseados em transformadores.

  • CodeGemma (2B e 7B) – uma versão refinada do Gemma, otimizada para sugestão e geração de código.

  • Gemma 2 (2B, 9B e 27B) – modelos de texto para texto atualizados, treinados com uma arquitetura mais recente com as versões 2B e 9B treinadas por meio de destilação de modelos maiores.

  • RecurrentGemma (2B e 9B) – um modelo desenvolvido a partir da nova arquitetura Griffin, que utiliza uma combinação de atenção local e recorrências lineares para atingir uma inferência rápida ao gerar sequências longas.

  • PaliGemma (3B) – um modelo de visão-linguagem capaz de absorver texto e imagens e fornecer uma saída de texto.


Como usar este guia

Nesta série, vamos:

  • Agrupar as arquiteturas específicas de vários modelos.

  • Explicar como esses parâmetros afetam as gerações de modelos (por exemplo, número de incorporações, multiconsultas versus multicabeças versus consultas agrupadas).

  • Fornecer exemplos de código dos modelos para análise mais detalhada.

Para fornecer informações sobre o modelo, utilizamos o módulo print do Hugging Face Transformers, como no código simples abaixo.

from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("google/gemma-7b")
print(model)

Você também pode fazer análises detalhadas dentro do modelo com torchinfo ou summary() na API da classe Keras Model.


O que não é este guia

Este guia não é uma introdução à IA. Ele pressupõe conhecimento prático de redes neurais, transformadores e termos associados, como tokens. Se você precisar de uma atualização sobre esses conceitos, veja aqui alguns recursos para começar:

Uma ferramenta prática de aprendizado sobre redes neurais que funciona no navegador.

Uma introdução aos transformadores.


Gemma

O Gemma é um LLM de peso aberto. Ele é fornecido em variantes pré-treinadas, ajustadas por instruções e brutas, em vários tamanhos de parâmetro. Ele é baseado na arquitetura de LLM introduzida pelo Google Research no artigo Attention Is All You Need. Sua principal função é gerar texto tokenword por tokenword, com base em um prompt fornecido por um usuário. Em tarefas como a de tradução, o Gemma pega uma frase de um idioma como entrada e gera a saída da frase equivalente em outro idioma.

Como você verá em breve, o Gemma é um ótimo modelo por si só, mas também é útil em extensões personalizadas, para atender às diferentes necessidades dos usuários.


Arquitetura do Gemma

Primeiro, vejamos o decodificador de transformadores no qual os modelos Gemma se baseiam.

Transformer decoder architecture

Ao contrário da arquitetura original de modelo de transformador codificador-decodificador introduzida no artigo "Attention Is All You Need", o Gemma é um modelo "somente decodificador".

Os principais parâmetros da arquitetura estão resumidos na tabela abaixo.

Core parameters of the architecture

Os modelos são treinados em um comprimento de contexto de 8192 tokens. Isso significa que eles podem processar até aproximadamente 6144 palavras (usando a regra geral de 100 tokens ~= 75 palavras) por vez.

Vale notar que o limite de entrada prático pode variar com base na tarefa e no uso. O motivo é que a geração de texto consome tokens dentro da janela de contexto, reduzindo efetivamente o espaço para novas entradas. Embora o limite técnico de entrada permaneça constante, a saída gerada torna-se parte da entrada subsequente, influenciando as gerações adicionais.


d_model (2B: 2048, 7B: 3072)

d_model representa o tamanho das incorporações (representações vetoriais de palavras ou subpalavras, também conhecidas como tokens) usadas como entrada para o decodificador, além de determinar o tamanho da representação interna, dentro das camadas do decodificador.

d_model x Num heads x Head size
"d_model x número de cabeças x tamanho da cabeça" define o número de parâmetros em self_attn.

Um valor mais alto de d_model significa que o modelo tem mais "espaço" para representar as nuances de diferentes palavras e seus relacionamentos. Isso pode levar a um desempenho melhor, especialmente para tarefas de linguagem complexas. No entanto, o aumento de d_model também torna o modelo maior e mais caro, em termos de computação, para treinar e usar.


Camadas (2B: 18, 7B: 28)

Os transformadores consistem em várias camadas empilhadas. Os modelos mais aprofundados têm mais camadas e, portanto, mais parâmetros, e podem aprender padrões mais intrincados. No entanto, esses parâmetros adicionais significam que eles também são mais propensos a sobreajuste e exigem mais recursos de computação.

Essa capacidade aumentada de representação pode resultar no ruído de aprendizado do modelo ou em padrões de dados de treinamento específicos que não têm a capacidade de generalizar para novos exemplos.

Além disso, os modelos mais aprofundados geralmente exigem mais dados de treinamento para evitar o sobreajuste. Nos casos em que os dados disponíveis são limitados, o modelo pode não ter exemplos suficientes para aprender uma representação generalizável, levando à memorização de dados de treinamento.


Dimensões ocultas de propagação direta (2B: 32768, 7B: 49152)

Cada camada de transformador inclui uma rede de propagação direta depois do mecanismo de atenção. Essa rede tem sua própria dimensionalidade, muitas vezes maior do que o tamanho de d_model, para aumentar a potência expressiva do modelo.

Isso é implementado como um perceptron multicamadas (MLP, na sigla em inglês), um tipo de rede neural, para transformar ainda mais as incorporações e extrair padrões mais intrincados.

multi-layer perceptron (MLP) neural network achitecture

No Gemma, a não linearidade da ReLU padrão é substituída pela função de ativação GeGLU, uma variação de GLU (unidade linear fechada). A GeGLU divide a ativação em duas partes: uma parte sigmoide e uma projeção linear. A saída da parte sigmoide é multiplicada, em termos de elementos, com a projeção linear, resultando em uma função de ativação não linear.

GeGLU activation function example

Número de cabeças (2B: 8, 7B: 16)

Cada camada de transformador contém vários mecanismos de atenção trabalhando em paralelo. Essas "cabeças" permitem que o modelo se concentre em diferentes aspectos da sequência de entrada simultaneamente. Aumentar o número de cabeças pode aumentar a capacidade do modelo de capturar diversos relacionamentos nos dados.


Número de cabeças KV (2B: 1, 7B: 16)

O modelo 7B usa a atenção multicabeças (MHA, na sigla em inglês), enquanto o modelo 2B usa a atenção multiconsultas (MQA, na sigla em inglês). A MQA compartilha as mesmas projeções de chave e valor, o que significa que cada cabeça se concentra na mesma representação subjacente, mas com projeções de consulta diferentes.

A MHA original oferece um aprendizado de representações mais rico, mas tem custos de computação mais altos. A MQA fornece uma alternativa eficiente que se mostrou eficaz.


Tamanho da cabeça (2B: 256, 7B: 256)

Refere-se à dimensionalidade de cada cabeça de atenção dentro do mecanismo de atenção multicabeças. É calculado dividindo-se a dimensão da incorporação pelo número de cabeças. Por exemplo, se a dimensão da incorporação for 2048 e houver 8 cabeças, cada cabeça terá um tamanho de 256.


Tamanho de vocabulário (2B: 256128, 7B: 256128)

Define o número de tokens exclusivos (palavras, subpalavras ou caracteres) que o modelo entende e é capaz de processar. O tokenizador do Gemma é baseado em SentencePiece. O tamanho de vocabulário é predeterminado antes do treinamento. O SentencePiece, então, aprende a segmentação ideal de subpalavras com base no tamanho de vocabulário escolhido e nos dados de treinamento. O grande vocabulário de 256 mil do Gemma permite que ele lide com diversas entradas de texto e potencialmente melhore o desempenho em várias tarefas, por exemplo, ao processar entradas de texto multilíngues.


Gemma 7B

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 3072, padding_idx=0)
    (layers): ModuleList(
      (0-27): 28 x GemmaDecoderLayer(
        (self_attn): GemmaSdpaAttention(
          (q_proj): Linear(in_features=3072, out_features=4096, bias=False)
          (k_proj): Linear(in_features=3072, out_features=4096, bias=False)
          (v_proj): Linear(in_features=3072, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=3072, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=3072, out_features=24576, bias=False)
          (up_proj): Linear(in_features=3072, out_features=24576, bias=False)
          (down_proj): Linear(in_features=24576, out_features=3072, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm()
        (post_attention_layernorm): GemmaRMSNorm()
      )
    )
    (norm): GemmaRMSNorm()
  )
  (lm_head): Linear(in_features=3072, out_features=256000, bias=False)
)
Gemma 7B architecture

embed_tokens (camada de incorporação)

Essa camada converte os tokens de entrada (palavras ou subpalavras) em representações numéricas densas (incorporações) que o modelo é capaz de processar. Ela tem um tamanho de vocabulário de 256.000 e cria incorporações de dimensão 3072.


layers

Esse é o coração do modelo, e consiste em 28 blocos GemmaDecoderLayer empilhados. Cada uma dessas camadas refina as incorporações de token para capturar relacionamentos complexos entre as palavras e seu contexto.


self_attn

No mecanismo de autoatenção, o modelo atribui pesos diferentes às palavras na entrada ao criar a próxima palavra. Com base em um mecanismo de atenção dot-product ajustado, o modelo emprega projeções lineares (q_proj, k_proj, v_proj e o_proj) para gerar representações de consulta, chave, valor e saída.

Todos os valores de out_features são os mesmos 4096 para q_proj, k_proj e v_proj, já que esse modelo usa a atenção multicabeças (MHA). Eles têm 16 cabeças com um tamanho de 256 em paralelo, totalizando 4096 (256 x 16).

Além disso, o modelo utiliza as informações posicionais de forma mais eficaz ao empregar o rotary_emb (GemmaRotaryEmbedding) para a codificação posicional (também conhecida como RoPE).

Por fim, a camada o_proj projeta a saída de atenção novamente na dimensão original (3072).


Observe que o modelo Gemma 2B usa a atenção multiconsultas (MQA).

Multi-Query Attention (MQA) architecture used in Gemma 2B model

k_proj e v_proj compartilham a mesma cabeça com um tamanho de 256, resultando em out_features de 256. Em comparação, q_proj e o_proj têm 8 cabeças (256 x 8 = 2048) em paralelo.

(self_attn): GemmaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )


mlp

Utiliza gate_proj e up_proj para um mecanismo de fechamento, seguido por down_proj para reduzir a dimensão novamente para 3072.


input_layernorm, post_attention_layernorm e norm

Essas camadas de normalização estabilizam o treinamento e elevam a capacidade do modelo de aprender com eficiência.


lm_head

Essa camada final mapeia as incorporações refinadas (3072) novamente para uma distribuição de probabilidade para o próximo token sobre o espaço de vocabulário (256000).


CodeGemma (2B e 7B)

Os modelos CodeGemma são modelos Gemma ajustados e otimizados para sugestão de código e assistência de chat de codificação. Os modelos CodeGemma são treinados em mais de 500 bilhões de tokens principalmente de código. Além disso, o CodeGemma adiciona o recurso fill-in-the-middle (FIM, na sigla em inglês), permitindo sugestões entre duas partes de texto existente.

O CodeGemma destaca a ajustabilidade dos pontos de verificação do Gemma. Por meio de treinamento adicional, os modelos se especializam em uma determinada tarefa, aprendendo uma sugestão mais complexa do que a sugestão pura de sufixo.


Uso do CodeGemma

Você pode usar 4 tokens definidos pelo usuário: 3 para FIM e um token "<|file_separator|>" para suporte ao contexto multiarquivos.

BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"

Imagine que você esteja tentando obter sugestões de código como na tela abaixo.

Code snippet example - CodeGemma (2B and 7B)

E que o prompt de entrada deva ficar assim:

<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n    sys.exit(0)<|fim_middle|>

O modelo fornecerá "sys" como a sugestão de código.

Saiba mais sobre o CodeGemma em CodeGemma / Quickstart.


O que vem a seguir?

Este artigo discutiu a arquitetura do Gemma.

Em nossa próxima série de postagens, você verá o modelo mais recente, o Gemma 2. Com melhorias substanciais nas medidas de segurança, esse modelo supera seu antecessor em termos de desempenho e eficiência durante a inferência.

Não deixe de nos acompanhar! Agradecemos a leitura.



Referências


Artigos

Exemplos de código

Gemma

CodeGemma


📋 The complete Gemma architecture series