Explicação sobre o Gemma: novidades do Gemma 2

AGO 22, 2024
Ju-yeong Ji Gemma DevRel
Ravin Kumar Google Data Scientist Language Applications

Na postagem anterior da série Explicação sobre o Gemma, discutimos a arquitetura do Gemma. Nesta postagem, você conhecerá o modelo mais recente, o Gemma 2. Vamos começar!


Gemma 2

Recentemente, lançamos o Gemma 2, nosso inovador pacote de modelos abertos, estabelecendo um novo padrão de desempenho e acessibilidade. Disponível nos tamanhos de parâmetro 2B, 9B e 27B, o Gemma 2 já deixou a sua marca. Nosso modelo 27B subiu rapidamente no placar do LMSYS Chatbot Arena, superando até mesmo os modelos populares com mais do que o dobro de seu tamanho em conversas envolventes do mundo real, estabelecendo-se como um dos modelos abertos mais úteis e mais bem avaliados. Enquanto isso, o modelo Gemma 2 2B demonstra sua excepcional proeza de IA de conversação, superando todos os modelos GPT-3.5 no Chatbot Arena em um tamanho executável em dispositivos de borda.

Os desenvolvedores podem acessar recursos robustos de ajuste com o Gemma 2 em várias plataformas e ferramentas. O ajuste do Gemma 2 é simplificado com soluções baseadas na nuvem, como o Google Cloud, e ferramentas da comunidade, como o Axolotl. A integração total com parceiros como Hugging Face e NVIDIA TensorRT-LLM, bem como nossos JAX e Keras, permite a otimização do desempenho e a implantação eficiente em diversas configurações de hardware.

Estes são os principais parâmetros dos novos modelos:

Core parameters of new Gemma models, August 2024

Principais diferenças

O Gemma 2 compartilha uma base arquitetônica semelhante à dos modelos Gemma originais, incluindo a implementação de Rotary Positioning Embeddings (RoPE) e a não linearidade aproximada de GeGLU. No entanto, ele introduz inovações arquitetônicas que o diferenciam de seus antecessores.


Alternância ente atenção local e global

Em vez de considerar todas as palavras de um texto de uma só vez, ele às vezes se concentra em uma pequena janela de palavras (atenção local) e às vezes considera todas as palavras (atenção global). Essa combinação ajuda o modelo a entender o contexto imediato e o significado geral do texto de forma eficiente.


Soft-capping de logits

Imagine que você esteja treinando um modelo para prever a próxima palavra de uma frase. Às vezes, o modelo pode estar excessivamente confiante quanto a uma palavra específica, mesmo que ela não seja a melhor escolha. O soft-capping de logits evita isso ao limitar a confiança que o modelo pode ter em suas previsões, o que leva a um desempenho geral melhor.


RMSNorm para pré e pós-normalização

Pense nisso como uma maneira de evitar que os cálculos do modelo se tornem grandes ou pequenos demais durante o treinamento. Assim como podemos ajustar o volume de um alto-falante para evitar distorções, o RMSNorm garante que as informações que fluem pelo modelo permaneçam dentro de um intervalo razoável, levando a um treinamento mais estável e eficaz.


Atenção de consulta agrupada (GQA)

Essa técnica ajuda o modelo a processar informações com mais eficiência, especialmente ao lidar com grandes quantidades de texto. Ela melhora a atenção multicabeças (MHA, na sigla em inglês) tradicional ao agrupar consultas, permitindo um processamento mais rápido, principalmente em modelos grandes. É como dividir uma tarefa grande em partes menores e mais gerenciáveis, permitindo que o modelo entenda as relações entre as palavras mais rapidamente sem sacrificar a acurácia.


Gemma 27B

Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 4608, padding_idx=0)
    (layers): ModuleList(
      (0-45): 46 x Gemma2DecoderLayer(
        (self_attn): Gemma2SdpaAttention(
          (q_proj): Linear(in_features=4608, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4608, out_features=2048, bias=False)
          (v_proj): Linear(in_features=4608, out_features=2048, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4608, bias=False)
          (rotary_emb): Gemma2RotaryEmbedding()
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=4608, out_features=36864, bias=False)
          (up_proj): Linear(in_features=4608, out_features=36864, bias=False)
          (down_proj): Linear(in_features=36864, out_features=4608, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm()
        (post_attention_layernorm): Gemma2RMSNorm()
        (pre_feedforward_layernorm): Gemma2RMSNorm()
        (post_feedforward_layernorm): Gemma2RMSNorm()
      )
    )
    (norm): Gemma2RMSNorm()
  )
  (lm_head): Linear(in_features=4608, out_features=256000, bias=False)
)
Gemma 27B architecture

self_attn

No mecanismo de autoatenção, o Gemma 2 usa a atenção de consulta agrupada (GQA, na sigla em inglês).

k_proj e v_proj compartilham a mesma cabeça com um tamanho de 128 e 16 cabeças (128 x 16 = 2048). Em comparação, q_proj e o_proj têm 32 cabeças (128 x 32 = 4096) em paralelo.


Observe que o modelo Gemma 9B usa a mesma GQA, mas um número diferente de cabeças (8 para k_proj e v_proj, 16 para q_proj e o_proj) e o tamanho da cabeça (256).

(self_attn): Gemma2SdpaAttention(
          (q_proj): Linear(in_features=3584, out_features=4096, bias=False)
          (k_proj): Linear(in_features=3584, out_features=2048, bias=False)
          (v_proj): Linear(in_features=3584, out_features=2048, bias=False)
          (o_proj): Linear(in_features=4096, out_features=3584, bias=False)
          (rotary_emb): Gemma2RotaryEmbedding()
        )

O modelo 2B usa 4 para k_proj e v_proj, 8 para q_proj e o_proj e o tamanho da cabeça (256).


pre_feedforward_layernorm e post_feedforward_layernorm

Outra distinção importante é a inclusão de RMSNorm adicional no Gemma 2, o que aumenta a estabilidade do processo de treinamento.


Principais conclusões

Nosso relatório técnico fornece mais detalhes, mas este é um breve resumo das principais conclusões sobre o Gemma 2:


Destilação versus treinamento do zero:

Treinamos os modelos 2B e 9B com destilação de conhecimento a partir do modelo maior (27B).

A destilação de conhecimento a partir de um modelo maior, mesmo com um número igual de tokens de treinamento, leva a melhorias significativas de desempenho.


Atenção de consulta agrupada versus atenção multicabeças:

A substituição da MHA pela GQA resulta em um desempenho comparável, oferecendo eficiência de parâmetros e tempos de inferência menores, o que torna a GQA a escolha preferencial.


Profundidade versus largura do modelo:

Um modelo mais aprofundado apresenta um desempenho ligeiramente superior em comparação com um modelo mais largo com o mesmo número de parâmetros.


O que vem a seguir?

Neste artigo, você conheceu o Gemma 2, a próxima geração de modelos Gemma.

Em nossa próxima série de postagens, você verá o RecurrentGemma, que é um modelo aberto baseado no Griffin.

Se você quiser se aprofundar no fascinante mundo da IA e obter insights dos especialistas que estão moldando o desenvolvimento dessa tecnologia, acesse goo.gle/ai-podcast ou pesquise a série "People of AI Podcast" em qualquer plataforma de podcast.

Não deixe de nos acompanhar! Agradecemos a leitura.



Referências


Artigos


Exemplos de código


📋 A série completa sobre a arquitetura do Gemma