Criação de pipelines de dados de alto desempenho com Grain e ArrayRecord

3 DE OUTUBRO DE 2025
Jiyang Kang Technical Program Manager
Shivaji Dutta Field Solutions Architect
Ihor Indyk Software Engineer
Felix Chern Software Engineer

Ao treinar modelos grandes em aceleradores poderosos, como GPUs e TPUs, a última coisa que você quer é que o acelerador fique inativo enquanto aguarda os dados. O sistema, como um todo, é tão rápido quanto sua parte mais lenta e, muitas vezes, esse gargalo é o pipeline de entrada de dados. Portanto, para o aprendizado de máquina em larga escala, um pipeline de dados eficiente e reproduzível é essencial. Este guia mostrará como resolver esse desafio criando um pipeline de dados robusto e de alto desempenho com o Grain, uma biblioteca de carregamento de dados flexível para o JAX, e o ArrayRecord, um formato de arquivo altamente eficiente.


Noções básicas sobre os componentes centrais

Grain: um carregador de dados de alto desempenho para o JAX

O Grain é uma biblioteca de carregamento de dados leve e de código aberto projetada especificamente para resolver esse problema para cargas de trabalho baseadas no JAX. Ele garante que os dados sejam carregados, pré-processados e alimentados no modelo de forma eficiente, permitindo maximizar o desempenho do hardware.

Por que usar o Grain?

O Grain foi criado com base em uma filosofia de desempenho, reprodutibilidade e flexibilidade. Veja a seguir os principais benefícios que ele oferece:

  • Exceptional performance: Grain is built for speed. It uses efficient multiprocessing (via the .mp_prefetch() method) to run data loading and transformations in parallel, ensuring that a buffer of prepared data is always ready for your model. This keeps your accelerators saturated and minimizes training time.

  • Guaranteed determinism and reproducibility: Grain provides full reproducibility, which is critical for credible research. By setting a simple seed, you ensure the data is always shuffled the same way. Crucially, its data iterators are stateful and can be checkpointed. This means if your training job is interrupted or preempted, you can restart from the exact same point in the data stream.

  • Uma API declarativa e intuitiva: você define o pipeline de dados encadeando métodos simples e legíveis. A partir de uma origem MapDataset, você pode adicionar, com fluidez, transformações como .shuffle (), .map() e .batch(). Esse estilo declarativo torna o pipeline de dados fácil de entender, modificar e manter.

  • Unlocking true global shuffling: To get the best performance from your models, you need to shuffle your data effectively. When paired with a file format that supports random access, like ArrayRecord, Grain can perform a true global shuffle across your entire dataset, even when it doesn’t fit into host memory. This is a powerful feature that is often computationally impractical with other data loaders and formats.


O que é o ArrayRecord e por que usá-lo?

Embora o TFRecord seja um padrão familiar, sua natureza sequencial não permite um verdadeiro embaralhamento global. O ArrayRecord é um formato de arquivo moderno, projetado especificamente para resolver esse problema ao oferecer uma nova fronteira em eficiência de dados.

ArrayRecord File Layout

Como funciona: projetado para velocidade e paralelismo

O alto desempenho do ArrayRecord está enraizado em seu design central, baseado no formato de arquivo Riegeli do Google. Essa estrutura oferece várias vantagens importantes para o tratamento de dados em larga escala:

  1. Efficient random access: ArrayRecord features a built-in metadata index that maps every record to its precise location. This is the key design choice that enables instant, direct access to any record in the dataset, completely avoiding the need to read a file from the beginning.


2. Massive parallelism: Records are grouped into data chunks. This structure is inherently designed to be read in parallel, allowing multiple processes to read different chunks of the same file simultaneously to dramatically increase read throughput.


3. Exceptional performance: As a result of this indexed and chunked design, benchmarks show ArrayRecord can achieve a read throughput an order of magnitude higher than traditional formats, making it ideal for today's massive datasets.


