« Créer rapidement des outils faciles à utiliser pour les appliquer au machine learning »
- Combine uniquement Python et Numpy
→ utilise XLA pour compiler et exécuter Numpy sur GPU/TPU
→ permet de compiler en JIT des fonctions Python avec une seule API pour les intégrer facilement à des kernels optimisés par XLA
→ facilite aussi l’exécution sur plusieurs GPU/TPU (vmap, pmap)
- Dépasse largement les performances du couple Python+Numpy classique
1 commentaires
DeepMind a entièrement refactorisé son stack sur la base de Jax
https://deepmind.com/blog/article/using-jax-to-accelerate-our-research