결정성 및 재현 가능성 보장: Grain은 신뢰받는 연구에 필수인 완벽한 재현 가능성을 갖추고 있습니다. 시드를 설정하기만 하면 데이터가 항상 동일한 방식으로 셔플됩니다. 중요한 점은 데이터 반복자의 상태가 저장되며 체크 포인트를 저장할 수 있다는 것입니다. 그래서 학습 작업이 중단되거나 선점되더라도 데이터 스트림의 정확히 같은 지점에서 작업을 다시 시작할 수 있습니다.

2025년 10월 3일
Jiyang Kang Technical Program Manager
Shivaji Dutta Field Solutions Architect
Ihor Indyk Software Engineer
Felix Chern Software Engineer

GPU와 TPU 같은 고성능 가속기에서 대형 모델을 학습시킬 때 가장 달갑지 않은 상황은 가속기가 데이터를 기다리느라 유휴 상태에 들어가는 것입니다. 시스템에서 가장 느린 구성 요소가 전체 시스템의 속도를 좌우하기 마련인데, 이런 병목 현상을 데이터 입력 파이프라인이 일으키는 경우가 많습니다. 대규모 머신러닝에 효율적이고 재현 가능한 데이터 파이프라인이 꼭 필요한 것도 바로 그 때문입니다. 이 가이드는 유연한 JAX용 데이터 로딩 라이브러리인 Grain과 고효율 파일 형식인 ArrayRecord를 사용하여 강력한 고성능 데이터 파이프라인을 구축함으로써 이 문제를 해결하는 방법을 소개합니다.


핵심 구성요소 소개

Grain: JAX용 고성능 데이터 로더

Grain은 JAX 기반 워크로드에서 발생하는 병목 현상을 해결하기 위해 특별히 설계된 가벼운 오픈소스 데이터 로딩 라이브러리입니다. Grain을 사용하면 데이터를 효율적으로 모델에 로드하고, 전처리하고, 공급하여 사용자의 하드웨어 성능을 극대화할 수 있습니다.

Grain의 사용 의의

성능, 재현 가능성, 유연성이라는 철학을 토대로 구축한 Grain은 다음과 같은 주요 이점을 제공합니다.

  • Exceptional performance: Grain is built for speed. It uses efficient multiprocessing (via the .mp_prefetch() method) to run data loading and transformations in parallel, ensuring that a buffer of prepared data is always ready for your model. This keeps your accelerators saturated and minimizes training time.

  • Guaranteed determinism and reproducibility: Grain provides full reproducibility, which is critical for credible research. By setting a simple seed, you ensure the data is always shuffled the same way. Crucially, its data iterators are stateful and can be checkpointed. This means if your training job is interrupted or preempted, you can restart from the exact same point in the data stream.

  • 직관적인 선언형 API: 사용자는 간단하고 읽기 쉬운 여러 가지 메서드를 서로 체인으로 연결하여 데이터 파이프라인을 정의합니다. MapDataset 소스부터 시작해서 .shuffle(), .map(), .batch() 등의 변환을 유동적으로 추가할 수 있습니다. 이러한 선언형 스타일 덕분에 데이터 파이프라인을 이해하고,·변경하고,·유지하기가 쉽습니다.

  • Unlocking true global shuffling: To get the best performance from your models, you need to shuffle your data effectively. When paired with a file format that supports random access, like ArrayRecord, Grain can perform a true global shuffle across your entire dataset, even when it doesn’t fit into host memory. This is a powerful feature that is often computationally impractical with other data loaders and formats.


ArrayRecord의 개념과 사용 의의

TFRecord가 익숙한 표준이기는 하지만, 순차적인 특성상 완전한 글로벌 셔플이 불가능합니다. ArrayRecord는 이 문제를 해결하기 위해 특별히 설계된 파일 형식으로, 새로운 차원의 데이터 효율성을 제공합니다.

ArrayRecord File Layout

속도와 병렬화에 초점을 맞춰 설계된 ArrayRecord의 작동 원리

ArrayRecord가 발휘하는 뛰어난 성능의 비결은 Google의 Riegeli 파일 형식에 기반한 핵심 설계입니다. 이 구조는 대규모 데이터 처리에 여러 가지 핵심 이점을 제공합니다.

  1. Efficient random access: ArrayRecord features a built-in metadata index that maps every record to its precise location. This is the key design choice that enables instant, direct access to any record in the dataset, completely avoiding the need to read a file from the beginning.


