Saat melatih model besar pada akselerator yang kuat, seperti GPU dan TPU, hal terakhir yang Anda inginkan adalah akselerator tidak ada aktivitas, menunggu data. Seluruh sistem Anda hanya akan berjalan secepat bagian yang paling lambat, dan sering kali, bottleneck tersebut adalah pipeline input data. Oleh karena itu, untuk machine learning berskala besar, data pipeline yang efisien dan dapat direproduksi sangatlah penting. Panduan ini akan menunjukkan kepada Anda cara mengatasi tantangan ini dengan membangun data pipeline yang kuat dan berkinerja baik menggunakan Grain, library pemuatan data yang fleksibel untuk JAX, dan ArrayRecord, sebuah format file yang sangat efisien.
Grain adalah library pemuatan data open source yang ringan dan dirancang khusus untuk mengatasi masalah ini untuk beban kerja berbasis JAX. Grain memastikan bahwa data dimuat, diproses sebelumnya, dan dimasukkan ke model Anda secara efisien, sehingga Anda bisa memaksimalkan performa hardware Anda.
Grain dibangun berdasarkan filosofi performa, reproduktifitas, dan fleksibilitas. Berikut adalah manfaat utama yang diberikannya:
.mp_prefetch()
method) to run data loading and transformations in parallel, ensuring that a buffer of prepared data is always ready for your model. This keeps your accelerators saturated and minimizes training time..shuffle()
, .map()
, dan .batch()
. Gaya deklaratif ini membuat data pipeline Anda mudah dipahami, dimodifikasi, dan dipelihara.Meskipun TFRecord adalah standar yang familier, sifat sekuensialnya tidak memungkinkan pengacakan global yang sebenarnya. ArrayRecord adalah format file modern yang dirancang khusus untuk mengatasi masalah ini, menawarkan terobosan baru dalam efisiensi data.
Performa tinggi ArrayRecord berakar pada desain intinya, yang berbasis format file Riegeli dari Google. Struktur ini memberikan beberapa keuntungan utama untuk penanganan data berskala besar:
2. Massive parallelism: Records are grouped into data chunks. This structure is inherently designed to be read in parallel, allowing multiple processes to read different chunks of the same file simultaneously to dramatically increase read throughput.
3. Exceptional performance: As a result of this indexed and chunked design, benchmarks show ArrayRecord can achieve a read throughput an order of magnitude higher than traditional formats, making it ideal for today's massive datasets.
4. Smart data integrity: The format handles data integrity intelligently by leveraging the powerful error correction in underlying cloud storage systems (like Google Cloud Storage) rather than adding redundant checks. This provides robust protection against corruption without unnecessary performance overhead.
Fitur ArrayRecord secara langsung mendukung kemampuan lanjutan yang dibutuhkan oleh loader data modern seperti Grain.
The most important benefit is achieving true, deterministic global shuffling. Because any record can be accessed instantly, a data loader can generate perfectly randomized indices in the dataset on the fly as the training happens and then fetch data in that specific order. This capability, which is computationally impractical with sequential formats like TFRecord, is vital for reproducible research and optimal model training.
Berikut adalah uraian detail tentang perbandingan ArrayRecord dan TFRecord pada berbagai fitur utama:
2. Akses Acak
3. Pengacakan Global
4. I/O Paralel
5. Integrasi
6. Kasus Penggunaan Utama
Metode konversi set data Anda bergantung pada apakah set data tersebut merupakan set data standar yang terdaftar di katalog TensorFlow Datasets (TFDS) atau set data khusus dan berpemilik.
Jika Anda menggunakan set data populer, seperti cifar10
atau imagenet2012
, alat command line tfds adalah metode yang paling mudah.
Prasyarat: Instal set data TensorFlow
pip install -q --upgrade tfds-nightly
Menggunakan CLI build tfds
Perintah ini mendownload data sumber, menjalankan logika persiapan, dan menyimpan output dalam format yang Anda inginkan.
# Generate the 'cifar10' dataset in ArrayRecord format
tfds build cifar10 --file_format=array_record
File ArrayRecord yang dihasilkan akan disimpan di direktori ~/tensorflow_datasets/
Anda, siap untuk digunakan.
Untuk konversi berskala besar set data TFRecord khusus milik Anda sendiri, pendekatan yang disarankan adalah menggunakan Apache Beam. Library array_record
menyediakan pipeline Beam yang sudah dipaketkan yang membuat konversi ini sangat sederhana dan skalabel. Metode ini sangat direkomendasikan untuk set data yang sangat besar, karena pemrosesan bisa didistribusikan ke banyak worker menggunakan layanan seperti Google Cloud Dataflow.
Prasyarat: Instal Apache Beam dan Array Record Beam SDK
pip install -q apache-beam
pip install -q array-record-beam-sdk
Menggunakan pipeline konversi yang sudah dipaketkan
Modul array_record.beam.pipelines
berisi fungsi convert_tf_to_arrayrecord_disk_match_shards
, aplikasi utilitas yang dibuat khusus untuk menangani seluruh proses konversi. Ia membaca file TFRecord dan menulis pecahan set data ArrayRecord yang sesuai.
Berikut cara menggunakan fungsi ini dalam skrip Python:
from apache_beam.options import pipeline_options
from array_record.beam.pipelines import convert_tf_to_arrayrecord_disk_match_shards
# 1. Define your input and output patterns.
# This example uses Google Cloud Storage (GCS) paths, which is common for large datasets.
input_pattern = 'gs://your-gcs-bucket/path/to/records-*.tfrecord'
output_path = 'gs://your-gcs-bucket/path/to/converted-records'
# Arguments dictionary for the conversion function.
args = {
'input': input_pattern,
'output': output_path,
}
# 2. Configure pipeline options for execution.
# To run locally on your machine (for smaller datasets or testing):
# No options are needed; the local runner is used by default.
local_pipeline_options = pipeline_options.PipelineOptions()
# To run at scale on Google Cloud Dataflow (for large datasets):
# Uncomment the following lines and fill in your project details.
#
# dataflow_pipeline_options = pipeline_options.PipelineOptions(
# runner='DataflowRunner',
# project='your-gcp-project-id',
# region='your-gcp-region',
# # A requirements.txt file may be needed for dependencies on Dataflow workers.
# # requirements_file='requirements.txt',
# temp_location='gs://your-gcs-bucket/path/to/temp'
# )
# 3. Define and run the main execution logic.
def main():
print("Starting TFRecord to ArrayRecord conversion...")
convert_tf_to_arrayrecord_disk_match_shards(
args=args,
# Pass the appropriate options here.
# Use `local_pipeline_options` for local runs.
# Use `dataflow_pipeline_options` for cloud runs.
pipeline_options=local_pipeline_options,
).run()
print(f"Conversion complete. ArrayRecord files written to '{output_path}'.")
if __name__ == '__main__':
main()
Pendekatan ini lebih kuat dan tangguh daripada menulis pipeline manual karena ini adalah API tingkat tinggi yang telah teruji serta dirancang khusus untuk tugas ini, menangani detail seperti mencocokkan pecahan output dengan pecahan input secara otomatis.
Setelah data Anda berada dalam format ArrayRecord, Anda bisa mendefinisikan pipeline input berkinerja tinggi menggunakan Grain Dataset
API. Prosesnya melibatkan pembuatan sumber dan kemudian metode transformasi berantai.
Pertama, tentukan file ArrayRecord Anda untuk membuat MapDataset
.
import grain
# Path to your generated ArrayRecord files
file_paths = ["~/tensorflow_datasets/cifar10/3.0.2/cifar10-train.array_record-00000-of-00001"]
# Create a data source
data_source = grain.sources.ArrayRecordDataSource(file_paths)
# Create a MapDataset from the source
dataset = grain.MapDataset.source(data_source)
Sekarang, terapkan transformasi ke MapDataset
. Setiap metode menghasilkan MapDataset
baru, yang memungkinkan Anda melakukan panggilan berantai bersama-sama secara deklaratif.
# Example parsing function
def parse_and_transform(record):
# Your logic to parse features, augment data, etc.
return {"record": record}
BATCH_SIZE = 32
# Chain transformations
# The order of operations matters.
dataset = (
dataset.shuffle(seed=42)
.map(parse_and_transform)
.batch(batch_size=BATCH_SIZE, drop_remainder=True)
)
DatasetIterator
Terakhir, buat iterator dari set data yang telah didefinisikan secara lengkap untuk mengulang data dalam skrip pelatihan Anda.
# Create the stateful iterator
data_iterator = iter(dataset)
# You can now loop over the data
for batch in data_iterator:
# Your training step with the batch...
pass
# For checkpoint saving/restoration, you can get/set the iterator's state
# state = data_iterator.get_state()
# data_iterator.set_state(state)
Untuk mencegah data pipeline mengalami bottleneck, Anda harus menggunakan multipemrosesan untuk memuat dan melakukan praproses data secara paralel dengan pelatihan model. Dalam Dataset API, ini dapat dicapai dengan menambahkan transformasi .mp_prefetch()
ke pipeline Anda.
Metode ini memulai kumpulan proses worker untuk menyiapkan kumpulan data secara asinkron di latar belakang dan menyimpannya dalam buffer, sehingga mereka siap pada saat loop pelatihan Anda membutuhkannya.
Berikut ini adalah cara menerapkannya:
# The full pipeline with performance tuning.
dataset = (
grain.MapDataset.source(data_source)
.shuffle(seed=42)
.map(parse_and_transform)
# Convert to an iterable dataset to apply prefetching.
.to_iter_dataset()
.batch(batch_size=BATCH_SIZE, drop_remainder=True)
# Apply multiprocessing and prefetching.
.mp_prefetch(
grain.multiprocessing.MultiprocessingOptions(
num_workers=16 # Number of parallel worker processes.
)
)
)
# Create the final iterator
data_iterator = iter(dataset)
num_workers
: Ini menentukan jumlah proses turunan paralel yang digunakan untuk memuat data. Jika Anda melihat akselerator sering tidak ada aktivitas karena menunggu data, meningkatkan nilai ini bisa meningkatkan performa secara signifikan. Jumlah optimal bergantung pada inti CPU yang tersedia pada mesin Anda dan kompleksitas fungsi peta Anda.Ingin mempelajari lebih dalam dan mulai membangun? Lihat dokumentasi resmi dan kode sumber untuk teknologi yang dibahas dalam panduan ini.
Data pipeline deterministik serta berkinerja yang dibangun dengan Grain dan ArrayRecord sangatlah penting untuk pelatihan model berskala besar. Contoh terbaiknya adalah MaxText, Model Bahasa Besar open source berkinerja tinggi yang ditulis dalam JAX. MaxText memanfaatkan teknik data pipeline ini untuk memasukkan data secara efisien ke cluster TPU dan GPU besar.