Meskipun JAX terkenal sebagai framework populer untuk pengembangan model AI berskala besar, ia juga diadopsi secara cepat di berbagai domain ilmiah yang lebih luas. Kami sangat antusias melihat penggunaannya yang semakin banyak di bidang yang intensif secara komputasi seperti machine learning berbasis fisika. JAX mendukung transformasi composable, serangkaian fungsi tingkat tinggi. Sebagai contoh, grad mengambil sebuah fungsi sebagai input dan mengembalikan fungsi lain yang menghitung gradiennya—dan yang terpenting, Anda bisa menyusun (mengomposisikan) transformasi ini dengan bebas. Desain inilah yang membuat JAX sangat elegan untuk turunan tingkat tinggi dan transformasi kompleks lainnya.
Baru-baru ini, saya mendapat kesempatan untuk berbicara dengan Zekun Shi dan Min Lin, peneliti dari National University of Singapore dan Sea AI Lab. Pengalaman mereka dengan jelas menggambarkan bagaimana JAX bisa mengatasi tantangan mendasar dalam penelitian ilmiah, terutama seputar jurang komputasional yang dihadapi saat menyelesaikan Partial Differential Equations (PDE) yang kompleks. Perjalanan mereka dari bergulat dengan keterbatasan framework tradisional hingga memanfaatkan diferensiasi otomatis mode Taylor JAX yang unik adalah cerita yang akan menjadi inspirasi bagi banyak peneliti.
Pekerjaan kami berfokus pada bidang komputasi ilmiah yang menantang: menggunakan neural network untuk menyelesaikan PDE tingkat tinggi. Neural network adalah estimator fungsi universal, menjadikannya alternatif yang menjanjikan untuk metode tradisional seperti elemen hingga. Namun, rintangan utama dalam memecahkan PDE dengan neural network adalah Anda perlu mengevaluasi turunan tingkat tinggi, terkadang hingga tingkat keempat atau bahkan lebih tinggi lagi, termasuk turunan parsial campuran.
Framework deep learning standar, yang biasanya dioptimalkan untuk melatih model melalui backpropagation, tidak cocok untuk tugas ini karena menghitung turunan tingkat tinggi sangatlah mahal. Biaya penerapan back-propagation (AD mode mundur) secara berulang untuk turunan tingkat tinggi meningkat secara eksponensial dengan tingkat turunan (k) dan secara polinomial dengan dimensi domain (d). “Kutukan dimensionalitas” dan penskalaan eksponensial dalam tingkat turunan ini membuatnya hampir tidak mungkin untuk menangani masalah dunia nyata yang besar dan kompleks.
Meskipun ada library populer lainnya untuk Deep Learning, penelitian kami membutuhkan kemampuan yang lebih mendasar: Diferensiasi otomatis mode Taylor (AD). JAX adalah terobosan besar bagi kami.
Perbedaan arsitektur utama JAX adalah representasi fungsi dan mekanisme transformasinya yang kuat, diimplementasikan dengan menelusuri kode Python dan dikompilasi untuk performa tinggi. Sistem ini dirancang dengan generalitas yang memungkinkan berbagai aplikasi serbaguna, mulai dari kompilasi tepat waktu hingga menghitung turunan standar. Fleksibilitas yang mendasari inilah yang memungkinkan operasi lanjutan yang tidak mudah dicapai dalam framework lainnya. Bagi kami, aplikasi yang sangat penting adalah dukungan untuk AD mode Taylor, yang kami pelajari merupakan hasil langsung dan kuat dari arsitektur unik ini, membuat JAX diperlengkapi dengan sempurna untuk pekerjaan ilmiah kami. AD mode Taylor memungkinkan komputasi turunan tingkat tinggi yang efisien dengan mendorong perluasan deret Taylor suatu fungsi dan menghitung turunan tingkat tinggi secara efisien dalam sekali jalan, bukan melalui back-propagation yang berulang dan memakan banyak biaya. Ini memungkinkan kami untuk mengembangkan sebuah algoritme, Stochastic Taylor Derivative Estimator (STDE), untuk mengacak dan mengestimasi setiap operator diferensial secara efisien.
Dalam makalah terbaru kami, "Stochastic Taylor Derivative Estimator: Efficient amortization for arbitrary differential operators", yang menerima penghargaan Best Paper Award di NeurIPS 2024, kami mendemonstrasikan bagaimana pendekatan ini bisa digunakan. Kami menunjukkan bahwa dengan menggunakan mode Taylor JAX, kami dapat membuat algoritme untuk mengekstrak turunan parsial tingkat tinggi secara efisien. Ide intinya adalah memanfaatkan AD mode Taylor untuk menghitung kontraksi tensor turunan tingkat tinggi yang muncul dalam PDE secara efisien. Dengan membangun vektor singgung acak khusus (atau "jets"), kami bisa mendapatkan estimasi yang tidak bias dari operator diferensial kompleks secara acak dalam satu proses yang efisien.
Hasilnya sangat dramatis. Dengan menggunakan metode STDE di JAX, kami meraih >peningkatan kecepatan 1000x dan >pengurangan memori 30x dibandingkan dengan metode dasar. Peningkatan efisiensi ini memungkinkan kami menyelesaikan PDE 1 juta dimensi hanya dalam 8 menit pada satu GPU NVIDIA A100, sebuah tugas yang sebelumnya tak terpecahkan.
Hal ini tidak akan mungkin terwujud dengan framework yang hanya ditujukan untuk beban kerja machine learning standar. Framework lain sangat dioptimalkan untuk backpropagation, tetapi kurang fokus pada representasi grafik komputasi secara menyeluruh dibandingkan JAX. Ini membantu JAX unggul dengan operasi seperti transposisi fungsi atau mengimplementasikan diferensiasi mode Taylor tingkat tinggi.
Di luar mode Taylor, desain modular JAX dan dukungannya untuk tipe data umum dan transformasi fungsi sangatlah penting untuk penelitian kami. Dalam makalah lain, “Automatic Functional Differentiation in JAX”, kami bahkan telah menggeneralisasi JAX untuk menangani vektor berdimensi tak terbatas (fungsi dalam ruang Hilbert) dengan menggambarkannya sebagai array khusus dan mendaftarkannya ke dalam JAX. Hal ini membuat kami bisa memanfaatkan kembali mesin yang sudah ada untuk menghitung turunan variasional untuk fungsional dan operator, sebuah fungsionalitas yang sama sekali tidak terjangkau oleh framework lain.
Karena alasan ini, kami telah mengadopsi JAX tidak hanya untuk project ini, tetapi juga untuk berbagai penelitian di berbagai bidang seperti kimia kuantum. Desain fundamentalnya sebagai sistem yang umum, dapat diperluas, dan secara simbolis kuat menjadikannya pilihan ideal untuk mendorong batas-batas komputasi ilmiah. Kami meyakini bahwa komunitas ilmiah perlu mengetahui tentang kemampuan ini.
Pengalaman Zekun dan Min menunjukkan kekuatan dan fleksibilitas JAX. Metode STDE yang mereka kembangkan menggunakan JAX merupakan kontribusi yang signifikan bagi dunia machine learning berbasis fisika, yang memungkinkan kita menangani masalah yang sebelumnya sulit dipecahkan. Kami mendorong Anda untuk membaca makalah pemenang penghargaan mereka untuk mempelajari detail teknis secara lebih dalam dan menjelajahi library STDE open source mereka di GitHub, yang merupakan tambahan luar biasa untuk lanskap alat ilmiah asli JAX.
Cerita seperti ini menyoroti tren yang sedang berkembang: JAX bukan hanya alat untuk deep learning; ia merupakan library dasar untuk pemrograman terdiferensiasi yang mendorong generasi baru penemuan ilmiah. Tim JAX di Google berkomitmen untuk mendukung dan mengembangkan ekosistem yang dinamis ini, dan hal tersebut dimulai dengan mendengarkan langsung dari Anda.
Kami sangat senang bisa bermitra dengan Anda untuk membangun alat komputasi ilmiah generasi berikutnya. Silakan hubungi tim kami untuk membagikan pekerjaan Anda atau mendiskusikan hal-hal yang Anda butuhkan dari JAX.
Terima kasih banyak kepada Zekun dan Min yang telah membagikan perjalanannya yang penuh makna kepada kami.
Referensi
Shi, Z., Hu, Z., Lin, M., & Kawaguchi, K. (2025). Stochastic Taylor Derivative Estimator: Efficient amortization for arbitrary differential operators. Advances in Neural Information Processing Systems, 37.
Lin, M. (2023). Automatic Functional Differentiation in JAX. International Conference on Learning Representations Kedua Belas.