4. Smart data integrity: The format handles data integrity intelligently by leveraging the powerful error correction in underlying cloud storage systems (like Google Cloud Storage) rather than adding redundant checks. This provides robust protection against corruption without unnecessary performance overhead.


Por que estamos usando?

Os recursos do ArrayRecord ativam diretamente os recursos avançados exigidos pelos carregadores de dados modernos, como o Grain.

The most important benefit is achieving true, deterministic global shuffling. Because any record can be accessed instantly, a data loader can generate perfectly randomized indices in the dataset on the fly as the training happens and then fetch data in that specific order. This capability, which is computationally impractical with sequential formats like TFRecord, is vital for reproducible research and optimal model training.


ArrayRecord versus TFRecord: uma comparação detalhada

Segue uma análise detalhada de como o ArrayRecord e o TFRecord se comparam quanto aos principais recursos:

  1. Estrutura

  • O ArrayRecord foi criado com base no formato de arquivo Riegeli do Google, projetado para armazenar sequências de registros com foco em decodificação de alta velocidade, integridade de dados e compactação forte. Ele agrupa os registros em blocos e inclui um índice de metadados no final do arquivo.

  • O TFRecord é uma sequência de registros binários na qual cada registro é, tipicamente, um buffer de protocolo tf.train.Example.


2. Acesso aleatório

  • O ArrayRecord oferece acesso aleatório nativo e eficiente. Sua estrutura de arquivo inclui um índice interno de posições de registro, permitindo o acesso direto e rápido a qualquer registro por seu índice, sem a necessidade de ler todo o arquivo.

  • O TFRecord, por outro lado, não tem acesso aleatório nativo. Como um formato sequencial otimizado para o streaming de dados, o acesso a um registro específico requer a iteração pelo arquivo, desde o início.


3. Embaralhamento global

  • Com o ArrayRecord, o verdadeiro embaralhamento global se torna possível. Graças a seu acesso aleatório eficiente, um carregador de dados como o Grain pode gerar índices em ordem embaralhada e ler registros em tempo real.

  • Com o TFRecord, é difícil de alcançar o verdadeiro embaralhamento global. Seu embaralhamento "global" geralmente depende de aproximações, como ao embaralhar uma lista de nomes de arquivos fragmentados e, em seguida, embaralhar registros dentro de um pequeno buffer de memória. Esse não é um verdadeiro embaralhamento global.


4. E/S paralela

  • O ArrayRecord dá suporte nativo à E/S paralela. A estrutura interna em blocos de um arquivo ArrayRecord torna inerentemente fácil a leitura em paralelo de vários processos de diferentes partes do mesmo arquivo, o que simplifica o gerenciamento de dados.

  • O TFRecord dá suporte à leitura paralela, mas ela é normalmente alcançada pela fragmentação do conjunto de dados em muitos arquivos TFRecord pequenos e com diferentes workers lendo diferentes arquivos. Isso pode resultar em um grande número de arquivos para gerenciar.


5. Integração

  • O ArrayRecord foi projetado para E/S de alto desempenho e funciona perfeitamente com carregadores baseados em JAX, como o Grain. Ele também pode ser utilizado dentro do ecossistema do TensorFlow via tfds.data_source.

  • O TFRecord é altamente integrado ao ecossistema tf.data do TensorFlow.


6. Caso de uso principal

  • ArrayRecord is ideal for high-throughput data loading for performance-critical machine learning, especially where determinism and true global shuffling are required (e.g., JAX/TPU workloads).

  • TFRecord is suited for general-purpose, large-scale data storage for TensorFlow and is optimized for sequential reads.


Como converter conjuntos de dados TFRecord em ArrayRecord

O método para converter um conjunto de dados depende de ele ser o conjunto padrão registrado no catálogo TensorFlow Datasets (TFDS) ou um conjunto personalizado e proprietário.

Método 1: para conjuntos de dados padrão no catálogo TFDS

