JAX を幅広いコンピューティング タスクに採用するデベロッパーが増えており、その役割は、大規模 AI に焦点を当てる当初の役割を超えて拡大しています。JAX は、引き続き LLM と基盤モデルの開発に広く利用されているフレームワークである一方で、多様な科学分野でも躍進しています。特に期待感を醸成している分野の一つはロボット工学で、JAX はシミュレーション、制御、学習ベースメソッドの統合における優れた機能を実現しています。
先日、ノースウェスタン大学で Todd Murphey 教授の指導を受けるロボット学博士課程の研究者 Max Muchen Sun 氏とお話しをする機会がありました。彼の経験は、JAX がロボット工学研究における重要な課題、特に複雑な制御アルゴリズムの計算効率や、モデルベース アプローチと学習ベース アプローチのシームレスな組み合わせに関する課題にどのように対処できるかを明確に示すものです。従来のツールに対応したり、vmap や scan といった JAX の独自の機能を活用したりする Max 氏のストーリーには、この分野の多くの人が共感し、刺激を受けるでしょう。
私の JAX への関心は、計算効率の観点から始まりました。当時のメンターであった Ian Abraham 氏(現在はエール大学教授)が autograd を使用しており、後に JAX に導いてくれました。私たちは、カバレッジの問題の制御フレームワークであるエルゴード制御を使用した研究に取り組んでいました。標準的な制御定式化と比較して、エルゴード制御の計算は本質的にさらに複雑です。リアルタイムのエルゴード制御を実現するため、私は当初、標準 NumPy を使用し、ベクトル化とブロードキャストの機能を活用していました。
私が最初に注目した JAX の機能は、JAX の vmap でした。標準 NumPy のベクトル化とブロードキャストのメカニズムを組み合わせて、関数変換と構成的抽象化によってさらに一般化することで、解決しようとしている問題の並列化の推論と実装がはるかに容易になります。
その後、scan について学びました。最初はそれほど直感的ではありませんでしたが、最終的には動的システムの軌道をシミュレートするための効率的なツールになりました。軌道の最適化において、システム ダイナミクスのフォワード シミュレーションは、繰り返し実行する必要がある中心的オペレーションであり、計算のボトルネックとなることが多くあります。scan を使用すると、標準 NumPy ベースの実装と比較して、軌道シミュレーションを最大 2 桁高速化できます。この使いやすさと実質的なスピードという利点により、私は JAX エコシステムに完全に夢中になりました。
一方、私の博士課程の重点は、自律的な探索とマルチエージェントの協調のために、モデルベースの制御と学習ベースの表現を統合することでした。モデルベースのメソッドは、スタンドアロンのソリューションとしてではなく、学習効率と堅牢性を向上させる構造として捉えています。JAX のコンポーザビリティは、モデルベースのパイプラインと学習ベースのパイプラインの統合に最適です。
Robotics: Science and Systems(RSS)で採択された私の最新の論文の一つでは、生成モデルからのフロー マッチングとモデルベースの最適制御を組み合わせてロボット探索に適用しました。フロー勾配を使って、LQR ベースの更新を介して状態空間フローを制御にマッピングしていますが、これは動的システムにおける誤差逆伝播法のようなアプローチです。当初、PyTorch でフロー マッチング モジュールをビルドし、LQR には C++ を使用しましたが、統合に時間がかかりました。JAX に切り替えて、vmap と grad を使用してフロー マッチング部分を再実装し、OTT(Optimal Transport Toolbox)などの JAX ベースのツールを活用しました。残りの部分は、JAX ネイティブの LQR パイプラインでした。
IEEE International Conference on Robotics and Automation(ICRA)で発表した別の最近の論文では、モデルベースのゲーム理論的制御パイプラインを生成軌道モデルに統合し、デモンストレーションからマルチエージェントの協調を学習しました。完全なソリューションとしてゲーム理論的制御を使用すると、計算コストが高く、手動での損失関数の指定が必要となることが多くなりますが、そうではなく、ゲーム理論に基づく計算を構造化されたレイヤとして条件付き変分オートエンコーダ(CVAE)の内部に組み込みました。これにより、パフォーマンスを犠牲にすることなくデータ効率が向上しました。いずれのコンポーネントも JAX で実装されました。CVAE は Flax を使用しての、制御レイヤはゼロからの実装です。JAX により、grad が均衡を通じて直接微分できるようになったため、処理がシームレスになりました。また、合成データを生成するための JAX ベースの iLQGames ソルバーも構築しました。
これらのプロジェクトを経て、私は JAX コードの多くを動的システム計算、特に LQR ベースの計算に再利用していることに気づきました。学習ベースとモデルベースの制御を統合するために LQR を非標準的な方法で使用していたため、これをスタンドアロンの JAX ネイティブ ソルバー LQRax にパッケージ化しました。これは、GPU アクセラレーション、vmap、scan、grad をサポートし、ベクトル化された微分可能な LQR を実現します。エルゴード的制御やゲーム理論的制御などの例を盛り込み、モデルベースのメソッドがどのように学習を補完できるかを強調しました。
私は CPU と GPU の両方で JAX を使用していますが、多くの場合、ML コミュニティとは異なる方法で使用しています。たとえば、フロー マッチング プロジェクトでは、LQR の計算は CPU 上の方が高速ですが、フロー マッチングの勾配は GPU 上の方が高速です。私は通常、すべての計算をローカルで実行するため、TPU を使用したことはありません。数年前、NVIDIA Jetson 上で JAX を試してみましたが、インストールが大変でした。JAX がこれらの組み込みプラットフォームでサポートされるようになったことを嬉しく思います。これはロボット工学にとって非常に重要です。私は、すべての計算をオンボードで実行する Jetson を搭載した四足歩行ロボットでの群衆ナビゲーション アルゴリズムをテストしてきましたが、近日中にこのプロジェクトに JAX を統合する予定です。
今後も使い始めたときと同じ理由で JAX を使い続ける予定です。まず、計算効率、特に GPU ベースの並列化は、ロボット工学においてますます重要になっています。トレーニングにとどまらず、大規模な並列シミュレーションやリアルタイムのパラメータ更新など、新しいモデルベースの制御の可能性が現実のものとなります。これは、具現化されたアクティブ ラーニングに類似したアプローチです。2 つ目に、JAX は、ダイナミクス、損失形状、または微分可能なソルバーのいずれの場合でも、モデルベースの構造を学習パイプラインに統合するプロセスを直感的にします。この柔軟性が、さらに先へ進む意欲をかき立ててくれます。
Max 氏の体験は、JAX がロボット工学コミュニティに提供するいくつかの重要な利点を実証しています。vmap による並列処理の高速化と scan による軌道シミュレーションの高速化は、リアルタイム制御と複雑な計画に不可欠です。さらに、JAX の関数型パラダイムと自動微分機能は、従来のモデルベースの手法と最新の学習ベースのコンポーネントを統合するのにも適しています。
Max 氏のようなストーリーは、エコシステムが急速に成長し、成熟している兆候だと考えています。彼の LQRax パッケージは、JAX ネイティブのロボットツールの活気ある環境にぴったりです。GitHub でプロジェクトを探索し、実際に試してみることをおすすめします。シミュレーションの世界では、JAX は、Brax や新しい MuJoCo XLA(MJX)といった大規模並列エンジンを使用した優れた基盤を提供します。MJX は、一般的で標準的な MuJoCo 物理エンジンを JAX に直接組み込みます。コミュニティから専門的なツールも登場しており、たとえば、制御に焦点を当てた多体動力学向けの JaxSim ライブラリなどがあります。
Trajax のような先駆的ライブラリが道を拓いてきた軌道最適化の分野では、次世代の制御システムを開発する研究者にとって歓迎すべき最新のライブラリとして LQRax が登場しています。モデルベース制御とディープ ラーニングのギャップを埋める優れたコンポーザブル ツールを提供することで、JAX の精神を完璧に体現しています。
有益な体験談を共有してくれた Max 氏に心から感謝します。彼や他の研究者がどのように JAX を活用して次世代のインテリジェント ロボット システムを構築し続けていくのかを楽しみにしています。Google の JAX チームは、この活気に満ちたエコシステムをサポートし成長させることに全力を注ぎます。