API AI Edge Torch Generative para LLMs personalizados no dispositivo

MAI 29, 2024
Cormac Brick Principal Engineer
Haoliang Zhang Software Engineer

É um grande prazer para nós permitir que os desenvolvedores forneçam novos modelos de IA generativa no dispositivo para dispositivos de borda. Para atender a essa necessidade, anunciamos a API AI Edge Torch Generative, com a qual os desenvolvedores podem criar LLMs de alto desempenho no PyTorch para implantação usando o ambiente de execução do TensorFlow Lite (TFLite). Esta é a segunda de uma série de postagens que abordam os lançamentos para desenvolvedores do Google AI Edge. A primeira postagem da série introduziu o Google AI Edge Torch, que ativa a inferência de alto desempenho de modelos do PyTorch em dispositivos móveis usando o ambiente de execução TFLite.

A API AI Edge Torch Generative permite que os desenvolvedores forneçam novos recursos avançados no dispositivo, tais como resumo, geração de conteúdo e muito mais. Já permitimos que os desenvolvedores forneçam alguns dos LLMs mais populares para dispositivos usando a API MediaPipe LLM Inference. Agora, temos o grande prazer de permitir que os desenvolvedores forneçam qualquer modelo com suporte no dispositivo com um ótimo desempenho. A versão inicial da API AI Edge Torch Generative oferece o seguinte:

  • API de criação fácil de usar para suporte a transformers personalizados

  • Ótimo desempenho em CPU, com suporte a GPU e NPU em breve

  • Totalmente compatível com fluxos de implantação do TFLite existentes, incluindo quantização e ambiente de execução

  • Funciona com modelos como TinyLlama, Phi-2 e Gemma 2B

  • Compatível com as interfaces de tempo de execução do TFLite e do Mediapipe LLM com suporte para Android, iOS e Web

Nesta postagem do blog, nos aprofundaremos no desempenho, na portabilidade, na experiência de criação do desenvolvedor, no pipeline de inferência completo e no conjunto de ferramentas de depuração. Outras documentações e exemplos estão disponíveis aqui.


Desempenho

Como parte de nosso trabalho para fazer com que alguns dos LLMs mais populares funcionem perfeitamente por meio da API MediaPipe LLMInference, nossa equipe criou vários transformers escritos de forma totalmente manual e com desempenho de última geração no dispositivo (blog da API MediaPipe LLM Inference). Alguns temas emergiram desse trabalho: como representar a atenção de forma eficaz, o uso da quantização e a importância de uma boa representação da Cache de KV. A API generativa torna tudo isso fácil de expressar (como veremos na próxima seção) e ainda atinge um desempenho que é 90% mais alto do que o de nossas versões escritas manualmente, com uma velocidade de desenvolvimento muito mais alta.

A tabela a seguir mostra os principais comparativos de mercado em três exemplos de modelo:

On device performance benchmarks across TinyLlama, Gemma 2B and Phi-2 models for Samsung S23 and Pixel 8 Pro

Os comparativos de mercado foram feitos em núcleos grandes, com quatro threads de CPU, e são as implementações de CPU mais rápidas desses modelos de que temos conhecimento atualmente nos dispositivos listados.


Experiência de criação

A biblioteca de criação de núcleos fornece os elementos essenciais de base para modelos de transformer comuns (estilo somente codificador, somente decodificador ou codificador-decodificador etc.). Ela permite que você crie um modelo a partir do zero ou recrie um modelo existente para melhorar o desempenho. Recomendamos a recriação para a maioria dos usuários, pois ela não requer etapas de treinamento/ajuste. Os principais benefícios de criação da API generativa incluem:

  • Um conjunto de elementos essenciais de transformers de núcleo otimizados para conversibilidade, desempenho e portabilidade de plataforma que são fáceis de combinar com ops regulares do PyTorch.

  • Um mecanismo leve de remapeamento.

  • APIs de quantização intuitivas.

  • Exportação de várias assinaturas com preenchimento, decodificação ou assinaturas personalizadas e funciona perfeitamente com tarefas de MP prontas/APIs LLMInference.

Como exemplo, demonstramos aqui como recriar a funcionalidade principal do TinyLLama(1.1B) com cerca de 50 linhas de Python com a nova API generativa.

Etapa 1: Defina a estrutura do modelo