Se você estiver usando um conjunto de dados bem conhecido, como cifar10 ou imagenet2012, a ferramenta de linha de comando tfds será o método mais simples.

Pré-requisito: instalar os conjuntos de dados do TensorFlow

pip install -q --upgrade tfds-nightly
Shell

Com o uso da CLI tfds build

Esse comando faz o download dos dados de origem, executa a lógica de preparação e salva a saída no formato desejado.

# Generate the 'cifar10' dataset in ArrayRecord format
tfds build cifar10 --file_format=array_record
Shell

Os arquivos ArrayRecord gerados serão armazenados no diretório ~/tensorflow_datasets/, prontos para uso.

Método 2: para conjuntos de dados TFRecord personalizados ou proprietários

Para a conversão em larga escala de seus próprios conjuntos de dados TFRecord personalizados, a abordagem recomendada é usar o Apache Beam. A biblioteca array_record fornece pipelines do Beam pré-empacotados que tornam essa conversão incrivelmente simples e escalonável. Esse método é altamente recomendado para conjuntos de dados muito grandes, já que o processamento pode ser distribuído entre muitos workers por meio de um serviço como o Google Cloud Dataflow.

Pré-requisitos: instalar o Apache Beam e o SDK Array Record Beam

pip install -q apache-beam
pip install -q array-record-beam-sdk
Shell

Com o uso do pipeline de conversão pré-empacotado

O módulo array_record.beam.pipelines contém a função convert_tf_to_arrayrecord_disk_match_shards, um utilitário criado especificamente para lidar com todo o processo de conversão. Ele lê arquivos TFRecord e grava um conjunto de dados ArrayRecord fragmentado correspondente.

Veja como você o usaria em um script do Python:

from apache_beam.options import pipeline_options
from array_record.beam.pipelines import convert_tf_to_arrayrecord_disk_match_shards
 
# 1. Define your input and output patterns.
# This example uses Google Cloud Storage (GCS) paths, which is common for large datasets.
input_pattern = 'gs://your-gcs-bucket/path/to/records-*.tfrecord'
output_path = 'gs://your-gcs-bucket/path/to/converted-records'
 
# Arguments dictionary for the conversion function.
args = {
    'input': input_pattern,
    'output': output_path,
}
 
# 2. Configure pipeline options for execution.
 
# To run locally on your machine (for smaller datasets or testing):
# No options are needed; the local runner is used by default.
local_pipeline_options = pipeline_options.PipelineOptions()
 
 
# To run at scale on Google Cloud Dataflow (for large datasets):
# Uncomment the following lines and fill in your project details.
#
# dataflow_pipeline_options = pipeline_options.PipelineOptions(
#     runner='DataflowRunner',
#     project='your-gcp-project-id',
#     region='your-gcp-region',
#     # A requirements.txt file may be needed for dependencies on Dataflow workers.
#     # requirements_file='requirements.txt',
#     temp_location='gs://your-gcs-bucket/path/to/temp'
# )
 
 
# 3. Define and run the main execution logic.
def main():
  print("Starting TFRecord to ArrayRecord conversion...")
  convert_tf_to_arrayrecord_disk_match_shards(
      args=args,
      # Pass the appropriate options here.
      # Use `local_pipeline_options` for local runs.
      # Use `dataflow_pipeline_options` for cloud runs.
      pipeline_options=local_pipeline_options,
  ).run()
  print(f"Conversion complete. ArrayRecord files written to '{output_path}'.")
 
if __name__ == '__main__':
  main()
Python

Essa abordagem é mais poderosa e robusta do que escrever um pipeline manual porque se trata de uma API testada e de alto nível projetada especificamente para essa tarefa, lidando com detalhes como a correspondência automática entre fragmentos de saída e fragmentos de entrada.


Criação de um pipeline Grain e ArrayRecord: tutorial conceitual

