Grain と ArrayRecord を使用した高性能データ パイプラインの構築

2025年10月3日
Jiyang Kang Technical Program Manager
Shivaji Dutta Field Solutions Architect
Ihor Indyk Software Engineer
Felix Chern Software Engineer

GPU や TPU などの強力なアクセラレータで大規模なモデルをトレーニングする場合、最も避けたいのはアクセラレータがデータ待ちでアイドル状態になることです。システム全体の速度は最も遅い部分によって決まり、多くの場合、そのボトルネックはデータ入力パイプラインです。そのため、大規模な機械学習には効率的で再現性の高いデータ パイプラインが不可欠です。このガイドでは、Grain(JAX 用の柔軟なデータ読み込みライブラリ)と ArrayRecord(非常に効率的なファイル形式)を使用して堅牢で高性能なデータ パイプラインを構築することで、この課題を解決する方法を説明します。


コア コンポーネントの理解

Grain: JAX 用の高性能データローダ

Grain は軽量でオープンソースのデータ読み込みライブラリで、JAX ベースのワークロードにおけるこの問題を解決するために特別に設計されています。データの読み込み、前処理、モデルへのフィードを効率的に行うことで、ハードウェアのパフォーマンスを最大限に高めます。

Grain を使用する理由

Grain は、パフォーマンス、再現性、柔軟性という理念のもとに開発されました。その主なメリットは次のとおりです。

  • Exceptional performance: Grain is built for speed. It uses efficient multiprocessing (via the .mp_prefetch() method) to run data loading and transformations in parallel, ensuring that a buffer of prepared data is always ready for your model. This keeps your accelerators saturated and minimizes training time.

  • Guaranteed determinism and reproducibility: Grain provides full reproducibility, which is critical for credible research. By setting a simple seed, you ensure the data is always shuffled the same way. Crucially, its data iterators are stateful and can be checkpointed. This means if your training job is interrupted or preempted, you can restart from the exact same point in the data stream.

  • 直感的な宣言型 API: シンプルで読みやすいメソッドを連結することでデータ パイプラインを定義します。MapDataset ソースから始めて、.shuffle().map().batch() などの変換を柔軟に追加できます。この宣言型スタイルにより、データ パイプラインの理解、変更、保守が容易になります。

  • Unlocking true global shuffling: To get the best performance from your models, you need to shuffle your data effectively. When paired with a file format that supports random access, like ArrayRecord, Grain can perform a true global shuffle across your entire dataset, even when it doesn’t fit into host memory. This is a powerful feature that is often computationally impractical with other data loaders and formats.


ArrayRecord の概要と使用する理由

TFRecord はよく知られた形式規格ですが、そのシーケンシャルな性質上、完全なグローバル シャッフルは実現できません。ArrayRecord はこの問題を解決するために特別に設計された最新のファイル形式であり、データ効率の新たな可能性を引き出します。

ArrayRecord File Layout

仕組み: スピードと並列処理を重視した設計

ArrayRecord の高いパフォーマンスは、Google の Riegeli ファイル形式に基づいたコア設計に根ざしています。この構造は、大規模なデータ処理において次のような重要なメリットを提供します。

  1. Efficient random access: ArrayRecord features a built-in metadata index that maps every record to its precise location. This is the key design choice that enables instant, direct access to any record in the dataset, completely avoiding the need to read a file from the beginning.


2. Massive parallelism: Records are grouped into data chunks. This structure is inherently designed to be read in parallel, allowing multiple processes to read different chunks of the same file simultaneously to dramatically increase read throughput.


3. Exceptional performance: As a result of this indexed and chunked design, benchmarks show ArrayRecord can achieve a read throughput an order of magnitude higher than traditional formats, making it ideal for today's massive datasets.


4. Smart data integrity: The format handles data integrity intelligently by leveraging the powerful error correction in underlying cloud storage systems (like Google Cloud Storage) rather than adding redundant checks. This provides robust protection against corruption without unnecessary performance overhead.


ArrayRecord を使用する理由

ArrayRecord の機能は、Grain のような最新のデータローダに必要な高度な機能を直接実現します。

The most important benefit is achieving true, deterministic global shuffling. Because any record can be accessed instantly, a data loader can generate perfectly randomized indices in the dataset on the fly as the training happens and then fetch data in that specific order. This capability, which is computationally impractical with sequential formats like TFRecord, is vital for reproducible research and optimal model training.


ArrayRecord と TFRecord の詳細な比較

以下は、ArrayRecord と TFRecord の主要な機能に関する詳細な比較です。

  1. 構造

  • ArrayRecord は、Google の Riegeli ファイル形式をベースとして構築されており、高速デコード、データの整合性、強力な圧縮に重点を置いたレコードのシーケンス保存を目的に設計されています。レコードをチャンクにグループ化し、ファイルの末尾にメタデータ インデックスを含めます。

  • TFRecordバイナリ レコードのシーケンスであり、各レコードは通常 tf.train.Example プロトコル バッファです。


