When training large models on powerful accelerators like GPUs and TPUs, the last thing you want is for your accelerator to be idle, waiting for data. Your entire system is only as fast as its slowest part, and often, that bottleneck is the data input pipeline. Therefore, for large-scale machine learning, an efficient and reproducible data pipeline is essential. This guide will show you how to solve this challenge by building a robust and performant data pipeline using Grain, a flexible data loading library for JAX, and ArrayRecord, a highly efficient file format.
Grain is a lightweight, open-source data loading library designed specifically to solve this problem for JAX-based workloads. It ensures that data is loaded, preprocessed, and fed to your model efficiently, allowing you to maximize the performance of your hardware.
Grain is built on a philosophy of performance, reproducibility, and flexibility. Here are the key benefits it provides:
.mp_prefetch()
method) to run data loading and transformations in parallel, ensuring that a buffer of prepared data is always ready for your model. This keeps your accelerators saturated and minimizes training time..shuffle()
, .map()
, and .batch()
. This declarative style makes your data pipeline easy to understand, modify, and maintain.While TFRecord is a familiar standard, its sequential nature does not allow true global shuffle. ArrayRecord is a modern file format designed specifically to solve this problem, offering a new frontier in data efficiency.
ArrayRecord's high performance is rooted in its core design, which is based on Google's Riegeli file format. This structure provides several key advantages for large-scale data handling:
2. Massive parallelism: Records are grouped into data chunks. This structure is inherently designed to be read in parallel, allowing multiple processes to read different chunks of the same file simultaneously to dramatically increase read throughput.
3. Exceptional performance: As a result of this indexed and chunked design, benchmarks show ArrayRecord can achieve a read throughput an order of magnitude higher than traditional formats, making it ideal for today's massive datasets.
4. Smart data integrity: The format handles data integrity intelligently by leveraging the powerful error correction in underlying cloud storage systems (like Google Cloud Storage) rather than adding redundant checks. This provides robust protection against corruption without unnecessary performance overhead.
ArrayRecord's features directly enable the advanced capabilities required by modern data loaders like Grain.
The most important benefit is achieving true, deterministic global shuffling. Because any record can be accessed instantly, a data loader can generate perfectly randomized indices in the dataset on the fly as the training happens and then fetch data in that specific order. This capability, which is computationally impractical with sequential formats like TFRecord, is vital for reproducible research and optimal model training.
Here’s a detailed breakdown of how ArrayRecord and TFRecord compare across key features:
2. Random Access
3. Global Shuffling
4. Parallel I/O
5. Integration
6. Primary Use Case
The method for converting your dataset depends on whether it is a standard, registered dataset in the TensorFlow Datasets (TFDS) catalog or a custom, proprietary dataset.
If you are using a well-known dataset like cifar10
or imagenet2012
, the tfds command-line tool is the most straightforward method.
Prerequisite: Install TensorFlow datasets
pip install -q --upgrade tfds-nightly
Using the tfds build CLI
This command downloads the source data, runs the preparation logic, and saves the output in your desired format.
# Generate the 'cifar10' dataset in ArrayRecord format
tfds build cifar10 --file_format=array_record
The generated ArrayRecord files will be stored in your ~/tensorflow_datasets/
directory, ready to use.
For large-scale conversion of your own custom TFRecord datasets, the recommended approach is to use Apache Beam. The array_record
library provides pre-packaged Beam pipelines that make this conversion incredibly simple and scalable. This method is highly recommended for massive datasets, as the processing can be distributed across many workers using a service like Google Cloud Dataflow.
Prerequisites: Install Apache Beam and Array Record Beam SDK
pip install -q apache-beam
pip install -q array-record-beam-sdk
Using the pre-packaged conversion pipeline
The array_record.beam.pipelines
module contains the convert_tf_to_arrayrecord_disk_match_shards
function, a purpose-built utility that handles the entire conversion process. It reads TFRecord files and writes a corresponding sharded ArrayRecord dataset.
Here is how you would use it in a Python script:
from apache_beam.options import pipeline_options
from array_record.beam.pipelines import convert_tf_to_arrayrecord_disk_match_shards
# 1. Define your input and output patterns.
# This example uses Google Cloud Storage (GCS) paths, which is common for large datasets.
input_pattern = 'gs://your-gcs-bucket/path/to/records-*.tfrecord'
output_path = 'gs://your-gcs-bucket/path/to/converted-records'
# Arguments dictionary for the conversion function.
args = {
'input': input_pattern,
'output': output_path,
}
# 2. Configure pipeline options for execution.
# To run locally on your machine (for smaller datasets or testing):
# No options are needed; the local runner is used by default.
local_pipeline_options = pipeline_options.PipelineOptions()
# To run at scale on Google Cloud Dataflow (for large datasets):
# Uncomment the following lines and fill in your project details.
#
# dataflow_pipeline_options = pipeline_options.PipelineOptions(
# runner='DataflowRunner',
# project='your-gcp-project-id',
# region='your-gcp-region',
# # A requirements.txt file may be needed for dependencies on Dataflow workers.
# # requirements_file='requirements.txt',
# temp_location='gs://your-gcs-bucket/path/to/temp'
# )
# 3. Define and run the main execution logic.
def main():
print("Starting TFRecord to ArrayRecord conversion...")
convert_tf_to_arrayrecord_disk_match_shards(
args=args,
# Pass the appropriate options here.
# Use `local_pipeline_options` for local runs.
# Use `dataflow_pipeline_options` for cloud runs.
pipeline_options=local_pipeline_options,
).run()
print(f"Conversion complete. ArrayRecord files written to '{output_path}'.")
if __name__ == '__main__':
main()
This approach is more powerful and robust than writing a manual pipeline because it's a tested, high-level API designed specifically for this task, handling details like matching output shards to input shards automatically.
Once your data is in the ArrayRecord format, you can define your high-performance input pipeline using Grain's Dataset
API. The process involves creating a source and then chaining transformation methods.
First, point to your ArrayRecord files to create a MapDataset
.
import grain
# Path to your generated ArrayRecord files
file_paths = ["~/tensorflow_datasets/cifar10/3.0.2/cifar10-train.array_record-00000-of-00001"]
# Create a data source
data_source = grain.sources.ArrayRecordDataSource(file_paths)
# Create a MapDataset from the source
dataset = grain.MapDataset.source(data_source)
Now, apply transformations to the MapDataset
. Each method returns a new MapDataset
, allowing you to chain calls together declaratively.
# Example parsing function
def parse_and_transform(record):
# Your logic to parse features, augment data, etc.
return {"record": record}
BATCH_SIZE = 32
# Chain transformations
# The order of operations matters.
dataset = (
dataset.shuffle(seed=42)
.map(parse_and_transform)
.batch(batch_size=BATCH_SIZE, drop_remainder=True)
)
DatasetIterator
Finally, create an iterator from your fully defined dataset to loop through the data in your training script.
# Create the stateful iterator
data_iterator = iter(dataset)
# You can now loop over the data
for batch in data_iterator:
# Your training step with the batch...
pass
# For checkpoint saving/restoration, you can get/set the iterator's state
# state = data_iterator.get_state()
# data_iterator.set_state(state)
To prevent your data pipeline from becoming a bottleneck, you should use multiprocessing to load and preprocess data in parallel with model training. In the Dataset API, this is achieved by adding the .mp_prefetch()
transformation to your pipeline.
This method starts a pool of worker processes to asynchronously prepare data batches in the background and stores them in a buffer, so they are ready the moment your training loop needs them.
Here's how to apply it:
# The full pipeline with performance tuning.
dataset = (
grain.MapDataset.source(data_source)
.shuffle(seed=42)
.map(parse_and_transform)
# Convert to an iterable dataset to apply prefetching.
.to_iter_dataset()
.batch(batch_size=BATCH_SIZE, drop_remainder=True)
# Apply multiprocessing and prefetching.
.mp_prefetch(
grain.multiprocessing.MultiprocessingOptions(
num_workers=16 # Number of parallel worker processes.
)
)
)
# Create the final iterator
data_iterator = iter(dataset)
num_workers
: This specifies the number of parallel child processes to use for data loading. If you notice your accelerator is often idle waiting for data, increasing this value can significantly improve throughout. The optimal number depends on the CPU cores available on your machine and the complexity of your map function.Want to dive deeper and start building? Check out the official documentation and source code for the technologies discussed in this guide.
The performant and deterministic data pipelines built with Grain and ArrayRecord are critical for large-scale model training. A prime example is MaxText, a high-performance, open-source Large Language Model written in JAX. MaxText leverages these exact data pipeline techniques to efficiently feed data to large TPU and GPU clusters.