Criação de pipelines de dados de alto desempenho com Grain e ArrayRecord

7 DE OUTUBRO DE 2025
Jiyang Kang Technical Program Manager
Shivaji Dutta Field Solutions Architect
Ihor Indyk Software Engineer
Felix Chern Software Engineer

Ao treinar modelos grandes em aceleradores poderosos, como GPUs e TPUs, a última coisa que você quer é que o acelerador fique inativo enquanto aguarda os dados. O sistema, como um todo, é tão rápido quanto sua parte mais lenta e, muitas vezes, esse gargalo é o pipeline de entrada de dados. Portanto, para o aprendizado de máquina em larga escala, um pipeline de dados eficiente e reproduzível é essencial. Este guia mostrará como resolver esse desafio criando um pipeline de dados robusto e de alto desempenho com o Grain, uma biblioteca de carregamento de dados flexível para o JAX, e o ArrayRecord, um formato de arquivo altamente eficiente.


Noções básicas sobre os componentes centrais

Grain: um carregador de dados de alto desempenho para o JAX

O Grain é uma biblioteca de carregamento de dados leve e de código aberto projetada especificamente para resolver esse problema para cargas de trabalho baseadas no JAX. Ele garante que os dados sejam carregados, pré-processados e alimentados no modelo de forma eficiente, permitindo maximizar o desempenho do hardware.

Por que usar o Grain?

O Grain foi criado com base em uma filosofia de desempenho, reprodutibilidade e flexibilidade. Veja a seguir os principais benefícios que ele oferece:

  • Desempenho excepcional: o Grain foi criado para oferecer velocidade. Ele usa o multiprocessamento eficiente por meio do método .mp_prefetch() para executar o carregamento de dados e as transformações em paralelo, garantindo que um buffer de dados preparados esteja sempre pronto para o modelo. Isso mantém os aceleradores saturados e minimiza o tempo de treinamento.

  • Determinismo e reprodutibilidade garantidos: o Grain fornece reprodutibilidade total, o que é crucial para pesquisas confiáveis. Ao definir uma propagação simples, você garante que os dados sejam sempre embaralhados da mesma maneira. Fundamentalmente, ele tem iteradores de dados com estado e que podem ter pontos de verificação definidos. Isso significa que, se o job de treinamento for interrompido, de maneira forçada ou não, você poderá recomeçar exatamente no mesmo ponto no fluxo de dados.

  • Uma API declarativa e intuitiva: você define o pipeline de dados encadeando métodos simples e legíveis. A partir de uma origem MapDataset, você pode adicionar, com fluidez, transformações como .shuffle (), .map() e .batch(). Esse estilo declarativo torna o pipeline de dados fácil de entender, modificar e manter.

  • Habilitação do verdadeiro embaralhamento global: para obter o melhor desempenho dos modelos, você precisa embaralhar os dados de forma eficaz. Quando utilizado em conjunto com um formato de arquivo que dá suporte ao acesso aleatório, como ArrayRecord, o Grain pode fazer um verdadeiro embaralhamento global em todo o conjunto de dados, mesmo quando ele não cabe na memória do host. Esse é um recurso poderoso que, muitas vezes, é inviável, do ponto de vista da computação, com outros carregadores e formatos de dados.


O que é o ArrayRecord e por que usá-lo?

Embora o TFRecord seja um padrão familiar, sua natureza sequencial não permite um verdadeiro embaralhamento global. O ArrayRecord é um formato de arquivo moderno, projetado especificamente para resolver esse problema ao oferecer uma nova fronteira em eficiência de dados.

ArrayRecord File Layout

Como funciona: projetado para velocidade e paralelismo

O alto desempenho do ArrayRecord está enraizado em seu design central, baseado no formato de arquivo Riegeli do Google. Essa estrutura oferece várias vantagens importantes para o tratamento de dados em larga escala:

  1. Acesso aleatório eficiente: o ArrayRecord tem um índice de metadados integrado que mapeia cada registro para sua localização precisa. Essa é a principal escolha de design que permite o acesso instantâneo e direto a qualquer registro no conjunto de dados, evitando completamente a necessidade de ler um arquivo desde o início.


