Grain と ArrayRecord を使用した高性能データ パイプラインの構築

2025年10月7日
Jiyang Kang Technical Program Manager
Shivaji Dutta Field Solutions Architect
Ihor Indyk Software Engineer
Felix Chern Software Engineer

GPU や TPU などの強力なアクセラレータで大規模なモデルをトレーニングする場合、最も避けたいのはアクセラレータがデータ待ちでアイドル状態になることです。システム全体の速度は最も遅い部分によって決まり、多くの場合、そのボトルネックはデータ入力パイプラインです。そのため、大規模な機械学習には効率的で再現性の高いデータ パイプラインが不可欠です。このガイドでは、Grain(JAX 用の柔軟なデータ読み込みライブラリ)と ArrayRecord(非常に効率的なファイル形式)を使用して堅牢で高性能なデータ パイプラインを構築することで、この課題を解決する方法を説明します。


コア コンポーネントの理解

Grain: JAX 用の高性能データローダ

Grain は軽量でオープンソースのデータ読み込みライブラリで、JAX ベースのワークロードにおけるこの問題を解決するために特別に設計されています。データの読み込み、前処理、モデルへのフィードを効率的に行うことで、ハードウェアのパフォーマンスを最大限に高めます。

Grain を使用する理由

Grain は、パフォーマンス、再現性、柔軟性という理念のもとに開発されました。その主なメリットは次のとおりです。

  • 優れたパフォーマンス: Grain はスピードを重視して設計されています。効率的なマルチプロセス(.mp_prefetch() メソッドによる)を使用して、データの読み込みと変換を並列実行し、モデル用に準備されたデータのバッファを常に確保します。これにより、アクセラレータのフル稼働状態を維持し、トレーニング時間を最小限に抑えます。

  • 決定論と再現性の保証: Grain は、信頼できる研究に欠かせない、完全な再現性を提供します。シンプルなシードを設定することで、データが常に同じ方法でシャッフルされるようになります。重要なのは、そのデータ イテレータがステートフルであり、チェックポイントとして設定できる点です。つまり、トレーニング ジョブが中断またはプリエンプトされた場合でも、データ ストリーム内のまったく同じポイントから再開できます。

  • 直感的な宣言型 API: シンプルで読みやすいメソッドを連結することでデータ パイプラインを定義します。MapDataset ソースから始めて、.shuffle().map().batch() などの変換を柔軟に追加できます。この宣言型スタイルにより、データ パイプラインの理解、変更、保守が容易になります。

  • 完全なグローバル シャッフルの実現: モデルからベスト パフォーマンスを引き出すには、データを効果的にシャッフルする必要があります。ArrayRecord などのランダム アクセスをサポートするファイル形式と組み合わせることで、Grain はホストメモリに収まらない場合であっても、データセット全体に対して完全なグローバル シャッフルを実行できます。これは、多くの場合、他のデータローダや形式では計算上実現不可能な高度な機能です。


ArrayRecord の概要と使用する理由

TFRecord はよく知られた形式規格ですが、そのシーケンシャルな性質上、完全なグローバル シャッフルは実現できません。ArrayRecord はこの問題を解決するために特別に設計された最新のファイル形式であり、データ効率の新たな可能性を引き出します。

ArrayRecord File Layout

仕組み: スピードと並列処理を重視した設計

ArrayRecord の高いパフォーマンスは、Google の Riegeli ファイル形式に基づいたコア設計に根ざしています。この構造は、大規模なデータ処理において次のような重要なメリットを提供します。

  1. 効率的なランダム アクセス: ArrayRecord には、各レコードを正確な位置にマッピングするメタデータ インデックスが組み込まれています。これは、データセット内の任意のレコードに即座に直接アクセスできるようにする重要な設計上の選択であり、ファイルを最初から読み取る必要が完全になくなります。


2. 大規模な並列処理: レコードはデータチャンクにグループ化されます。この構造は、本質的に並列読み取りを前提に設計されており、複数のプロセスが同じファイルの異なるチャンクを同時に読み取ることで、読み取りスループットを劇的に向上させます。


