Building High-Performance Data Pipelines with Grain and ArrayRecord

OCT. 3, 2025
Jiyang Kang Technical Program Manager
Shivaji Dutta Field Solutions Architect
Ihor Indyk Software Engineer
Felix Chern Software Engineer

When training large models on powerful accelerators like GPUs and TPUs, the last thing you want is for your accelerator to be idle, waiting for data. Your entire system is only as fast as its slowest part, and often, that bottleneck is the data input pipeline. Therefore, for large-scale machine learning, an efficient and reproducible data pipeline is essential. This guide will show you how to solve this challenge by building a robust and performant data pipeline using Grain, a flexible data loading library for JAX, and ArrayRecord, a highly efficient file format.


Understanding the core components

Grain: A high-performance data loader for JAX

Grain is a lightweight, open-source data loading library designed specifically to solve this problem for JAX-based workloads. It ensures that data is loaded, preprocessed, and fed to your model efficiently, allowing you to maximize the performance of your hardware.

Why use Grain?

Grain is built on a philosophy of performance, reproducibility, and flexibility. Here are the key benefits it provides:

  • Exceptional performance: Grain is built for speed. It uses efficient multiprocessing (via the .mp_prefetch() method) to run data loading and transformations in parallel, ensuring that a buffer of prepared data is always ready for your model. This keeps your accelerators saturated and minimizes training time.

  • Guaranteed determinism and reproducibility: Grain provides full reproducibility, which is critical for credible research. By setting a simple seed, you ensure the data is always shuffled the same way. Crucially, its data iterators are stateful and can be checkpointed. This means if your training job is interrupted or preempted, you can restart from the exact same point in the data stream.

  • An intuitive, declarative API: You define your data pipeline by chaining together simple, readable methods. Starting with a MapDataset source, you can fluidly add transformations like .shuffle(), .map(), and .batch(). This declarative style makes your data pipeline easy to understand, modify, and maintain.

  • Unlocking true global shuffling: To get the best performance from your models, you need to shuffle your data effectively. When paired with a file format that supports random access, like ArrayRecord, Grain can perform a true global shuffle across your entire dataset, even when it doesn’t fit into host memory. This is a powerful feature that is often computationally impractical with other data loaders and formats.


What is ArrayRecord and why use it?

While TFRecord is a familiar standard, its sequential nature does not allow true global shuffle. ArrayRecord is a modern file format designed specifically to solve this problem, offering a new frontier in data efficiency.

ArrayRecord File Layout

How it works: Designed for speed and parallelism

ArrayRecord's high performance is rooted in its core design, which is based on Google's Riegeli file format. This structure provides several key advantages for large-scale data handling:

  1. Efficient random access: ArrayRecord features a built-in metadata index that maps every record to its precise location. This is the key design choice that enables instant, direct access to any record in the dataset, completely avoiding the need to read a file from the beginning.


2. Massive parallelism: Records are grouped into data chunks. This structure is inherently designed to be read in parallel, allowing multiple processes to read different chunks of the same file simultaneously to dramatically increase read throughput.


3. Exceptional performance: As a result of this indexed and chunked design, benchmarks show ArrayRecord can achieve a read throughput an order of magnitude higher than traditional formats, making it ideal for today's massive datasets.


4. Smart data integrity: The format handles data integrity intelligently by leveraging the powerful error correction in underlying cloud storage systems (like Google Cloud Storage) rather than adding redundant checks. This provides robust protection against corruption without unnecessary performance overhead.


Why are we using it?

ArrayRecord's features directly enable the advanced capabilities required by modern data loaders like Grain.

The most important benefit is achieving true, deterministic global shuffling. Because any record can be accessed instantly, a data loader can generate perfectly randomized indices in the dataset on the fly as the training happens and then fetch data in that specific order. This capability, which is computationally impractical with sequential formats like TFRecord, is vital for reproducible research and optimal model training.


ArrayRecord vs. TFRecord: A detailed comparison

Here’s a detailed breakdown of how ArrayRecord and TFRecord compare across key features:

  1. Structure

  • ArrayRecord is built on the Riegeli file format from Google, which is designed for storing sequences of records with a focus on high-speed decoding, data integrity, and strong compression. It groups records into chunks and includes a metadata index at the end of the file.

  • TFRecord is a sequence of binary records, where each record is typically a tf.train.Example protocol buffer.


