Jika Anda merasa penasaran dengan hal-hal yang dibutuhkan untuk membangun model bahasa dari awal dengan JAX, postingan ini cocok bagi Anda. Baru-baru ini saya mengadakan lokakarya tentang topik ini di Cloud Next 2025 dan mendapatkan beberapa masukan yang luar biasa, jadi saya memutuskan untuk menulis panduan bagi semua orang yang tidak bisa datang ke acara tersebut.
Dalam artikel dan contoh kode ini, Anda akan membuat dan melatih model GPT-2, menunjukkan bagaimana JAX mempermudah pemanfaatan kekuatan TPU Google. Anda bisa menjalankan seluruh project ini secara gratis dengan menggunakan TPU di Colab atau Kaggle, dan Anda dapat menemukan notebook lengkapnya di sini.
Ini adalah tutorial langsung, jadi saya akan mengasumsikan Anda sudah familier dengan konsep machine learning secara umum. Jika JAX adalah hal baru bagi Anda, panduan developer PyTorch untuk dasar-dasar JAX adalah tempat yang tepat untuk memulai.
Pertama, mari kita lihat sekilas alat-alat yang akan digunakan.
Sebelum kita mulai membangun model, mari kita bahas secara singkat tentang ekosistem JAX. Ekosistem JAX mengambil pendekatan modular, dengan JAX core yang menyediakan kemampuan pemrosesan numerik inti, dan koleksi library berlimpah yang dibangun di atasnya untuk melayani berbagai kebutuhan spesifik aplikasi, misalnya, Flax untuk membangun neural network, Orbax untuk checkpoint dan persistensi model, dan Optax untuk pengoptimalan (kami akan menggunakan ketiganya dalam artikel ini). Transformasi fungsi bawaan, seperti autograd, vektorisasi, dan kompilasi JIT, ditambah performa yang kuat dan API yang mudah digunakan, membuat JAX sempurna untuk pelatihan Model Bahasa Besar.
OpenAI sebelumnya telah merilis kode dan bobot model GPT2, yang merupakan referensi yang baik, dan ada banyak upaya komunitas, seperti nanoGPT, untuk mereplikasi model tersebut. Berikut adalah diagram arsitektur model tingkat tinggi untuk GPT2:
Kita akan menggunakan NNX (antarmuka Flax yang baru) untuk membangun model GPT2. Untuk mempersingkatnya, mari kita fokus pada blok transformer, yang merupakan kunci untuk model bahasa besar modern. Blok transformer menangkap dependensi jarak jauh dalam setiap urutan dan membangun pemahaman kontekstual yang kaya. Blok transformer GPT2 terdiri dari 2 lapisan LayerNorm, 1 lapisan Multi-Head Attention (MHA), 2 lapisan dropout, 2 lapisan proyeksi linier, dan 2 koneksi sisa. Jadi, pertama-tama kita menentukan lapisan ini dalam fungsi __init__
pada class TransformerBlock
:
class TransformerBlock(nnx.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
ff_dim: int,
dropout_rate: float,
rngs: nnx.Rngs,
):
self.layer_norm1 = nnx.LayerNorm(
epsilon=1e-6, num_features=embed_dim, rngs=rngs
)
self.mha = nnx.MultiHeadAttention(
num_heads=num_heads, in_features=embed_dim, rngs=rngs
)
self.dropout1 = nnx.Dropout(rate=dropout_rate)
self.layer_norm2 = nnx.LayerNorm(
epsilon=1e-6, num_features=embed_dim, rngs=rngs
)
self.linear1 = nnx.Linear(
in_features=embed_dim, out_features=ff_dim, rngs=rngs
)
self.linear2 = nnx.Linear(
in_features=ff_dim, out_features=embed_dim, rngs=rngs
)
self.dropout2 = nnx.Dropout(rate=dropout_rate)
Berikutnya, kita susun semua lapisan ini dalam fungsi __call__
:
class TransformerBlock(nnx.Module):
def __call__(self, inputs, training: bool = False):
input_shape = inputs.shape
bs, seq_len, emb_sz = input_shape
attention_output = self.mha(
inputs_q=self.layer_norm1(inputs),
mask=causal_attention_mask(seq_len),
decode=False,
)
x = inputs + self.dropout1(
attention_output, deterministic=not training
)
# MLP
mlp_output = self.linear1(self.layer_norm2(x))
mlp_output = nnx.gelu(mlp_output)
mlp_output = self.linear2(mlp_output)
mlp_output = self.dropout2(
mlp_output, deterministic=not training
)
return x + mlp_output
Kode ini akan terlihat sangat familier jika Anda pernah menggunakan framework ML lainnya, seperti PyTorch atau TensorFlow, untuk melatih model bahasa. Namun satu hal yang sangat saya sukai dari JAX adalah kemampuannya yang luar biasa untuk secara otomatis menjalankan kode secara paralel melalui SPMD (Single Program Multiple Data), yang diperlukan karena kita akan menjalankan kode pada beberapa akselerator (beberapa inti TPU). Mari kita lihat cara kerjanya.
Untuk melakukan SPMD, pertama-tama kita perlu memastikan bahwa kita menggunakan TPU. Pilih runtime TPU jika Anda menggunakan Colab atau Kaggle (Anda juga bisa menggunakan Cloud TPU VM).
import jax
jax.devices()
# Free-tier Colab offers TPU v2:
# [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
# TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
# TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
# TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
# TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
# TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
# TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
# TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
Colab dan Kaggle menawarkan TPU v2 atau v3, yang memiliki 8 inti TPU terpisah. Berikut adalah tampilan tray TPU v3:
Untuk melatih model GPT2 secara efisien, kita akan menjalankan semua inti TPU secara bersamaan melalui SPMD dan memanfaatkan paralelisme data di JAX. Untuk mencapai hal ini, kita menetapkan sebuah mesh hardware:
mesh = jax.make_mesh((8, 1), ('batch', 'model'))
Ibaratkan mesh sebagai matriks 2D akselerator. Dalam kasus ini, kita menetapkan 2 sumbu untuk mesh - sumbu batch
dan sumbu model
. Jadi secara total kita memiliki 8 kali 1, yang berarti 8 inti. Sumbu ini menentukan cara kita mempartisi data dan parameter model. Kita bisa mengubah sumbu ini nanti jika kita ingin bereksperimen dengan skema paralelisme lainnya.
Sekarang kita mengubah fungsi __init__
dengan memberi tahu JAX bahwa kita ingin mempartisi parameter model menggunakan sumbu ‘model’. Ini dilakukan dengan menambahkan nnx.with_partitioning
ketika menginisialisasi tensor bobot: untuk tensor bobot 1D seperti tensor skala/bias LayerNorm, kita secara langsung mempartisinya di sepanjang sumbu ‘model’; untuk tensor bobot 2D seperti tensor kernel MHA dan Linear, kita mempartisi dimensi ke-2 di sepanjang sumbu model
.
class TransformerBlock(nnx.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
ff_dim: int,
dropout_rate: float,
rngs: nnx.Rngs,
):
self.layer_norm1 = nnx.LayerNorm(
epsilon=1e-6, num_features=embed_dim,rngs=rngs, rngs=rngs,
scale_init=nnx.with_partitioning(
nnx.initializers.ones_init(),
("model"),
),
bias_init=nnx.with_partitioning(
nnx.initializers.zeros_init(),
("model"),
),
)
self.mha = nnx.MultiHeadAttention(
num_heads=num_heads, in_features=embed_dim,
kernel_init=nnx.with_partitioning(
nnx.initializers.xavier_uniform(),
(None, "model"),
),
bias_init=nnx.with_partitioning(
nnx.initializers.zeros_init(),
("model"),
),
)
# Other layers in the block are omitted for brevity
Kita perlu mempartisi lapisan lain seperti ini agar kita bisa mengaktifkan paralelisme tensor model untuk seluruh model GPT2. Meskipun kita tidak menggunakan paralelisme tensor model dalam tutorial ini, sebaiknya kita tetap mengimplementasikannya karena ukuran model dapat bertambah dan kita mungkin perlu mempartisi parameter model kita di masa depan. Dengan mengimplementasikan ini, kita bisa mengubah hanya satu baris kode dan langsung menjalankan model yang lebih besar. Sebagai contoh,
mesh = jax.make_mesh((4, 2), ('batch', 'model'))
Berikutnya, kita menetapkan fungsi loss_fn
dan train_step
yang serupa dengan blog sebelumnya. Fungsi train_step()
menghitung gradien dari fungsi cross-entropy loss dan mengupdate bobot melalui optimizer, dan fungsi ini akan dipanggil secara berulang untuk melatih model. Untuk mencapai performa terbaik, kami melakukan kompilasi JIT pada kedua fungsi tersebut menggunakan dekorator @nnx.jit
, karena keduanya membutuhkan komputasi intensif.
@nnx.jit
def loss_fn(model, batch):
logits = model(batch[0])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch[1]
).mean()
return loss, logits
@nnx.jit
def train_step(
model: nnx.Module,
optimizer: nnx.Optimizer,
metrics: nnx.MultiMetric,
batch,
):
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(model, batch)
metrics.update(loss=loss, logits=logits, lables=batch[1])
optimizer.update(grads)
Untuk optimizer, kami menggunakan AdamW dari Optax dengan cosine decay schedule. Anda juga bisa bereksperimen dengan optimizer atau schedule lain di Optax.
schedule = optax.cosine_decay_schedule(
init_value=init_learning_rate, decay_steps=max_steps
)
optax_chain = optax.chain(
optax.adamw(learning_rate=schedule, weight_decay=weight_decay)
)
optimizer = nnx.Optimizer(model, optax_chain)
Terakhir, kita membuat loop pelatihan sederhana.
while True:
input_batch, target_batch = get_batch("train")
train_step(
model,
optimizer,
train_metrics,
jax.device_put(
(input_batch, target_batch),
NamedSharding(mesh, P("batch", None)),
),
)
step += 1
if step > max_steps:
break
Perhatikan cara kami mempartisi data input di sepanjang sumbu batch menggunakan fungsi jax.device_put
. Dalam kasus ini JAX akan mengaktifkan paralelisme data dan menyatukan semuanya dengan menyisipkan komunikasi kolektif (AllReduce) secara otomatis, dan melakukan overlap komputasi serta komunikasi sebanyak mungkin. Untuk diskusi yang lebih mendalam mengenai komputasi paralel, silakan lihat dokumentasi Pengantar pemrograman paralel JAX.
Pada titik ini, model seharusnya mulai dilatih dan kita bisa mengamati nilai loss dalam pelatihan jika Weights and Biases digunakan untuk melacak prosesnya. Berikut adalah uji coba untuk melatih model GPT2 124M:
Dibutuhkan ~7 jam di Kaggle TPU v3 (yang bisa kita gunakan selama 9 jam tanpa gangguan), tetapi jika kita menggunakan Trillium, waktu pelatihan turun menjadi ~1,5 jam (perhatikan bahwa Trillium memiliki 32G HBM (High Bandwidth Memory) per chip, sehingga kita dapat menggandakan ukuran batch dan memangkas langkah pelatihan menjadi separuhnya).
Nilai akhir loss kira-kira setara dengan nanoGPT, yang benar-benar saya nikmati dan pelajari saat menulis contoh kode ini.
Jika kita menggunakan Cloud TPU, kita juga bisa memantau penggunaan TPU melalui perintah ‘tpu-info’ (bagian dari paket Cloud TPU Monitoring Debugging) atau dasbor Weights and Biases. TPU kita akan bekerja sangat cepat!
Setelah model dilatih, kita bisa menyimpannya menggunakan Orbax:
checkpointer = orbax.PyTreeCheckpointer()
train_state = nnx.pure(nnx.state(model))
checkpointer.save(checkpoint_path, train_state)
Selesai. Itulah yang kita perlukan untuk melatih model GPT2. Anda bisa menemukan detail tambahan, seperti pemuatan data, hyperparameter, metrik, dalam notebook lengkap.
Tentu saja GPT2 adalah model yang kecil saat ini dan banyak lab terdepan yang melatih model dengan ratusan miliar parameter. Namun sekarang setelah Anda mempelajari cara membangun model bahasa kecil dengan JAX dan TPU, Anda siap untuk mendalami Cara menskalakan model Anda.
Selain itu, Anda bisa menggunakan MaxText untuk melatih LLM termutakhir yang telah dibuat sebelumnya atau belajar membangun model terbaru dari awal dengan menggunakan referensi contoh LLM JAX atau model Stanford Marin.
Saya tidak sabar untuk melihat model luar biasa yang Anda bangun dengan JAX dan TPU!