使用 Grain 和 ArrayRecord 构建高性能数据管道

2025年10月3日
Jiyang Kang Technical Program Manager
Shivaji Dutta Field Solutions Architect
Ihor Indyk Software Engineer
Felix Chern Software Engineer

在 GPU 和 TPU 等功能强大的加速器上训练大型模型时,人们最不希望看到的情况就是加速器因等待数据而处于空闲状态。整个系统的速度取决于最慢的环节,而往往这个瓶颈正是数据输入管道。因此,对于大规模机器学习而言,高效且可复现的数据管道至关重要。本指南将向您展示如何通过使用 Grain(一个灵活的 JAX 数据加载库)和 ArrayRecord(一种高效文件格式)来构建稳健的高性能数据管道,从而解决这一挑战。


了解核心组件

Grain:用于 JAX 的高性能数据加载器

Grain 是一个轻量级的开源数据加载库,专为基于 JAX 的工作负载而设计,能够有效解决这一问题。它可确保数据高效地完成加载、预处理并输送至模型,从而最大限度发挥硬件性能。

为何选择 Grain?

Grain 基于性能、可复现性与灵活性三大核心理念构建,其主要优势包括:

  • Exceptional performance: Grain is built for speed. It uses efficient multiprocessing (via the .mp_prefetch() method) to run data loading and transformations in parallel, ensuring that a buffer of prepared data is always ready for your model. This keeps your accelerators saturated and minimizes training time.

  • Guaranteed determinism and reproducibility: Grain provides full reproducibility, which is critical for credible research. By setting a simple seed, you ensure the data is always shuffled the same way. Crucially, its data iterators are stateful and can be checkpointed. This means if your training job is interrupted or preempted, you can restart from the exact same point in the data stream.

  • 直观的声明式 API:您可以通过将简洁易读的方法链接在一起来定义数据管道。从 MapDataset 源开始,您可以流畅地添加诸如 .shuffle().map().batch() 等转换操作。这种声明式样式使数据管道易于理解、修改和维护。

  • Unlocking true global shuffling: To get the best performance from your models, you need to shuffle your data effectively. When paired with a file format that supports random access, like ArrayRecord, Grain can perform a true global shuffle across your entire dataset, even when it doesn’t fit into host memory. This is a powerful feature that is often computationally impractical with other data loaders and formats.


ArrayRecord 是什么?为何要使用它?

尽管 TFRecord 是广为熟知的标准格式,但其顺序性质无法实现真正的全局打乱。ArrayRecord 正是为解决这一痛点而设计的现代文件格式,为数据效率开启了全新境界。

ArrayRecord File Layout

运作原理:专为速度和并行性而设计

ArrayRecord 的高性能植根于其核心设计,该设计基于 Google 的 Riegeli 文件格式。这种结构为大规模数据处理提供了若干关键优势:

  1. Efficient random access: ArrayRecord features a built-in metadata index that maps every record to its precise location. This is the key design choice that enables instant, direct access to any record in the dataset, completely avoiding the need to read a file from the beginning.


2. Massive parallelism: Records are grouped into data chunks. This structure is inherently designed to be read in parallel, allowing multiple processes to read different chunks of the same file simultaneously to dramatically increase read throughput.


3. Exceptional performance: As a result of this indexed and chunked design, benchmarks show ArrayRecord can achieve a read throughput an order of magnitude higher than traditional formats, making it ideal for today's massive datasets.


4. Smart data integrity: The format handles data integrity intelligently by leveraging the powerful error correction in underlying cloud storage systems (like Google Cloud Storage) rather than adding redundant checks. This provides robust protection against corruption without unnecessary performance overhead.


为什么要使用?

ArrayRecord 的特性直接实现了 Grain 等现代数据加载器所需的高级功能。

The most important benefit is achieving true, deterministic global shuffling. Because any record can be accessed instantly, a data loader can generate perfectly randomized indices in the dataset on the fly as the training happens and then fetch data in that specific order. This capability, which is computationally impractical with sequential formats like TFRecord, is vital for reproducible research and optimal model training.


ArrayRecord 与 TFRecord 的详细对比

以下是 ArrayRecord 与 TFRecord 在关键特性上的详细对比分析:

  1. 结构

  • ArrayRecord 基于 Google 的 Riegeli 文件格式构建,该格式专为存储记录序列而设计,重点关注高速解码、数据完整性强压缩率。它将记录分组为数据块,并在文件末尾嵌入元数据索引。

  • TFRecord二进制记录序列,其中每条记录通常为 tf.train.Example 协议缓冲区。


