Elemen menarik dari era AI saat ini adalah bagaimana model-model fondasi yang canggih dibagikan secara terbuka dan membantu mempercepat inovasi bagi semua orang. Kemajuan ini menginspirasi kita untuk bertanya, "Apa langkah berikutnya untuk keterbukaan?" Proyek Marin melihat peluang untuk memperluas definisi "terbuka" agar mencakup seluruh proses ilmiah di balik sebuah model.
Proyek Marin CRFM (Center for Research on Foundation Models) Stanford dirancang sebagai "lab terbuka", yang sasarannya bukan hanya untuk membagikan model, melainkan juga untuk membuat seluruh perjalanannya dapat diakses—termasuk kode, set data, metodologi data, eksperimen, hiperparameter, dan log pelatihan. Level transparansi ini melengkapi ekosistem yang ada dengan menyediakan resource yang unik dan dapat direproduksi sepenuhnya yang memberdayakan para peneliti untuk meneliti, mengembangkan, dan memercayai model yang sedang dikembangkan. Proyek Marin Stanford berupaya untuk mendorong masa depan yang lebih transparan dan mudah diakses untuk penelitian model fondasi.
Rilis pertama dari lab terbuka ini adalah model Marin-8B-Base dan Marin-8B-Instruct. Sesuai dengan prinsip proyek, model, data, kode, dan tokenizer semuanya dirilis di bawah lisensi Apache 2.0 yang permisif. Komitmen terhadap reproduksibilitas penuh ini merupakan masalah engineering yang berat, yang membutuhkan kontrol atas setiap sumber varians dalam sistem yang terdistribusi secara masif. Keberhasilan proyek ini bergantung pada technology stack yang dapat memberikan jaminan reproduktifitas ini dalam skala besar, dan memaksimalkan efisiensi untuk melatih model dasar dengan harga/performa terdepan.
Agar proyek Marin berhasil menciptakan model fondasi yang benar-benar terbuka, skalabel, dan dapat direproduksi, tim CRFM harus mengatasi beberapa tantangan engineering. Tim memilih JAX sebagai fondasi karena prinsip desainnya memberikan solusi langsung untuk masalah ini, dan membangun framework baru, Levanter (lihat di bawah), untuk mengasah kekuatan JAX. Berikut adalah beberapa contoh tantangan dan solusinya.
Masalah: Loop pelatihan inti dieksekusi miliaran kali, sehingga overhead dari bahasa pemrograman tafsiran seperti Python menciptakan bottleneck performa yang sangat besar. Jika operasi dijalankan selangkah demi selangkah, loop tersebut juga dapat menimbulkan traffic memori dan overhead yang berlebihan—terutama pada hardware seperti TPU, yang throughput-nya bergantung pada eksekusi operasi fusi yang efisien.
Solusi kami:
@jax.jit
. Compiler XLA JAX mentransformasi seluruh proses ini menjadi satu kernel kode mesin yang sangat optimal, menggabungkan operasi untuk memaksimalkan pemanfaatan hardware dalam skala besar.jax.value_and_grad
untuk menghitung kerugian dan gradiennya dalam satu pass. JAX juga memudahkan penggunaan teknik lanjutan seperti gradient checkpointing, menghemat memori dan memungkinkan kami menggunakan ukuran batch yang lebih besar dengan hampir tanpa overhead.Pallas
yang canggih milik JAX, sebuah implementasi Dot Product Attention yang sangat optimal, salah satu operasi terpenting yang menjadi inti dari hampir semua model bahasa besar.Masalah: Pelatihan model mutakhir membutuhkan penskalaan hingga ribuan chip akselerator. Mengelola secara manual cara model dan data dipartisi dan cara perangkat berkomunikasi sangatlah rumit, dan kodenya dengan cepat menjadi sulit dibaca, di-debug, dan diadaptasi.
Solusi kami:
@jax.jit
JAX juga mendukung paralelisasi Single-Program, Multiple-Data (SPMD) secara mulus, yang mengotomatiskan sharding dan komunikasi data yang mendasarinya. Compiler XLA secara otomatis menjadwalkan komunikasi di antara akselerator untuk meminimalkan waktu tunggu di jaringan dan memaksimalkan waktu komputasi.jit
lebih mudah dan aman digunakan, Levanter mengembangkan Haliax, sebuah library untuk tensor yang diberi nama. Saat sumbu tensor dirujuk dengan nama yang mudah dibaca (seperti "embed" atau "batch"), sebagai ganti indeks posisi, kode menjadi terdokumentasi sendiri dan tangguh.Masalah: Pelatihan skala besar membutuhkan akses fleksibel ke cluster komputasi masif. Kami sangat bergantung pada instance TPU yang dapat di-preempt untuk mengelola biaya, yang berarti kami memerlukan cara untuk dengan mudah menggabungkan banyak slice TPU yang lebih kecil dan berbeda menjadi satu cluster logis dan tangguh terhadap gangguan yang sering terjadi.
Solusi kami:
Masalah: Sasaran inti proyek Marin adalah memungkinkan sains yang dapat diverifikasi. Hal ini membutuhkan pencapaian hasil yang dapat direproduksi, bahkan ketika pelatihan dihentikan sementara, dimulai ulang, atau dipindahkan antar konfigurasi mesin yang berbeda—sebuah tantangan teknis yang signifikan.
Solusi kami:
Masalah: Meskipun JAX menyediakan mesin yang canggih, belum ada framework level tinggi yang memenuhi persyaratan gabungan kami yang ketat untuk keterbacaan, skalabilitas masif, dan determinisme bitwise. Kami membutuhkan sistem yang lengkap dan beropini untuk mengatur seluruh proses pelatihan.
Solusi kami:
jit
) dengan kontrol level rendah (Pallas
).Prinsip, alat, dan library yang dibahas di atas diterapkan dan digunakan selama pelatihan Marin-8B. Arsitektur modelnya adalah transformator bergaya Llama.
Sebagai ganti proses statis dan monolitik, pelatihan Marin-8B merupakan perjalanan adaptif, yang secara internal dijuluki proses "Tootsie". Penggambaran jujur tentang alur kerja penelitian di dunia nyata ini dirinci secara publik. Proses ini mencakup lebih dari 12 triliun token dan melibatkan beberapa fase yang beradaptasi dengan data, teknik, dan bahkan konfigurasi hardware baru yang berbeda—bermigrasi di antara konfigurasi TPU multi-slice skala besar (pod 2x v5e-256 sampai 1x v4-2048) di tengah prosesnya. Tim terus menyaring campuran data, menggabungkan sumber berkualitas lebih tinggi, dan menyesuaikan hiperparameter, seperti laju pembelajaran dan ukuran batch untuk mengoptimalkan performa. Realitas "berantakan" ini merupakan alat pendidikan yang ampuh, dan kemampuan stack JAX dan Levanter untuk menangani pergeseran signifikan ini sambil mempertahankan reproduksibilitas bit-per-bit merupakan demonstrasi yang kuat akan ketangguhannya.
Proyek Marin merupakan undangan terbuka untuk berpartisipasi dalam pengembangan model fondasi di masa depan dan berkontribusi pada ekosistem JAX. Perjalanan Marin merupakan jawaban atas pertanyaan kita, "Apa langkah berikutnya untuk keterbukaan?" Upaya untuk menciptakan 'lab terbuka' ini dimungkinkan oleh kapabilitas teknis ekosistem JAX. Performa, portabilitas, dan desain fondasinya untuk reproduksibilitas merupakan unsur-unsur kunci yang memungkinkan kita untuk membuat 'perjalanan lengkap' penelitian menjadi lebih mudah diakses.
Dengan berbagi segalanya, mulai dari metodologi data hingga log pelatihan, kami bertujuan untuk menyediakan resource yang sepenuhnya dapat direproduksi—resource yang memungkinkan para peneliti untuk meneliti, mengembangkan, dan memercayai penelitian secara mendalam. Kami yakin ini merupakan langkah kolaboratif menuju masa depan AI yang lebih transparan. Kami mengundang Anda untuk bergabung dengan kami di 'lab terbuka' ini—untuk menggunakan Marin, berkontribusi pada penelitian, dan membantu membangun gelombang model fondasi yang inovatif dan tepercaya berikutnya.
Resource utama untuk proyek ini adalah situs web resmi, marin.community. Dari sana, Anda dapat menemukan model yang dirilis di Hugging Face, menjelajahi "lab terbuka" di GitHub, membaca dokumentasi Marin, dan mendalami framework pelatihan Levanter. Anda juga dapat menguji coba Marin dalam kolaborasi dengan contoh inferensi sederhana.
Diskusi aktif juga berlangsung di saluran Discord, tempat Anda dapat berinteraksi langsung dengan para developer lain. Bagi yang baru mengenal ekosistem ini, dokumentasi JAX resmi menyediakan materi yang sangat baik, termasuk panduan memulai.