在 GPU 和 TPU 等功能强大的加速器上训练大型模型时,人们最不希望看到的情况就是加速器因等待数据而处于空闲状态。整个系统的速度取决于最慢的环节,而往往这个瓶颈正是数据输入管道。因此,对于大规模机器学习而言,高效且可复现的数据管道至关重要。本指南将向您展示如何通过使用 Grain(一个灵活的 JAX 数据加载库)和 ArrayRecord(一种高效文件格式)来构建稳健的高性能数据管道,从而解决这一挑战。
Grain 是一个轻量级的开源数据加载库,专为基于 JAX 的工作负载而设计,能够有效解决这一问题。它可确保数据高效地完成加载、预处理并输送至模型,从而最大限度发挥硬件性能。
Grain 基于性能、可复现性与灵活性三大核心理念构建,其主要优势包括:
.mp_prefetch()
method) to run data loading and transformations in parallel, ensuring that a buffer of prepared data is always ready for your model. This keeps your accelerators saturated and minimizes training time..shuffle()
、.map()
和 .batch()
等转换操作。这种声明式样式使数据管道易于理解、修改和维护。尽管 TFRecord 是广为熟知的标准格式,但其顺序性质无法实现真正的全局打乱。ArrayRecord 正是为解决这一痛点而设计的现代文件格式,为数据效率开启了全新境界。
ArrayRecord 的高性能植根于其核心设计,该设计基于 Google 的 Riegeli 文件格式。这种结构为大规模数据处理提供了若干关键优势:
2. Massive parallelism: Records are grouped into data chunks. This structure is inherently designed to be read in parallel, allowing multiple processes to read different chunks of the same file simultaneously to dramatically increase read throughput.
3. Exceptional performance: As a result of this indexed and chunked design, benchmarks show ArrayRecord can achieve a read throughput an order of magnitude higher than traditional formats, making it ideal for today's massive datasets.
4. Smart data integrity: The format handles data integrity intelligently by leveraging the powerful error correction in underlying cloud storage systems (like Google Cloud Storage) rather than adding redundant checks. This provides robust protection against corruption without unnecessary performance overhead.
ArrayRecord 的特性直接实现了 Grain 等现代数据加载器所需的高级功能。
The most important benefit is achieving true, deterministic global shuffling. Because any record can be accessed instantly, a data loader can generate perfectly randomized indices in the dataset on the fly as the training happens and then fetch data in that specific order. This capability, which is computationally impractical with sequential formats like TFRecord, is vital for reproducible research and optimal model training.
以下是 ArrayRecord 与 TFRecord 在关键特性上的详细对比分析:
2. 随机访问
3. 全局打乱
4. 并行 I/O
5. 集成
6. 主要用例
转换数据集的方法取决于它是 TensorFlow 数据集 (TFDS) 目录中的标准注册数据集还是自定义的专有数据集。
如果您使用的是 cifar10
或 imagenet2012
等知名数据集,tfds 命令行工具是最直接的转换方法。
先决条件:安装 TensorFlow 数据集
pip install -q --upgrade tfds-nightly
使用 tfds build CLI
该命令会执行三大操作:下载源数据、运行预处理逻辑,以及将输出结果保存为指定格式。
# Generate the 'cifar10' dataset in ArrayRecord format
tfds build cifar10 --file_format=array_record
生成的 ArrayRecord 文件将存储于 ~/tensorflow_datasets/
目录中,立即可用。
对于专属自定义 TFRecord 数据集的大规模转换,建议使用 Apache Beam 框架。array_record
库提供的预制 Beam 流水线能使转换过程变得异常简单且可扩展。该方法特别适用于海量数据集处理,因其可通过 Google Cloud Dataflow 等服务将任务分布于多个工作器中。
先决条件:安装 Apache Beam 和 Array Record Beam SDK
pip install -q apache-beam
pip install -q array-record-beam-sdk
使用预制的转换流水线
array_record.beam.pipelines
模块包含专为转换设计的 convert_tf_to_arrayrecord_disk_match_shards
函数,该实用程序能完整处理转换流程:读取 TFRecord 文件并写入对应的分片式 ArrayRecord 数据集。
在 Python 脚本中使用该程序的方法如下:
from apache_beam.options import pipeline_options
from array_record.beam.pipelines import convert_tf_to_arrayrecord_disk_match_shards
# 1. Define your input and output patterns.
# This example uses Google Cloud Storage (GCS) paths, which is common for large datasets.
input_pattern = 'gs://your-gcs-bucket/path/to/records-*.tfrecord'
output_path = 'gs://your-gcs-bucket/path/to/converted-records'
# Arguments dictionary for the conversion function.
args = {
'input': input_pattern,
'output': output_path,
}
# 2. Configure pipeline options for execution.
# To run locally on your machine (for smaller datasets or testing):
# No options are needed; the local runner is used by default.
local_pipeline_options = pipeline_options.PipelineOptions()
# To run at scale on Google Cloud Dataflow (for large datasets):
# Uncomment the following lines and fill in your project details.
#
# dataflow_pipeline_options = pipeline_options.PipelineOptions(
# runner='DataflowRunner',
# project='your-gcp-project-id',
# region='your-gcp-region',
# # A requirements.txt file may be needed for dependencies on Dataflow workers.
# # requirements_file='requirements.txt',
# temp_location='gs://your-gcs-bucket/path/to/temp'
# )
# 3. Define and run the main execution logic.
def main():
print("Starting TFRecord to ArrayRecord conversion...")
convert_tf_to_arrayrecord_disk_match_shards(
args=args,
# Pass the appropriate options here.
# Use `local_pipeline_options` for local runs.
# Use `dataflow_pipeline_options` for cloud runs.
pipeline_options=local_pipeline_options,
).run()
print(f"Conversion complete. ArrayRecord files written to '{output_path}'.")
if __name__ == '__main__':
main()
该方法比手动编写流水线更强大稳健,因为它是专门为此任务设计的经过验证的高阶 API,能自动处理诸如输出分片与输入分片匹配等技术细节。
当数据转换为 ArrayRecord 格式后,您即可使用 Grain 的 Dataset
API 定义高性能输入流水线。该流程包含两个步骤:创建数据源,然后链接转换方法。
首先指向 ArrayRecord 文件以创建 MapDataset
。
import grain
# Path to your generated ArrayRecord files
file_paths = ["~/tensorflow_datasets/cifar10/3.0.2/cifar10-train.array_record-00000-of-00001"]
# Create a data source
data_source = grain.sources.ArrayRecordDataSource(file_paths)
# Create a MapDataset from the source
dataset = grain.MapDataset.source(data_source)
现在对 MapDataset
应用转换操作。每个方法都会返回新的 MapDataset
,以便您以声明方式将调用链接在一起。
# Example parsing function
def parse_and_transform(record):
# Your logic to parse features, augment data, etc.
return {"record": record}
BATCH_SIZE = 32
# Chain transformations
# The order of operations matters.
dataset = (
dataset.shuffle(seed=42)
.map(parse_and_transform)
.batch(batch_size=BATCH_SIZE, drop_remainder=True)
)
DatasetIterator
最后,从完整定义的数据集中创建迭代器,以在训练脚本中循环遍历数据。
# Create the stateful iterator
data_iterator = iter(dataset)
# You can now loop over the data
for batch in data_iterator:
# Your training step with the batch...
pass
# For checkpoint saving/restoration, you can get/set the iterator's state
# state = data_iterator.get_state()
# data_iterator.set_state(state)
为避免数据管道成为性能瓶颈,建议使用多进程技术来加载和预处理数据,同时并行进行模型训练。在 Dataset API 中,只需在流水线末端添加 .mp_prefetch()
转换方法即可实现该功能。
该方法会启动工作器进程池,在后台异步预处理数据批次并存储至缓冲区,确保训练循环需要时可立即取用。
应用方法如下所示:
# The full pipeline with performance tuning.
dataset = (
grain.MapDataset.source(data_source)
.shuffle(seed=42)
.map(parse_and_transform)
# Convert to an iterable dataset to apply prefetching.
.to_iter_dataset()
.batch(batch_size=BATCH_SIZE, drop_remainder=True)
# Apply multiprocessing and prefetching.
.mp_prefetch(
grain.multiprocessing.MultiprocessingOptions(
num_workers=16 # Number of parallel worker processes.
)
)
)
# Create the final iterator
data_iterator = iter(dataset)
num_workers
:该参数用于指定数据加载所使用的并行子进程数量。若发现加速器经常因等待数据而处于空闲状态,增加此值可显著提升吞吐量。最佳数值取决于您计算机的可用 CPU 核心数以及映射函数的复杂程度。想要深入探索并开始构建?请查阅本指南中提及技术的官方文档与源代码。
基于 Grain 与 ArrayRecord 构建的高性能确定性数据管道,对于大规模模型训练至关重要。典型范例是 MaxText。这是一个用 JAX 编写的高性能开源大语言模型。MaxText 正是利用这些数据管道技术,成功实现了向大型 TPU 和 GPU 集群的高效数据供给。