2. Paralelismo massivo: os registros são agrupados em blocos de dados. Essa estrutura é inerentemente projetada para ser lida em paralelo, permitindo que vários processos leiam diferentes blocos do mesmo arquivo ao mesmo tempo para aumentar drasticamente a capacidade de processamento da leitura.


3. Desempenho excepcional: como resultado desse design indexado e fragmentado, os comparativos de mercado mostram que o ArrayRecord pode alcançar uma capacidade de processamento de leitura com uma ordem de magnitude mais alta do que os formatos tradicionais, o que o torna ideal para os enormes conjuntos de dados de hoje.


4. Integridade de dados inteligente: o formato lida com a integridade de dados de forma inteligente, aproveitando a poderosa correção de erros de sistemas de armazenamento em nuvem subjacentes (como o Google Cloud Storage), em vez de adicionar verificações redundantes. Isso fornece proteção robusta contra corrupção sem overhead desnecessário sobre o desempenho.


Por que estamos usando?

Os recursos do ArrayRecord ativam diretamente os recursos avançados exigidos pelos carregadores de dados modernos, como o Grain.

O benefício mais importante é alcançar um embaralhamento global verdadeiro e determinista. Como qualquer registro pode ser acessado instantaneamente, um carregador de dados pode gerar índices perfeitamente aleatórios no conjunto de dados em tempo real durante o treinamento e, em seguida, buscar dados nessa ordem específica. Essa capacidade, que é inviável do ponto de vista computacional com os formatos sequenciais, como o TFRecord, é vital para a pesquisa reproduzível e o treinamento ideal de modelos.


ArrayRecord versus TFRecord: uma comparação detalhada

Segue uma análise detalhada de como o ArrayRecord e o TFRecord se comparam quanto aos principais recursos:

  1. Estrutura

  • O ArrayRecord foi criado com base no formato de arquivo Riegeli do Google, projetado para armazenar sequências de registros com foco em decodificação de alta velocidade, integridade de dados e compactação forte. Ele agrupa os registros em blocos e inclui um índice de metadados no final do arquivo.

  • O TFRecord é uma sequência de registros binários na qual cada registro é, tipicamente, um buffer de protocolo tf.train.Example.


2. Acesso aleatório

  • O ArrayRecord oferece acesso aleatório nativo e eficiente. Sua estrutura de arquivo inclui um índice interno de posições de registro, permitindo o acesso direto e rápido a qualquer registro por seu índice, sem a necessidade de ler todo o arquivo.

  • O TFRecord, por outro lado, não tem acesso aleatório nativo. Como um formato sequencial otimizado para o streaming de dados, o acesso a um registro específico requer a iteração pelo arquivo, desde o início.


3. Embaralhamento global

  • Com o ArrayRecord, o verdadeiro embaralhamento global se torna possível. Graças a seu acesso aleatório eficiente, um carregador de dados como o Grain pode gerar índices em ordem embaralhada e ler registros em tempo real.

  • Com o TFRecord, é difícil de alcançar o verdadeiro embaralhamento global. Seu embaralhamento "global" geralmente depende de aproximações, como ao embaralhar uma lista de nomes de arquivos fragmentados e, em seguida, embaralhar registros dentro de um pequeno buffer de memória. Esse não é um verdadeiro embaralhamento global.


4. E/S paralela

  • O ArrayRecord dá suporte nativo à E/S paralela. A estrutura interna em blocos de um arquivo ArrayRecord torna inerentemente fácil a leitura em paralelo de vários processos de diferentes partes do mesmo arquivo, o que simplifica o gerenciamento de dados.

  • O TFRecord dá suporte à leitura paralela, mas ela é normalmente alcançada pela fragmentação do conjunto de dados em muitos arquivos TFRecord pequenos e com diferentes workers lendo diferentes arquivos. Isso pode resultar em um grande número de arquivos para gerenciar.


5. Integração

  • O ArrayRecord foi projetado para E/S de alto desempenho e funciona perfeitamente com carregadores baseados em JAX, como o Grain. Ele também pode ser utilizado dentro do ecossistema do TensorFlow via tfds.data_source.

  • O TFRecord é altamente integrado ao ecossistema tf.data do TensorFlow.