2. 随机访问

  • ArrayRecord 提供原生高效的随机访问能力。其文件结构包括内置的记录位置索引,支持通过索引值直接快速访问任意记录,而无需读取整个文件。

  • 反之,TFRecord 缺乏原生随机访问支持。作为针对流式数据优化的顺序格式,访问特定记录需要从头开始遍历文件。


3. 全局打乱

  • 使用 ArrayRecord 可实现真正的全局打乱。凭借其高效的随机访问特性,Grain 等数据加载器可生成乱序索引序列并实现动态读取。

  • TFRecord 难以实现真正的全局打乱。其所谓的“全局”打乱通常依赖于近似方案,例如先打乱分片文件名列表,再在小型内存缓冲区内部打乱记录。这并非真正的全局打乱。


4. 并行 I/O

  • ArrayRecord 原生支持并行 I/O,其内部数据块结构使多个进程能轻松并行读取同一文件的不同部分,极大简化数据管理。

  • TFRecord 虽支持并行读取,但通常需要将数据集分割为大量小型 TFRecord 文件,通过不同工作器读取不同文件实现。这可能会导致需要管理大量文件。


5. 集成

  • ArrayRecord专为高性能 I/O 而设计,可与 Grain 等基于 JAX 的加载器无缝配合使用,同时也能通过 tfds.data_source 在 TensorFlow 生态系统中使用。

  • TFRecord 则与 TensorFlow 的 tf.data 生态系统紧密集成。


6. 主要用例

  • ArrayRecord is ideal for high-throughput data loading for performance-critical machine learning, especially where determinism and true global shuffling are required (e.g., JAX/TPU workloads).

  • TFRecord is suited for general-purpose, large-scale data storage for TensorFlow and is optimized for sequential reads.


如何将 TFRecord 数据集转换为 ArrayRecord

转换数据集的方法取决于它是 TensorFlow 数据集 (TFDS) 目录中的标准注册数据集还是自定义的专有数据集。

方法 1:针对 TFDS 目录中的标准数据集

如果您使用的是 cifar10imagenet2012 等知名数据集,tfds 命令行工具是最直接的转换方法。

先决条件:安装 TensorFlow 数据集

pip install -q --upgrade tfds-nightly
Shell

使用 tfds build CLI

该命令会执行三大操作:下载源数据、运行预处理逻辑,以及将输出结果保存为指定格式。

# Generate the 'cifar10' dataset in ArrayRecord format
tfds build cifar10 --file_format=array_record
Shell

生成的 ArrayRecord 文件将存储于 ~/tensorflow_datasets/ 目录中,立即可用。

方法 2:针对自定义或专有的 TFRecord 数据集

对于专属自定义 TFRecord 数据集的大规模转换,建议使用 Apache Beam 框架。array_record 库提供的预制 Beam 流水线能使转换过程变得异常简单且可扩展。该方法特别适用于海量数据集处理,因其可通过 Google Cloud Dataflow 等服务将任务分布于多个工作器中。

先决条件:安装 Apache Beam 和 Array Record Beam SDK

pip install -q apache-beam
pip install -q array-record-beam-sdk
Shell

使用预制的转换流水线

array_record.beam.pipelines 模块包含专为转换设计的 convert_tf_to_arrayrecord_disk_match_shards 函数,该实用程序能完整处理转换流程:读取 TFRecord 文件并写入对应的分片式 ArrayRecord 数据集。

在 Python 脚本中使用该程序的方法如下:

from apache_beam.options import pipeline_options
from array_record.beam.pipelines import convert_tf_to_arrayrecord_disk_match_shards
 
# 1. Define your input and output patterns.
# This example uses Google Cloud Storage (GCS) paths, which is common for large datasets.
input_pattern = 'gs://your-gcs-bucket/path/to/records-*.tfrecord'
output_path = 'gs://your-gcs-bucket/path/to/converted-records'
 
# Arguments dictionary for the conversion function.
args = {
    'input': input_pattern,
    'output': output_path,
}
 
# 2. Configure pipeline options for execution.
 
