API generativa de AI Edge Torch para LLM personalizados en el dispositivo

MAY 29, 2024
Cormac Brick Principal Engineer
Haoliang Zhang Software Engineer

Estamos entusiasmados de permitirles a los desarrolladores integrar sin problemas nuevos modelos de IA generativa en dispositivos perimetrales. Para satisfacer esa necesidad, anunciamos la API generativa de AI Edge Torch, que permite a los desarrolladores crear modelos de lenguaje grandes (LLM) de alto rendimiento en PyTorch para su implementación utilizando el tiempo de ejecución de TensorFlow Lite (TFLite). Esta es la segunda de una serie de entradas de blog que hablan sobre los lanzamientos para desarrolladores de Google AI Edge. En la primera entrada de la serie, se presentó Google AI Edge Torch, que permite la inferencia de alto rendimiento de los modelos de PyTorch en dispositivos móviles utilizando el tiempo de ejecución de TFLite.

Con la API generativa de AI Edge Torch, los desarrolladores pueden integrar nuevas y potentes capacidades en el dispositivo, como el resumen, la generación de contenido y mucho más. Ya permitimos a los desarrolladores integrar algunos de los LLM más populares a dispositivos utilizando la API de inferencia MediaPipe LLM. Ahora estamos entusiasmados por permitir que los desarrolladores integren cualquier modelo compatible en dispositivos con un gran rendimiento. La versión inicial de la API generativa de AI Edge Torch ofrece lo siguiente:

  • API de autoría fácil de usar para compatibilidad con transformadores personalizados

  • Gran rendimiento en la CPU, próximamente compatible con GPU y NPU

  • Totalmente compatible con los flujos de implementación de TFLite existentes, incluida la cuantificación y el tiempo de ejecución

  • Funciona con modelos como TinyLlama, Phi-2 y Gemma 2B

  • Admite las interfaces de tiempo de ejecución de TFLite y MediaPipe LLM, y tiene compatibilidad con Android, iOS y web

En esta entrada de blog, profundizaremos en el rendimiento, la portabilidad, la experiencia de autoría del desarrollador, la canalización de inferencia de extremo a extremo y el conjunto de herramientas de depuración. Puedes encontrar más documentación y ejemplos aquí.


Rendimiento

Como parte de nuestro trabajo para que algunos de los LLM más populares funcionen sin problemas mediante la API de inferencia MediaPipe LLM, nuestro equipo creó varios transformadores completamente escritos a mano con un rendimiento de última generación integrado en el dispositivo (blog sobre la API de inferencia MediaPipe LLM). De este trabajo, surgieron algunos temas: cómo representar la atención de manera efectiva, el uso de la cuantificación y la importancia de una buena representación de la caché KV. La API generativa hace que cada uno de estos sea fácil de expresar (como veremos en la siguiente sección), al tiempo que alcanza un rendimiento equivalente al >90% de nuestras versiones escritas a mano con una velocidad de desarrollo mucho mejor.

La siguiente tabla muestra puntos de referencia clave en tres ejemplos de modelos:

On device performance benchmarks across TinyLlama, Gemma 2B and Phi-2 models for Samsung S23 and Pixel 8 Pro

Estos se comparan en funciones básicas, con cuatro subprocesos de CPU, y son las implementaciones de CPU más rápidas de estos modelos que conocemos actualmente en los dispositivos enumerados.


Experiencia de autoría

La biblioteca de autoría principal proporciona componentes básicos para modelos de transformadores comunes (solo codificador, solo decodificador, estilo codificador-decodificador, etc.). Te permite crear un modelo desde cero o reelaborar un modelo existente para mejorar el rendimiento. Recomendamos a la mayoría de los usuarios que reelaboren un modelo, ya que no requiere entrenamiento ni pasos de ajuste. Los beneficios clave de la creación de API generativas incluyen lo siguiente:

  • Un conjunto de componentes de transformadores principales optimizados para la convertibilidad, el rendimiento y la portabilidad de la plataforma que son fáciles de combinar con las operaciones normales de PyTorch.

  • Un mecanismo fácil de reasignación de pesos.

  • APIs de cuantificación intuitivas.

  • Exportación de firmas múltiples con firmas precargadas, decodificadas o personalizadas, y funciona a la perfección con las tareas de MP previamente empaquetadas o las APIs de inferencia LLM.

Como ejemplo, aquí mostramos cómo reelaborar la funcionalidad principal de TinyLlama (1.1B) con alrededor de 50 líneas de Python con la nueva API generativa.

