在 GPU 和 TPU 等功能强大的加速器上训练大型模型时,人们最不希望看到的情况就是加速器因等待数据而处于空闲状态。整个系统的速度取决于最慢的环节,而往往这个瓶颈正是数据输入管道。因此,对于大规模机器学习而言,高效且可复现的数据管道至关重要。本指南将向您展示如何通过使用 Grain(一个灵活的 JAX 数据加载库)和 ArrayRecord(一种高效文件格式)来构建稳健的高性能数据管道,从而解决这一挑战。
Grain 是一个轻量级的开源数据加载库,专为基于 JAX 的工作负载而设计,能够有效解决这一问题。它可确保数据高效地完成加载、预处理并输送至模型,从而最大限度发挥硬件性能。
Grain 基于性能、可复现性与灵活性三大核心理念构建,其主要优势包括:
.mp_prefetch () 方法),实现数据加载与转换的并行执行,确保始终有就绪的数据缓冲区供模型使用。这不仅能使加速器保持饱和,还能最大程度缩短训练时间。.shuffle()、.map() 和 .batch() 等转换操作。这种声明式样式使数据管道易于理解、修改和维护。尽管 TFRecord 是广为熟知的标准格式,但其顺序性质无法实现真正的全局打乱。ArrayRecord 正是为解决这一痛点而设计的现代文件格式,为数据效率开启了全新境界。
ArrayRecord 的高性能植根于其核心设计,该设计基于 Google 的 Riegeli 文件格式。这种结构为大规模数据处理提供了若干关键优势:
2. 大规模并行:记录分组存储于数据块中。这种结构天生支持并行读取,允许多个进程同时读取同一文件的不同数据块,从而显著提升读取吞吐量。
3. 卓越的性能:得益于索引化和分块设计,基准测试显示 ArrayRecord 的读取吞吐量比传统格式高出一个数量级,使其完美适配当今海量数据集的处理需求。
4. 智能数据完整性:该格式通过利用底层云存储系统(如 Google Cloud Storage)的强大纠错功能,而非添加冗余检查,智能地处理数据完整性。这提供了稳健的防损坏保护,且不会带来不必要的性能开销。
ArrayRecord 的特性直接实现了 Grain 等现代数据加载器所需的高级功能。
其最重要的优势在于实现了真正确定性的全局打乱。由于任何记录都可以即时访问,因此数据加载器可以在训练过程中动态生成数据集的完全随机索引,并按此特定顺序获取数据。这种对于 TFRecord 等顺序格式而言计算成本极高的能力,恰恰是可复现研究与模型优化训练的关键所在。
以下是 ArrayRecord 与 TFRecord 在关键特性上的详细对比分析:
2. 随机访问
3. 全局打乱
4. 并行 I/O
5. 集成
6. 主要用例
转换数据集的方法取决于它是 TensorFlow 数据集 (TFDS) 目录中的标准注册数据集还是自定义的专有数据集。
如果您使用的是 cifar10 或 imagenet2012 等知名数据集,tfds 命令行工具是最直接的转换方法。
先决条件:安装 TensorFlow 数据集
pip install -q --upgrade tfds-nightly
使用 tfds build CLI
该命令会执行三大操作:下载源数据、运行预处理逻辑,以及将输出结果保存为指定格式。
# Generate the 'cifar10' dataset in ArrayRecord format
tfds build cifar10 --file_format=array_record
生成的 ArrayRecord 文件将存储于 ~/tensorflow_datasets/ 目录中,立即可用。
对于专属自定义 TFRecord 数据集的大规模转换,建议使用 Apache Beam 框架。array_record 库提供的预制 Beam 流水线能使转换过程变得异常简单且可扩展。该方法特别适用于海量数据集处理,因其可通过 Google Cloud Dataflow 等服务将任务分布于多个工作器中。
先决条件:安装 Apache Beam 和 Array Record Beam SDK
pip install -q apache-beam
pip install -q array-record-beam-sdk
使用预制的转换流水线
array_record.beam.pipelines 模块包含专为转换设计的 convert_tf_to_arrayrecord_disk_match_shards 函数,该实用程序能完整处理转换流程:读取 TFRecord 文件并写入对应的分片式 ArrayRecord 数据集。
在 Python 脚本中使用该程序的方法如下:
from apache_beam.options import pipeline_options
from array_record.beam.pipelines import convert_tf_to_arrayrecord_disk_match_shards
# 1. Define your input and output patterns.
# This example uses Google Cloud Storage (GCS) paths, which is common for large datasets.
input_pattern = 'gs://your-gcs-bucket/path/to/records-*.tfrecord'
output_path = 'gs://your-gcs-bucket/path/to/converted-records'
# Arguments dictionary for the conversion function.
args = {
'input': input_pattern,
'output': output_path,
}
# 2. Configure pipeline options for execution.
# To run locally on your machine (for smaller datasets or testing):
# No options are needed; the local runner is used by default.
local_pipeline_options = pipeline_options.PipelineOptions()
# To run at scale on Google Cloud Dataflow (for large datasets):
# Uncomment the following lines and fill in your project details.
#
# dataflow_pipeline_options = pipeline_options.PipelineOptions(
# runner='DataflowRunner',
# project='your-gcp-project-id',
# region='your-gcp-region',
# # A requirements.txt file may be needed for dependencies on Dataflow workers.
# # requirements_file='requirements.txt',
# temp_location='gs://your-gcs-bucket/path/to/temp'
# )
# 3. Define and run the main execution logic.
def main():
print("Starting TFRecord to ArrayRecord conversion...")
convert_tf_to_arrayrecord_disk_match_shards(
args=args,
# Pass the appropriate options here.
# Use `local_pipeline_options` for local runs.
# Use `dataflow_pipeline_options` for cloud runs.
pipeline_options=local_pipeline_options,
).run()
print(f"Conversion complete. ArrayRecord files written to '{output_path}'.")
if __name__ == '__main__':
main()
该方法比手动编写流水线更强大稳健,因为它是专门为此任务设计的经过验证的高阶 API,能自动处理诸如输出分片与输入分片匹配等技术细节。
当数据转换为 ArrayRecord 格式后,您即可使用 Grain 的 Dataset API 定义高性能输入流水线。该流程包含两个步骤:创建数据源,然后链接转换方法。
首先指向 ArrayRecord 文件以创建 MapDataset。
import grain
# Path to your generated ArrayRecord files
file_paths = ["~/tensorflow_datasets/cifar10/3.0.2/cifar10-train.array_record-00000-of-00001"]
# Create a data source
data_source = grain.sources.ArrayRecordDataSource(file_paths)
# Create a MapDataset from the source
dataset = grain.MapDataset.source(data_source)
现在对 MapDataset 应用转换操作。每个方法都会返回新的 MapDataset,以便您以声明方式将调用链接在一起。
# Example parsing function
def parse_and_transform(record):
# Your logic to parse features, augment data, etc.
return {"record": record}
BATCH_SIZE = 32
# Chain transformations
# The order of operations matters.
dataset = (
dataset.shuffle(seed=42)
.map(parse_and_transform)
.batch(batch_size=BATCH_SIZE, drop_remainder=True)
)
DatasetIterator最后,从完整定义的数据集中创建迭代器,以在训练脚本中循环遍历数据。
# Create the stateful iterator
data_iterator = iter(dataset)
# You can now loop over the data
for batch in data_iterator:
# Your training step with the batch...
pass
# For checkpoint saving/restoration, you can get/set the iterator's state
# state = data_iterator.get_state()
# data_iterator.set_state(state)
为避免数据管道成为性能瓶颈,建议使用多进程技术来加载和预处理数据,同时并行进行模型训练。在 Dataset API 中,只需在流水线末端添加 .mp_prefetch() 转换方法即可实现该功能。
该方法会启动工作器进程池,在后台异步预处理数据批次并存储至缓冲区,确保训练循环需要时可立即取用。
应用方法如下所示:
# The full pipeline with performance tuning.
dataset = (
grain.MapDataset.source(data_source)
.shuffle(seed=42)
.map(parse_and_transform)
# Convert to an iterable dataset to apply prefetching.
.to_iter_dataset()
.batch(batch_size=BATCH_SIZE, drop_remainder=True)
# Apply multiprocessing and prefetching.
.mp_prefetch(
grain.multiprocessing.MultiprocessingOptions(
num_workers=16 # Number of parallel worker processes.
)
)
)
# Create the final iterator
data_iterator = iter(dataset)
num_workers:该参数用于指定数据加载所使用的并行子进程数量。若发现加速器经常因等待数据而处于空闲状态,增加此值可显著提升吞吐量。最佳数值取决于您计算机的可用 CPU 核心数以及映射函数的复杂程度。想要深入探索并开始构建?请查阅本指南中提及技术的官方文档与源代码。
基于 Grain 与 ArrayRecord 构建的高性能确定性数据管道,对于大规模模型训练至关重要。典型范例是 MaxText。这是一个用 JAX 编写的高性能开源大语言模型。MaxText 正是利用这些数据管道技术,成功实现了向大型 TPU 和 GPU 集群的高效数据供给。