Explicação sobre o Gemma: arquitetura do RecurrentGemma

AGO 29, 2024
Ju-yeong Ji Gemma DevRel
Ravin Kumar Google Data Scientist Language Applications

Na postagem anterior da série Explicação sobre o Gemma, discutimos a arquitetura do Gemma 2 mais recente. Nesta postagem, você conhecerá a arquitetura do RecurrentGemma. Vamos começar!


RecurrentGemma 2B, 9B

O RecurrentGemma baseia-se no Griffin, um modelo híbrido que combina recorrências lineares fechadas com atenção de janela deslizante local. Essa mudança melhora a computação e a memória e é mais adequada para prompts de contexto longos.

Griffin hybrid model architecture

No entanto, ele tem a desvantagem do desempenho reduzido nas buscas mais difíceis devido ao estado de tamanho fixo da arquitetura do Griffin. Embora seja possível fornecer todo o texto de um livro como entrada, essa abordagem pode não ser a ideal. As redes neurais recorrentes (RNNs, na sigla em inglês) podem ter dificuldade para aprender dependências de longo alcance em sequências excessivamente longas, e o modelo tem uma janela de contexto limitada. Isso significa que ele só consegue considerar efetivamente um número determinado de tokens anteriores ao fazer previsões.

Além disso, os modelos recorrentes ainda não receberam tanta atenção em termos de otimizações de tempo de inferência em comparação com seus transformadores equivalentes. E há menos pesquisas e apoio da comunidade disponíveis em comparação com a arquitetura de transformadores já estabelecida.

Portanto, esse modelo será extremamente valioso em cenários nos quais haja uma preocupação com o esgotamento da janela de contexto do LLM. Ao priorizar as informações mais recentes e descartar estrategicamente dados mais antigos, o RecurrentGemma garante o alto desempenho do LLM à medida que o contexto se expande.

Abaixo, temos o diagrama da arquitetura do modelo RecurrentGemma 2B.

Recurrent Gemma 2B model architecture

O Griffin segue o mesmo padrão residual e o mesmo bloco de MLP que outra linha de base de transformador. No entanto, ao contrário da linha de base do transformador de MQA e do modelo Hawk, o Griffin usa uma combinação de blocos recorrentes e de MQA.

Layered structure of recurrent and MQA blocks

O Griffin usa uma estrutura em camadas alternando dois blocos residuais com um bloco recorrente, seguido por um bloco residual que incorpora o bloco de atenção de MQA local.

Os principais parâmetros da arquitetura estão resumidos na tabela abaixo.

Core parameters of the architecture of 2B and 9B models

Parâmetros de não incorporação e parâmetros de incorporação

Os parâmetros de não incorporação são distribuídos pelas camadas ocultas do modelo, em componentes como mecanismos de atenção e redes de propagação direta.

Observação: a nomenclatura "2B" do modelo vem desse parâmetro.

Os parâmetros de incorporação geralmente são encontrados na camada dedicada, chamada camada de incorporação. Ela é responsável por mapear tokens distintos (como palavras ou caracteres) em representações vetoriais contínuas (incorporações).

Observação: 0.7B pode ser calculado como 256 k (tamanho do vocabulário) x 2560 (largura do modelo)


Largura do modelo e largura da RNN

Largura do modelo refere-se ao tamanho das camadas ocultas no modelo, determinando a capacidade do modelo de representar padrões complexos, assim como os modelos Gemma de base.

A largura da rede neural recorrente (RNN, na sigla em inglês) é o tamanho do estado oculto mantido pela Real-Gated Linear Recurrent Unit (RG-LRU). Ao contrário dos transformadores tradicionais, o bloco recorrente mantém um estado interno de tamanho fixo, independentemente do tamanho da entrada. Isso permite que o RecurrentGemma processe sequências mais longas com menos memória, tornando-o mais eficiente em tarefas como a geração de artigos ou códigos longos.


Fator de expansão de MLP

É igual às dimensões ocultas de propagação direta no modelo Gemma de base. Para simplificar, aplicamos um fator de expansão de 3 no modelo RecurrentGemma, resultando em uma dimensão de MLP de 7680 (calculada como 2560 x 3).


Tamanho da janela de atenção local

O estado mantido pelo RecurrentGemma tem um tamanho finito e não se expande com sequências mais longas do que a janela de atenção local de tokens de 2k. Isso significa que, embora o tamanho máximo de amostras geradas de forma autorregressiva pelo Gemma seja limitado pela capacidade de memória do sistema host, o RecurrentGemma pode gerar sequências de tamanho arbitrário, superando essa restrição.