import torch
import torch.nn as nn
 
from ai_edge_torch.generative.layers.attention import TransformerBlock
import ai_edge_torch.generative.layers.attention_utils as attn_utils
import ai_edge_torch.generative.layers.builder as builder
import ai_edge_torch.generative.layers.model_config as cfg
 
 
class TinyLLamma(nn.Module):
 
  def __init__(self, config: cfg.ModelConfig):
    super().__init__()
 
    self.config = config
    # Construct model layers.
    self.lm_head = nn.Linear(
        config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
    )
    self.tok_embedding = nn.Embedding(
        config.vocab_size, config.embedding_dim, padding_idx=0
    )
    self.transformer_blocks = nn.ModuleList(
        TransformerBlock(config) for _ in range(config.num_layers)
    )
    self.final_norm = builder.build_norm(
        config.embedding_dim,
        config.final_norm_config,
    )
    self.rope_cache = attn_utils.build_rope_cache(
        size=config.kv_cache_max,
        dim=int(config.attn_config.rotary_percentage * config.head_dim),
        base=10_000,
        condense_ratio=1,
        dtype=torch.float32,
        device=torch.device("cpu"),
    )
    self.mask_cache = attn_utils.build_causal_mask_cache(
        size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
    )
    self.config = config

Etapa 2: Defina a função forward do modelo

@torch.inference_mode
  def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
    B, T = idx.size()
    cos, sin = self.rope_cache
    cos = cos.index_select(0, input_pos)
    sin = sin.index_select(0, input_pos)
    mask = self.mask_cache.index_select(2, input_pos)
    mask = mask[:, :, :, : self.config.kv_cache_max]
 
    # forward the model itself
    x = self.tok_embedding(idx)  # token embeddings of shape (b, t, n_embd)
 
    for i, block in enumerate(self.transformer_blocks):
      x = block(x, (cos, sin), mask, input_pos)
 
    x = self.final_norm(x)
    res = self.lm_head(x)  # (b, t, vocab_size)
    return res

Etapa 3: Mapeie os pesos do modelo antigo

A biblioteca permite mapear pesos facilmente com as APIs ModelLoader, por exemplo:

import ai_edge_torch.generative.utilities.loader as loading_utils
 
 
# This map will associate old tensor names with the new model.
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
    ff_up_proj="model.layers.{}.mlp.up_proj",
    ff_down_proj="model.layers.{}.mlp.down_proj",
    ff_gate_proj="model.layers.{}.mlp.gate_proj",
    attn_query_proj="model.layers.{}.self_attn.q_proj",
    attn_key_proj="model.layers.{}.self_attn.k_proj",
    attn_value_proj="model.layers.{}.self_attn.v_proj",
    attn_output_proj="model.layers.{}.self_attn.o_proj",
    pre_attn_norm="model.layers.{}.input_layernorm",
    pre_ff_norm="model.layers.{}.post_attention_layernorm",
    embedding="model.embed_tokens",
    final_norm="model.norm",
    lm_head="lm_head",
)

Depois que essas etapas forem concluídas, você poderá executar algumas entradas de exemplo para verificar a exatidão numérica (consulte o link) do modelo recriado. Se a verificação numérica for aprovada, você poderá prosseguir para a etapa de conversão e quantização.


Conversão e quantização

Com as APIs de conversão fornecidas pelo ai_edge_torch, você pode utilizar a mesma API para converter modelos de transformer (recriadas) em um modelo do TensorFlow Lite altamente otimizado. O processo de conversão inclui as seguintes etapas principais:

1) Exporte para o StableHLO. O modelo do PyTorch é rastreado e compilado para um gráfico FX com ops Aten pelo compilador torch dynamo e, em seguida, baixado para o gráfico StableHLO via ai_edge_torch.

2) O ai_edge_torch executa outras passagens do compilador no StableHLO, incluindo fusão/dobra de ops, entre outras, e gera um flatbuffer do TFLite de alto desempenho (com ops fundidos para SDPA, KVCache).


Quantização

A biblioteca de núcleos da API generativa também fornece um conjunto de APIs de quantização que abrange roteiros comuns de quantização de LLMs. O roteiro recebe um parâmetro adicional para a API de conversão ai_edge_torch, o que cobre automaticamente a quantização. Em versões futuras, esperamos expandir o conjunto de modos de quantização disponíveis.


