대규모 AI 모델 개발용 프레임워크로 높은 인기를 누리며 잘 알려진 JAX는 보다 광범위한 과학 분야에서도 빠르게 도입되고 있습니다. 특히 물리학 기반 머신러닝과 같이 계산 집약적인 분야에서 그 활용도가 높아지고 있으며 앞으로 기대가 큽니다. JAX는 고차 함수 집합인 구성 가능한 변환(composable transformations)을 지원합니다. 예를 들어, grad는 함수를 입력값으로 받아 그 기울기를 계산하는 또 다른 함수를 반환합니다. 그리고 무엇보다 중요한 것은 이러한 변환을 자유롭게 중첩(구성)할 수 있다는 점입니다. 이러한 설계 덕분에 JAX는 고계도함수와 그 밖의 복잡한 변환에 특히 적합합니다.
최근 국립싱가포르대학교와 Sea AI Lab의 연구원인 Zekun Shi, Min Lin과 이야기를 나눌 기회가 있었습니다. 두 사람의 경험은 JAX가 과학 연구의 근본적인 과제, 그중에서도 특히 복잡한 편미분 방정식(PDE)을 풀 때 직면하는 계산 절벽(computational cliff) 문제를 어떻게 해결하게 되었는지를 명확하게 보여주는 사례였습니다. 기존 프레임워크의 한계를 극복하고 JAX의 고유한 테일러 모드 자동 미분(Taylor mode automatic differentiation)을 활용하기까지의 여정에 많은 연구자들이 공감하게 될 것입니다.
저희는 과학 컴퓨팅 중에서도 까다로운 영역인 신경망을 활용하여 고차 편미분 방정식(PDE)을 푸는 것에 중점을 두고 연구를 진행하고 있습니다. 신경망은 '범용 함수 근사기(universal function approximator)'이므로 유한 요소법과 같은 기존 방법에 대한 좋은 대안이 될 수 있습니다. 하지만 신경망을 사용하여 PDE를 풀 때는 혼합 편미분을 비롯한 4차 이상의 고계도함수를 계산해야 한다는 큰 난관이 있습니다.
역전파를 통한 학습 모델에 우선적으로 최적화된 표준 딥 러닝 프레임워크는 고계도함수를 계산할 때 비용이 급증하기 때문에 이 작업에 적합하지 않았습니다. 고계도함수에 역전파(역방향 모드 AD)를 사용하는 비용은 미분 차수(k)에 따라 기하급수적으로 증가하고 도메인 차원(d)에 따라 다항적으로 증가합니다. 이 '차원성의 저주(curse of dimensionality)'와 미분 차수의 기하급수적 확장으로 인해 규모가 크고 복잡한 현실 세계의 문제를 풀기가 사실상 불가능했습니다.
딥 러닝에 널리 사용되는 다른 라이브러리도 있지만, 저희 연구에는 테일러 모드 자동 미분(AD) 같은 보다 근본적인 기능이 필요했습니다. JAX는 정말 획기적인 솔루션이었습니다.
JAX의 핵심적인 아키텍처 특징은 Python 코드 추적을 통해 구현되고 높은 성능을 내도록 컴파일된 강력한 함수 표현 및 변환 메커니즘입니다. 적시 컴파일부터 표준 미분 계산까지 다양한 분야에 적용할 수 있도록 범용성을 고려하여 설계된 시스템으로, 바로 이러한 근본적 유연성 덕분에 다른 프레임워크에서는 구현이 쉽지 않은 고급 연산을 수행할 수 있습니다. 저희에게 가장 중요한 애플리케이션은 테일러 모드 AD 지원이었습니다. JAX의 독보적인 아키텍처가 제공하는 직접적이고 강력한 결과물로, 저희의 과학 연구에 완벽하게 부합했죠. 테일러 모드 AD는 함수의 테일러 급수를 전개하고, 반복적이고 비용이 많이 소요되는 역전파 과정 대신 단 한 번의 전달로 계산을 수행함으로써 고계도함수를 효율적으로 계산할 수 있게 해 줍니다. 이를 통해 저희는 모든 미분 연산자를 효율적으로 무작위화하고 추정하는 '확률 테일러 미분 추정기(STDE)' 알고리즘을 개발할 수 있었습니다.
최근 저희는 NeurIPS 2024에서 최우수 논문상을 수상한 '확률 테일러 미분 추정기: 임의 미분 연산자에 대한 효율적 상각'에서 이 접근법의 활용 가능성을 입증했습니다. 해당 논문에서는 JAX에서 테일러 모드를 사용하여 고차 편미분을 효율적으로 추출하는 알고리즘을 개발할 수 있음을 보여주는데, 여기서 핵심 아이디어는 테일러 모드 AD를 활용하여 편미분 방정식(PDE)에 등장하는 고차 미분 텐서의 축약을 효율적으로 계산하는 것이었습니다. 특수한 무작위 탄젠트 벡터('제트/jet')를 구성함으로써, 한 번의 효율적인 순전파(forward pass)에서 임의의 복잡한 미분 연산자에 대한 편향되지 않은 추정치를 얻을 수 있었습니다.
결과는 놀라웠습니다. JAX에서 저희의 STDE 방식을 사용함으로써 기준 방식 대비 >속도 1,000배 향상, >메모리 30배 절감을 달성했으며, 이러한 효율성 향상 덕분에 이전에는 해결이 불가능했던 100만 차원 편미분 방정식을 NVIDIA A100 GPU 하나에서 단 8분 만에 풀 수 있었습니다.
이런 결과는 표준 머신러닝 워크로드에만 맞춰진 프레임워크였다면 불가능했을 것입니다. 역전파에 고도로 최적화된 다른 프레임워크들도 있지만, 엔드투엔드 계산 그래프 표현은 JAX에 미치지 못하는 수준입니다. 따라서 JAX의 진가는 함수 전치나 고차 테일러 모드 미분 구현과 같은 작업에서 제대로 발휘됩니다.
테일러 모드뿐만 아니라, JAX의 모듈식 설계와 범용적인 데이터 유형 지원 및 함수 변환 지원 기능 역시 저희 연구에 매우 중요합니다. 별도의 논문 'JAX에서의 자동 함수 미분'에서는 JAX를 일반화하여 무한 차원 벡터(힐베르트 공간의 함수)를 사용자 정의 배열로 기술하고 JAX에 등록하여 처리하기도 했습니다. 이를 통해 기존 기기를 재사용하여 함수와 연산자에 대한 변분 도함수를 계산할 수 있었으며, 이는 다른 프레임워크에서는 구현하기 어려운 독보적인 기능입니다.
이와 같은 이유로 저희는 본 프로젝트뿐만 아니라 양자 화학을 비롯한 다양한 연구 분야에서도 JAX를 채택했습니다. 일반적이면서도 확장 가능하며, 강력한 기호 기반 추론 역량을 갖춘 JAX의 근본적인 설계 방식은 과학 컴퓨팅의 한계를 확장하기 위한 최적의 선택이 될 것이라 확신합니다. 또한 이러한 사실을 과학계와 공유하는 것이 매우 중요하다고 생각합니다.
Zekun과 Min의 경험은 JAX의 강력함과 유연성을 보여줍니다. 두 사람이 JAX를 활용하여 개발한 STDE 방법은 물리 기반 머신러닝 분야에 큰 기여를 했으며, 이전에는 해결하기 어려웠던 수준의 문제를 풀 수 있는 길을 열었습니다. 기술적 세부 사항에 대해 더 깊이 알고 싶다면, 수상 경력에 빛나는 논문을 참고해 보시기 바랍니다. 또한, JAX 기반 과학 도구 생태계를 한층 풍성하게 만들어 준 추가 기능인 그들의 오픈소스 GitHub STDE 라이브러리도 함께 살펴보시길 권장해 드립니다.
이러한 사례들은 JAX가 단순한 딥 러닝 도구를 넘어, 차세대 과학적 발견을 가능하게 할 미분가능 프로그래밍의 핵심 라이브러리라는 것을 보여줍니다. Google JAX 팀은 이 활기찬 생태계를 지원하고 발전시키기 위해 최선을 다하고 있으며, 그 첫걸음은 바로 사용자 여러분의 목소리를 직접 듣는 것에서 시작됩니다.
차세대 과학 컴퓨팅 도구를 함께 구축할 수 있게 되어 매우 기쁘게 생각합니다. 연구 내용을 공유하시거나 JAX와 관련한 필요 사항을 논의하고 싶으시다면 언제든 저희 담당자에게 문의해 주세요.
유용한 정보의 탐색 여정을 공유해 주신 Zekun과 Min에게 진심으로 감사의 말씀을 드립니다.
참조
Shi, Z., Hu, Z., Lin, M., & Kawaguchi, K. (2025). 확률 테일러 미분 추정기: 임의 미분 연산자에 대한 효율적 상각 신경 정보 처리 시스템의 발전, 37.
Lin, M. (2023). JAX에서의 자동 함수 미분. 제12회 표현학습국제학회(International Conference on Learning Representations).