We are excited to announce Google AI Edge Torch - a direct path from PyTorch to the TensorFlow Lite (TFLite) runtime with great model coverage and CPU performance. TFLite already works with models written in Jax, Keras, and TensorFlow, and we are now adding PyTorch as part of a wider commitment to framework optionality. This new offering is now available as part of Google AI Edge, a suite of tools with easy access to ready-to-use ML tasks, frameworks that enable you to build ML pipelines, and run popular LLMs and custom models – all on-device. This is the first of a series of blog posts covering Google AI Edge releases that will help developers build AI enabled features, and easily deploy them on multiple platforms. AI Edge Torch is released in Beta today featuring: Direct PyTorch integration

Excellent CPU performance and initial GPU support

Validated on over 70 models from torchvision, timm, torchaudio and HuggingFace

Support for > 70% of core_aten operators in PyTorch

Compatibility with existing TFLite runtime, with no change to deployment code needed

Support for Model Explorer visualization at multiple stages of the workflow.

A simple, PyTorch-centric experience Google AI Edge Torch was built from the ground up to provide a great experience to the PyTorch community, with APIs that feel native, and provide an easy conversion path.

import torchvision import ai_edge_torch # Initialize model resnet18 = torchvision . models . resnet18 () . eval () # Convert sample_input = ( torch . randn ( 4 , 3 , 224 , 224 ),) edge_model = ai_edge_torch . convert ( resnet18 , sample_input ) # Inference in Python output = edge_model ( * sample_input ) # Export to a TfLite model for on-device deployment edge_model . export ( 'resnet.tflite' ))

Under the hood, ai_edge_torch.convert() is integrated with TorchDynamo using torch.export - which is the PyTorch 2.x way to export PyTorch models into standardized model representations intended to be run on different environments. Our current implementation supports more than 60% of core_aten operators, which we plan to increase significantly as we build towards a 1.0 release of ai_edge_torch . We’ve included examples showing PT2E quantization, the quantization approach native to PyTorch2, to enable easy quantization workflows. We’re excited to hear from the PyTorch community to find ways to improve developer experience when bringing innovation that starts in PyTorch to a wide set of devices.

Coverage & Performance Prior to this release, many developers were using community provided paths such as ONNX2TF to enable PyTorch models on TFLite. Our goal in developing AI Edge Torch was to reduce developer friction, provide great model coverage, and to continue our mission of delivering best in class performance on Android devices. On coverage, our tests demonstrate significant improvements over the defined set of models over existing workflows, particularly ONNX2TF

On performance, our tests show consistent performance with ONNX2TF baseline, while also showing meaningfully better performance than the ONNX runtime:

This shows detailed per-model performance on the subset of the models covered by ONNX:

Figure: Inference latency per network compared to ONNX, measured on Pixel8, fp32 precision, XNNPACK fixed to 4 threads to aid reproducibility, average of 100 runs after 20 iteration warm up

Early Adoption and Partnerships In the last few months, we have worked closely with early adoption partners including Shopify, Adobe, and Niantic to improve our PyTorch support. ai_edge_torch is already being used by the team at Shopify to perform on-device background removal for product images and will be available in an upcoming release of the Shopify app.