3. 優れたパフォーマンス: このインデックス化とチャンク化を採用した設計により、ベンチマークでは ArrayRecord が従来の形式よりも桁違いに高い読み取りスループットを達成できることが示されており、今日の大規模なデータセットに最適です。


4. スマートなデータの整合性: この形式は、冗長なチェックを追加するのではなく、基盤となるクラウド ストレージ システム(Google Cloud Storage など)の強力なエラー訂正機能を活用することで、データの整合性をインテリジェントに処理します。これにより、不必要なパフォーマンス オーバーヘッドを発生させることなく、データ破損に対する堅牢な保護を提供します。


ArrayRecord を使用する理由

ArrayRecord の機能は、Grain のような最新のデータローダに必要な高度な機能を直接実現します。

最も重要なメリットは、真に決定論的なグローバル シャッフルを実現できることです。あらゆるレコードに即座にアクセスできるため、データローダはトレーニングが発生するとすぐにデータセット内で完全にランダム化されたインデックスを生成し、その特定の順序でデータを取得できます。TFRecord などのシーケンシャルな形式では計算上不可能なこの機能は、再現性の高い研究や最適なモデル トレーニングに不可欠です。


ArrayRecord と TFRecord の詳細な比較

以下は、ArrayRecord と TFRecord の主要な機能に関する詳細な比較です。

  1. 構造

  • ArrayRecord は、Google の Riegeli ファイル形式をベースとして構築されており、高速デコード、データの整合性、強力な圧縮に重点を置いたレコードのシーケンス保存を目的に設計されています。レコードをチャンクにグループ化し、ファイルの末尾にメタデータ インデックスを含めます。

  • TFRecordバイナリ レコードのシーケンスであり、各レコードは通常 tf.train.Example プロトコル バッファです。


2. ランダム アクセス

  • ArrayRecord は、ネイティブで効率的なランダム アクセスを提供します。ファイル構造にはレコード位置のインデックスが組み込まれているため、ファイル全体を読み取る必要がなく、インデックスを使用して任意のレコードに直接かつ高速にアクセスできます。

  • 一方、TFRecord にはネイティブのランダム アクセス機能がありません。ストリーミング データ用に最適化されたシーケンシャル形式であるため、特定のレコードにアクセスするにはファイルの先頭から反復処理する必要があります。


3. グローバル シャッフル

  • ArrayRecord を使用すると、完全なグローバル シャッフルが可能になります。効率的なランダム アクセスのおかげで、Grain のようなデータローダはシャッフルされた順序でインデックスを生成し、その場でレコードを読み取れます。

  • TFRecord では完全なグローバル シャッフルを実現することは困難です。「グローバル」シャッフルと言われるものの多くは、シャーディングされたファイル名のリストをシャッフルした後に小さなメモリバッファ内でレコードをシャッフルするなど、近似に依存しています。これは完全なグローバル シャッフルではありません。


4. 並列 I/O

  • ArrayRecord は並列 I/O をネイティブにサポートしています。ArrayRecord ファイルの内部チャンク構造のおかげで、複数のプロセスが同一ファイルの異なる部分から並列に読み取ることが本質的に容易になり、データ マネジメントが簡素化します。

  • TFRecord は並列読み取りをサポートしていますが、通常、データセットを多数の小さな TFRecord ファイルにシャーディングし、異なるワーカーが異なるファイルから読み取ることで実現します。結果として、管理するファイル数が膨大になる可能性があります。


5. 統合

  • ArrayRecord は、高性能 I/O 向けに設計されており、Grain のような JAX ベースのローダとシームレスに連携します。また、tfds.data_source を介して TensorFlow エコシステム内でも使用できます。

  • TFRecordTensorFlow の tf.data エコシステムと緊密に統合されています。


6. 主なユースケース

  • ArrayRecord は、特に決定論と完全なグローバル シャッフルが必要な場合(JAX / TPU ワークロードなど)、パフォーマンスが重要な機械学習の高スループット データ読み込みに最適です。

  • TFRecord は TensorFlow の汎用的な大規模データ ストレージに適しており、シーケンシャル読み取りに最適化されています。


