AI Edge Torch: inferencia de alto rendimiento de modelos PyTorch en dispositivos móviles

MAY 14, 2024
Cormac Brick Principal Engineer
Advait Jain Software Engineer
Haoliang Zhang Software Engineer

Nos complace anunciar Google AI Edge Torch, una ruta directa desde PyTorch hasta el tiempo de ejecución de TensorFlow Lite (TFLite) con una excelente cobertura del modelo y rendimiento de la CPU. TFLite ya funciona con modelos escritos en Jax, Keras y TensorFlow, y ahora agregamos PyTorch como parte de un compromiso más amplio con la opcionalidad del framework.

Esta nueva oferta ahora está disponible como parte de Google AI Edge, un paquete de herramientas con fácil acceso a tareas de ML listas para usar, frameworks que te permiten crear canales de ML y ejecutar LLM populares y modelos personalizados, todo en el dispositivo. Esta es la primera de una serie de entradas de blog que cubren los lanzamientos de Google AI Edge que ayudarán a los desarrolladores a crear funciones habilitadas para IA y a implementarlas fácilmente en múltiples plataformas.

AI Edge Torch se lanza hoy en versión beta con:

  • Integración directa con PyTorch
  • Excelente rendimiento de la CPU y compatibilidad inicial con la GPU
  • Validación en más de 70 modelos de torchvision, timm, torchaudio y HuggingFace
  • Compatibilidad para el > 70% de los operadores core_aten en PyTorch
  • Compatibilidad con el tiempo de ejecución TFLite existente, sin necesidad de cambiar el código de implementación
  • Compatibilidad para la visualización de Model Explorer en múltiples etapas del flujo de trabajo


Una experiencia simple centrada en PyTorch

Google AI Edge Torch se creó desde cero para proporcionar una gran experiencia a la comunidad de PyTorch, con API que se sienten nativas y proporcionan una ruta de conversión fácil.

import torchvision
import ai_edge_torch
 
# Inicializar modelo
resnet18 = torchvision.models.resnet18().eval()
 
# Convertir
sample_input = (torch.randn(4, 3, 224, 224),)
edge_model = ai_edge_torch.convert(resnet18, sample_input)
 
# Inferencia en Python
output = edge_model(*sample_input)
 
# Exportar a un modelo TFLite para la implementación en el dispositivo
edge_model.export('resnet.tflite'))

Detrás de escena, ai_edge_torch.convert() se integra con TorchDynamo utilizando torch.export, que es la forma de PyTorch 2.x de exportar modelos PyTorch a representaciones de modelos estandarizados destinados a ejecutarse en diferentes entornos. Nuestra implementación actual admite más del 60% de los operadores core_aten, que planeamos aumentar significativamente a medida que avanzamos hacia una versión 1.0 de ai_edge_torch. Incluimos ejemplos que muestran la cuantificación PT2E, el enfoque de cuantificación nativo de PyTorch2, para permitir flujos de trabajo de cuantificación fáciles. Nos entusiasma escuchar a la comunidad de PyTorch para encontrar formas de mejorar la experiencia del desarrollador al llevar la innovación que comienza en PyTorch a un amplio conjunto de dispositivos.


Cobertura y rendimiento

Antes de esta versión, muchos desarrolladores utilizaban rutas proporcionadas por la comunidad, como ONNX2TF, para habilitar los modelos PyTorch en TFLite. Nuestro objetivo al desarrollar AI Edge Torch era reducir la fricción del desarrollador, proporcionar una excelente cobertura del modelo y continuar con nuestra misión de ofrecer el mejor rendimiento de su clase en dispositivos Android.

En cuanto a la cobertura, nuestras pruebas demuestran mejoras significativas sobre el conjunto definido de modelos sobre los flujos de trabajo existentes, particularmente ONNX2TF

Table showing performance improvement in existing workflows over defined set models

En cuanto al rendimiento, nuestras pruebas muestran un rendimiento coherente con la línea de base de ONNX2TF, al tiempo que muestran un rendimiento significativamente mejor que el tiempo de ejecución de ONNX:

Table showing performance with ONNX2TF baseline

Aquí se muestra el rendimiento detallado por modelo en el subconjunto de los modelos cubiertos por ONNX:

Chart showing per model TFLite latency relative to ONNX
Figura: Latencia de inferencia por red en comparación con ONNX, medida en Pixel8, precisión fp32, XNNPACK fijado en 4 subprocesos para ayudar a la reproducibilidad, promedio de 100 ejecuciones después de 20 iteraciones de preparación