2. Random Access

  • ArrayRecord offers native and efficient random access. Its file structure includes a built-in index of record positions, allowing for direct and fast access to any record by its index without needing to read the entire file.

  • TFRecord, on the other hand, lacks native random access. As a sequential format optimized for streaming data, accessing a specific record requires iterating through the file from the beginning.


3. Global Shuffling

  • With ArrayRecord, true global shuffling is possible. Thanks to its efficient random access, a data loader like Grain can generate indices in a shuffled order and read records on the fly.

  • With TFRecord, true global shuffling is difficult to achieve. "Global" shuffling often relies on approximations, like shuffling a list of sharded filenames and then shuffling records within a small memory buffer. This is not a true global shuffle.


4. Parallel I/O

  • ArrayRecord natively supports parallel I/O. The internal chunked structure of an ArrayRecord file makes it inherently easy for multiple processes to read from different parts of the same file in parallel, which simplifies data management.

  • TFRecord supports parallel reading, but it is typically achieved by sharding the dataset into many small TFRecord files and having different workers read from different files. This can result in a large number of files to manage.


5. Integration

  • ArrayRecord is designed for high-performance I/O and works seamlessly with JAX-based loaders like Grain. It is also usable within the TensorFlow ecosystem via tfds.data_source.

  • TFRecord is tightly integrated with TensorFlow's tf.data ecosystem.


6. Primary Use Case

  • ArrayRecord is ideal for high-throughput data loading for performance-critical machine learning, especially where determinism and true global shuffling are required (e.g., JAX/TPU workloads).

  • TFRecord is suited for general-purpose, large-scale data storage for TensorFlow and is optimized for sequential reads.


How to convert TFRecord datasets to ArrayRecord

The method for converting your dataset depends on whether it is a standard, registered dataset in the TensorFlow Datasets (TFDS) catalog or a custom, proprietary dataset.

Method 1: For standard datasets in the TFDS catalog

If you are using a well-known dataset like cifar10 or imagenet2012, the tfds command-line tool is the most straightforward method.

Prerequisite: Install TensorFlow datasets

pip install -q --upgrade tfds-nightly
Shell

Using the tfds build CLI

This command downloads the source data, runs the preparation logic, and saves the output in your desired format.

# Generate the 'cifar10' dataset in ArrayRecord format
tfds build cifar10 --file_format=array_record
Shell

The generated ArrayRecord files will be stored in your ~/tensorflow_datasets/ directory, ready to use.

Method 2: For custom or proprietary TFRecord datasets

For large-scale conversion of your own custom TFRecord datasets, the recommended approach is to use Apache Beam. The array_record library provides pre-packaged Beam pipelines that make this conversion incredibly simple and scalable. This method is highly recommended for massive datasets, as the processing can be distributed across many workers using a service like Google Cloud Dataflow.

Prerequisites: Install Apache Beam and Array Record Beam SDK

pip install -q apache-beam
pip install -q array-record-beam-sdk
Shell

Using the pre-packaged conversion pipeline

The array_record.beam.pipelines module contains the convert_tf_to_arrayrecord_disk_match_shards function, a purpose-built utility that handles the entire conversion process. It reads TFRecord files and writes a corresponding sharded ArrayRecord dataset.

Here is how you would use it in a Python script:

from apache_beam.options import pipeline_options
from array_record.beam.pipelines import convert_tf_to_arrayrecord_disk_match_shards

# 1. Define your input and output patterns.
# This example uses Google Cloud Storage (GCS) paths, which is common for large datasets.
input_pattern = 'gs://your-gcs-bucket/path/to/records-*.tfrecord'
output_path = 'gs://your-gcs-bucket/path/to/converted-records'

# Arguments dictionary for the conversion function.
args = {
    'input': input_pattern,
    'output': output_path,
}

# 2. Configure pipeline options for execution.

# To run locally on your machine (for smaller datasets or testing):
# No options are needed; the local runner is used by default.
local_pipeline_options = pipeline_options.PipelineOptions()


