Model dasar Marin Stanford: Model terbuka penuh pertama yang dikembangkan menggunakan JAX

16 JULI 2025
Srikanth Kilaru Senior Product Manager Google ML Frameworks
David Hall Research Engineering Lead Stanford HAI

Elemen menarik dari era AI saat ini adalah bagaimana model-model fondasi yang canggih dibagikan secara terbuka dan membantu mempercepat inovasi bagi semua orang. Kemajuan ini menginspirasi kita untuk bertanya, "Apa langkah berikutnya untuk keterbukaan?" Proyek Marin melihat peluang untuk memperluas definisi "terbuka" agar mencakup seluruh proses ilmiah di balik sebuah model.

Proyek Marin CRFM (Center for Research on Foundation Models) Stanford dirancang sebagai "lab terbuka", yang sasarannya bukan hanya untuk membagikan model, melainkan juga untuk membuat seluruh perjalanannya dapat diakses—termasuk kode, set data, metodologi data, eksperimen, hiperparameter, dan log pelatihan. Level transparansi ini melengkapi ekosistem yang ada dengan menyediakan resource yang unik dan dapat direproduksi sepenuhnya yang memberdayakan para peneliti untuk meneliti, mengembangkan, dan memercayai model yang sedang dikembangkan. Proyek Marin Stanford berupaya untuk mendorong masa depan yang lebih transparan dan mudah diakses untuk penelitian model fondasi.


Spektrum keterbukaan model AI

The Spectrum of AI Model Openness

Rilis pertama dari lab terbuka ini adalah model Marin-8B-Base dan Marin-8B-Instruct. Sesuai dengan prinsip proyek, model, data, kode, dan tokenizer semuanya dirilis di bawah lisensi Apache 2.0 yang permisif. Komitmen terhadap reproduksibilitas penuh ini merupakan masalah engineering yang berat, yang membutuhkan kontrol atas setiap sumber varians dalam sistem yang terdistribusi secara masif. Keberhasilan proyek ini bergantung pada technology stack yang dapat memberikan jaminan reproduktifitas ini dalam skala besar, dan memaksimalkan efisiensi untuk melatih model dasar dengan harga/performa terdepan.


Tantangan inti dalam membangun model fondasi terbuka

Agar proyek Marin berhasil menciptakan model fondasi yang benar-benar terbuka, skalabel, dan dapat direproduksi, tim CRFM harus mengatasi beberapa tantangan engineering. Tim memilih JAX sebagai fondasi karena prinsip desainnya memberikan solusi langsung untuk masalah ini, dan membangun framework baru, Levanter (lihat di bawah), untuk mengasah kekuatan JAX. Berikut adalah beberapa contoh tantangan dan solusinya.


Mencapai kecepatan maksimum pada satu akselerator

Masalah: Loop pelatihan inti dieksekusi miliaran kali, sehingga overhead dari bahasa pemrograman tafsiran seperti Python menciptakan bottleneck performa yang sangat besar. Jika operasi dijalankan selangkah demi selangkah, loop tersebut juga dapat menimbulkan traffic memori dan overhead yang berlebihan—terutama pada hardware seperti TPU, yang throughput-nya bergantung pada eksekusi operasi fusi yang efisien.

Solusi kami:

  • Untuk mengeliminasi overhead interpreter, Levanter merangkum seluruh langkah pelatihan multitahap (forward pass, loss, backpropagation, dan update) ke dalam satu fungsi dan menggunakan dekorator @jax.jit. Compiler XLA JAX mentransformasi seluruh proses ini menjadi satu kernel kode mesin yang sangat optimal, menggabungkan operasi untuk memaksimalkan pemanfaatan hardware dalam skala besar.

  • Untuk menghindari komputasi yang redundan, kami menggunakan jax.value_and_grad untuk menghitung kerugian dan gradiennya dalam satu pass. JAX juga memudahkan penggunaan teknik lanjutan seperti gradient checkpointing, menghemat memori dan memungkinkan kami menggunakan ukuran batch yang lebih besar dengan hampir tanpa overhead.

  • Levanter juga menggunakan kernel Splash Attention berbasis Pallas yang canggih milik JAX, sebuah implementasi Dot Product Attention yang sangat optimal, salah satu operasi terpenting yang menjadi inti dari hampir semua model bahasa besar.


Mengelola kompleksitas paralelisme skala besar