Exportação de várias assinaturas

Identificamos que, em cenários reais de inferência, os modelos de LLM precisam ter funções de inferência claramente separadas (desagregadas), como preenchimento e decodificação, para atingir o melhor desempenho. Isso se baseia parcialmente na observação de que o preenchimento/a decodificação podem assumir diferentes formas de tensor: o preenchimento é vinculado à computação, enquanto a decodificação é vinculada à memória. Para LLMs grandes, é fundamental evitar a duplicação de pesos de modelo entre o preenchimento e a decodificação. Conseguimos isso usando o recurso de várias assinatura existente no TFLite e no ai_edge_torch, que permite definir facilmente vários pontos de entrada para o modelo, conforme mostrado abaixo.

def convert_tiny_llama_to_tflite(
    prefill_seq_len: int = 512,
    kv_cache_max_len: int = 1024,
    quantize: bool = True,
):
  pytorch_model = tiny_llama.build_model(kv_cache_max_len=kv_cache_max_len)
 
  # Tensors used to trace the model graph during conversion.
  prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
  prefill_input_pos = torch.arange(0, prefill_seq_len)
  decode_token = torch.tensor([[0]], dtype=torch.long)
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
 
  # Set up Quantization for model.
  quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None
 
  edge_model = (
      ai_edge_torch.signature(
          'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
      )
      .signature('decode', pytorch_model, (decode_token, decode_input_pos))
      .convert(quant_config=quant_config)
  )
  edge_model.export(f'/tmp/tiny_llama_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite')

Otimizações de desempenho específicas de LLMs

Durante nossa fase de investigação de desempenho, encontramos alguns aspectos críticos para a melhoria do desempenho de LLMs:

1) SDPA e KVCache de alto desempenho: descobrimos que sem otimizações/fusões suficientes do compilador, o modelo do TFLite convertido não terá um ótimo desempenho, dadas as ops granulares nessas funções. Para resolver isso, introduzimos o limite de função de nível alto e as ops compostas do StableHLO.

2) Uso do delegado XNNPack do TFLite para acelerar ainda mais o SDPA: é fundamental garantir que os cálculos pesados de MatMul/Matriz-vetor sejam bem otimizados. A biblioteca XNNPack tem excelente desempenho para esses primitivos em uma ampla gama de CPUs móveis.

3) Evite os cálculos desnecessários: os modelos em forma estática podem induzir mais computação do que o minimamente necessário se tiverem um tamanho de mensagem de entrada fixo longo na fase de preenchimento ou um comprimento de sequência fixo grande na fase de decodificação.

4) Consumo de memória em tempo de execução. Introduzimos um mecanismo de armazenamento em cache/pré-empacotamento de pesos no delegado XNNPack do TFLite para reduzir significativamente o uso máximo da memória.


Implementação

A inferência de LLM normalmente envolve muitas etapas de pré/pós-processamento e orquestração sofisticada, por exemplo, tokenização, amostragem e lógica de decodificação autorregressiva. Para esse fim, fornecemos as soluções baseadas em MediaPipe e um exemplo de inferência em C++ puro.


Use a API MediaPipe LLM Inference

A API MediaPipe LLM Inference é uma API de nível alto com suporte à inferência de LLM usando uma interface de entrada e saída de prompts. Ela lida com toda a complexidade da implementação do pipeline de LLM nos bastidores e torna a implantação muito mais fácil e fluente. Para implantar usando a API MP LLM Inference, você precisa converter modelos usando as assinaturas de preenchimento e decodificação esperadas e criar um pacote conforme mostrado no código abaixo:

def bundle_tinyllama_q8():
  output_file = "PATH/tinyllama_q8_seq1024_kv1280.task"
  tflite_model = "PATH/tinyllama_prefill_decode_hlfb_quant.tflite"
  tokenizer_model = "PATH/tokenizer.model"
  config = llm_bundler.BundleConfig(
      tflite_model=tflite_model,
      tokenizer_model=tokenizer_model,
      start_token="<s>",
      stop_tokens=["</s>"],
      output_filename=output_file,
      enable_bytes_to_unicode_mapping=False,
  )
  llm_bundler.create_bundle(config)

Inferência em C++ puro via ambiente de execução do TFLite

