GPU와 TPU 같은 고성능 가속기에서 대형 모델을 학습시킬 때 가장 달갑지 않은 상황은 가속기가 데이터를 기다리느라 유휴 상태에 들어가는 것입니다. 시스템에서 가장 느린 구성 요소가 전체 시스템의 속도를 좌우하기 마련인데, 이런 병목 현상을 데이터 입력 파이프라인이 일으키는 경우가 많습니다. 대규모 머신러닝에 효율적이고 재현 가능한 데이터 파이프라인이 꼭 필요한 것도 바로 그 때문입니다. 이 가이드는 유연한 JAX용 데이터 로딩 라이브러리인 Grain과 고효율 파일 형식인 ArrayRecord를 사용하여 강력한 고성능 데이터 파이프라인을 구축함으로써 이 문제를 해결하는 방법을 소개합니다.
Grain은 JAX 기반 워크로드에서 발생하는 병목 현상을 해결하기 위해 특별히 설계된 가벼운 오픈소스 데이터 로딩 라이브러리입니다. Grain을 사용하면 데이터를 효율적으로 모델에 로드하고, 전처리하고, 공급하여 사용자의 하드웨어 성능을 극대화할 수 있습니다.
성능, 재현 가능성, 유연성이라는 철학을 토대로 구축한 Grain은 다음과 같은 주요 이점을 제공합니다.
.mp_prefetch()
method) to run data loading and transformations in parallel, ensuring that a buffer of prepared data is always ready for your model. This keeps your accelerators saturated and minimizes training time..shuffle()
, .map()
, .batch()
등의 변환을 유동적으로 추가할 수 있습니다. 이러한 선언형 스타일 덕분에 데이터 파이프라인을 이해하고,·변경하고,·유지하기가 쉽습니다.TFRecord가 익숙한 표준이기는 하지만, 순차적인 특성상 완전한 글로벌 셔플이 불가능합니다. ArrayRecord는 이 문제를 해결하기 위해 특별히 설계된 파일 형식으로, 새로운 차원의 데이터 효율성을 제공합니다.
ArrayRecord가 발휘하는 뛰어난 성능의 비결은 Google의 Riegeli 파일 형식에 기반한 핵심 설계입니다. 이 구조는 대규모 데이터 처리에 여러 가지 핵심 이점을 제공합니다.
2. Massive parallelism: Records are grouped into data chunks. This structure is inherently designed to be read in parallel, allowing multiple processes to read different chunks of the same file simultaneously to dramatically increase read throughput.
3. Exceptional performance: As a result of this indexed and chunked design, benchmarks show ArrayRecord can achieve a read throughput an order of magnitude higher than traditional formats, making it ideal for today's massive datasets.
4. Smart data integrity: The format handles data integrity intelligently by leveraging the powerful error correction in underlying cloud storage systems (like Google Cloud Storage) rather than adding redundant checks. This provides robust protection against corruption without unnecessary performance overhead.
ArrayRecord의 여러 특징은 Grain 같은 최신 데이터 로더에 필요한 고급 기능을 직접적으로 구현합니다.
The most important benefit is achieving true, deterministic global shuffling. Because any record can be accessed instantly, a data loader can generate perfectly randomized indices in the dataset on the fly as the training happens and then fetch data in that specific order. This capability, which is computationally impractical with sequential formats like TFRecord, is vital for reproducible research and optimal model training.
주요 특징을 중심으로 ArrayRecord와 TFRecord를 자세히 비교해 보겠습니다.
2. 랜덤 액세스
3. 글로벌 셔플
4. 병렬 I/O
5. 통합
6. 주요 사용 사례
데이터 세트 변환 메서드는 해당 데이터 세트가 TensorFlow Dataset(TFDS) 카탈로그에 등록된 표준 데이터 세트인지 아니면 커스텀 독점 데이터 세트인지에 따라 달라집니다.
cifar10
또는 imagenet2012
처럼 잘 알려진 데이터 세트를 사용하고 있다면 tfds 명령줄 도구가 가장 직관적인 메서드입니다.
전제 조건: TensorFlow 데이터 세트 설치
pip install -q --upgrade tfds-nightly
tfds build 명령줄 도구 사용
이 명령은 소스 데이터를 다운로드하여 준비 로직을 실행하고 결과물을 사용자가 원하는 포맷으로 저장합니다.
# Generate the 'cifar10' dataset in ArrayRecord format
tfds build cifar10 --file_format=array_record
생성된 ArrayRecord 파일은 ~/tensorflow_datasets/
디렉터리에 저장되며 바로 사용 가능합니다.
자체 커스텀 TFRecord 데이터 세트를 대규모로 변환할 때 권장하는 접근법은 Apache Beam을 사용하는 것입니다. array_record
라이브러리는 사전에 패키지로 만든 Beam 파이프라인을 제공하므로 변환이 놀라울 정도로 간단하며 확장도 가능합니다. 방대한 데이터 세트에 이 메서드를 적극 권장하는 이유는 Google Cloud Dataflow 같은 서비스를 사용하는 여러 작업자에 처리를 분산시킬 수 있기 때문입니다.
전제 조건: Apache Beam 및 Array Record Beam SDK 설치
pip install -q apache-beam
pip install -q array-record-beam-sdk
사전에 패키지로 만든 변환 파이프라인 사용
array_record.beam.pipelines
모듈에 포함된 convert_tf_to_arrayrecord_disk_match_shards
함수는 특정한 목적으로 개발한 유틸리티로, 전체 변환 프로세스를 담당합니다. 이 유틸리티는 TFRecord 파일을 판독하여 분할된 해당 ArrayRecord 데이터 세트를 작성합니다.
이를 Python 스크립트에서 사용하는 방법은 다음과 같습니다.
from apache_beam.options import pipeline_options
from array_record.beam.pipelines import convert_tf_to_arrayrecord_disk_match_shards
# 1. Define your input and output patterns.
# This example uses Google Cloud Storage (GCS) paths, which is common for large datasets.
input_pattern = 'gs://your-gcs-bucket/path/to/records-*.tfrecord'
output_path = 'gs://your-gcs-bucket/path/to/converted-records'
# Arguments dictionary for the conversion function.
args = {
'input': input_pattern,
'output': output_path,
}
# 2. Configure pipeline options for execution.
# To run locally on your machine (for smaller datasets or testing):
# No options are needed; the local runner is used by default.
local_pipeline_options = pipeline_options.PipelineOptions()
# To run at scale on Google Cloud Dataflow (for large datasets):
# Uncomment the following lines and fill in your project details.
#
# dataflow_pipeline_options = pipeline_options.PipelineOptions(
# runner='DataflowRunner',
# project='your-gcp-project-id',
# region='your-gcp-region',
# # A requirements.txt file may be needed for dependencies on Dataflow workers.
# # requirements_file='requirements.txt',
# temp_location='gs://your-gcs-bucket/path/to/temp'
# )
# 3. Define and run the main execution logic.
def main():
print("Starting TFRecord to ArrayRecord conversion...")
convert_tf_to_arrayrecord_disk_match_shards(
args=args,
# Pass the appropriate options here.
# Use `local_pipeline_options` for local runs.
# Use `dataflow_pipeline_options` for cloud runs.
pipeline_options=local_pipeline_options,
).run()
print(f"Conversion complete. ArrayRecord files written to '{output_path}'.")
if __name__ == '__main__':
main()
이 접근법은 변환 작업을 위해 특별히 설계되어 테스트를 거친 고수준 API이며, 출력 분할을 자동으로 입력 분할과 매치하는 등의 세부 사항을 처리해 주므로 분할 파이프라인을 작성하는 것보다 더 효과적이고 확실합니다.
데이터가 ArrayRecord 포맷으로 변환되면 Grain Dataset
API를 사용해서 고성능 입력 파이프라인을 정의할 수 있습니다. 이 프로세스에는 소스를 생성하고 여러 가지 변환 메서드를 체인으로 연결하는 작업이 포함됩니다.
우선 ArrayRecord 파일을 지정하여 MapDataset
를 생성합니다.
import grain
# Path to your generated ArrayRecord files
file_paths = ["~/tensorflow_datasets/cifar10/3.0.2/cifar10-train.array_record-00000-of-00001"]
# Create a data source
data_source = grain.sources.ArrayRecordDataSource(file_paths)
# Create a MapDataset from the source
dataset = grain.MapDataset.source(data_source)
이제 MapDataset
에 변환을 적용합니다. 각 메서드가 새로운 MapDataset
를 반환하므로 사용자는 선언을 통해 호출을 체인으로 연결할 수 있습니다.
# Example parsing function
def parse_and_transform(record):
# Your logic to parse features, augment data, etc.
return {"record": record}
BATCH_SIZE = 32
# Chain transformations
# The order of operations matters.
dataset = (
dataset.shuffle(seed=42)
.map(parse_and_transform)
.batch(batch_size=BATCH_SIZE, drop_remainder=True)
)
DatasetIterator
마지막으로, 완전히 정의한 데이터세트에서 반복자를 생성하여 학습 스크립트의 데이터를 루프합니다.
# Create the stateful iterator
data_iterator = iter(dataset)
# You can now loop over the data
for batch in data_iterator:
# Your training step with the batch...
pass
# For checkpoint saving/restoration, you can get/set the iterator's state
# state = data_iterator.get_state()
# data_iterator.set_state(state)
데이터 파이프라인이 일으키는 병목 현상을 방지하려면 다중 처리를 사용해서 데이터 로드 및 전처리와 모델 학습을 동시에 처리해야 합니다. 이를 위해 Dataset API에서 파이프라인에 .mp_prefetch()
변환을 추가합니다.
이 메서드는 작업자 프로세스 풀을 시작해서 백그라운드에서 데이터 배치를 비동기적으로 준비하여 사용자의 학습 루프에 데이터 배치가 필요할 때 바로 사용할 수 있도록 버퍼로 저장합니다.
적용 방법은 다음과 같습니다.
# The full pipeline with performance tuning.
dataset = (
grain.MapDataset.source(data_source)
.shuffle(seed=42)
.map(parse_and_transform)
# Convert to an iterable dataset to apply prefetching.
.to_iter_dataset()
.batch(batch_size=BATCH_SIZE, drop_remainder=True)
# Apply multiprocessing and prefetching.
.mp_prefetch(
grain.multiprocessing.MultiprocessingOptions(
num_workers=16 # Number of parallel worker processes.
)
)
)
# Create the final iterator
data_iterator = iter(dataset)
num_workers
: 데이터 로드에 사용할 병렬 하위 프로세스 수를 지정합니다. 가속기가 데이터를 기다리느라 유휴 상태에 들어가는 일이 잦을 경우, 이 값을 높이면 처리량이 눈에 띄게 개선됩니다. 최적의 값은 머신에서 이용 가능한 CPU 코어와 매핑 기능의 복잡성에 따라 달라집니다.자세한 정보를 확인하고 구축을 시작하고 싶으신가요? 이 가이드에서 다룬 기술의 공식 문서와 소스 코드를 살펴보세요.
Grain과 ArrayRecord로 구축된 충분한 성능의 결정적 데이터 파이프라인은 대규모 모델 학습에 꼭 필요합니다. 주된 예시로는 JAX로 작성한 고성능 오픈 소스 대규모 언어 모델인 MaxText를 들 수 있습니다. MaxText는 정확히 이 데이터 파이프라인 기법을 활용해서 대형 TPU 및 GPU 클러스터에 효율적으로 데이터를 공급합니다.