Paso 1: definir la estructura del modelo

import torch
import torch.nn as nn
 
from ai_edge_torch.generative.layers.attention import TransformerBlock
import ai_edge_torch.generative.layers.attention_utils as attn_utils
import ai_edge_torch.generative.layers.builder as builder
import ai_edge_torch.generative.layers.model_config as cfg
 
 
class TinyLLamma(nn.Module):
 
  def __init__(self, config: cfg.ModelConfig):
    super().__init__()
 
    self.config = config
    # Construct model layers.
    self.lm_head = nn.Linear(
        config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
    )
    self.tok_embedding = nn.Embedding(
        config.vocab_size, config.embedding_dim, padding_idx=0
    )
    self.transformer_blocks = nn.ModuleList(
        TransformerBlock(config) for _ in range(config.num_layers)
    )
    self.final_norm = builder.build_norm(
        config.embedding_dim,
        config.final_norm_config,
    )
    self.rope_cache = attn_utils.build_rope_cache(
        size=config.kv_cache_max,
        dim=int(config.attn_config.rotary_percentage * config.head_dim),
        base=10_000,
        condense_ratio=1,
        dtype=torch.float32,
        device=torch.device("cpu"),
    )
    self.mask_cache = attn_utils.build_causal_mask_cache(
        size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
    )
    self.config = config

Paso 2: definir la función directa del modelo

@torch.inference_mode
  def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
    B, T = idx.size()
    cos, sin = self.rope_cache
    cos = cos.index_select(0, input_pos)
    sin = sin.index_select(0, input_pos)
    mask = self.mask_cache.index_select(2, input_pos)
    mask = mask[:, :, :, : self.config.kv_cache_max]
 
    # forward the model itself
    x = self.tok_embedding(idx)  # token embeddings of shape (b, t, n_embd)
 
    for i, block in enumerate(self.transformer_blocks):
      x = block(x, (cos, sin), mask, input_pos)
 
    x = self.final_norm(x)
    res = self.lm_head(x)  # (b, t, vocab_size)
    return res

Paso 3: asignar los pesos del modelo anterior

Con la biblioteca, puedes asignar pesos de manera fácil con las APIs ModelLoader, por ejemplo:

import ai_edge_torch.generative.utilities.loader as loading_utils
 
 
# This map will associate old tensor names with the new model.
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
    ff_up_proj="model.layers.{}.mlp.up_proj",
    ff_down_proj="model.layers.{}.mlp.down_proj",
    ff_gate_proj="model.layers.{}.mlp.gate_proj",
    attn_query_proj="model.layers.{}.self_attn.q_proj",
    attn_key_proj="model.layers.{}.self_attn.k_proj",
    attn_value_proj="model.layers.{}.self_attn.v_proj",
    attn_output_proj="model.layers.{}.self_attn.o_proj",
    pre_attn_norm="model.layers.{}.input_layernorm",
    pre_ff_norm="model.layers.{}.post_attention_layernorm",
    embedding="model.embed_tokens",
    final_norm="model.norm",
    lm_head="lm_head",
)

Después de que se realicen esos pasos, puedes ejecutar algunas entradas de muestra para verificar la exactitud numérica (consulta el vínculo) del modelo reelaborado. Si pasa la verificación numérica, puedes continuar con el paso de conversión y cuantificación.


Conversión & Cuantificación

Con las APIs de conversión proporcionadas por ai_edge_torch, puedes aprovechar la misma API para convertir modelos de transformadores (reelaborados) a un modelo de TensorFlow Lite altamente optimizado. El proceso de conversión contiene los siguientes pasos clave:

1) Exportar a StableHLO. El modelo de PyTorch es rastreado y compilado en un gráfico FX con operaciones ATen por el compilador TorchDynamo, y luego llevado a un gráfico de StableHLO por ai_edge_torch.

2) ai_edge_torch ejecuta más pases de compilador en StableHLO, incluida la fusión o el plegado de operaciones, etc., y genera un FlatBuffer de TFLite de alto rendimiento (con operaciones fusionadas para el SDPA y la caché KV).


Cuantificación

La biblioteca principal de la API generativa también proporciona un conjunto de API de cuantificación que abarca recetas de cuantificación de LLM comunes. La receta pasa un parámetro adicional a la API del conversor de ai_edge_torch, que cubre automáticamente la cuantificación. En futuras versiones, esperamos ampliar el conjunto de modos de cuantificación disponibles.