Também fornecemos um exemplo em C++ fácil de usar (sem dependência do MediaPipe) para mostrar como executar um exemplo de geração de texto de ponta a ponta. Os desenvolvedores podem usar esse exemplo como ponto de partida para integrar os modelos exportados a seus pipelines e requisitos de produção exclusivos, o que permite uma personalização melhor e mais flexibilidade.


Suporte multiplataforma

Como o ambiente de execução da inferência de núcleo está no TFLite, todo o pipeline pode ser facilmente integrado a apps Android (incluídos no Google Play) ou iOS sem modificações. Isso garantirá que os modelos convertidos da nova API generativa sejam imediatamente implantáveis pela simples adição de algumas dependências de ops personalizadas. Em versões futuras, levaremos o suporte a GPU para o Android e o iOS e também lidaremos com os aceleradores de ML (TPU, NPU).


Conjunto de ferramentas

O recém-anunciado Model Explorer é uma ferramenta útil para visualizar modelos grandes, como o Gemma 2B. A visualização hierárquica e a comparação lado a lado facilitam a visualização das versões originais/recriadas/convertidas de modelos. Para obter mais detalhes e saber como você pode visualizar informações de comparativos de mercado para o ajuste de desempenho, confira esta postagem do blog.

Segue um exemplo de como usamos isso na criação do modelo do PyTorch TinyLlama, mostrando o modelo export() do PyTorch ao lado do modelo do TFLite. Usando o Model Explorer, podemos comparar facilmente como cada camada (por exemplo, RMSNorms, SelfAttention) é expressada.

Uma comparação lado a lado entre o TinyLlama PyTorch e o TFLite convertido

Resumo e o que vem a seguir

A API AI Edge Torch Generative é um poderoso complemento para modelos otimizados prontos, disponíveis na API Mediapipe LLM Inference para desenvolvedores que desejam ativar seus próprios modelos de IA generativa no dispositivo. Aguarde as atualizações nos próximos meses, incluindo suporte a Web, quantização aprimorada e suporte a computação mais amplo, além da CPU. Também estamos interessados em explorar uma integração de framework ainda melhor.

Esse é um pré-lançamento antecipado da biblioteca, que está em fase experimental, com o objetivo de engajamento com a comunidade de desenvolvedores. Portanto saiba que as APIs provavelmente sofrerão alterações, que haverá algumas arestas a serem aparadas e que o suporte a quantização e modelos será limitado. Mas você já tem muito com que começar em nosso repositório do GitHub, então entre e fique à vontade para compartilhar PRs, problemas e solicitações de recursos.


Na parte 3 desta série, veremos mais detalhadamente a ferramenta de visualização Model Explorer, que permite que os desenvolvedores visualizem, depurem e explorem modelos.



Agradecimentos

Este trabalho é uma colaboração entre várias equipes funcionais do Google. Gostaríamos de agradecer a todos os membros da equipe que contribuíram para que ele fosse possível: Aaron Karp, Advait Jain, Akshat Sharma, Alan Kelly, Andrei Kulik, Arian Afaian, Chun-nien Chan, Chuo-Ling Chang, Cormac Brick, Eric Yang, Frank Barchard, Gunhyun Park, Han Qi, Haoliang Zhang, Ho Ko, Jing Jin, Joe Zoe, Juhyun Lee, Kevin Gleason, Khanh LeViet, Kris Tonthat, Kristen Wright, Lin Chen, Linkun Chen, Lu Wang, Majid Dadashi, Manfei Bai, Mark Sherwood, Matthew Soulanille, Matthias Grundmann, Maxime Brénon, Michael Levesque-Dion, Mig Gerard, Milen Ferev, Mohammadreza Heydary, Na Li, Paul Ruiz, Pauline Sho, Pei Zhang, Ping Yu, Pulkit Bhuwalka, Quentin Khan, Ram Iyengar, Renjie Wu, Rocky Rhodes, Sachin Kotwani, Sandeep Dasgupta, Sebastian Schmidt, Siyuan Liu, Steven Toribio, Suleman Shahid, Tenghui Zhu, T.J. Alumbaugh, Tyler Mullen, Weiyi Wang, Wonjoo Lee, Yi-Chun Kuo, Yishuang Pang, Yu-hui Chen, Zoe Wang e Zichuan Wei.