# To run locally on your machine (for smaller datasets or testing):
# No options are needed; the local runner is used by default.
local_pipeline_options = pipeline_options.PipelineOptions()
 
 
# To run at scale on Google Cloud Dataflow (for large datasets):
# Uncomment the following lines and fill in your project details.
#
# dataflow_pipeline_options = pipeline_options.PipelineOptions(
#     runner='DataflowRunner',
#     project='your-gcp-project-id',
#     region='your-gcp-region',
#     # A requirements.txt file may be needed for dependencies on Dataflow workers.
#     # requirements_file='requirements.txt',
#     temp_location='gs://your-gcs-bucket/path/to/temp'
# )
 
 
# 3. Define and run the main execution logic.
def main():
  print("Starting TFRecord to ArrayRecord conversion...")
  convert_tf_to_arrayrecord_disk_match_shards(
      args=args,
      # Pass the appropriate options here.
      # Use `local_pipeline_options` for local runs.
      # Use `dataflow_pipeline_options` for cloud runs.
      pipeline_options=local_pipeline_options,
  ).run()
  print(f"Conversion complete. ArrayRecord files written to '{output_path}'.")
 
if __name__ == '__main__':
  main()
Python

该方法比手动编写流水线更强大稳健,因为它是专门为此任务设计的经过验证的高阶 API,能自动处理诸如输出分片与输入分片匹配等技术细节。


构建 Grain 与 ArrayRecord 流水线:概念性演示

当数据转换为 ArrayRecord 格式后,您即可使用 Grain 的 Dataset API 定义高性能输入流水线。该流程包含两个步骤:创建数据源,然后链接转换方法。

第 1 步:从数据源创建 MapDataset

首先指向 ArrayRecord 文件以创建 MapDataset

import grain 
 
# Path to your generated ArrayRecord files
file_paths = ["~/tensorflow_datasets/cifar10/3.0.2/cifar10-train.array_record-00000-of-00001"]
 
# Create a data source
data_source = grain.sources.ArrayRecordDataSource(file_paths)
 
# Create a MapDataset from the source
dataset = grain.MapDataset.source(data_source)
Python

第 2 步:链接转换(打乱、映射、批处理)

现在对 MapDataset 应用转换操作。每个方法都会返回新的 MapDataset,以便您以声明方式将调用链接在一起。

# Example parsing function
def parse_and_transform(record):
    # Your logic to parse features, augment data, etc.
    return {"record": record}
 
BATCH_SIZE = 32
 
# Chain transformations
# The order of operations matters.
dataset = (
    dataset.shuffle(seed=42)
           .map(parse_and_transform)
           .batch(batch_size=BATCH_SIZE, drop_remainder=True)
)
Python

Step 3: Create and use the DatasetIterator

最后,从完整定义的数据集中创建迭代器,以在训练脚本中循环遍历数据。

# Create the stateful iterator
data_iterator = iter(dataset)
 
# You can now loop over the data
for batch in data_iterator:
    # Your training step with the batch...
    pass
 
    # For checkpoint saving/restoration, you can get/set the iterator's state
    # state = data_iterator.get_state()
    # data_iterator.set_state(state)
Python

性能配置设置

为避免数据管道成为性能瓶颈,建议使用多进程技术来加载和预处理数据,同时并行进行模型训练。在 Dataset API 中,只需在流水线末端添加 .mp_prefetch() 转换方法即可实现该功能。

该方法会启动工作器进程池,在后台异步预处理数据批次并存储至缓冲区,确保训练循环需要时可立即取用。

应用方法如下所示:

# The full pipeline with performance tuning.
dataset = (
    grain.MapDataset.source(data_source)
    .shuffle(seed=42)
    .map(parse_and_transform)
 
    # Convert to an iterable dataset to apply prefetching.
    .to_iter_dataset()
    .batch(batch_size=BATCH_SIZE, drop_remainder=True)
    # Apply multiprocessing and prefetching.
    .mp_prefetch(
        grain.multiprocessing.MultiprocessingOptions(
            num_workers=16   # Number of parallel worker processes.
        )
    )
)
 
# Create the final iterator
data_iterator = iter(dataset)
Python
  • num_workers该参数用于指定数据加载所使用的并行子进程数量。若发现加速器经常因等待数据而处于空闲状态,增加此值可显著提升吞吐量。最佳数值取决于您计算机的可用 CPU 核心数以及映射函数的复杂程度。


进一步探索

想要深入探索并开始构建?请查阅本指南中提及技术的官方文档与源代码。

基础技术


真实场景示例:大规模 LLM 训练

基于 Grain 与 ArrayRecord 构建的高性能确定性数据管道,对于大规模模型训练至关重要。典型范例是 MaxText。这是一个用 JAX 编写的高性能开源大语言模型。MaxText 正是利用这些数据管道技术,成功实现了向大型 TPU 和 GPU 集群的高效数据供给。