Exportación de firmas múltiples

Identificamos que en escenarios de inferencia reales, los modelos de LLM deben tener funciones de inferencia (prellenado, decodificación) claramente separadas (desagregadas) para lograr el mejor rendimiento de servicio. Esto se basa, en parte, en la observación de que el prellenado y la decodificación pueden tomar diferentes formas de tensores; el prellenado está limitado por el cálculo, mientras que la decodificación está limitada por la memoria. Para LLM grandes, es fundamental evitar duplicar los pesos del modelo entre el prellenado y la decodificación. Logramos esto con la función de firmas múltiples existente en TFLite y ai_edge_torch, que te permite definir fácilmente varios puntos de entrada para el modelo como se muestra a continuación.

def convert_tiny_llama_to_tflite(
    prefill_seq_len: int = 512,
    kv_cache_max_len: int = 1024,
    quantize: bool = True,
):
  pytorch_model = tiny_llama.build_model(kv_cache_max_len=kv_cache_max_len)
 
  # Tensors used to trace the model graph during conversion.
  prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
  prefill_input_pos = torch.arange(0, prefill_seq_len)
  decode_token = torch.tensor([[0]], dtype=torch.long)
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
 
  # Set up Quantization for model.
  quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None
 
  edge_model = (
      ai_edge_torch.signature(
          'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
      )
      .signature('decode', pytorch_model, (decode_token, decode_input_pos))
      .convert(quant_config=quant_config)
  )
  edge_model.export(f'/tmp/tiny_llama_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite')

Optimizaciones de rendimiento específicas de LLM

Durante nuestra fase de investigación del rendimiento, encontramos algunos aspectos fundamentales para mejorar el rendimiento de LLM:

1) SDPA y caché KV de alto rendimiento: encontramos que sin suficientes optimizaciones/fusiones de compiladores, el modelo de TFLite convertido no tendrá un gran rendimiento, dadas las operaciones granulares en estas funciones. Para abordar esto, presentamos el límite de funciones de alto nivel y las operaciones compuestas de StableHLO.

2) Aprovechar el delegado XNNPack de TFLite para acelerar aún más el SDPA: es fundamental garantizar que los cálculos pesados de MatMul/Matrix-vector estén bien optimizados. La biblioteca XNNPack tiene un excelente rendimiento para estas primitivas en una amplia gama de CPU móviles.

3) Evitar cálculos innecesarios: los modelos de forma estática pueden inducir más cálculos de los que se requieren mínimamente si los modelos tienen un tamaño de mensaje de entrada fijo y largo en la etapa de prellenado o una longitud de secuencia fija y extensa en la etapa de decodificación.

4) Consumo de memoria en tiempo de ejecución. Introdujimos un mecanismo de almacenamiento en caché/preempaquetado de pesos en el delegado XNNPack de TFLite para reducir significativamente el uso máximo de memoria.


Implementación

La inferencia de LLM generalmente implica muchos pasos previos/posteriores al procesamiento y una orquestación sofisticada, por ejemplo, tokenización, muestreo y lógica de decodificación autorregresiva. Con este fin, brindamos tanto las soluciones basadas en MediaPipe como un ejemplo de inferencia pura de C++.


Uso de la API de inferencia MediaPipe LLM

La API de inferencia MediaPipe LLM es una API de alto nivel que admite la inferencia de LLM mediante una interfaz de entrada/salida de indicaciones. Se encarga de toda la complejidad que conlleva implementar la canalización de LLM en un nivel subyacente y hace que esto sea mucho más fácil y fluido. Para realizar una implementación utilizando la API de inferencia MP LLM, debes asegurarte de convertir los modelos utilizando las firmas de prellenado y decodificación esperadas, y crear un conjunto como se muestra en el siguiente código:

def bundle_tinyllama_q8():
  output_file = "PATH/tinyllama_q8_seq1024_kv1280.task"
  tflite_model = "PATH/tinyllama_prefill_decode_hlfb_quant.tflite"
  tokenizer_model = "PATH/tokenizer.model"
  config = llm_bundler.BundleConfig(
      tflite_model=tflite_model,
      tokenizer_model=tokenizer_model,
      start_token="<s>",
      stop_tokens=["</s>"],
      output_filename=output_file,
      enable_bytes_to_unicode_mapping=False,
  )
  llm_bundler.create_bundle(config)

