Apresentamos os modelos Gemma no Keras

FEV 21, 2024
Martin Görner Product Manager Keras

A equipe do Keras tem o prazer de anunciar que o Gemma, uma família de modelos abertos leves e de última geração, construídos usando a mesma pesquisa e tecnologia que usamos para criar os modelos Gemini, agora está disponível na coleção KerasNLP. Graças ao Keras 3, o Gemma é executado com JAX, PyTorch e TensorFlow. Com esta versão, o Keras também introduz vários novos recursos projetados especificamente para grandes modelos de linguagem: uma nova API LoRA (Low Rank Adaptation) e recursos de treinamento paralelo de modelos em grande escala.

Se você quiser passar diretamente para os exemplos de código, acesse:

Criar um agora

Os modelos Gemma têm tamanhos de parâmetros 2B e 7B portáveis e avanços significativos em relação a modelos abertos similares e até mesmo a alguns maiores. Por exemplo:

  • O Gemma 7B atinge uma nova pontuação de 64,3% em respostas corretas, a melhor da categoria, no comparativo de mercado em compreensão de linguagem MMLU (contra 62,5% para o Mistral-7B e 54,8% para o Llama2-13B).
  • O Gemma adiciona 11 pontos percentuais à pontuação do comparativo de mercado do GSM8K para problemas matemáticos no nível do ensino fundamental (46,4% para o Gemma 7B contra 35,4% para o Mistral-7B e 28,7% para o Llama2-13B).
  • E mais 6,1 pontos percentuais de respostas corretas no HumanEval, um desafio de codificação (32,3% para o Gemma 7B contra 26,2% para o Mistral 7B e 18,3% para o Llama2 13B).

Os modelos Gemma são oferecidos com uma API KerasNLP popular e uma implementação do Keras superlegível. Você pode instanciar o modelo com uma única linha de código:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")

E é possível executá-lo diretamente em um prompt de texto. Sim, a tokenização é integrada, embora você possa dividi-la facilmente, se necessário. Leia o guia do KerasNLP (em inglês) para saber como.

gemma_lm.generate("Keras is a", max_length=32)
> "Keras is a popular deep learning framework for neural networks..."

Experimente aqui: Primeiros passos com os modelos Gemma (em inglês).

Graças ao Keras 3, você pode escolher o back-end para executar o modelo. Veja como fazer a mudança:

os.environ["KERAS_BACKEND"] = "jax" # Or "tensorflow" or "torch".
import keras # import keras after having selected the backend

O Keras 3 tem vários recursos novos específicos para grandes modelos de linguagem. O principal deles é uma nova API LoRA (Low Rank Adaptation) para o ajuste eficiente de parâmetros. Veja como ativá-la:

gemma_lm.backbone.enable_lora(rank=4)
# Note: rank=4 replaces the weights matrix of relevant layers with the
# product AxB of two matrices of rank 4, which reduces the number of
# trainable parameters.

Essa linha única reduz o número de parâmetros treináveis de 2,5 bilhões para 1,3 milhão!

Experimente aqui: Ajuste dos modelos Gemma com o LoRA (em inglês).

Como ajustar os modelos Gemma em várias GPUs/TPUs

O Keras 3 também dá suporte ao treinamento de modelos em larga escala, e o Gemma é o modelo perfeito para experimentá-lo. A nova API de distribuição do Keras oferece opções de treinamento paralelo de dados e de modelos distribuído. O objetivo é que a nova API seja compatível com vários back-ends, mas, por enquanto, ela é implementada apenas para o back-end JAX, devido à sua escalonabilidade comprovada (os modelos Gemma foram treinados com o JAX).

Para ajustar o Gemma 7B maior, uma configuração distribuída é útil. Por exemplo, uma TPUv3 com 8 núcleos de TPU, que você pode obter gratuitamente no Kaggle, ou uma máquina de 8 GPUs do Google Cloud. Veja como configurar o modelo para treinamento distribuído usando o paralelismo de modelos:

device_mesh = keras.distribution.DeviceMesh(
   (1, 8), # Mesh topology
   ["batch", "model"], # named mesh axes
   devices=keras.distribution.list_devices() # actual accelerators
)
 
 
# Model config
layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = (None, "model")
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
   None, "model", None)
layout_map["decoder_block.*attention_output.*kernel"] = (
   None, None, "model")
layout_map["decoder_block.*ffw_gating.*kernel"] = ("model", None)
layout_map["decoder_block.*ffw_linear.*kernel"] = (None, "model")
 
 
# Set the model config and load the model
model_parallel = keras.distribution.ModelParallel(
   device_mesh, layout_map, batch_dim_name="batch")
keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
# Ready: you can now train with model.fit() or generate text with generate()

O que esse snippet de código faz é configurar os 8 aceleradores em uma matriz 1 x 8, na qual as duas dimensões são chamadas de "batch" e "model". Os pesos dos modelos são fragmentados na dimensão "model", dividida aqui entre os 8 aceleradores, enquanto os lotes de dados não são particionados, uma vez que a dimensão "batch" é 1.

Experimente aqui: Ajuste dos modelos Gemma em várias GPUs/TPUs (em inglês).

O que vem por aí

Em breve, publicaremos um guia mostrando como particionar corretamente um modelo Transformer e escrever as 6 linhas de configuração de particionamento acima. Isso não é muito longo, mas não caberia nesta postagem.

Você deve ter notado que os particionamentos de camadas são definidos por meio de regexes nos nomes das camadas. Você pode verificar os nomes das camadas com este snippet de código. Nós o executamos para construir o LayoutMap acima.

# This is for the first Transformer block only,
# but they all have the same structure
tlayer = gemma_lm.backbone.get_layer('decoder_block_0')
for variable in tlayer.weights:
 print(f'{variable.path:<58}  {str(variable.shape):<16}')

O paralelismo completo de modelos GSPMD funciona aqui com apenas alguns hints de particionamento porque o Keras transmite essas configurações para o poderoso compilador XLA, que determina todos os outros detalhes da computação distribuída.

Esperamos que você goste de experimentar os modelos Gemma. Aqui, você tem também um tutorial de ajuste de instruções (em inglês) que pode ser útil. E, a propósito, se você quiser compartilhar seus pesos ajustados com a comunidade, o Kaggle Model Hub agora dá suporte a uploads de pesos ajustados pelo usuário. Acesse a página de modelos Gemma no Kaggle e veja o que outras pessoas já criaram!