TFRecord データセットを ArrayRecord に変換する方法

データセットを変換する方法は、TensorFlow データセット(TFDS)カタログに登録された標準のデータセットであるか、独自のカスタム データセットであるかによって異なります。

方法 1: TFDS カタログの標準データセットの場合

cifar10imagenet2012 などのよく知られているデータセットを使用している場合は、tfds コマンドライン ツールが最も簡単な方法です。

前提条件: TensorFlow データセットをインストールすること

pip install -q --upgrade tfds-nightly
Shell

tfds build CLI の使用

このコマンドはソースデータをダウンロードし、準備ロジックを実行して、出力を目的の形式で保存します。

# Generate the 'cifar10' dataset in ArrayRecord format
tfds build cifar10 --file_format=array_record
Shell

生成された ArrayRecord ファイルは ~/tensorflow_datasets/ ディレクトリに保存され、すぐに使用できます。

方法 2: カスタムまたは独自の TFRecord データセットの場合

独自のカスタム TFRecord データセットを大規模に変換するには、Apache Beam の使用をおすすめします。array_record ライブラリは、この変換を非常にシンプルでスケーラブルにするパッケージ済みの Beam パイプラインを提供します。この方法では、Google Cloud Dataflow などのサービスを使用して多くのワーカーに処理を分散できるため、大規模なデータセットに特におすすめです。

前提条件: Apache Beam と Array Record Beam SDK をインストールすること

pip install -q apache-beam
pip install -q array-record-beam-sdk
Shell

パッケージ済み変換パイプラインの使用

array_record.beam.pipelines モジュールには、変換プロセス全体を処理する専用ユーティリティの convert_tf_to_arrayrecord_disk_match_shards 関数が含まれています。これは TFRecord ファイルを読み取り、対応するシャーディングされた ArrayRecord データセットを書き込みます。

Python スクリプトでの使用方法は次のとおりです。

from apache_beam.options import pipeline_options
from array_record.beam.pipelines import convert_tf_to_arrayrecord_disk_match_shards
 
# 1. Define your input and output patterns.
# This example uses Google Cloud Storage (GCS) paths, which is common for large datasets.
input_pattern = 'gs://your-gcs-bucket/path/to/records-*.tfrecord'
output_path = 'gs://your-gcs-bucket/path/to/converted-records'
 
# Arguments dictionary for the conversion function.
args = {
    'input': input_pattern,
    'output': output_path,
}
 
# 2. Configure pipeline options for execution.
 
# To run locally on your machine (for smaller datasets or testing):
# No options are needed; the local runner is used by default.
local_pipeline_options = pipeline_options.PipelineOptions()
 
 
# To run at scale on Google Cloud Dataflow (for large datasets):
# Uncomment the following lines and fill in your project details.
#
# dataflow_pipeline_options = pipeline_options.PipelineOptions(
#     runner='DataflowRunner',
#     project='your-gcp-project-id',
#     region='your-gcp-region',
#     # A requirements.txt file may be needed for dependencies on Dataflow workers.
#     # requirements_file='requirements.txt',
#     temp_location='gs://your-gcs-bucket/path/to/temp'
# )
 
 
# 3. Define and run the main execution logic.
def main():
  print("Starting TFRecord to ArrayRecord conversion...")
  convert_tf_to_arrayrecord_disk_match_shards(
      args=args,
      # Pass the appropriate options here.
      # Use `local_pipeline_options` for local runs.
      # Use `dataflow_pipeline_options` for cloud runs.
      pipeline_options=local_pipeline_options,
  ).run()
  print(f"Conversion complete. ArrayRecord files written to '{output_path}'.")
 
if __name__ == '__main__':
  main()
Python

このアプローチはこのタスク専用に設計されたテスト済みの高レベル API であり、出力シャードと入力シャードのマッチングといった詳細を自動的に処理するため、手動でパイプラインを作成するよりも強力で堅牢です。