Adopción anticipada y asociaciones

En los últimos meses, trabajamos en estrecha colaboración con los primeros socios de adopción, incluidos Shopify, Adobe y Niantic, para mejorar nuestro soporte de PyTorch. ai_edge_torch ya está siendo utilizado por el equipo de Shopify para realizar la eliminación en segundo plano en el dispositivo para las imágenes de los productos y estará disponible en una próxima versión de la aplicación Shopify.

Quote image with text reads "Converting PyTorch models to run locally on Android was complex. Google's new tools simplify this, enabling fast creation of mobile-ready PyTorch models - Mustapha Ali, Shopify, Director of Engineering

Asociaciones y delegados de silicio

También trabajamos en estrecha colaboración con socios para trabajar en el soporte de hardware en CPU, GPU y aceleradores, que incluye ARM, Google Tensor G3, MediaTek, Qualcomm y Samsung System LSI. A través de estas asociaciones, mejoramos el rendimiento y la cobertura, y validamos los archivos TFLite generados por PyTorch en los delegados del acelerador.

También estamos encantados de anunciar conjuntamente el nuevo delegado TensorFlow Lite de Qualcomm, que ahora está disponible abiertamente aquí para que lo use cualquier desarrollador. Los delegados TFLite son módulos de software complementarios que ayudan a acelerar la ejecución en GPU y aceleradores de hardware. Este nuevo delegado de QNN admite la mayoría de los modelos en nuestro conjunto de pruebas PyTorch Beta, al tiempo que proporciona compatibilidad para un amplio conjunto de silicio de Qualcomm, y proporciona aceleraciones promedio significativas en relación con la CPU(20x) y la GPU(5x) utilizando las unidades de procesamiento neuronal y DSP de Qualcomm. Para que sea más fácil de probar, Qualcomm también lanzó recientemente su nuevo AI Hub. Qualcomm AI Hub es un servicio en la nube que permite a los desarrolladores probar modelos TFLite contra un amplio grupo de dispositivos Android y proporciona visibilidad de las ganancias de rendimiento disponibles en diferentes dispositivos utilizando el delegado QNN.


Lo que viene

En los próximos meses continuaremos iterando de forma abierta, con lanzamientos que amplían la cobertura del modelo, mejoran el soporte de la GPU y habilitan nuevos modos de cuantificación a medida que creamos una versión 1.0. En la parte 2 de esta serie, echaremos un vistazo más profundo a la API generativa de AI Edge Torch, que permite a los desarrolladores llevar modelos GenAI personalizados al límite con un gran rendimiento.

Nos gustaría agradecer a todos nuestros clientes de acceso anticipado por sus valiosos comentarios, que nos ayudaron a detectar los primeros errores y garantizar una experiencia fluida para los desarrolladores. También nos gustaría agradecer a los socios de hardware y a los colaboradores del ecosistema de XNNPACK que nos han ayudado a mejorar el rendimiento en una amplia gama de dispositivos. También nos gustaría agradecer a la comunidad de PyTorch en general por su orientación y apoyo.



Agradecimientos

Nos gustaría agradecer a todos los miembros del equipo que colaboraron con este trabajo: Aaron Karp, Advait Jain, Akshat Sharma, Alan Kelly, Arian Arfaian, Chun-nien Chan, Chuo-Ling Chang, Claudio Basille, Cormac Brick, Dwarak Rajagopal, Eric Yang, Gunhyun Park, Han Qi, Haoliang Zhang, Jing Jin, Juhyun Lee, Jun Jiang, Kevin Gleason, Khanh LeViet, Kris Tonthat, Kristen Wright, Lu Wang, Luke Boyer, Majid Dadashi, Maria Lyubimtseva, Mark Sherwood, Matthew Soulanille, Matthias Grundmann, Meghna Johar, Milad Mohammadi, Na Li, Paul Ruiz, Pauline Sho, Ping Yu, Pulkit Bhuwalka, Ram Iyengar, Sachin Kotwani, Sandeep Dasgupta, Sharbani Roy, Shauheen Zahirazami, Siyuan Liu, Vamsi Manchala, Vitalii Dziuba, Weiyi Wang, Wonjoo Lee, Yishuang Pang, Zoe Wang y el equipo de StableHLO.