2. Massive parallelism: Records are grouped into data chunks. This structure is inherently designed to be read in parallel, allowing multiple processes to read different chunks of the same file simultaneously to dramatically increase read throughput.


3. Exceptional performance: As a result of this indexed and chunked design, benchmarks show ArrayRecord can achieve a read throughput an order of magnitude higher than traditional formats, making it ideal for today's massive datasets.


4. Smart data integrity: The format handles data integrity intelligently by leveraging the powerful error correction in underlying cloud storage systems (like Google Cloud Storage) rather than adding redundant checks. This provides robust protection against corruption without unnecessary performance overhead.


우리가 ArrayRecord를 사용하는 이유

ArrayRecord의 여러 특징은 Grain 같은 최신 데이터 로더에 필요한 고급 기능을 직접적으로 구현합니다.

The most important benefit is achieving true, deterministic global shuffling. Because any record can be accessed instantly, a data loader can generate perfectly randomized indices in the dataset on the fly as the training happens and then fetch data in that specific order. This capability, which is computationally impractical with sequential formats like TFRecord, is vital for reproducible research and optimal model training.


ArrayRecord와 TFRecord의 상세 비교

주요 특징을 중심으로 ArrayRecord와 TFRecord를 자세히 비교해 보겠습니다.

  1. 구조

  • ArrayRecord고속 디코딩, 데이터 무결성, 고효율 압축에 초점을 두고 레코드 시퀀스를 저장하기 위해 설계된 Google Riegeli 파일 형식을 기반으로 구축되었습니다. 레코드를 청크로 그룹화하고 파일 끝부분에 메타데이터 색인을 포함합니다.

  • TFRecord바이너리 레코드 시퀀스로, 각 레코드가 보통 tf.train.Example 프로토콜 버퍼입니다.


2. 랜덤 액세스

  • ArrayRecord효율적인 네이티브 랜덤 액세스를 제공합니다. 파일 구조에 레코드 위치의 자체 인덱스가 포함되어 있어 전체 파일을 판독하지 않고도 색인을 사용해 어떤 레코드에든 빠르게 직접 액세스할 수 있습니다.

  • 반면에, TFRecord에는 네이티브 랜덤 액세스 기능이 없습니다. 데이터 스트리밍에 최적화된 순차적인 포맷이기 때문에 특정 레코드에 액세스하려면 시작 부분부터 파일을 반복해야 합니다.


3. 글로벌 셔플

  • ArrayRecord를 사용하면 완전한 글로벌 셔플이 가능합니다. 효율적인 랜덤 액세스 덕분에 Grain 같은 데이터 로더가 셔플된 순서로 인덱스를 생성하고 실시간으로 레코드를 판독할 수 있습니다.

  • TFRecord의 경우 완전한 글로벌 셔플을 구현하기가 어렵습니다. 분할된 파일 이름 목록을 셔플한 후 소규모 메모리 버퍼 내에서 레코드를 셔플하는 식으로 '글로벌' 셔플을 근사치에 의존하는 경우가 많은데, 그것은 완전한 글로벌 셔플이 아닙니다.


4. 병렬 I/O

  • ArrayRecord는 병렬 I/O를 기본으로 지원합니다. ArrayRecord 파일의 청크로 나눠진 내부 구조 덕분에 복수의 프로세스가 병렬로 동일한 파일의 서로 다른 부분을 판독하는 것이 근본적으로 쉬운데, 이는 데이터 관리를 간소화합니다.

  • TFRecord는 병렬 판독을 지원하지만 주로 데이터 세트를 여러 개의 작은 TFRecord 파일로 분할하고 서로 다른 작업자가 서로 다른 파일을 판독하게 하는 방식을 사용합니다. 그래서 관리해야 하는 파일의 수가 많아질 수 있습니다.


5. 통합

  • 고성능 I/O를 위해 설계한 ArrayRecordGrain 같은 JAX 기반 로더와 함께 사용 시 매끄럽게 작동합니다. tfds.data_source를 통해 TensorFlow 생태계에서도 사용할 수 있습니다.

  • TFRecordTensorFlow's tf.데이터 생태계에 밀접하게 통합되어 있습니다.


6. 주요 사용 사례

  • ArrayRecord is ideal for high-throughput data loading for performance-critical machine learning, especially where determinism and true global shuffling are required (e.g., JAX/TPU workloads).

  • TFRecord is suited for general-purpose, large-scale data storage for TensorFlow and is optimized for sequential reads.


