PyTorch est mort. Vive JAX
(neel04.github.io)- Si PyTorch entraîne une perte de productivité et un gaspillage de temps de développement, c’est « non pas parce que le framework est mauvais en soi, mais parce qu’il n’a pas été conçu pour les cas d’usage auxquels on l’applique aujourd’hui »
La philosophie de PyTorch
- La philosophie de PyTorch : dynamique, facile à déboguer et pythonique
- À l’inverse, TensorFlow 1.x visait à devenir un framework statique mais performant, en s’appuyant fortement sur le compilateur XLA
- Les développeurs de TensorFlow ont réalisé que la communauté détestait l’API 1.x, et ont donc décidé d’utiliser Keras comme interface principale tout en réduisant le rôle du compilateur XLA
- PyTorch est resté fidèle à ses racines et a adopté une approche plus dynamique d’« exécution immédiate », où
torch.Tensorest évalué immédiatement, contrairement à l’approche statique et différée de TensorFlow - Cela a porté ses fruits, et une grande partie de la recherche a migré vers PyTorch
- En 2021, avec l’arrivée de GPT-3, la performance et la scalabilité sont devenues des préoccupations majeures
- PyTorch a plutôt bien répondu à cette demande dans une certaine mesure, mais comme il n’avait pas été conçu avec cette philosophie en tête, la dette s’est accumulée et ses fondations ont commencé à vaciller
- Les développeurs de PyTorch ne voulaient accepter aucun compromis et ont choisi de poursuivre simultanément deux voies
- utiliser le compilateur XLA comme backend par défaut, performant et stable
- construire la pile
torch.compilepour laisser à l’utilisateur la liberté d’invoquer un compilateur si nécessaire
- L’absence de stratégie à long terme est un problème grave
- PyTorch ne veut pas s’engager dans une philosophie centrée sur le compilateur comme JAX, mais aucune bonne alternative ne semble se dessiner
- Quelle est la réponse des produits concurrents à ce problème ?
Développement piloté par le compilateur dans JAX
- JAX exploite XLA, la puissante pile de compilation de TensorFlow
- XLA est un compilateur puissant, mais tout cela est abstrait pour l’utilisateur final
- Tant qu’une fonction est pure, on peut utiliser le décorateur
@jax.jitpour la compiler en JIT et la rendre exploitable par XLA - XLA gère en coulisses la validation de l’exactitude du graphe généré, le partitionneur GSPMD pour l’auto-parallélisation avec sharding dans JAX, l’optimisation de graphe, la fusion d’opérateurs et de kernels, l’ordonnancement masquant la latence, le recouvrement asynchrone des communications, la génération de code vers d’autres backends comme triton, etc.
- Tant qu’on respecte les contraintes de JAX, XLA s’en charge automatiquement
- Par exemple, lors de la parallélisation, il n’est pas nécessaire d’utiliser des primitives de communication comme
torch.distributed.barrier() - Le support DDP est possible avec un code simple
- L’approche de XLA consiste à faire en sorte que le calcul suive le sharding. Ainsi, si un tableau d’entrée est shardé selon un axe, XLA le gère automatiquement pour les calculs en aval
- L’idée de « développement piloté par le compilateur » ressemble à la manière dont fonctionne le compilateur Rust
- Les limites de PyTorch
- Mécontentement face au choix des développeurs de PyTorch d’intégrer et de faire dépendre les nouvelles fonctionnalités d’une pile de compilation, au lieu de préserver leur philosophie fondamentale de flexibilité et de liberté
- Selon la roadmap officielle de PyTorch 2.x, il existe un plan clair à long terme pour intégrer totalement XLA à Torch
- C’est une idée terrible. C’est comme dire que forcer du code C++ dans le compilateur Rust offrirait une meilleure expérience que d’utiliser Rust lui-même
- Contrairement à JAX, Torch n’a pas été conçu autour de XLA
- Si PyTorch décide d’utiliser une pile de compilation basée sur XLA, le framework idéal ne serait-il pas justement un framework conçu et construit spécifiquement autour de cela ?
- Même si PyTorch poursuivait une approche « multi-backend » permettant de choisir le backend de compilation voulu, cela n’aggraverait-il pas la fragmentation et ne détruirait-il pas complètement l’API en essayant de respecter les limites de toutes les piles de compilation ?
- Quiconque a déjà utilisé Torch/XLA sur TPU souffre d’un sérieux PTSD
Le multi-backend est un échec
- En essayant de tout faire à la fois, PyTorch échoue misérablement
- La décision de conception « multi-backend » aggrave exponentiellement ce problème
- En théorie, cela donne l’impression qu’on peut choisir la pile que l’on veut ; en pratique, c’est un chaos emmêlé de tracebacks incompréhensibles et de problèmes d’incompatibilité
- Conflit entre les contraintes inter-backends et l’API de PyTorch
- Le problème n’est pas simplement de faire fonctionner ces backends, mais le fait que les contraintes qu’ils imposent s’accordent mal avec l’API flexible et pythonique de PyTorch
- Il existe un arbitrage entre la cohérence de l’API et le respect des limitations des backends
- En conséquence, les développeurs essaient davantage de s’appuyer sur la génération de code plutôt que de réellement s’intégrer à un backend unique et de s’y engager
- L’absence de stratégie chez PyTorch
- Parce que PyTorch refuse les arbitrages significatifs, chaque décision ressemble à un compromis bancal
- Il n’y a ni cohérence ni stratégie d’ensemble
- Au final, cela provoque beaucoup de frustration chez les utilisateurs et donne l’impression d’un bric-à-brac de fonctionnalités qui ne vont pas ensemble
- Il n’y a pas de moyen plus rapide de tuer un écosystème
- Pourquoi il ne faut pas suivre l’approche de JAX
- PyTorch ne devrait pas suivre l’approche « compilateur et backend intégrés » de JAX
- Parce que JAX a été explicitement conçu pour fonctionner avec XLA
- Remplacer le frontend de PyTorch par celui de JAX ne peut pas constituer une stratégie
- Il est pratiquement impossible d’inventer une meilleure API que celle de JAX sur la base de XLA
- Il ne s’agit pas de blâmer les développeurs qui essaient des idées nouvelles et différentes
- Mais si PyTorch veut résister à l’épreuve du temps, il devrait davantage se concentrer sur le renforcement de ses fondations que sur la livraison de nouvelles fonctionnalités tape-à-l’œil qui s’effondrent dès qu’on sort des conditions idéales d’un tutoriel
La fragmentation de PyTorch et la programmation fonctionnelle de JAX
- L’API fonctionnelle de JAX
- Les fonctions JAX doivent être pures, c’est-à-dire sans effets de bord globaux
- Comme une fonction mathématique, à données identiques, elles doivent toujours renvoyer la même sortie, quel que soit le contexte d’exécution
- Grâce à cette philosophie de conception, les fonctions JAX sont composables et interopèrent bien entre elles
- La complexité du développement diminue, et les fonctions sont définies par une signature précise et une tâche concrète bien définie
- Si les types sont respectés, la fonction est garantie de fonctionner immédiatement
- Cela convient bien au calcul scientifique, en particulier au deep learning
- Exemple de l’API optax
- Grâce à l’approche fonctionnelle, optax dispose d’un concept de « chain »
- Il s’agit de plusieurs fonctions appliquées séquentiellement aux gradients
- Le composant fondamental est
GradientTransformation - Cela produit une API à la fois puissante et expressive
- Par exemple, cela rend très simple des opérations comme le clipping des gradients, l’EMA des gradients ou la combinaison d’optimizers
- Les avantages de la conception fonctionnelle
- Un autre résultat très intéressant de cette conception fonctionnelle est
vmap - Cela signifie « vectorized map », ce qui décrit exactement sa fonction
- On peut tout mapper, et tant qu’il s’agit de
vmap, XLA fusionne et optimise automatiquement - Il n’est pas nécessaire de réfléchir à la dimension batch lors de l’écriture des fonctions
- Il suffit de passer tout le code dans
vmap - Cela signifie qu’on a moins besoin d’opérations ein-*
- Il devient plus intuitif de raisonner sur des manipulations de tenseurs 2D/3D, avec une bien meilleure lisibilité
- Comme il suffit d’isoler les composants individuels pour raisonner dessus, il devient plus facile d’écrire du code complexe qui fonctionne correctement
- Tant qu’on respecte les contraintes de pureté et qu’on a la bonne signature, on bénéficie de tous les autres avantages comme la composabilité
- Un autre résultat très intéressant de cette conception fonctionnelle est
- Les problèmes de l’écosystème PyTorch
- Avec torch, quel que soit le stack utilisé (
FSDP+ multi-nœud +torch.compile, etc.), il y a toujours un risque que quelque chose casse - Il faut que plusieurs éléments fonctionnent correctement ensemble, et si un seul composant échoue, il faut déboguer jusqu’à 3 heures du matin
- Comme il est impossible de tester toutes les combinaisons des dizaines de fonctionnalités proposées par PyTorch, il y aura toujours des bugs non détectés pendant le développement
- Il est impossible d’écrire du code qui fonctionne bien sans un effort considérable
- L’écosystème torch est devenu extrêmement volumineux et truffé de bugs
- Faute d’abstraction partagée, de nouvelles bibliothèques et de nouveaux frameworks apparaissent sans avoir été conçus pour s’interfacer avec les autres « solutions »
- Cela dégénère rapidement en chaos de dépendances et de
requirements.txt - 70 à 80 % des issues GitHub ou des discussions sur les forums viennent simplement d’erreurs entre différentes bibliothèques
- Il n’existe presque aucun moyen d’y remédier
- Avec torch, quel que soit le stack utilisé (
- L’absence de solution
- C’est un problème d’OOP et de conception
- On pourrait penser qu’un objet fondamental et typiquement PyTorch comme PyTree aurait aidé à construire une base commune d’abstraction
- Il est également impossible d’adopter un paradigme de programmation fonctionnelle
- Cela convergerait vers une version moins performante de JAX tout en cassant la rétrocompatibilité de toutes les codebases torch existantes
- PyTorch semble complètement dans l’impasse sur ce point
L’avantage de JAX en matière de reproductibilité
- Gestion des seeds
- La gestion des seeds dans PyTorch n’est pas idéale
- En général, il faut exécuter plusieurs lignes de code
- C’est facile à oublier ou à mal configurer
- JAX force à créer des clés explicites et à les passer à toutes les fonctions qui nécessitent de l’aléatoire
- Cette approche élimine complètement le problème, car le RNG est toujours seedé statiquement
- JAX dispose de sa propre version de NumPy (
jax.numpy), il n’est donc pas nécessaire de gérer une seed séparément - Ce genre de petit choix QoL peut nettement améliorer l’expérience utilisateur de l’ensemble du framework
- Portabilité
- L’un des plus gros problèmes quand on utilise une codebase PyTorch est le manque de portabilité
- Une codebase écrite pour CUDA/GPU ne fonctionne pas bien lorsqu’elle s’exécute sur du matériel non Nvidia comme TPU, NPU, GPU AMD, etc.
- Il est difficile de porter du code PyTorch écrit pour un seul nœud vers du multi-nœud
- Le multi-nœud exige souvent des dizaines d’heures de développement et des modifications de code importantes
- L’approche centrée compilateur de JAX offre ici un avantage
- XLA gère le passage entre backends matériels et fonctionne bien sur GPU/TPU/multi-nœud/multi-slice avec des modifications de code minimales
- Cela facilite le support des appareils par les fournisseurs matériels et simplifie le passage d’un appareil à l’autre
- Comme tout le monde n’a pas accès au même matériel, des codebases portables sur différents types de matériel pourraient constituer un petit pas pour rendre le deep learning plus accessible aux débutants et aux intermédiaires
- Auto-scaling
- Une codebase capable de bien s’auto-scaler est très utile pour la reproductibilité
- Dans l’idéal, cela devrait se produire automatiquement avec un minimum de changements de code, indépendamment des frontières réseau
- JAX fait cela très bien
- Quand on écrit du code JAX, il n’est pas nécessaire de spécifier des primitives de communication ni de placer
torch.distributed.barrier()partout - XLA les insère automatiquement en fonction du matériel disponible
- Tous les appareils détectables par JAX sont utilisés automatiquement, quels que soient le réseau, la topologie, la configuration, etc.
- Il synchronise et prépare automatiquement le calcul, puis applique des passes d’optimisation afin de maximiser l’exécution asynchrone des kernels et de minimiser la latence
- La seule chose à faire est de spécifier le sharding des tenseurs que l’on souhaite répartir sur les appareils, par exemple la dimension batch des tableaux d’entrée
- Grâce à l’approche de XLA où « le calcul suit le sharding », il déduit automatiquement le reste
- Il devient possible d’exécuter facilement, comme hobby, des expériences validées à grande échelle afin d’expérimenter et potentiellement d’itérer
- Cela peut faciliter la redécouverte d’idées oubliées et encourager ce type d’expériences, puisqu’on peut les tester facilement comme une fonction à plus grande échelle avec un minimum d’effort
Les inconvénients de JAX
- Structure de gouvernance
- Actuellement, XLA relève de la gouvernance de TensorFlow
- Il y a eu des discussions sur la création d’un organe organisationnel séparé, semblable à celui de PyTorch, mais peu d’efforts concrets ont été réalisés
- La confiance envers Google reste limitée en raison de sa réputation à abandonner les produits impopulaires
- JAX est techniquement un projet DeepMind et joue un rôle central dans l’ensemble de la stratégie IA de Google, mais une structure distincte semblerait offrir d’importants bénéfices à long terme pour tout l’écosystème
- Un organe de gouvernance séparé donnerait une direction au développement du projet
- Il fournirait une structure concrète et permettrait d’éviter en une fois de nombreux problèmes en s’isolant de la bureaucratie notoire de Google
- JAX n’a pas nécessairement besoin d’une structure formelle de ce type, mais il serait bon d’avoir la garantie que son développement se poursuivra longtemps, indépendamment des décisions de la direction de Google
- Cela aiderait clairement à son adoption dans les entreprises et les grands laboratoires de recherche, qui hésitent à investir des ressources dans l’intégration d’un outil qui pourrait un jour ne plus être maintenu
- L’ouverture open source de XLA
- Pendant longtemps, XLA a été un projet closed source
- Toutefois, des efforts ont été engagés pour l’ouvrir, et aujourd’hui OpenXLA affiche des performances bien supérieures aux builds internes de XLA
- Mais la documentation sur les entrailles de XLA reste insuffisante
- La plupart des ressources se limitent à des talks en direct et à quelques articles occasionnels, souvent datés
- Une roadmap publique sur les fonctionnalités à venir permettrait aux gens de suivre les progrès et de contribuer plus facilement, notamment sur les sujets les plus intéressants
- Des mini-billets de blog dans le style d’Edward Yang, analysant chaque étape de la pile de compilation XLA et en expliquant les détails, offriraient aux praticiens un meilleur moyen d’évaluer ce que XLA peut et ne peut pas faire
- Je comprends que cela demande beaucoup de ressources et que ces efforts pourraient être mieux utilisés ailleurs, mais les gens font davantage confiance aux outils qu’ils comprennent, et cela aurait selon moi des effets positifs sur l’ensemble de l’écosystème au bénéfice de tous
- Intégration de l’écosystème
flaxest un point douloureux de l’écosystème JAX- Son API peu intuitive et sa syntaxe concise en font un enfer absolu pour les débutants venant de PyTorch
- Il vaut mieux utiliser
equinox - Il y a eu des tentatives de l’équipe de développement pour corriger les défauts de
flax, mais au final c’est une perte de temps - Si l’on veut une API dans le style de
equinox, autant utiliserequinox flaxn’a pas grand-chose qu’il fasse particulièrement mieux, et ce n’est pas difficile à reproduire avecequinox- Aujourd’hui, une grande partie de l’écosystème JAX est conçue autour de
flax equinox, comme il s’interface fondamentalement avec PyTree, est interopérable avec toutes les bibliothèques, même s’il faut un peu deeqx.partitionet defilter- Il faudrait changer le statu quo.
equinoxdevrait bénéficier d’un support de premier plan partout - C’est une opinion discutable, mais cela relève de l’erreur classique des coûts irrécupérables
equinoxfonctionne mieux de la manière dont le framework JAX aurait toujours dû fonctionner- Comme résumé dans la documentation de
equinox, si l’on compareequinoxetflax,equinoxest meilleur - Il est positif que les mainteneurs de l’écosystème JAX reconnaissent la popularité de
equinoxet s’adaptent en conséquence, mais on aimerait aussi davantage de soutien officiel de la part de Google et de l’équipeflax - Si vous voulez essayer JAX, il vaut mieux utiliser
equinox
- Les angles vifs
- En raison de choix de conception d’API et des limitations de XLA, JAX comporte des « angles vifs » auxquels il faut faire attention
- Une documentation bien rédigée les explique de manière très concise
- Il est recommandé de la lire au moins une fois avant d’utiliser JAX
- Comme toujours, faire le RTFM permet d’économiser énormément de temps et d’énergie
Conclusion
- Cet article de blog visait à corriger le mythe, souvent répété, selon lequel PyTorch serait le mieux adapté aux charges de travail de recherche réelles, en particulier sur GPU. Ce n’est plus le cas
- En réalité, l’auteur va jusqu’à soutenir qu’il serait extrêmement bénéfique pour l’ensemble du domaine de porter tout le code PyTorch vers JAX
- l’auto-parallélisation, la reproductibilité, une API fonctionnelle propre, etc., ne sont pas des détails mineurs et apporteraient beaucoup à de nombreuses codebases de recherche
- Si vous voulez améliorer ne serait-ce qu’un peu ce domaine, envisagez de réécrire vos codebases en JAX
8 commentaires
Le monde continue d’avancer. haha
Comparaison entre PyTorch et TensorFlow en 2022
Je vais m’en tenir à
torchetonnx.Un article écrit par un étudiant en licence... wow
Sans Huggingface, PyTorch serait vraiment mort, lol.
Vive JAX ! Je l’ai essayé récemment, et l’API NNX m’a beaucoup plu.
Le plus gros problème de JAX, c’est que c’est Google. Google est très connu pour abandonner ses projets open source (
Tflite,android things,dart,angular,bazel, etc.). Mêmetensorflowa commencé, à partir d’un certain moment, à être beaucoup moins bien mis à jour. À l’inverse,torchvient de Facebook, qui gère un vaste écosystème open source, et le projet est très bien entretenu ; il est déjà administré par la fondationtorch. Les défauts detorchsont réels sur certains points, mais lorsqu’il s’agit de savoir qui peut faire vivre durablement un projet open source, j’ai l’impression que JAX part déjà avec un risque important.Au moins, on dirait que Dart continuera à bien se porter pendant un moment grâce à Flutter.
Facebook semble malgré tout continuer à contribuer avec une certaine loyauté (?) aux technologies qu’il utilise, comme React ou Django, mais Google donne l’impression d’abandonner comme un vieux chiffon tout ce qui devient un peu dépassé...