GPU や TPU などの強力なアクセラレータで大規模なモデルをトレーニングする場合、最も避けたいのはアクセラレータがデータ待ちでアイドル状態になることです。システム全体の速度は最も遅い部分によって決まり、多くの場合、そのボトルネックはデータ入力パイプラインです。そのため、大規模な機械学習には効率的で再現性の高いデータ パイプラインが不可欠です。このガイドでは、Grain(JAX 用の柔軟なデータ読み込みライブラリ)と ArrayRecord(非常に効率的なファイル形式)を使用して堅牢で高性能なデータ パイプラインを構築することで、この課題を解決する方法を説明します。
Grain は軽量でオープンソースのデータ読み込みライブラリで、JAX ベースのワークロードにおけるこの問題を解決するために特別に設計されています。データの読み込み、前処理、モデルへのフィードを効率的に行うことで、ハードウェアのパフォーマンスを最大限に高めます。
Grain は、パフォーマンス、再現性、柔軟性という理念のもとに開発されました。その主なメリットは次のとおりです。
.mp_prefetch()
method) to run data loading and transformations in parallel, ensuring that a buffer of prepared data is always ready for your model. This keeps your accelerators saturated and minimizes training time..shuffle()
、.map()
、.batch()
などの変換を柔軟に追加できます。この宣言型スタイルにより、データ パイプラインの理解、変更、保守が容易になります。TFRecord はよく知られた形式規格ですが、そのシーケンシャルな性質上、完全なグローバル シャッフルは実現できません。ArrayRecord はこの問題を解決するために特別に設計された最新のファイル形式であり、データ効率の新たな可能性を引き出します。
ArrayRecord の高いパフォーマンスは、Google の Riegeli ファイル形式に基づいたコア設計に根ざしています。この構造は、大規模なデータ処理において次のような重要なメリットを提供します。
2. Massive parallelism: Records are grouped into data chunks. This structure is inherently designed to be read in parallel, allowing multiple processes to read different chunks of the same file simultaneously to dramatically increase read throughput.
3. Exceptional performance: As a result of this indexed and chunked design, benchmarks show ArrayRecord can achieve a read throughput an order of magnitude higher than traditional formats, making it ideal for today's massive datasets.
4. Smart data integrity: The format handles data integrity intelligently by leveraging the powerful error correction in underlying cloud storage systems (like Google Cloud Storage) rather than adding redundant checks. This provides robust protection against corruption without unnecessary performance overhead.
ArrayRecord の機能は、Grain のような最新のデータローダに必要な高度な機能を直接実現します。
The most important benefit is achieving true, deterministic global shuffling. Because any record can be accessed instantly, a data loader can generate perfectly randomized indices in the dataset on the fly as the training happens and then fetch data in that specific order. This capability, which is computationally impractical with sequential formats like TFRecord, is vital for reproducible research and optimal model training.
以下は、ArrayRecord と TFRecord の主要な機能に関する詳細な比較です。
2. ランダム アクセス
3. グローバル シャッフル
4. 並列 I/O
5. 統合
6. 主なユースケース
データセットを変換する方法は、TensorFlow データセット(TFDS)カタログに登録された標準のデータセットであるか、独自のカスタム データセットであるかによって異なります。
cifar10
や imagenet2012
などのよく知られているデータセットを使用している場合は、tfds コマンドライン ツールが最も簡単な方法です。
前提条件: TensorFlow データセットをインストールすること
pip install -q --upgrade tfds-nightly
tfds build CLI の使用
このコマンドはソースデータをダウンロードし、準備ロジックを実行して、出力を目的の形式で保存します。
# Generate the 'cifar10' dataset in ArrayRecord format
tfds build cifar10 --file_format=array_record
生成された ArrayRecord ファイルは ~/tensorflow_datasets/
ディレクトリに保存され、すぐに使用できます。
独自のカスタム TFRecord データセットを大規模に変換するには、Apache Beam の使用をおすすめします。array_record
ライブラリは、この変換を非常にシンプルでスケーラブルにするパッケージ済みの Beam パイプラインを提供します。この方法では、Google Cloud Dataflow などのサービスを使用して多くのワーカーに処理を分散できるため、大規模なデータセットに特におすすめです。
前提条件: Apache Beam と Array Record Beam SDK をインストールすること
pip install -q apache-beam
pip install -q array-record-beam-sdk
パッケージ済み変換パイプラインの使用
array_record.beam.pipelines
モジュールには、変換プロセス全体を処理する専用ユーティリティの convert_tf_to_arrayrecord_disk_match_shards
関数が含まれています。これは TFRecord ファイルを読み取り、対応するシャーディングされた ArrayRecord データセットを書き込みます。
Python スクリプトでの使用方法は次のとおりです。
from apache_beam.options import pipeline_options
from array_record.beam.pipelines import convert_tf_to_arrayrecord_disk_match_shards
# 1. Define your input and output patterns.
# This example uses Google Cloud Storage (GCS) paths, which is common for large datasets.
input_pattern = 'gs://your-gcs-bucket/path/to/records-*.tfrecord'
output_path = 'gs://your-gcs-bucket/path/to/converted-records'
# Arguments dictionary for the conversion function.
args = {
'input': input_pattern,
'output': output_path,
}
# 2. Configure pipeline options for execution.
# To run locally on your machine (for smaller datasets or testing):
# No options are needed; the local runner is used by default.
local_pipeline_options = pipeline_options.PipelineOptions()
# To run at scale on Google Cloud Dataflow (for large datasets):
# Uncomment the following lines and fill in your project details.
#
# dataflow_pipeline_options = pipeline_options.PipelineOptions(
# runner='DataflowRunner',
# project='your-gcp-project-id',
# region='your-gcp-region',
# # A requirements.txt file may be needed for dependencies on Dataflow workers.
# # requirements_file='requirements.txt',
# temp_location='gs://your-gcs-bucket/path/to/temp'
# )
# 3. Define and run the main execution logic.
def main():
print("Starting TFRecord to ArrayRecord conversion...")
convert_tf_to_arrayrecord_disk_match_shards(
args=args,
# Pass the appropriate options here.
# Use `local_pipeline_options` for local runs.
# Use `dataflow_pipeline_options` for cloud runs.
pipeline_options=local_pipeline_options,
).run()
print(f"Conversion complete. ArrayRecord files written to '{output_path}'.")
if __name__ == '__main__':
main()
このアプローチはこのタスク専用に設計されたテスト済みの高レベル API であり、出力シャードと入力シャードのマッチングといった詳細を自動的に処理するため、手動でパイプラインを作成するよりも強力で堅牢です。
データを ArrayRecord 形式に変換したら、Grain の Dataset
API を使用して高性能な入力パイプラインを定義できます。このプロセスでは、まずソースを作成し、次に変換メソッドを連結します。
まず、ArrayRecord ファイルを指定して MapDataset
を作成します。
import grain
# Path to your generated ArrayRecord files
file_paths = ["~/tensorflow_datasets/cifar10/3.0.2/cifar10-train.array_record-00000-of-00001"]
# Create a data source
data_source = grain.sources.ArrayRecordDataSource(file_paths)
# Create a MapDataset from the source
dataset = grain.MapDataset.source(data_source)
次に、変換を MapDataset
に適用します。各メソッドは新しい MapDataset
を返し、宣言的に互いの呼び出しを連鎖できます。
# Example parsing function
def parse_and_transform(record):
# Your logic to parse features, augment data, etc.
return {"record": record}
BATCH_SIZE = 32
# Chain transformations
# The order of operations matters.
dataset = (
dataset.shuffle(seed=42)
.map(parse_and_transform)
.batch(batch_size=BATCH_SIZE, drop_remainder=True)
)
DatasetIterator
最後に、完全に定義されたデータセットからイテレータを作成して、トレーニング スクリプト内のデータをループします。
# Create the stateful iterator
data_iterator = iter(dataset)
# You can now loop over the data
for batch in data_iterator:
# Your training step with the batch...
pass
# For checkpoint saving/restoration, you can get/set the iterator's state
# state = data_iterator.get_state()
# data_iterator.set_state(state)
データ パイプラインがボトルネックになるのを防ぐには、マルチプロセスを使用し、モデル トレーニングと並行してデータの読み込みと前処理を行うようにする必要があります。Dataset API では、.mp_prefetch()
変換をパイプラインに追加することでこれを実現します。
このメソッドでは、ワーカー プロセスのプールを起動し、バックグラウンドで非同期的にデータバッチを準備してバッファに格納するため、トレーニング ループに必要な際に即座に利用できるようになります。
適用方法は次のとおりです。
# The full pipeline with performance tuning.
dataset = (
grain.MapDataset.source(data_source)
.shuffle(seed=42)
.map(parse_and_transform)
# Convert to an iterable dataset to apply prefetching.
.to_iter_dataset()
.batch(batch_size=BATCH_SIZE, drop_remainder=True)
# Apply multiprocessing and prefetching.
.mp_prefetch(
grain.multiprocessing.MultiprocessingOptions(
num_workers=16 # Number of parallel worker processes.
)
)
)
# Create the final iterator
data_iterator = iter(dataset)
num_workers
: データの読み込みに使用する並列の子プロセス数を指定します。アクセラレータがデータ待ちで頻繁にアイドル状態になる場合は、この値を増やすことでスループットが大幅に改善する可能性があります。最適な数は、マシンで利用可能な CPU コア数と map 関数の複雑さによって異なります。さらに詳しく学び、ビルドを開始したい場合は、このガイドで説明しているテクノロジーの公式ドキュメントとソースコードをご覧ください。
Grain と ArrayRecord を使用してビルドされた高性能かつ決定論的なデータ パイプラインは、大規模モデル トレーニングに不可欠です。その好例が、JAX で書かれた高性能なオープンソースの大規模言語モデルである MaxText です。MaxText はまさにこれらのデータ パイプライン手法を活用して、大規模な TPU および GPU クラスタにデータを効率的に供給します。