개발자들이 다양한 계산 작업을 위해 점점 더 많이 JAX를 채택하면서 JAX의 역할은 본래 대규모 AI에 맞춰진 초점을 뛰어넘어 확장되고 있습니다. JAX는 여전히 LLM과 파운데이션 모델을 개발하는 데 사용되는 인기 있는 프레임워크지만, 다양한 과학 영역에서도 그 활용도가 이목을 끌고 있습니다. 특히 관심을 보이고 있는 영역은 로봇 공학으로, 시뮬레이션, 제어, 학습 기반 방법론의 통합 부문에서 JAX가 강력한 기능을 구현하고 있습니다.
최근 노스웨스턴 대학교의 토드 머피(Todd Murphey) 교수님의 지도를 받는 로봇 공학 박사과정생이자 연구자인 맥스 무첸 쑨(Max Muchen Sun) 씨와 즐거운 대화를 나누었습니다. 쑨 씨의 경험은 JAX가 로봇 공학 연구의 중대한 도전, 특히 복잡한 제어 알고리즘의 계산 효율과 모델 기반 및 학습 기반 접근 방식의 매끄러운 조합과 관련한 과제를 어떻게 처리할 수 있는지 분명히 보여줍니다. 전통적인 도구와 씨름하는 것에서 출발해 vmap이나 scan 같은 JAX의 고유 기능을 활용하기까지, 맥스 씨의 탐색 여정은 같은 분야에 있는 많은 사람이 공감할 수 있고 동시에 이들에게 영감을 주는 이야기입니다.
JAX에 관심을 가지게 된 건 계산 효율 측면에서 보여준 강점 때문이었어요. 그 당시 제 멘토였던 이안 에이브러햄(Ian Abraham)(현 예일대 교수) 교수님이 autograd를 사용했고 이후에 저를 JAX로 인도해 주셨죠. 저희는 커버리지 문제를 위한 제어 프레임워크인 에르고딕 제어를 사용하는 연구를 진행했어요. 표준 제어 공식에 비하면 에르고딕 제어의 계산 복잡도는 선천적으로 높아요. 초창기에 저는 실시간 에르고딕 제어를 구현하기 위해서 표준 NumPy를 사용하고 벡터화와 브로드캐스팅 기능을 활용했죠.
JAX의 기능 중 가장 먼저 제 시선을 사로잡은 것은 JAX의 vmap이었어요. 제게 이 기능은 표준 NumPy의 벡터화와 브로드캐스팅 메커니즘을 결합하고 함수 변환과 구성적 추상화를 통해 이 메커니즘을 더 일반화에 성공한 것으로 느껴졌습니다. 이를 통해 훨씬 더 쉽게 문제에 대해 추론하고 병렬화를 구현할 수 있었죠.
다음으로는 scan에 대해 배웠어요. 처음에는 직관성이 떨어졌지만, 결국 동적 시스템의 궤적을 시뮬레이션하는 효율적인 도구로 활용할 수 있었습니다. 궤적을 최적화할 때, 전방 시뮬레이션은 반복적으로 수행해야 하는 핵심 작업인 동시에 종종 계산을 방해하는 걸림돌이 되곤 했어요. scan 기능을 사용하면 표준 NumPy 기반 구현에 비해 궤적 최적화 속도를 최대 두 자릿수까지 올릴 수 있어요. 이처럼 사용이 편리한 동시에 속도 부문에 엄청난 이점을 가지고 있기 때문에 결국 JAX 생태계로 완전히 갈아타게 되었죠.
한편, 제 박사 과정의 중심 주제는 자율 탐사와 다중 에이전트 협력을 위해 모델 기반 제어와 학습 기반 표현을 통합하는 것이에요. 저는 모델 기반 방법을 독립적인 솔루션이 아니라 학습 효율성과 견고성을 향상해 주는 구조로 설정하고 있습니다. JAX가 가진 결합성 부문의 역량은 모델 기반 및 학습 기반 파이프라인을 통합하기에 이상적이죠.
로봇 공학: 과학 및 시스템(Robotics: Science and Systems, RSS) 심사를 통과한 제 최신 논문 중 한 편에서 생성형 모델의 플로 매칭과 로봇 탐사용 모델 기반 최적의 제어를 결합했습니다. 역전파와 비슷하지만 동적 시스템에서 이루어지는 LQR 기반 업데이트를 통해 플로 그래디언트를 사용하여 상태 공간 플로를 제어 장치에 매핑했죠. 처음에는 PyTorch에서 플로 매칭 모듈을 구축했고 LQR에 C++ 언어를 사용했지만, 통합이 느리게 진행됐어요. JAX로 전환하고 나서는 플로 매칭 부분을 vmap과 grad를 사용하여 구현하고 OTT(Optimal Transport Toolbox) 같은 JAX 기반 도구를 활용했죠. 남은 부분은 JAX 네이티브 LQR 파이프라인이었습니다.
IEEE 국제로봇자동화 학술대회(ICRA)에서 발표한 또 다른 최근 논문에서는 시연에서 다중 에이전트 협력을 학습하기 위해 모델 기반 게임 이론 제어 파이프라인을 생성형 궤적 모델과 통합했어요. 종종 계산적으로 비용이 많이 들고 수동 손실 사양을 요구하는 게임 이론 제어를 완전한 솔루션으로 사용하기보다 게임 이론 계산을 조건부 변분 오토인코더(CVAE) 내부에 구조화된 계층으로 임베드했죠. 이렇게 해서 성능을 저해하지 않고 데이터 효율을 개선할 수 있었습니다. 두 구성 요소는 JAX에서 구현했죠. Flax를 포함하는 CVAE와 제어 계층은 처음부터 만들었습니다. JAX로 이 과정을 매끄럽게 진행할 수 있었죠. grad는 평형을 통해 직접 미분할 수 있거든요. 가상 자료 생성을 위해 JAX 기반 iLQGames 솔버도 구축했습니다.
이런 프로젝트를 거치면서 제 스스로 JAX 코드의 상당 부분을 동적 시스템 계산에 다시 사용한다는 사실을 깨달았어요. 특히 LQR 기반 계산에 다시 사용되었죠. LQR을 사용하여 학습 기반 및 모델 기반 제어를 표준적이지 않은 방식으로 사용했기 때문에 독립적인 JAX 네이티브 솔버, LQRax에 패키지로 넣었어요. LQRax는 GPU 가속, vmap, scan 및 grad를 지원해서 벡터화되고 미분된 LQR을 사용할 수 있습니다. 이를 통해 에르고딕 및 게임 이론 제어 같은 예시를 포함해 모델 기반 방법이 학습을 보충하는 방식을 집중 조명했죠.
저는 종종 ML 커뮤니티에서는 일반적이지 않은 다른 방식으로 CPU와 GPU에서 모두 JAX를 사용합니다. 예를 들어, 플로우 매칭 프로젝트에서는 LQR이 CPU에서 더 빠르게 실행되는 반면, 플로우 매칭 그래디언트는 GPU에서 더 빠르게 실행되죠. 보통 로컬에서 모든 계산을 실행하기 때문에 TPU를 사용해 본 적은 없어요. 몇 년 전에 NVIDIA Jetson에서 JAX를 사용해 봤는데, 설치하기가 무척 어려웠습니다. 이제 JAX가 이런 임베디드 플랫폼에서 지원된다니 정말 기쁩니다. 로봇 공학에 매우 중요한 역할을 하게 될 것이라 생각합니다. 모든 계산이 온보드에서 이루어지는 Jetson을 사용하여 사족보행 로봇에서 크라우드 내비게이션 알고리즘을 테스트해 왔는데, 곧 이 프로젝트에 JAX를 통합할 계획이죠.
저는 앞으로도 JAX를 사용하기 시작한 계기와 같은 이유로 JAX를 사용할 것 같습니다. 첫 번째 이유는 계산 효율입니다. 특히 GPU 기반 동시 로드가 로봇 공학에서 점점 더 중요해지고 있기 때문에 이는 매우 중요한 역량이죠. 학습 외에도 대규모 병렬 시뮬레이션 및 체화된 활성 학습과 유사한 실시간 매개변수 업데이트와 같은 새로운 모델 기반 제어 가능성도 실현하게 될 것입니다. 두 번째 이유는 JAX를 사용하면 모델 기반 구조와 학습 파이프라인을 직관적으로 통합할 수 있다는 점입니다. 이러한 유연성을 덕분에 더욱 즐겁게 프로젝트를 진행하고 계속해서 한계를 시험할 수 있는 것 같습니다.
맥스 씨의 경험은 JAX가 로봇 공학 커뮤니티에 제공하는 여러 주요 이점을 잘 보여줍니다. vmap를 통한 병렬 작업의 상당한 시간 단축과 scan을 통한 궤적 시뮬레이션은 실시간 제어와 복잡한 계획에 중요한 역할을 합니다. 더욱이 JAX의 기능적 패러다임과 자동 적분 기능은 고전적인 모델 기반 기술과 현대적 학습 기반 구성요소를 통합하기에 안성맞춤인 기능이랍니다.
저희는 맥스 씨와 같은 이야기가 급속히 성장하는 동시에 성숙해지고 있는 생태계의 지표라고 생각합니다. 맥스 씨의 LQRax 패키지는 다양한 JAX 네이티브 로봇 공학 도구 상자에 추가되는 훌륭한 기능입니다. GitHub의 프로젝트를 살펴보고 직접 사용해 보시는 걸 추천해 드립니다. 시뮬레이션의 세계에서 JAX는 인기 많고 표준적인 MuJoCo 물리 엔진을 직접 JAX에 가져오는 Brax나 MuJoCo XLA (MJX)와 같은 대규모 병렬 엔진과 함께 강력한 토대를 제공합니다. 또한 커뮤니티에서는 제어 중심의 다물체 동역학을 위한 JaxSim 라이브러리처럼 특별한 도구도 확인할 수 있습니다.
Trajax가 길을 닦으며 선도했던 궤적 최적화 부문에서 LQRax는 차세대 제어 시스템을 구축하는 연구자를 위한 현대적인 라이브러리의 역할을 수행하며, 모델 기반 제어와 딥러닝의 간극을 메우는 강력하고 구성이 용이한 도구를 제공함으로써 JAX가 추구하는 이상을 완벽히 구현합니다.
통찰력 넘치는 탐색 여정을 공유해 주신 맥스 씨께 진심으로 감사드립니다. 그와 다른 연구자가 JAX를 계속 활용해 차세대 지능형 로봇 시스템을 어떻게 구축해 나갈지 정말 기대됩니다. Google의 JAX 팀에서도 이 활기찬 생태계를 지원하고 성장에 도움이 될 수 있도록 계속 노력하겠습니다.