6. Caso de uso principal

  • O ArrayRecord é ideal para o carregamento de dados com alta capacidade de processamento para o aprendizado de máquina no qual o desempenho é crucial, especialmente quando o determinismo e o verdadeiro embaralhamento global são necessários (por exemplo, em cargas de trabalho do JAX/de TPU).

  • O TFRecord é adequado para o armazenamento de dados de uso geral e em larga escala para o TensorFlow e é otimizado para leituras sequenciais.


Como converter conjuntos de dados TFRecord em ArrayRecord

O método para converter um conjunto de dados depende de ele ser o conjunto padrão registrado no catálogo TensorFlow Datasets (TFDS) ou um conjunto personalizado e proprietário.

Método 1: para conjuntos de dados padrão no catálogo TFDS

Se você estiver usando um conjunto de dados bem conhecido, como cifar10 ou imagenet2012, a ferramenta de linha de comando tfds será o método mais simples.

Pré-requisito: instalar os conjuntos de dados do TensorFlow

pip install -q --upgrade tfds-nightly
Shell

Com o uso da CLI tfds build

Esse comando faz o download dos dados de origem, executa a lógica de preparação e salva a saída no formato desejado.

# Generate the 'cifar10' dataset in ArrayRecord format
tfds build cifar10 --file_format=array_record
Shell

Os arquivos ArrayRecord gerados serão armazenados no diretório ~/tensorflow_datasets/, prontos para uso.

Método 2: para conjuntos de dados TFRecord personalizados ou proprietários

Para a conversão em larga escala de seus próprios conjuntos de dados TFRecord personalizados, a abordagem recomendada é usar o Apache Beam. A biblioteca array_record fornece pipelines do Beam pré-empacotados que tornam essa conversão incrivelmente simples e escalonável. Esse método é altamente recomendado para conjuntos de dados muito grandes, já que o processamento pode ser distribuído entre muitos workers por meio de um serviço como o Google Cloud Dataflow.

Pré-requisitos: instalar o Apache Beam e o SDK Array Record Beam

pip install -q apache-beam
pip install -q array-record-beam-sdk
Shell

Com o uso do pipeline de conversão pré-empacotado

O módulo array_record.beam.pipelines contém a função convert_tf_to_arrayrecord_disk_match_shards, um utilitário criado especificamente para lidar com todo o processo de conversão. Ele lê arquivos TFRecord e grava um conjunto de dados ArrayRecord fragmentado correspondente.

Veja como você o usaria em um script do Python:

from apache_beam.options import pipeline_options
from array_record.beam.pipelines import convert_tf_to_arrayrecord_disk_match_shards
 
# 1. Define your input and output patterns.
# This example uses Google Cloud Storage (GCS) paths, which is common for large datasets.
input_pattern = 'gs://your-gcs-bucket/path/to/records-*.tfrecord'
output_path = 'gs://your-gcs-bucket/path/to/converted-records'
 
# Arguments dictionary for the conversion function.
args = {
    'input': input_pattern,
    'output': output_path,
}
 
# 2. Configure pipeline options for execution.
 
# To run locally on your machine (for smaller datasets or testing):
# No options are needed; the local runner is used by default.
local_pipeline_options = pipeline_options.PipelineOptions()
 
 
# To run at scale on Google Cloud Dataflow (for large datasets):
# Uncomment the following lines and fill in your project details.
#
# dataflow_pipeline_options = pipeline_options.PipelineOptions(
#     runner='DataflowRunner',
#     project='your-gcp-project-id',
#     region='your-gcp-region',
#     # A requirements.txt file may be needed for dependencies on Dataflow workers.
#     # requirements_file='requirements.txt',
#     temp_location='gs://your-gcs-bucket/path/to/temp'
# )
 
 
# 3. Define and run the main execution logic.
def main():
  print("Starting TFRecord to ArrayRecord conversion...")
  convert_tf_to_arrayrecord_disk_match_shards(
      args=args,
      # Pass the appropriate options here.
      # Use `local_pipeline_options` for local runs.
      # Use `dataflow_pipeline_options` for cloud runs.
      pipeline_options=local_pipeline_options,
  ).run()
  print(f"Conversion complete. ArrayRecord files written to '{output_path}'.")
 
if __name__ == '__main__':
  main()
Python