Masalah: Pelatihan model mutakhir membutuhkan penskalaan hingga ribuan chip akselerator. Mengelola secara manual cara model dan data dipartisi dan cara perangkat berkomunikasi sangatlah rumit, dan kodenya dengan cepat menjadi sulit dibaca, di-debug, dan diadaptasi.

Solusi kami:

  • Dekorator @jax.jit JAX juga mendukung paralelisasi Single-Program, Multiple-Data (SPMD) secara mulus, yang mengotomatiskan sharding dan komunikasi data yang mendasarinya. Compiler XLA secara otomatis menjadwalkan komunikasi di antara akselerator untuk meminimalkan waktu tunggu di jaringan dan memaksimalkan waktu komputasi.

  • Untuk membuat kekuatan jit lebih mudah dan aman digunakan, Levanter mengembangkan Haliax, sebuah library untuk tensor yang diberi nama. Saat sumbu tensor dirujuk dengan nama yang mudah dibaca (seperti "embed" atau "batch"), sebagai ganti indeks posisi, kode menjadi terdokumentasi sendiri dan tangguh.

  • Abstraksi ini memungkinkan kita untuk mendefinisikan dan memodifikasi strategi sharding canggih seperti Fully Sharded Data Parallelism (FSDP) dan Tensor Parallelism hanya dengan mengubah beberapa baris dalam file konfigurasi, tanpa pernah menyentuh kode model.


Membangun dan mengelola cluster komputasi yang tangguh dan hemat biaya

Masalah: Pelatihan skala besar membutuhkan akses fleksibel ke cluster komputasi masif. Kami sangat bergantung pada instance TPU yang dapat di-preempt untuk mengelola biaya, yang berarti kami memerlukan cara untuk dengan mudah menggabungkan banyak slice TPU yang lebih kecil dan berbeda menjadi satu cluster logis dan tangguh terhadap gangguan yang sering terjadi.

Solusi kami:

  • Kami memanfaatkan Google Cloud TPU Multislice, sebuah teknologi yang memungkinkan tugas pelatihan menggunakan beberapa slice TPU seolah-olah merupakan satu sistem besar. Hal ini memudahkan penggabungan beberapa slice TPU kecil yang dapat di-preempt menjadi satu cluster komputasi tunggal yang andal untuk pelatihan.

  • Levanter menggunakan Ray untuk mengatur proses ini, dengan mulus meningkatkan atau menurunkan jumlah slice TPU selama tugas pelatihan dan, yang terpenting, memastikan tugas tetap tangguh jika ada slice tunggal yang di-preempt.

  • Berkat JAX dan XLA, Levanter dan Marin juga mampu memperoleh hasil performa tinggi serupa pada GPU.


Membangun kepercayaan ilmiah dengan reproduksibilitas sempurna

Masalah: Sasaran inti proyek Marin adalah memungkinkan sains yang dapat diverifikasi. Hal ini membutuhkan pencapaian hasil yang dapat direproduksi, bahkan ketika pelatihan dihentikan sementara, dimulai ulang, atau dipindahkan antar konfigurasi mesin yang berbeda—sebuah tantangan teknis yang signifikan.

Solusi kami:

  • Ini merupakan persyaratan mendasar yang mendorong desain Levanter. Kami memilih JAX khususnya karena jaminan reproduksibilitas yang kuat, seperti penggunaan default generator bilangan pseudo-acak (PRNG) deterministik.

  • Pilihan ini divalidasi selama pelatihan Marin-8B, yang melibatkan migrasi antara berbagai slice TPU dan jenis hardware sambil berhasil mempertahankan reproduksibilitas bit demi bit di seluruh preemption.

  • Levanter juga mencakup sistem pemuatan data yang tangguh yang dibangun di atas library Tensorstore Google. Penyimpanan data Levanter menawarkan akses acak dan deterministik ke setiap batch data pelatihan, terlepas dari pemulaian ulang tugas atau perubahan sumber data—penting untuk mendukung strategi pelatihan lanjutan seperti di tengah pelatihan. Determinisme JAX dan penyimpanan data Levanter juga memudahkan peneliti interpretabilitas untuk memahami bagaimana data tertentu memengaruhi model selama pelatihan.


Membuat framework yang kohesif

Masalah: Meskipun JAX menyediakan mesin yang canggih, belum ada framework level tinggi yang memenuhi persyaratan gabungan kami yang ketat untuk keterbacaan, skalabilitas masif, dan determinisme bitwise. Kami membutuhkan sistem yang lengkap dan beropini untuk mengatur seluruh proses pelatihan.