Depois que seus dados estiverem no formato ArrayRecord, você poderá definir o pipeline de entrada de alto desempenho usando a API Dataset do Grain. O processo envolve a criação de uma origem e, em seguida, o encadeamento de métodos de transformação.

Etapa 1: crie um MapDataset a partir de uma origem de dados

Primeiro, aponte para os arquivos ArrayRecord para criar um MapDataset.

import grain 
 
# Path to your generated ArrayRecord files
file_paths = ["~/tensorflow_datasets/cifar10/3.0.2/cifar10-train.array_record-00000-of-00001"]
 
# Create a data source
data_source = grain.sources.ArrayRecordDataSource(file_paths)
 
# Create a MapDataset from the source
dataset = grain.MapDataset.source(data_source)
Python

Etapa 2: encadeie as transformações (Shuffle, Map, Batch)

Agora, aplique transformações ao MapDataset. Cada método retorna um novo MapDataset, permitindo encadear as chamadas de forma declarativa.

# Example parsing function
def parse_and_transform(record):
    # Your logic to parse features, augment data, etc.
    return {"record": record}
 
BATCH_SIZE = 32
 
# Chain transformations
# The order of operations matters.
dataset = (
    dataset.shuffle(seed=42)
           .map(parse_and_transform)
           .batch(batch_size=BATCH_SIZE, drop_remainder=True)
)
Python

Step 3: Create and use the DatasetIterator

Por fim, crie um iterador a partir do conjunto de dados totalmente definido para fazer a repetição pelos dados no script de treinamento.

# Create the stateful iterator
data_iterator = iter(dataset)
 
# You can now loop over the data
for batch in data_iterator:
    # Your training step with the batch...
    pass
 
    # For checkpoint saving/restoration, you can get/set the iterator's state
    # state = data_iterator.get_state()
    # data_iterator.set_state(state)
Python

Configurações de desempenho

Para evitar que o pipeline de dados se torne um gargalo, você deve usar o multiprocessamento para carregar e pré-processar os dados paralelamente ao treinamento do modelo. Na API Dataset, isso é feito pela adição da transformação .mp_prefetch() ao pipeline.

Esse método inicia um grupo de processos worker para preparar lotes de dados de forma assíncrona em segundo plano e armazená-los em um buffer, para que estejam prontos quando o ciclo de treinamento precisar deles.

Veja como aplicá-lo:

# The full pipeline with performance tuning.
dataset = (
    grain.MapDataset.source(data_source)
    .shuffle(seed=42)
    .map(parse_and_transform)
 
    # Convert to an iterable dataset to apply prefetching.
    .to_iter_dataset()
    .batch(batch_size=BATCH_SIZE, drop_remainder=True)
    # Apply multiprocessing and prefetching.
    .mp_prefetch(
        grain.multiprocessing.MultiprocessingOptions(
            num_workers=16   # Number of parallel worker processes.
        )
    )
)
 
# Create the final iterator
data_iterator = iter(dataset)
Python
  • num_workers: especifica o número de processos filhos paralelos a serem usados para o carregamento de dados. Se você perceber que o acelerador fica inativo muitas vezes, à espera de dados, o aumento desse valor poderá melhorar significativamente a capacidade de processamento. O número ideal depende dos núcleos de CPU disponíveis na máquina e da complexidade da função map.


Explore mais

Quer se aprofundar e começar a criar? Confira a documentação oficial e o código-fonte das tecnologias discutidas neste guia.

Tecnologias de base


Exemplo do mundo real: treinamento de LLM em larga escala

Os pipelines de dados determinísticos e de alto desempenho criados com o Grain e o ArrayRecord são cruciais para o treinamento de modelos em larga escala. Um excelente exemplo é o MaxText, um LLM de código aberto de alto desempenho escrito em JAX. O MaxText aproveita essas técnicas exatas de pipeline de dados para alimentar dados com eficiência em grandes clusters de TPU e GPU.