GPU や TPU などの強力なアクセラレータで大規模なモデルをトレーニングする場合、最も避けたいのはアクセラレータがデータ待ちでアイドル状態になることです。システム全体の速度は最も遅い部分によって決まり、多くの場合、そのボトルネックはデータ入力パイプラインです。そのため、大規模な機械学習には効率的で再現性の高いデータ パイプラインが不可欠です。このガイドでは、Grain(JAX 用の柔軟なデータ読み込みライブラリ)と ArrayRecord(非常に効率的なファイル形式)を使用して堅牢で高性能なデータ パイプラインを構築することで、この課題を解決する方法を説明します。
Grain は軽量でオープンソースのデータ読み込みライブラリで、JAX ベースのワークロードにおけるこの問題を解決するために特別に設計されています。データの読み込み、前処理、モデルへのフィードを効率的に行うことで、ハードウェアのパフォーマンスを最大限に高めます。
Grain は、パフォーマンス、再現性、柔軟性という理念のもとに開発されました。その主なメリットは次のとおりです。
.mp_prefetch() メソッドによる)を使用して、データの読み込みと変換を並列実行し、モデル用に準備されたデータのバッファを常に確保します。これにより、アクセラレータのフル稼働状態を維持し、トレーニング時間を最小限に抑えます。.shuffle()、.map()、.batch() などの変換を柔軟に追加できます。この宣言型スタイルにより、データ パイプラインの理解、変更、保守が容易になります。TFRecord はよく知られた形式規格ですが、そのシーケンシャルな性質上、完全なグローバル シャッフルは実現できません。ArrayRecord はこの問題を解決するために特別に設計された最新のファイル形式であり、データ効率の新たな可能性を引き出します。
ArrayRecord の高いパフォーマンスは、Google の Riegeli ファイル形式に基づいたコア設計に根ざしています。この構造は、大規模なデータ処理において次のような重要なメリットを提供します。
2. 大規模な並列処理: レコードはデータチャンクにグループ化されます。この構造は、本質的に並列読み取りを前提に設計されており、複数のプロセスが同じファイルの異なるチャンクを同時に読み取ることで、読み取りスループットを劇的に向上させます。
3. 優れたパフォーマンス: このインデックス化とチャンク化を採用した設計により、ベンチマークでは ArrayRecord が従来の形式よりも桁違いに高い読み取りスループットを達成できることが示されており、今日の大規模なデータセットに最適です。
4. スマートなデータの整合性: この形式は、冗長なチェックを追加するのではなく、基盤となるクラウド ストレージ システム(Google Cloud Storage など)の強力なエラー訂正機能を活用することで、データの整合性をインテリジェントに処理します。これにより、不必要なパフォーマンス オーバーヘッドを発生させることなく、データ破損に対する堅牢な保護を提供します。
ArrayRecord の機能は、Grain のような最新のデータローダに必要な高度な機能を直接実現します。
最も重要なメリットは、真に決定論的なグローバル シャッフルを実現できることです。あらゆるレコードに即座にアクセスできるため、データローダはトレーニングが発生するとすぐにデータセット内で完全にランダム化されたインデックスを生成し、その特定の順序でデータを取得できます。TFRecord などのシーケンシャルな形式では計算上不可能なこの機能は、再現性の高い研究や最適なモデル トレーニングに不可欠です。
以下は、ArrayRecord と TFRecord の主要な機能に関する詳細な比較です。
2. ランダム アクセス
3. グローバル シャッフル
4. 並列 I/O
5. 統合
6. 主なユースケース
データセットを変換する方法は、TensorFlow データセット(TFDS)カタログに登録された標準のデータセットであるか、独自のカスタム データセットであるかによって異なります。
cifar10 や imagenet2012 などのよく知られているデータセットを使用している場合は、tfds コマンドライン ツールが最も簡単な方法です。
前提条件: TensorFlow データセットをインストールすること
pip install -q --upgrade tfds-nightly
tfds build CLI の使用
このコマンドはソースデータをダウンロードし、準備ロジックを実行して、出力を目的の形式で保存します。
# Generate the 'cifar10' dataset in ArrayRecord format
tfds build cifar10 --file_format=array_record
生成された ArrayRecord ファイルは ~/tensorflow_datasets/ ディレクトリに保存され、すぐに使用できます。
独自のカスタム TFRecord データセットを大規模に変換するには、Apache Beam の使用をおすすめします。array_record ライブラリは、この変換を非常にシンプルでスケーラブルにするパッケージ済みの Beam パイプラインを提供します。この方法では、Google Cloud Dataflow などのサービスを使用して多くのワーカーに処理を分散できるため、大規模なデータセットに特におすすめです。
前提条件: Apache Beam と Array Record Beam SDK をインストールすること
pip install -q apache-beam
pip install -q array-record-beam-sdk
パッケージ済み変換パイプラインの使用
array_record.beam.pipelines モジュールには、変換プロセス全体を処理する専用ユーティリティの convert_tf_to_arrayrecord_disk_match_shards 関数が含まれています。これは TFRecord ファイルを読み取り、対応するシャーディングされた ArrayRecord データセットを書き込みます。
Python スクリプトでの使用方法は次のとおりです。
from apache_beam.options import pipeline_options
from array_record.beam.pipelines import convert_tf_to_arrayrecord_disk_match_shards
# 1. Define your input and output patterns.
# This example uses Google Cloud Storage (GCS) paths, which is common for large datasets.
input_pattern = 'gs://your-gcs-bucket/path/to/records-*.tfrecord'
output_path = 'gs://your-gcs-bucket/path/to/converted-records'
# Arguments dictionary for the conversion function.
args = {
'input': input_pattern,
'output': output_path,
}
# 2. Configure pipeline options for execution.
# To run locally on your machine (for smaller datasets or testing):
# No options are needed; the local runner is used by default.
local_pipeline_options = pipeline_options.PipelineOptions()
# To run at scale on Google Cloud Dataflow (for large datasets):
# Uncomment the following lines and fill in your project details.
#
# dataflow_pipeline_options = pipeline_options.PipelineOptions(
# runner='DataflowRunner',
# project='your-gcp-project-id',
# region='your-gcp-region',
# # A requirements.txt file may be needed for dependencies on Dataflow workers.
# # requirements_file='requirements.txt',
# temp_location='gs://your-gcs-bucket/path/to/temp'
# )
# 3. Define and run the main execution logic.
def main():
print("Starting TFRecord to ArrayRecord conversion...")
convert_tf_to_arrayrecord_disk_match_shards(
args=args,
# Pass the appropriate options here.
# Use `local_pipeline_options` for local runs.
# Use `dataflow_pipeline_options` for cloud runs.
pipeline_options=local_pipeline_options,
).run()
print(f"Conversion complete. ArrayRecord files written to '{output_path}'.")
if __name__ == '__main__':
main()
このアプローチはこのタスク専用に設計されたテスト済みの高レベル API であり、出力シャードと入力シャードのマッチングといった詳細を自動的に処理するため、手動でパイプラインを作成するよりも強力で堅牢です。
データを ArrayRecord 形式に変換したら、Grain の Dataset API を使用して高性能な入力パイプラインを定義できます。このプロセスでは、まずソースを作成し、次に変換メソッドを連結します。
まず、ArrayRecord ファイルを指定して MapDataset を作成します。
import grain
# Path to your generated ArrayRecord files
file_paths = ["~/tensorflow_datasets/cifar10/3.0.2/cifar10-train.array_record-00000-of-00001"]
# Create a data source
data_source = grain.sources.ArrayRecordDataSource(file_paths)
# Create a MapDataset from the source
dataset = grain.MapDataset.source(data_source)
次に、変換を MapDataset に適用します。各メソッドは新しい MapDataset を返し、宣言的に互いの呼び出しを連鎖できます。
# Example parsing function
def parse_and_transform(record):
# Your logic to parse features, augment data, etc.
return {"record": record}
BATCH_SIZE = 32
# Chain transformations
# The order of operations matters.
dataset = (
dataset.shuffle(seed=42)
.map(parse_and_transform)
.batch(batch_size=BATCH_SIZE, drop_remainder=True)
)
DatasetIterator を作成して使用する最後に、完全に定義されたデータセットからイテレータを作成して、トレーニング スクリプト内のデータをループします。
# Create the stateful iterator
data_iterator = iter(dataset)
# You can now loop over the data
for batch in data_iterator:
# Your training step with the batch...
pass
# For checkpoint saving/restoration, you can get/set the iterator's state
# state = data_iterator.get_state()
# data_iterator.set_state(state)
データ パイプラインがボトルネックになるのを防ぐには、マルチプロセスを使用し、モデル トレーニングと並行してデータの読み込みと前処理を行うようにする必要があります。Dataset API では、.mp_prefetch() 変換をパイプラインに追加することでこれを実現します。
このメソッドでは、ワーカー プロセスのプールを起動し、バックグラウンドで非同期的にデータバッチを準備してバッファに格納するため、トレーニング ループに必要な際に即座に利用できるようになります。
適用方法は次のとおりです。
# The full pipeline with performance tuning.
dataset = (
grain.MapDataset.source(data_source)
.shuffle(seed=42)
.map(parse_and_transform)
# Convert to an iterable dataset to apply prefetching.
.to_iter_dataset()
.batch(batch_size=BATCH_SIZE, drop_remainder=True)
# Apply multiprocessing and prefetching.
.mp_prefetch(
grain.multiprocessing.MultiprocessingOptions(
num_workers=16 # Number of parallel worker processes.
)
)
)
# Create the final iterator
data_iterator = iter(dataset)
num_workers: データの読み込みに使用する並列の子プロセス数を指定します。アクセラレータがデータ待ちで頻繁にアイドル状態になる場合は、この値を増やすことでスループットが大幅に改善する可能性があります。最適な数は、マシンで利用可能な CPU コア数と map 関数の複雑さによって異なります。さらに詳しく学び、ビルドを開始したい場合は、このガイドで説明しているテクノロジーの公式ドキュメントとソースコードをご覧ください。
Grain と ArrayRecord を使用してビルドされた高性能かつ決定論的なデータ パイプラインは、大規模モデル トレーニングに不可欠です。その好例が、JAX で書かれた高性能なオープンソースの大規模言語モデルである MaxText です。MaxText はまさにこれらのデータ パイプライン手法を活用して、大規模な TPU および GPU クラスタにデータを効率的に供給します。