Ao treinar modelos grandes em aceleradores poderosos, como GPUs e TPUs, a última coisa que você quer é que o acelerador fique inativo enquanto aguarda os dados. O sistema, como um todo, é tão rápido quanto sua parte mais lenta e, muitas vezes, esse gargalo é o pipeline de entrada de dados. Portanto, para o aprendizado de máquina em larga escala, um pipeline de dados eficiente e reproduzível é essencial. Este guia mostrará como resolver esse desafio criando um pipeline de dados robusto e de alto desempenho com o Grain, uma biblioteca de carregamento de dados flexível para o JAX, e o ArrayRecord, um formato de arquivo altamente eficiente.
O Grain é uma biblioteca de carregamento de dados leve e de código aberto projetada especificamente para resolver esse problema para cargas de trabalho baseadas no JAX. Ele garante que os dados sejam carregados, pré-processados e alimentados no modelo de forma eficiente, permitindo maximizar o desempenho do hardware.
O Grain foi criado com base em uma filosofia de desempenho, reprodutibilidade e flexibilidade. Veja a seguir os principais benefícios que ele oferece:
.mp_prefetch()
method) to run data loading and transformations in parallel, ensuring that a buffer of prepared data is always ready for your model. This keeps your accelerators saturated and minimizes training time..shuffle ()
, .map()
e .batch()
. Esse estilo declarativo torna o pipeline de dados fácil de entender, modificar e manter.Embora o TFRecord seja um padrão familiar, sua natureza sequencial não permite um verdadeiro embaralhamento global. O ArrayRecord é um formato de arquivo moderno, projetado especificamente para resolver esse problema ao oferecer uma nova fronteira em eficiência de dados.
O alto desempenho do ArrayRecord está enraizado em seu design central, baseado no formato de arquivo Riegeli do Google. Essa estrutura oferece várias vantagens importantes para o tratamento de dados em larga escala:
2. Massive parallelism: Records are grouped into data chunks. This structure is inherently designed to be read in parallel, allowing multiple processes to read different chunks of the same file simultaneously to dramatically increase read throughput.
3. Exceptional performance: As a result of this indexed and chunked design, benchmarks show ArrayRecord can achieve a read throughput an order of magnitude higher than traditional formats, making it ideal for today's massive datasets.
4. Smart data integrity: The format handles data integrity intelligently by leveraging the powerful error correction in underlying cloud storage systems (like Google Cloud Storage) rather than adding redundant checks. This provides robust protection against corruption without unnecessary performance overhead.
Os recursos do ArrayRecord ativam diretamente os recursos avançados exigidos pelos carregadores de dados modernos, como o Grain.
The most important benefit is achieving true, deterministic global shuffling. Because any record can be accessed instantly, a data loader can generate perfectly randomized indices in the dataset on the fly as the training happens and then fetch data in that specific order. This capability, which is computationally impractical with sequential formats like TFRecord, is vital for reproducible research and optimal model training.
Segue uma análise detalhada de como o ArrayRecord e o TFRecord se comparam quanto aos principais recursos:
2. Acesso aleatório
3. Embaralhamento global
4. E/S paralela
5. Integração
6. Caso de uso principal
O método para converter um conjunto de dados depende de ele ser o conjunto padrão registrado no catálogo TensorFlow Datasets (TFDS) ou um conjunto personalizado e proprietário.
Se você estiver usando um conjunto de dados bem conhecido, como cifar10
ou imagenet2012
, a ferramenta de linha de comando tfds será o método mais simples.
Pré-requisito: instalar os conjuntos de dados do TensorFlow
pip install -q --upgrade tfds-nightly
Com o uso da CLI tfds build
Esse comando faz o download dos dados de origem, executa a lógica de preparação e salva a saída no formato desejado.
# Generate the 'cifar10' dataset in ArrayRecord format
tfds build cifar10 --file_format=array_record
Os arquivos ArrayRecord gerados serão armazenados no diretório ~/tensorflow_datasets/
, prontos para uso.
Para a conversão em larga escala de seus próprios conjuntos de dados TFRecord personalizados, a abordagem recomendada é usar o Apache Beam. A biblioteca array_record
fornece pipelines do Beam pré-empacotados que tornam essa conversão incrivelmente simples e escalonável. Esse método é altamente recomendado para conjuntos de dados muito grandes, já que o processamento pode ser distribuído entre muitos workers por meio de um serviço como o Google Cloud Dataflow.
Pré-requisitos: instalar o Apache Beam e o SDK Array Record Beam
pip install -q apache-beam
pip install -q array-record-beam-sdk
Com o uso do pipeline de conversão pré-empacotado
O módulo array_record.beam.pipelines
contém a função convert_tf_to_arrayrecord_disk_match_shards
, um utilitário criado especificamente para lidar com todo o processo de conversão. Ele lê arquivos TFRecord e grava um conjunto de dados ArrayRecord fragmentado correspondente.
Veja como você o usaria em um script do Python:
from apache_beam.options import pipeline_options
from array_record.beam.pipelines import convert_tf_to_arrayrecord_disk_match_shards
# 1. Define your input and output patterns.
# This example uses Google Cloud Storage (GCS) paths, which is common for large datasets.
input_pattern = 'gs://your-gcs-bucket/path/to/records-*.tfrecord'
output_path = 'gs://your-gcs-bucket/path/to/converted-records'
# Arguments dictionary for the conversion function.
args = {
'input': input_pattern,
'output': output_path,
}
# 2. Configure pipeline options for execution.
# To run locally on your machine (for smaller datasets or testing):
# No options are needed; the local runner is used by default.
local_pipeline_options = pipeline_options.PipelineOptions()
# To run at scale on Google Cloud Dataflow (for large datasets):
# Uncomment the following lines and fill in your project details.
#
# dataflow_pipeline_options = pipeline_options.PipelineOptions(
# runner='DataflowRunner',
# project='your-gcp-project-id',
# region='your-gcp-region',
# # A requirements.txt file may be needed for dependencies on Dataflow workers.
# # requirements_file='requirements.txt',
# temp_location='gs://your-gcs-bucket/path/to/temp'
# )
# 3. Define and run the main execution logic.
def main():
print("Starting TFRecord to ArrayRecord conversion...")
convert_tf_to_arrayrecord_disk_match_shards(
args=args,
# Pass the appropriate options here.
# Use `local_pipeline_options` for local runs.
# Use `dataflow_pipeline_options` for cloud runs.
pipeline_options=local_pipeline_options,
).run()
print(f"Conversion complete. ArrayRecord files written to '{output_path}'.")
if __name__ == '__main__':
main()
Essa abordagem é mais poderosa e robusta do que escrever um pipeline manual porque se trata de uma API testada e de alto nível projetada especificamente para essa tarefa, lidando com detalhes como a correspondência automática entre fragmentos de saída e fragmentos de entrada.
Depois que seus dados estiverem no formato ArrayRecord, você poderá definir o pipeline de entrada de alto desempenho usando a API Dataset
do Grain. O processo envolve a criação de uma origem e, em seguida, o encadeamento de métodos de transformação.
Primeiro, aponte para os arquivos ArrayRecord para criar um MapDataset
.
import grain
# Path to your generated ArrayRecord files
file_paths = ["~/tensorflow_datasets/cifar10/3.0.2/cifar10-train.array_record-00000-of-00001"]
# Create a data source
data_source = grain.sources.ArrayRecordDataSource(file_paths)
# Create a MapDataset from the source
dataset = grain.MapDataset.source(data_source)
Agora, aplique transformações ao MapDataset
. Cada método retorna um novo MapDataset
, permitindo encadear as chamadas de forma declarativa.
# Example parsing function
def parse_and_transform(record):
# Your logic to parse features, augment data, etc.
return {"record": record}
BATCH_SIZE = 32
# Chain transformations
# The order of operations matters.
dataset = (
dataset.shuffle(seed=42)
.map(parse_and_transform)
.batch(batch_size=BATCH_SIZE, drop_remainder=True)
)
DatasetIterator
Por fim, crie um iterador a partir do conjunto de dados totalmente definido para fazer a repetição pelos dados no script de treinamento.
# Create the stateful iterator
data_iterator = iter(dataset)
# You can now loop over the data
for batch in data_iterator:
# Your training step with the batch...
pass
# For checkpoint saving/restoration, you can get/set the iterator's state
# state = data_iterator.get_state()
# data_iterator.set_state(state)
Para evitar que o pipeline de dados se torne um gargalo, você deve usar o multiprocessamento para carregar e pré-processar os dados paralelamente ao treinamento do modelo. Na API Dataset, isso é feito pela adição da transformação .mp_prefetch()
ao pipeline.
Esse método inicia um grupo de processos worker para preparar lotes de dados de forma assíncrona em segundo plano e armazená-los em um buffer, para que estejam prontos quando o ciclo de treinamento precisar deles.
Veja como aplicá-lo:
# The full pipeline with performance tuning.
dataset = (
grain.MapDataset.source(data_source)
.shuffle(seed=42)
.map(parse_and_transform)
# Convert to an iterable dataset to apply prefetching.
.to_iter_dataset()
.batch(batch_size=BATCH_SIZE, drop_remainder=True)
# Apply multiprocessing and prefetching.
.mp_prefetch(
grain.multiprocessing.MultiprocessingOptions(
num_workers=16 # Number of parallel worker processes.
)
)
)
# Create the final iterator
data_iterator = iter(dataset)
num_workers
: especifica o número de processos filhos paralelos a serem usados para o carregamento de dados. Se você perceber que o acelerador fica inativo muitas vezes, à espera de dados, o aumento desse valor poderá melhorar significativamente a capacidade de processamento. O número ideal depende dos núcleos de CPU disponíveis na máquina e da complexidade da função map.Quer se aprofundar e começar a criar? Confira a documentação oficial e o código-fonte das tecnologias discutidas neste guia.
Os pipelines de dados determinísticos e de alto desempenho criados com o Grain e o ArrayRecord são cruciais para o treinamento de modelos em larga escala. Um excelente exemplo é o MaxText, um LLM de código aberto de alto desempenho escrito em JAX. O MaxText aproveita essas técnicas exatas de pipeline de dados para alimentar dados com eficiência em grandes clusters de TPU e GPU.