O JAX está sendo cada vez mais adotado pelos desenvolvedores para uma ampla gama de tarefas computacionais, o que expande seu papel além do foco original na IA em larga escala. Embora continue sendo um framework popular para o desenvolvimento de LLMs e modelos de base, o JAX também está ganhando força em diversos domínios científicos. Uma área que de grande interesse em particular é a robótica, na qual o JAX está habilitando recursos poderosos em simulação, controle e integração de métodos baseados em aprendizado.
Recentemente, tive o prazer de falar com Max Muchen Sun, candidato a um PhD em robótica e pesquisador da Northwestern University, orientado pelo Prof. Todd Murphey. Sua experiência ilustra claramente como o JAX pode enfrentar desafios críticos na pesquisa em robótica, particularmente em torno da eficiência computacional para algoritmos de controle complexos e da combinação perfeita de abordagens baseadas em modelo e em aprendizado. A jornada de Max, desde sua batalha com as ferramentas tradicionais até o uso dos recursos exclusivos do JAX, como o vmap e o scan, é uma história com a qual muitos profissionais da área se identificarão e se sentirão inspirados.
Meu interesse no JAX começou a partir de uma perspectiva de eficiência computacional. Meu mentor na época, Ian Abraham (hoje professor da Universidade de Yale), estava usando o autograd e, mais tarde, me direcionou para o JAX. Estávamos trabalhando em pesquisa usando o controle ergódico, que é um framework de controle para problemas de cobertura. Em comparação com as formulações de controle padrão, a complexidade computacional do controle ergódico é inerentemente maior. Para obter esse controle em tempo real, usei inicialmente o NumPy padrão e aproveitei os recursos de vetorização e broadcast.
O primeiro recurso do JAX que chamou minha atenção foi o vmap. Na minha opinião, ele combina os mecanismos de vetorização e broadcast do NumPy padrão e os generaliza ainda mais por meio da transformação de funções e da abstração composicional, facilitando muito para mim o raciocínio e a implementação do carregamento em paralelo para os problemas que estou resolvendo.
Depois, eu aprendi sobre o scan. Ele era menos intuitivo no início, mas acabou se tornando uma ferramenta eficiente para simular as trajetórias de sistemas dinâmicos. Na otimização de trajetórias, a simulação direta da dinâmica do sistema é uma operação central que deve ser realizada repetidamente e, muitas vezes, se torna o gargalo computacional. Com o scan, a simulação de trajetórias pode ser acelerada até duas ordens de magnitude em comparação com as implementações padrão baseadas em NumPy. A facilidade de uso e a vantagem substancial de velocidade me atraíram para o ecossistema do JAX.
Por outro lado, um foco central do meu doutorado tem sido integrar o controle baseado em modelo com representações baseadas em aprendizado para a exploração autônoma e a cooperação multiagentes. Vejo os métodos baseados em modelo não como soluções autônomas, mas como estruturas para melhorar a eficiência e a robustez do aprendizado. A composição do JAX o tornou ideal para mesclar pipelines baseados em modelo e em aprendizado.
Em um dos meus últimos artigos aceitos na Robotics: Science and Systems (RSS), combinei a correspondência de fluxos de modelos generativos com o controle ideal baseado em modelo para a exploração robótica, usando gradientes de fluxo para mapear fluxos de estado-espaço para controles por meio de uma atualização baseada em LQR, análoga à retropropagação, mas em sistemas dinâmicos. Inicialmente, criei o módulo de correspondência de fluxos no PyTorch e usei o C++ para o LQR, mas a integração foi lenta. Mudei para o JAX e implementei novamente a parte de correspondência de fluxos usando o vmap e o grad e utilizei ferramentas baseadas no JAX, como o OTT (Optimal Transport Toolbox). A parte restante era um pipeline de LQR nativo do JAX.
Em outro artigo recente apresentado na IEEE International Conference on Robotics and Automation (ICRA), integrei um pipeline de controle por teoria de jogos baseado em modelo a um modelo de trajetória generativa para aprender a cooperação multiagentes a partir de demonstrações. Em vez de usar o controle por teoria de jogos como uma solução completa, que costuma ser computacionalmente cara e exigir especificação de perdas manual, incorporei a computação por teoria de jogos como uma camada estruturada dentro de um autocodificador variacional condicional (CVAE, na sigla em inglês). Isso melhorou a eficiência dos dados sem sacrificar o desempenho. Ambos os componentes foram implementados em JAX — o CVAE com o Flax e a camada de controle a partir do zero. O JAX fez isso com perfeição: o grad conseguiu fazer a diferenciação no equilíbrio diretamente. Também criei um solucionador iLQGames baseado no JAX para gerar dados sintéticos.
Após esses projetos, percebi que estava reutilizando grande parte do meu código JAX para cálculos de sistemas dinâmicos, especialmente os baseados em LQR. Como usei o LQR para integrar o controle baseado em aprendizado e em modelo de maneiras diferentes do padrão, empacotei-o em um solucionador autônomo nativo do JAX, o LQRax. Ele dá suporte à aceleração de GPU, ao vmap, ao scan e ao grad, permitindo o LQR vetorizado e diferenciável. Incluí exemplos como controle ergódico e por teoria de jogos para destacar como os métodos baseados em modelo podem complementar o aprendizado.
Eu uso o JAX em CPUs e GPUs, muitas vezes de forma diferente da comunidade de aprendizado de máquina. Por exemplo, no projeto de correspondência de fluxos, o LQR é executado mais rapidamente em CPUs, enquanto os gradientes de correspondência de fluxos são mais rápidos em GPUs. Não usei TPUs porque normalmente executo todos os cálculos localmente. Alguns anos atrás, experimentei o JAX em um Nvidia Jetson, e a instalação foi difícil. É ótimo que o JAX agora tenha suporte nessas plataformas incorporadas, o que é fundamental para a robótica. Estive testando um algoritmo de navegação em multidões em um robô quadrúpede usando um Jetson com toda computação feita de forma integrada, e pretendo integrar o JAX a esse projeto em breve.
No futuro, continuarei usando o JAX pelos mesmos motivos pelos quais comecei a fazer isso. Primeiro, porque a eficiência computacional, em especial o carregamento em paralelo baseado em GPU, é cada vez mais vital na robótica. Além do treinamento, ela abre novas possibilidades de controle baseado em modelo, como simulações paralelas massivas e atualizações de parâmetros em tempo real, semelhante ao aprendizado ativo incorporado. Em segundo lugar, o JAX torna intuitiva a integração de estruturas baseadas em modelo a pipelines de aprendizado, seja para dinâmica, modelagem de perdas ou solucionadores diferenciáveis. Essa flexibilidade me anima muito a ir mais longe.
A experiência de Max demonstra várias vantagens importantes que o JAX oferece à comunidade da robótica. As acelerações significativas alcançadas com o vmap para operações paralelas e o scan para simulações de trajetórias são cruciais para o controle em tempo real e o planejamento complexo. Além disso, o paradigma funcional e os recursos de diferenciação automática do JAX o tornam adequado para integrar técnicas clássicas baseadas em modelo com componentes modernos baseados em aprendizado.
Acreditamos que histórias como a de Max são um sinal de um ecossistema em rápida expansão e amadurecimento. O pacote LQRax é um ótimo complemento para um cenário vibrante de ferramentas de robótica nativas do JAX, e encorajamos você a explorar o projeto no GitHub e a experimentá-lo. No mundo da simulação, o JAX fornece uma base poderosa com mecanismos massivamente paralelos, como o Brax e o novo MuJoCo XLA (MJX), que traz o popular mecanismo de física MuJoCo padrão diretamente para o JAX. Também estamos vendo ferramentas especializadas da comunidade, como a biblioteca JaxSim para dinâmicas multicorpos focadas em controle.
No domínio da otimização de trajetórias, no qual pioneiros como o Trajax abriram o caminho, o LQRax chega como uma biblioteca moderna muito bem-vinda para pesquisadores que estão criando a próxima geração de sistemas de controle. Ele incorpora perfeitamente o espírito do JAX, fornecendo uma ferramenta poderosa de composição que faz uma ponte entre o controle baseado em modelo e o aprendizado profundo.
Agradecemos sinceramente ao Max por compartilhar sua jornada inspiradora conosco. Mal podemos esperar para ver como ele e outros pesquisadores continuarão utilizando o JAX para criar a próxima geração de sistemas robóticos inteligentes. A equipe do JAX do Google se compromete em apoiar e expandir esse vibrante ecossistema.