2. ランダム アクセス

  • ArrayRecord は、ネイティブで効率的なランダム アクセスを提供します。ファイル構造にはレコード位置のインデックスが組み込まれているため、ファイル全体を読み取る必要がなく、インデックスを使用して任意のレコードに直接かつ高速にアクセスできます。

  • 一方、TFRecord にはネイティブのランダム アクセス機能がありません。ストリーミング データ用に最適化されたシーケンシャル形式であるため、特定のレコードにアクセスするにはファイルの先頭から反復処理する必要があります。


3. グローバル シャッフル

  • ArrayRecord を使用すると、完全なグローバル シャッフルが可能になります。効率的なランダム アクセスのおかげで、Grain のようなデータローダはシャッフルされた順序でインデックスを生成し、その場でレコードを読み取れます。

  • TFRecord では完全なグローバル シャッフルを実現することは困難です。「グローバル」シャッフルと言われるものの多くは、シャーディングされたファイル名のリストをシャッフルした後に小さなメモリバッファ内でレコードをシャッフルするなど、近似に依存しています。これは完全なグローバル シャッフルではありません。


4. 並列 I/O

  • ArrayRecord は並列 I/O をネイティブにサポートしています。ArrayRecord ファイルの内部チャンク構造のおかげで、複数のプロセスが同一ファイルの異なる部分から並列に読み取ることが本質的に容易になり、データ マネジメントが簡素化します。

  • TFRecord は並列読み取りをサポートしていますが、通常、データセットを多数の小さな TFRecord ファイルにシャーディングし、異なるワーカーが異なるファイルから読み取ることで実現します。結果として、管理するファイル数が膨大になる可能性があります。


5. 統合

  • ArrayRecord は、高性能 I/O 向けに設計されており、Grain のような JAX ベースのローダとシームレスに連携します。また、tfds.data_source を介して TensorFlow エコシステム内でも使用できます。

  • TFRecordTensorFlow の tf.data エコシステムと緊密に統合されています。


6. 主なユースケース

  • ArrayRecord is ideal for high-throughput data loading for performance-critical machine learning, especially where determinism and true global shuffling are required (e.g., JAX/TPU workloads).

  • TFRecord is suited for general-purpose, large-scale data storage for TensorFlow and is optimized for sequential reads.


TFRecord データセットを ArrayRecord に変換する方法

データセットを変換する方法は、TensorFlow データセット(TFDS)カタログに登録された標準のデータセットであるか、独自のカスタム データセットであるかによって異なります。

方法 1: TFDS カタログの標準データセットの場合

cifar10imagenet2012 などのよく知られているデータセットを使用している場合は、tfds コマンドライン ツールが最も簡単な方法です。

前提条件: TensorFlow データセットをインストールすること

pip install -q --upgrade tfds-nightly
Shell

tfds build CLI の使用

このコマンドはソースデータをダウンロードし、準備ロジックを実行して、出力を目的の形式で保存します。

# Generate the 'cifar10' dataset in ArrayRecord format
tfds build cifar10 --file_format=array_record
Shell

生成された ArrayRecord ファイルは ~/tensorflow_datasets/ ディレクトリに保存され、すぐに使用できます。

方法 2: カスタムまたは独自の TFRecord データセットの場合

独自のカスタム TFRecord データセットを大規模に変換するには、Apache Beam の使用をおすすめします。array_record ライブラリは、この変換を非常にシンプルでスケーラブルにするパッケージ済みの Beam パイプラインを提供します。この方法では、Google Cloud Dataflow などのサービスを使用して多くのワーカーに処理を分散できるため、大規模なデータセットに特におすすめです。

前提条件: Apache Beam と Array Record Beam SDK をインストールすること

pip install -q apache-beam
pip install -q array-record-beam-sdk
Shell

パッケージ済み変換パイプラインの使用

array_record.beam.pipelines モジュールには、変換プロセス全体を処理する専用ユーティリティの convert_tf_to_arrayrecord_disk_match_shards 関数が含まれています。これは TFRecord ファイルを読み取り、対応するシャーディングされた ArrayRecord データセットを書き込みます。

Python スクリプトでの使用方法は次のとおりです。

from apache_beam.options import pipeline_options
from array_record.beam.pipelines import convert_tf_to_arrayrecord_disk_match_shards
 
# 1. Define your input and output patterns.
# This example uses Google Cloud Storage (GCS) paths, which is common for large datasets.
input_pattern = 'gs://your-gcs-bucket/path/to/records-*.tfrecord'
output_path = 'gs://your-gcs-bucket/path/to/converted-records'
 
# Arguments dictionary for the conversion function.
args = {
    'input': input_pattern,
    'output': output_path,
}
 
# 2. Configure pipeline options for execution.
 