TFRecord 데이터 세트를 ArrayRecord로 변환하는 방법

데이터 세트 변환 메서드는 해당 데이터 세트가 TensorFlow Dataset(TFDS) 카탈로그에 등록된 표준 데이터 세트인지 아니면 커스텀 독점 데이터 세트인지에 따라 달라집니다.

메서드 1: TFDS 카탈로그에 등록된 표준 데이터 세트인 경우

cifar10 또는 imagenet2012처럼 잘 알려진 데이터 세트를 사용하고 있다면 tfds 명령줄 도구가 가장 직관적인 메서드입니다.

전제 조건: TensorFlow 데이터 세트 설치

pip install -q --upgrade tfds-nightly
Shell

tfds build 명령줄 도구 사용

이 명령은 소스 데이터를 다운로드하여 준비 로직을 실행하고 결과물을 사용자가 원하는 포맷으로 저장합니다.

# Generate the 'cifar10' dataset in ArrayRecord format
tfds build cifar10 --file_format=array_record
Shell

생성된 ArrayRecord 파일은 ~/tensorflow_datasets/ 디렉터리에 저장되며 바로 사용 가능합니다.

메서드 2: 커스텀 또는 독점 TFRecord 데이터 세트인 경우

자체 커스텀 TFRecord 데이터 세트를 대규모로 변환할 때 권장하는 접근법은 Apache Beam을 사용하는 것입니다. array_record 라이브러리는 사전에 패키지로 만든 Beam 파이프라인을 제공하므로 변환이 놀라울 정도로 간단하며 확장도 가능합니다. 방대한 데이터 세트에 이 메서드를 적극 권장하는 이유는 Google Cloud Dataflow 같은 서비스를 사용하는 여러 작업자에 처리를 분산시킬 수 있기 때문입니다.

전제 조건: Apache Beam 및 Array Record Beam SDK 설치

pip install -q apache-beam
pip install -q array-record-beam-sdk
Shell

사전에 패키지로 만든 변환 파이프라인 사용

array_record.beam.pipelines 모듈에 포함된 convert_tf_to_arrayrecord_disk_match_shards 함수는 특정한 목적으로 개발한 유틸리티로, 전체 변환 프로세스를 담당합니다. 이 유틸리티는 TFRecord 파일을 판독하여 분할된 해당 ArrayRecord 데이터 세트를 작성합니다.

이를 Python 스크립트에서 사용하는 방법은 다음과 같습니다.

from apache_beam.options import pipeline_options
from array_record.beam.pipelines import convert_tf_to_arrayrecord_disk_match_shards
 
# 1. Define your input and output patterns.
# This example uses Google Cloud Storage (GCS) paths, which is common for large datasets.
input_pattern = 'gs://your-gcs-bucket/path/to/records-*.tfrecord'
output_path = 'gs://your-gcs-bucket/path/to/converted-records'
 
# Arguments dictionary for the conversion function.
args = {
    'input': input_pattern,
    'output': output_path,
}
 
# 2. Configure pipeline options for execution.
 
# To run locally on your machine (for smaller datasets or testing):
# No options are needed; the local runner is used by default.
local_pipeline_options = pipeline_options.PipelineOptions()
 
 
# To run at scale on Google Cloud Dataflow (for large datasets):
# Uncomment the following lines and fill in your project details.
#
# dataflow_pipeline_options = pipeline_options.PipelineOptions(
#     runner='DataflowRunner',
#     project='your-gcp-project-id',
#     region='your-gcp-region',
#     # A requirements.txt file may be needed for dependencies on Dataflow workers.
#     # requirements_file='requirements.txt',
#     temp_location='gs://your-gcs-bucket/path/to/temp'
# )
 
 
# 3. Define and run the main execution logic.
def main():
  print("Starting TFRecord to ArrayRecord conversion...")
  convert_tf_to_arrayrecord_disk_match_shards(
      args=args,
      # Pass the appropriate options here.
      # Use `local_pipeline_options` for local runs.
      # Use `dataflow_pipeline_options` for cloud runs.
      pipeline_options=local_pipeline_options,
  ).run()
  print(f"Conversion complete. ArrayRecord files written to '{output_path}'.")
 
if __name__ == '__main__':
  main()
Python

이 접근법은 변환 작업을 위해 특별히 설계되어 테스트를 거친 고수준 API이며, 출력 분할을 자동으로 입력 분할과 매치하는 등의 세부 사항을 처리해 주므로 분할 파이프라인을 작성하는 것보다 더 효과적이고 확실합니다.