Inferencia pura de C++ a través del tiempo de ejecución de TFLite

También te proporcionamos un ejemplo de C++ fácil de usar (sin dependencia de MediaPipe) para mostrar cómo ejecutar un ejemplo de generación de texto de extremo a extremo. Los desarrolladores pueden usar este ejemplo como punto de partida para integrar los modelos exportados con sus canalizaciones y requisitos de producción únicos, lo que permite una mejor personalización y flexibilidad.


Compatibilidad multiplataforma

Dado que el tiempo de ejecución de la inferencia principal está en TFLite, toda la canalización se puede integrar fácilmente en tus apps para iOS o Android (incluso en Google Play) sin ninguna modificación. Esto garantizará que los modelos convertidos desde la nueva API generativa se puedan implementar de inmediato con solo agregar algunas dependencias de operaciones personalizadas. En futuras versiones, ofreceremos compatibilidad con GPU para iOS & y Android, y también para aceleradores de AA de destino (TPU, NPU).


Herramientas

El recientemente anunciado Explorador de modelos es una herramienta útil para visualizar modelos grandes como Gemma 2B. La vista jerárquica y la comparación facilitan la visualización de las versiones del modelo originales/reelaboradas/convertidas. Para obtener más detalles sobre esto y conocer cómo puedes visualizar la información de referencia a fin de ajustar el rendimiento, consulta esta entrada de blog.

A continuación se muestra un ejemplo de cómo usamos esto al crear el modelo TinyLlama de PyTorch, que muestra el modelo export() de PyTorch junto con el modelo de TFLite. Con Explorador de modelos, podemos comparar fácilmente cómo se expresa cada capa (p. ej., RMSNorms, SelfAttention).

Una comparación entre TinyLlama de PyTorch y TFLite convertido

Resumen & y lo que viene

La API generativa de AI Edge Torch es una buena socia de los modelos optimizados precompilados disponibles en la API de inferencia MediaPipe LLM para desarrolladores que desean habilitar sus propios modelos de IA generativa en dispositivos. En los próximos meses, saldrán nuevas actualizaciones que incluyen compatibilidad web, cuantificación mejorada y compatibilidad de cálculo ampliada que va más allá de la CPU. También estamos interesados en explorar una integración de marcos aún mejor.

Esta es una versión preliminar anticipada de la biblioteca, que se encuentra en una etapa experimental con el objetivo de interactuar con la comunidad de desarrolladores. Puedes esperar cambios en las APIs, algunas imperfecciones y una compatibilidad limitada para la cuantificación y los modelos. Sin embargo, ya hay mucho con lo que empezar en nuestro repositorio de GitHub: entra y no dudes en compartir PR, problemas y solicitudes de funciones.


En la tercera parte de esta serie, echaremos un vistazo más profundo a la herramienta de visualización Explorador de modelos, que permite a los desarrolladores visualizar, depurar y explorar modelos.



Agradecimientos

Este trabajo es una colaboración entre varios equipos funcionales de Google. Nos gustaría agradecer a todos los miembros del equipo que contribuyeron a este trabajo: Aaron Karp, Advait Jain, Akshat Sharma, Alan Kelly, Andrei Kulik, Arian Afaian, Chun-nien Chan, Chuo-Ling Chang, Cormac Brick, Eric Yang, Frank Barchard, Gunhyun Park, Han Qi, Haoliang Zhang, Ho Ko, Jing Jin, Joe Zoe, Juhyun Lee, Kevin Gleason, Khanh LeViet, Kris Tonthat, Kristen Wright, Lin Chen, Linkun Chen, Lu Wang, Majid Dadashi, Manfei Bai, Mark Sherwood, Matthew Soulanille, Matthias Grundmann, Maxime Brénon, Michael Levesque-Dion, Mig Gerard, Milen Ferev, Mohammadreza Heydary, Na Li, Paul Ruiz, Pauline Sho, Pei Zhang, Ping Yu, Pulkit Bhuwalka, Quentin Khan, Ram Iyengar, Renjie Wu, Rocky Rhodes, Sachin Kotwani, Sandeep Dasgupta, Sebastian Schmidt, Siyuan Liu, Steven Toribio, Suleman Shahid, Tenghui Zhu, T.J. Alumbaugh, Tyler Mullen, Weiyi Wang, Wonjoo Lee, Yi-Chun Kuo, Yishuang Pang, Yu-hui Chen, Zoe Wang y Zichuan Wei.