RecurrentGemmaForCausalLM(
  (model): RecurrentGemmaModel(
    (embed_tokens): Embedding(256000, 2560, padding_idx=0)
    (layers): ModuleList(
      (0-1): 2 x RecurrentGemmaDecoderLayer(
        (temporal_pre_norm): RecurrentGemmaRMSNorm()
        (temporal_block): RecurrentGemmaRecurrentBlock(
          (linear_y): Linear(in_features=2560, out_features=2560, bias=True)
          (linear_x): Linear(in_features=2560, out_features=2560, bias=True)
          (linear_out): Linear(in_features=2560, out_features=2560, bias=True)
          (conv_1d): Conv1d(2560, 2560, kernel_size=(4,), stride=(1,), padding=(3,), groups=2560)
          (rg_lru): RecurrentGemmaRglru()
          (act_fn): PytorchGELUTanh()
        )
        (channel_pre_norm): RecurrentGemmaRMSNorm()
        (mlp_block): RecurrentGemmaMlp(
          (gate_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (up_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (down_proj): Linear(in_features=7680, out_features=2560, bias=True)
          (act_fn): PytorchGELUTanh()
        )
      )
      (2): RecurrentGemmaDecoderLayer(
        (temporal_pre_norm): RecurrentGemmaRMSNorm()
        (temporal_block): RecurrentGemmaSdpaAttention(
          (q_proj): Linear(in_features=2560, out_features=2560, bias=False)
          (k_proj): Linear(in_features=2560, out_features=256, bias=False)
          (v_proj): Linear(in_features=2560, out_features=256, bias=False)
          (o_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (rotary_emb): RecurrentGemmaRotaryEmbedding()
        )
        (channel_pre_norm): RecurrentGemmaRMSNorm()
        (mlp_block): RecurrentGemmaMlp(
          (gate_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (up_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (down_proj): Linear(in_features=7680, out_features=2560, bias=True)
          (act_fn): PytorchGELUTanh()
        )
      )
 
      :
 
      (23): RecurrentGemmaDecoderLayer(
        (temporal_pre_norm): RecurrentGemmaRMSNorm()
        (temporal_block): RecurrentGemmaSdpaAttention(
          (q_proj): Linear(in_features=2560, out_features=2560, bias=False)
          (k_proj): Linear(in_features=2560, out_features=256, bias=False)
          (v_proj): Linear(in_features=2560, out_features=256, bias=False)
          (o_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (rotary_emb): RecurrentGemmaRotaryEmbedding()
        )
        (channel_pre_norm): RecurrentGemmaRMSNorm()
        (mlp_block): RecurrentGemmaMlp(
          (gate_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (up_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (down_proj): Linear(in_features=7680, out_features=2560, bias=True)
          (act_fn): PytorchGELUTanh()
        )
      )
      (24-25): 2 x RecurrentGemmaDecoderLayer(
        (temporal_pre_norm): RecurrentGemmaRMSNorm()
        (temporal_block): RecurrentGemmaRecurrentBlock(
          (linear_y): Linear(in_features=2560, out_features=2560, bias=True)
          (linear_x): Linear(in_features=2560, out_features=2560, bias=True)
          (linear_out): Linear(in_features=2560, out_features=2560, bias=True)
          (conv_1d): Conv1d(2560, 2560, kernel_size=(4,), stride=(1,), padding=(3,), groups=2560)
          (rg_lru): RecurrentGemmaRglru()
          (act_fn): PytorchGELUTanh()
        )
        (channel_pre_norm): RecurrentGemmaRMSNorm()
        (mlp_block): RecurrentGemmaMlp(
          (gate_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (up_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (down_proj): Linear(in_features=7680, out_features=2560, bias=True)
          (act_fn): PytorchGELUTanh()
        )
      )
    )
    (final_norm): RecurrentGemmaRMSNorm()
  )
  (lm_head): Linear(in_features=2560, out_features=256000, bias=False)
)

embed_tokens (camada de incorporação)

Recebe o texto de entrada como uma sequência de tokens e mapeia cada token para uma representação vetorial contínua de tamanho 2560. Tem um tamanho de vocabulário de 256.000, que é igual ao dos modelos Gemma de base.


layers

Há 26 camadas de decodificadores no total, agrupadas em padrões de repetição.

O modelo começa com dois blocos residuais com um bloco recorrente (0-1). Essa sequência é, então, seguida por um bloco residual (2) e uma série de blocos contínuos que se alternam até o final da camada (25).

Recurrent block architecture

Bloco residual com um bloco recorrente

No bloco recorrente (bloco de combinação temporal), o modelo recebe a entrada de dimensão (largura do modelo) 2560 e aplica duas camadas lineares com dimensão de saída (largura da RNN) 2560 em paralelo, criando duas ramificações.

Na primeira ramificação (lado direito), ele aplica uma pequena camada Conv1D separável com uma dimensão de filtro temporal de 4. E a camada da RG-LRU vem em seguida.

Na segunda ramificação (lado esquerdo), ele aplica uma não linearidade de GeLU.

Depois, ele mescla as ramificações por meio da multiplicação de elementos e aplica uma camada linear final com dimensão de saída (largura do modelo) 2560.

RecurrentGemma-Residual-block

Depois de aplicar o RMSNorm, vem o bloco de MLP.


Bloco residual com um MQA local

Depois de ter dois blocos residuais com um bloco recorrente (0-1), segue um bloco residual com uma MQA local (2). Uma das principais desvantagens de usar a atenção global é que sua complexidade computacional aumenta quadraticamente no tamanho da sequência. Para resolver isso, o RecurrentGemma usa uma atenção de janela deslizante local. Isso permite que cada posição atenda apenas a um número fixo de tokens no passado.

No bloco de MQA local (bloco de combinação temporal), o modelo recebe a entrada de dimensão (largura do modelo) 2560. Ele usa projeções lineares (q_proj, k_proj, v_proj, o_proj) para criar representações de consulta, chave, valor e saída. Observe que out_features para k_proj e v_proj é 256, pois eles compartilham a mesma cabeça com um tamanho de 256, enquanto q_proj e o_proj têm 10 cabeças (256 x 10 = 2560) em paralelo.

Ele incorpora rotary_emb (RecurrentGemmaRotaryEmbedding) para Rotary Positional Embeddings (RoPE), assim como os modelos Gemma de base.

A aplicação de RMSNorm e do bloco de MLP é igual à do bloco residual anterior.


O que vem a seguir?

Neste artigo, você conheceu o RecurrentGemma.

Na próxima postagem, você conhecerá o PaliGemma, um modelo de visão-linguagem (VLM, na sigla em inglês) leve e aberto.

Não deixe de nos acompanhar! Agradecemos a leitura.


Referências

Artigos


Exemplos de código


📋 A série completa sobre a arquitetura do Gemma