Grain と ArrayRecord パイプラインの構築: 概念的なチュートリアル

データを ArrayRecord 形式に変換したら、Grain の Dataset API を使用して高性能な入力パイプラインを定義できます。このプロセスでは、まずソースを作成し、次に変換メソッドを連結します。

ステップ 1: データソースから MapDataset を作成する

まず、ArrayRecord ファイルを指定して MapDataset を作成します。

import grain 
 
# Path to your generated ArrayRecord files
file_paths = ["~/tensorflow_datasets/cifar10/3.0.2/cifar10-train.array_record-00000-of-00001"]
 
# Create a data source
data_source = grain.sources.ArrayRecordDataSource(file_paths)
 
# Create a MapDataset from the source
dataset = grain.MapDataset.source(data_source)
Python

ステップ 2: 変換を連結する(シャッフル、マッピング、バッチ処理)

次に、変換を MapDataset に適用します。各メソッドは新しい MapDataset を返し、宣言的に互いの呼び出しを連鎖できます。

# Example parsing function
def parse_and_transform(record):
    # Your logic to parse features, augment data, etc.
    return {"record": record}
 
BATCH_SIZE = 32
 
# Chain transformations
# The order of operations matters.
dataset = (
    dataset.shuffle(seed=42)
           .map(parse_and_transform)
           .batch(batch_size=BATCH_SIZE, drop_remainder=True)
)
Python

ステップ 3: DatasetIterator を作成して使用する

最後に、完全に定義されたデータセットからイテレータを作成して、トレーニング スクリプト内のデータをループします。

# Create the stateful iterator
data_iterator = iter(dataset)
 
# You can now loop over the data
for batch in data_iterator:
    # Your training step with the batch...
    pass
 
    # For checkpoint saving/restoration, you can get/set the iterator's state
    # state = data_iterator.get_state()
    # data_iterator.set_state(state)
Python

パフォーマンス設定

データ パイプラインがボトルネックになるのを防ぐには、マルチプロセスを使用し、モデル トレーニングと並行してデータの読み込みと前処理を行うようにする必要があります。Dataset API では、.mp_prefetch() 変換をパイプラインに追加することでこれを実現します。

このメソッドでは、ワーカー プロセスのプールを起動し、バックグラウンドで非同期的にデータバッチを準備してバッファに格納するため、トレーニング ループに必要な際に即座に利用できるようになります。

適用方法は次のとおりです。

# The full pipeline with performance tuning.
dataset = (
    grain.MapDataset.source(data_source)
    .shuffle(seed=42)
    .map(parse_and_transform)
 
    # Convert to an iterable dataset to apply prefetching.
    .to_iter_dataset()
    .batch(batch_size=BATCH_SIZE, drop_remainder=True)
    # Apply multiprocessing and prefetching.
    .mp_prefetch(
        grain.multiprocessing.MultiprocessingOptions(
            num_workers=16   # Number of parallel worker processes.
        )
    )
)
 
# Create the final iterator
data_iterator = iter(dataset)
Python
  • num_workers: データの読み込みに使用する並列の子プロセス数を指定します。アクセラレータがデータ待ちで頻繁にアイドル状態になる場合は、この値を増やすことでスループットが大幅に改善する可能性があります。最適な数は、マシンで利用可能な CPU コア数と map 関数の複雑さによって異なります。


さらに詳しい情報

さらに詳しく学び、ビルドを開始したい場合は、このガイドで説明しているテクノロジーの公式ドキュメントとソースコードをご覧ください。

基盤テクノロジー


実際の例: 大規模な LLM トレーニング

Grain と ArrayRecord を使用してビルドされた高性能かつ決定論的なデータ パイプラインは、大規模モデル トレーニングに不可欠です。その好例が、JAX で書かれた高性能なオープンソースの大規模言語モデルである MaxText です。MaxText はまさにこれらのデータ パイプライン手法を活用して、大規模な TPU および GPU クラスタにデータを効率的に供給します。