Essa abordagem é mais poderosa e robusta do que escrever um pipeline manual porque se trata de uma API testada e de alto nível projetada especificamente para essa tarefa, lidando com detalhes como a correspondência automática entre fragmentos de saída e fragmentos de entrada.


Criação de um pipeline Grain e ArrayRecord: tutorial conceitual

Depois que seus dados estiverem no formato ArrayRecord, você poderá definir o pipeline de entrada de alto desempenho usando a API Dataset do Grain. O processo envolve a criação de uma origem e, em seguida, o encadeamento de métodos de transformação.

Etapa 1: crie um MapDataset a partir de uma origem de dados

Primeiro, aponte para os arquivos ArrayRecord para criar um MapDataset.

import grain 
 
# Path to your generated ArrayRecord files
file_paths = ["~/tensorflow_datasets/cifar10/3.0.2/cifar10-train.array_record-00000-of-00001"]
 
# Create a data source
data_source = grain.sources.ArrayRecordDataSource(file_paths)
 
# Create a MapDataset from the source
dataset = grain.MapDataset.source(data_source)
Python

Etapa 2: encadeie as transformações (Shuffle, Map, Batch)

Agora, aplique transformações ao MapDataset. Cada método retorna um novo MapDataset, permitindo encadear as chamadas de forma declarativa.

# Example parsing function
def parse_and_transform(record):
    # Your logic to parse features, augment data, etc.
    return {"record": record}
 
BATCH_SIZE = 32
 
# Chain transformations
# The order of operations matters.
dataset = (
    dataset.shuffle(seed=42)
           .map(parse_and_transform)
           .batch(batch_size=BATCH_SIZE, drop_remainder=True)
)
Python

Etapa 3: crie e use o DatasetIterator

Por fim, crie um iterador a partir do conjunto de dados totalmente definido para fazer a repetição pelos dados no script de treinamento.

# Create the stateful iterator
data_iterator = iter(dataset)
 
# You can now loop over the data
for batch in data_iterator:
    # Your training step with the batch...
    pass
 
    # For checkpoint saving/restoration, you can get/set the iterator's state
    # state = data_iterator.get_state()
    # data_iterator.set_state(state)
Python

Configurações de desempenho

Para evitar que o pipeline de dados se torne um gargalo, você deve usar o multiprocessamento para carregar e pré-processar os dados paralelamente ao treinamento do modelo. Na API Dataset, isso é feito pela adição da transformação .mp_prefetch() ao pipeline.

Esse método inicia um grupo de processos worker para preparar lotes de dados de forma assíncrona em segundo plano e armazená-los em um buffer, para que estejam prontos quando o ciclo de treinamento precisar deles.

Veja como aplicá-lo:

# The full pipeline with performance tuning.
dataset = (
    grain.MapDataset.source(data_source)
    .shuffle(seed=42)
    .map(parse_and_transform)
 
    # Convert to an iterable dataset to apply prefetching.
    .to_iter_dataset()
    .batch(batch_size=BATCH_SIZE, drop_remainder=True)
    # Apply multiprocessing and prefetching.
    .mp_prefetch(
        grain.multiprocessing.MultiprocessingOptions(
            num_workers=16   # Number of parallel worker processes.
        )
    )
)
 
# Create the final iterator
data_iterator = iter(dataset)
Python
  • num_workers: especifica o número de processos filhos paralelos a serem usados para o carregamento de dados. Se você perceber que o acelerador fica inativo muitas vezes, à espera de dados, o aumento desse valor poderá melhorar significativamente a capacidade de processamento. O número ideal depende dos núcleos de CPU disponíveis na máquina e da complexidade da função map.


Explore mais

Quer se aprofundar e começar a criar? Confira a documentação oficial e o código-fonte das tecnologias discutidas neste guia.

Tecnologias de base


Exemplo do mundo real: treinamento de LLM em larga escala

Os pipelines de dados determinísticos e de alto desempenho criados com o Grain e o ArrayRecord são cruciais para o treinamento de modelos em larga escala. Um excelente exemplo é o MaxText, um LLM de código aberto de alto desempenho escrito em JAX. O MaxText aproveita essas técnicas exatas de pipeline de dados para alimentar dados com eficiência em grandes clusters de TPU e GPU.