Grain 및 ArrayRecord 파이프라인 구축의 개념 둘러보기

데이터가 ArrayRecord 포맷으로 변환되면 Grain Dataset API를 사용해서 고성능 입력 파이프라인을 정의할 수 있습니다. 이 프로세스에는 소스를 생성하고 여러 가지 변환 메서드를 체인으로 연결하는 작업이 포함됩니다.

1단계: 데이터 소스에서 MapDataset 생성하기

우선 ArrayRecord 파일을 지정하여 MapDataset를 생성합니다.

import grain 
 
# Path to your generated ArrayRecord files
file_paths = ["~/tensorflow_datasets/cifar10/3.0.2/cifar10-train.array_record-00000-of-00001"]
 
# Create a data source
data_source = grain.sources.ArrayRecordDataSource(file_paths)
 
# Create a MapDataset from the source
dataset = grain.MapDataset.source(data_source)
Python

2단계: 변환을 체인으로 연결하기(셔플, 매핑, 배치)

이제 MapDataset에 변환을 적용합니다. 각 메서드가 새로운 MapDataset를 반환하므로 사용자는 선언을 통해 호출을 체인으로 연결할 수 있습니다.

# Example parsing function
def parse_and_transform(record):
    # Your logic to parse features, augment data, etc.
    return {"record": record}
 
BATCH_SIZE = 32
 
# Chain transformations
# The order of operations matters.
dataset = (
    dataset.shuffle(seed=42)
           .map(parse_and_transform)
           .batch(batch_size=BATCH_SIZE, drop_remainder=True)
)
Python

Step 3: Create and use the DatasetIterator

마지막으로, 완전히 정의한 데이터세트에서 반복자를 생성하여 학습 스크립트의 데이터를 루프합니다.

# Create the stateful iterator
data_iterator = iter(dataset)
 
# You can now loop over the data
for batch in data_iterator:
    # Your training step with the batch...
    pass
 
    # For checkpoint saving/restoration, you can get/set the iterator's state
    # state = data_iterator.get_state()
    # data_iterator.set_state(state)
Python

성능 구성 설정

데이터 파이프라인이 일으키는 병목 현상을 방지하려면 다중 처리를 사용해서 데이터 로드 및 전처리와 모델 학습을 동시에 처리해야 합니다. 이를 위해 Dataset API에서 파이프라인에 .mp_prefetch() 변환을 추가합니다.

이 메서드는 작업자 프로세스 풀을 시작해서 백그라운드에서 데이터 배치를 비동기적으로 준비하여 사용자의 학습 루프에 데이터 배치가 필요할 때 바로 사용할 수 있도록 버퍼로 저장합니다.

적용 방법은 다음과 같습니다.

# The full pipeline with performance tuning.
dataset = (
    grain.MapDataset.source(data_source)
    .shuffle(seed=42)
    .map(parse_and_transform)
 
    # Convert to an iterable dataset to apply prefetching.
    .to_iter_dataset()
    .batch(batch_size=BATCH_SIZE, drop_remainder=True)
    # Apply multiprocessing and prefetching.
    .mp_prefetch(
        grain.multiprocessing.MultiprocessingOptions(
            num_workers=16   # Number of parallel worker processes.
        )
    )
)
 
# Create the final iterator
data_iterator = iter(dataset)
Python
  • num_workers: 데이터 로드에 사용할 병렬 하위 프로세스 수를 지정합니다. 가속기가 데이터를 기다리느라 유휴 상태에 들어가는 일이 잦을 경우, 이 값을 높이면 처리량이 눈에 띄게 개선됩니다. 최적의 값은 머신에서 이용 가능한 CPU 코어와 매핑 기능의 복잡성에 따라 달라집니다.


추가 정보

자세한 정보를 확인하고 구축을 시작하고 싶으신가요? 이 가이드에서 다룬 기술의 공식 문서와 소스 코드를 살펴보세요.

기반 기술


실제 예시: 대규모 LLM 학습

Grain과 ArrayRecord로 구축된 충분한 성능의 결정적 데이터 파이프라인은 대규모 모델 학습에 꼭 필요합니다. 주된 예시로는 JAX로 작성한 고성능 오픈 소스 대규모 언어 모델인 MaxText를 들 수 있습니다. MaxText는 정확히 이 데이터 파이프라인 기법을 활용해서 대형 TPU 및 GPU 클러스터에 효율적으로 데이터를 공급합니다.