# To run at scale on Google Cloud Dataflow (for large datasets):
# Uncomment the following lines and fill in your project details.
#
# dataflow_pipeline_options = pipeline_options.PipelineOptions(
#     runner='DataflowRunner',
#     project='your-gcp-project-id',
#     region='your-gcp-region',
#     # A requirements.txt file may be needed for dependencies on Dataflow workers.
#     # requirements_file='requirements.txt',
#     temp_location='gs://your-gcs-bucket/path/to/temp'
# )


# 3. Define and run the main execution logic.
def main():
  print("Starting TFRecord to ArrayRecord conversion...")
  convert_tf_to_arrayrecord_disk_match_shards(
      args=args,
      # Pass the appropriate options here.
      # Use `local_pipeline_options` for local runs.
      # Use `dataflow_pipeline_options` for cloud runs.
      pipeline_options=local_pipeline_options,
  ).run()
  print(f"Conversion complete. ArrayRecord files written to '{output_path}'.")

if __name__ == '__main__':
  main()
Python

This approach is more powerful and robust than writing a manual pipeline because it's a tested, high-level API designed specifically for this task, handling details like matching output shards to input shards automatically.


Building a Grain and ArrayRecord pipeline: A conceptual walkthrough

Once your data is in the ArrayRecord format, you can define your high-performance input pipeline using Grain's Dataset API. The process involves creating a source and then chaining transformation methods.

Step 1: Create a MapDataset from a Data Source

First, point to your ArrayRecord files to create a MapDataset.

import grain 

# Path to your generated ArrayRecord files
file_paths = ["~/tensorflow_datasets/cifar10/3.0.2/cifar10-train.array_record-00000-of-00001"]

# Create a data source
data_source = grain.sources.ArrayRecordDataSource(file_paths)

# Create a MapDataset from the source
dataset = grain.MapDataset.source(data_source)
Python

Step 2: Chain Transformations (Shuffle, Map, Batch)

Now, apply transformations to the MapDataset. Each method returns a new MapDataset, allowing you to chain calls together declaratively.

# Example parsing function
def parse_and_transform(record):
    # Your logic to parse features, augment data, etc.
    return {"record": record}

BATCH_SIZE = 32

# Chain transformations
# The order of operations matters.
dataset = (
    dataset.shuffle(seed=42)
           .map(parse_and_transform)
           .batch(batch_size=BATCH_SIZE, drop_remainder=True)
)
Python

Step 3: Create and use the DatasetIterator

Finally, create an iterator from your fully defined dataset to loop through the data in your training script.

# Create the stateful iterator
data_iterator = iter(dataset)

# You can now loop over the data
for batch in data_iterator:
    # Your training step with the batch...
    pass

    # For checkpoint saving/restoration, you can get/set the iterator's state
    # state = data_iterator.get_state()
    # data_iterator.set_state(state)
Python

Performance configuration settings

To prevent your data pipeline from becoming a bottleneck, you should use multiprocessing to load and preprocess data in parallel with model training. In the Dataset API, this is achieved by adding the .mp_prefetch() transformation to your pipeline.

This method starts a pool of worker processes to asynchronously prepare data batches in the background and stores them in a buffer, so they are ready the moment your training loop needs them.

Here's how to apply it:

# The full pipeline with performance tuning.
dataset = (
    grain.MapDataset.source(data_source)
    .shuffle(seed=42)
    .map(parse_and_transform)
    
    # Convert to an iterable dataset to apply prefetching.
    .to_iter_dataset()
    .batch(batch_size=BATCH_SIZE, drop_remainder=True)
    # Apply multiprocessing and prefetching.
    .mp_prefetch(
        grain.multiprocessing.MultiprocessingOptions(
            num_workers=16   # Number of parallel worker processes.
        )
    )
)

# Create the final iterator
data_iterator = iter(dataset)
Python
  • num_workers: This specifies the number of parallel child processes to use for data loading. If you notice your accelerator is often idle waiting for data, increasing this value can significantly improve throughout. The optimal number depends on the CPU cores available on your machine and the complexity of your map function.


Explore further

Want to dive deeper and start building? Check out the official documentation and source code for the technologies discussed in this guide.

Foundational technologies


Real-world example: Large-scale LLM training

The performant and deterministic data pipelines built with Grain and ArrayRecord are critical for large-scale model training. A prime example is MaxText, a high-performance, open-source Large Language Model written in JAX. MaxText leverages these exact data pipeline techniques to efficiently feed data to large TPU and GPU clusters.