Solusi kami:

  • Kami membangun Levanter, framework native JAX, dari awal hingga menjadi sistem yang kami butuhkan: deterministik bitwise, skalabel dengan strategi distribusi tingkat lanjut, dan tangguh.

  • Kita bisa melakukan ini karena JAX lebih dari sekadar library; ia merupakan "meta-framework" untuk membangun alat-alat baru. Kami membangunnya berdasarkan dukungannya yang matang dan berperforma tinggi untuk TPU serta integrasinya yang mulus antara abstraksi level tinggi (jit) dengan kontrol level rendah (Pallas).

  • Pendekatan ini umum dalam komunitas JAX, yang telah menghasilkan ekosistem library yang dinamis, seperti Flax, Equinox, Orbax, dan Optax yang bekerja bersama, memungkinkan tim seperti kami untuk membangun solusi yang hebat.


Melihat dari balik layar: Pelayaran Marin-8B

Prinsip, alat, dan library yang dibahas di atas diterapkan dan digunakan selama pelatihan Marin-8B. Arsitektur modelnya adalah transformator bergaya Llama.


Sekilas pandang Marin-8B-Base: Arsitektur model

Marin 8B-Base model architecture at a glance

Sebagai ganti proses statis dan monolitik, pelatihan Marin-8B merupakan perjalanan adaptif, yang secara internal dijuluki proses "Tootsie". Penggambaran jujur tentang alur kerja penelitian di dunia nyata ini dirinci secara publik. Proses ini mencakup lebih dari 12 triliun token dan melibatkan beberapa fase yang beradaptasi dengan data, teknik, dan bahkan konfigurasi hardware baru yang berbeda—bermigrasi di antara konfigurasi TPU multi-slice skala besar (pod 2x v5e-256 sampai 1x v4-2048) di tengah prosesnya. Tim terus menyaring campuran data, menggabungkan sumber berkualitas lebih tinggi, dan menyesuaikan hiperparameter, seperti laju pembelajaran dan ukuran batch untuk mengoptimalkan performa. Realitas "berantakan" ini merupakan alat pendidikan yang ampuh, dan kemampuan stack JAX dan Levanter untuk menangani pergeseran signifikan ini sambil mempertahankan reproduksibilitas bit-per-bit merupakan demonstrasi yang kuat akan ketangguhannya.


Bergabunglah dengan komunitas Marin

Proyek Marin merupakan undangan terbuka untuk berpartisipasi dalam pengembangan model fondasi di masa depan dan berkontribusi pada ekosistem JAX. Perjalanan Marin merupakan jawaban atas pertanyaan kita, "Apa langkah berikutnya untuk keterbukaan?" Upaya untuk menciptakan 'lab terbuka' ini dimungkinkan oleh kapabilitas teknis ekosistem JAX. Performa, portabilitas, dan desain fondasinya untuk reproduksibilitas merupakan unsur-unsur kunci yang memungkinkan kita untuk membuat 'perjalanan lengkap' penelitian menjadi lebih mudah diakses.

Dengan berbagi segalanya, mulai dari metodologi data hingga log pelatihan, kami bertujuan untuk menyediakan resource yang sepenuhnya dapat direproduksi—resource yang memungkinkan para peneliti untuk meneliti, mengembangkan, dan memercayai penelitian secara mendalam. Kami yakin ini merupakan langkah kolaboratif menuju masa depan AI yang lebih transparan. Kami mengundang Anda untuk bergabung dengan kami di 'lab terbuka' ini—untuk menggunakan Marin, berkontribusi pada penelitian, dan membantu membangun gelombang model fondasi yang inovatif dan tepercaya berikutnya.

Resource utama untuk proyek ini adalah situs web resmi, marin.community. Dari sana, Anda dapat menemukan model yang dirilis di Hugging Face, menjelajahi "lab terbuka" di GitHub, membaca dokumentasi Marin, dan mendalami framework pelatihan Levanter. Anda juga dapat menguji coba Marin dalam kolaborasi dengan contoh inferensi sederhana.

Diskusi aktif juga berlangsung di saluran Discord, tempat Anda dapat berinteraksi langsung dengan para developer lain. Bagi yang baru mengenal ekosistem ini, dokumentasi JAX resmi menyediakan materi yang sangat baik, termasuk panduan memulai.