# To run locally on your machine (for smaller datasets or testing):
# No options are needed; the local runner is used by default.
local_pipeline_options = pipeline_options.PipelineOptions()
 
 
# To run at scale on Google Cloud Dataflow (for large datasets):
# Uncomment the following lines and fill in your project details.
#
# dataflow_pipeline_options = pipeline_options.PipelineOptions(
#     runner='DataflowRunner',
#     project='your-gcp-project-id',
#     region='your-gcp-region',
#     # A requirements.txt file may be needed for dependencies on Dataflow workers.
#     # requirements_file='requirements.txt',
#     temp_location='gs://your-gcs-bucket/path/to/temp'
# )
 
 
# 3. Define and run the main execution logic.
def main():
  print("Starting TFRecord to ArrayRecord conversion...")
  convert_tf_to_arrayrecord_disk_match_shards(
      args=args,
      # Pass the appropriate options here.
      # Use `local_pipeline_options` for local runs.
      # Use `dataflow_pipeline_options` for cloud runs.
      pipeline_options=local_pipeline_options,
  ).run()
  print(f"Conversion complete. ArrayRecord files written to '{output_path}'.")
 
if __name__ == '__main__':
  main()
Python

このアプローチはこのタスク専用に設計されたテスト済みの高レベル API であり、出力シャードと入力シャードのマッチングといった詳細を自動的に処理するため、手動でパイプラインを作成するよりも強力で堅牢です。


Grain と ArrayRecord パイプラインの構築: 概念的なチュートリアル

データを ArrayRecord 形式に変換したら、Grain の Dataset API を使用して高性能な入力パイプラインを定義できます。このプロセスでは、まずソースを作成し、次に変換メソッドを連結します。

ステップ 1: データソースから MapDataset を作成する

まず、ArrayRecord ファイルを指定して MapDataset を作成します。

import grain 
 
# Path to your generated ArrayRecord files
file_paths = ["~/tensorflow_datasets/cifar10/3.0.2/cifar10-train.array_record-00000-of-00001"]
 
# Create a data source
data_source = grain.sources.ArrayRecordDataSource(file_paths)
 
# Create a MapDataset from the source
dataset = grain.MapDataset.source(data_source)
Python

ステップ 2: 変換を連結する(シャッフル、マッピング、バッチ処理)

次に、変換を MapDataset に適用します。各メソッドは新しい MapDataset を返し、宣言的に互いの呼び出しを連鎖できます。

# Example parsing function
def parse_and_transform(record):
    # Your logic to parse features, augment data, etc.
    return {"record": record}
 
BATCH_SIZE = 32
 
# Chain transformations
# The order of operations matters.
dataset = (
    dataset.shuffle(seed=42)
           .map(parse_and_transform)
           .batch(batch_size=BATCH_SIZE, drop_remainder=True)
)
Python

Step 3: Create and use the DatasetIterator

最後に、完全に定義されたデータセットからイテレータを作成して、トレーニング スクリプト内のデータをループします。

# Create the stateful iterator
data_iterator = iter(dataset)
 
# You can now loop over the data
for batch in data_iterator:
    # Your training step with the batch...
    pass
 
    # For checkpoint saving/restoration, you can get/set the iterator's state
    # state = data_iterator.get_state()
    # data_iterator.set_state(state)
Python

パフォーマンス設定

データ パイプラインがボトルネックになるのを防ぐには、マルチプロセスを使用し、モデル トレーニングと並行してデータの読み込みと前処理を行うようにする必要があります。Dataset API では、.mp_prefetch() 変換をパイプラインに追加することでこれを実現します。

このメソッドでは、ワーカー プロセスのプールを起動し、バックグラウンドで非同期的にデータバッチを準備してバッファに格納するため、トレーニング ループに必要な際に即座に利用できるようになります。

適用方法は次のとおりです。

# The full pipeline with performance tuning.
dataset = (
    grain.MapDataset.source(data_source)
    .shuffle(seed=42)
    .map(parse_and_transform)
 
    # Convert to an iterable dataset to apply prefetching.
    .to_iter_dataset()
    .batch(batch_size=BATCH_SIZE, drop_remainder=True)
    # Apply multiprocessing and prefetching.
    .mp_prefetch(
        grain.multiprocessing.MultiprocessingOptions(
            num_workers=16   # Number of parallel worker processes.
        )
    )
)
 
# Create the final iterator
data_iterator = iter(dataset)
Python
  • num_workers: データの読み込みに使用する並列の子プロセス数を指定します。アクセラレータがデータ待ちで頻繁にアイドル状態になる場合は、この値を増やすことでスループットが大幅に改善する可能性があります。最適な数は、マシンで利用可能な CPU コア数と map 関数の複雑さによって異なります。


さらに詳しい情報

さらに詳しく学び、ビルドを開始したい場合は、このガイドで説明しているテクノロジーの公式ドキュメントとソースコードをご覧ください。

基盤テクノロジー


実際の例: 大規模な LLM トレーニング

Grain と ArrayRecord を使用してビルドされた高性能かつ決定論的なデータ パイプラインは、大規模モデル トレーニングに不可欠です。その好例が、JAX で書かれた高性能なオープンソースの大規模言語モデルである MaxText です。MaxText はまさにこれらのデータ パイプライン手法を活用して、大規模な TPU および GPU クラスタにデータを効率的に供給します。