Ajuster finement Llama 405B sur GPU AMD
(publish.obsidian.md)Felafax BlogTune Llama3 405B on AMD MI300x (notre parcours)
Introduction
- À mesure que les modèles open source grossissent, le besoin d’une infrastructure puissante capable de gérer l’entraînement d’IA à grande échelle augmente
- Felafax a ajusté finement le modèle LLaMA 3.1 405B sur des GPU AMD, démontrant l’efficacité du matériel AMD
- L’ensemble du travail a été publié en open source sur GitHub
- Les GPU AMD MI300X offrent de hautes performances par rapport au matériel IA de NVIDIA
- Le projet a été rendu possible grâce au soutien de TensorWave
Qu’est-ce que JAX et pourquoi l’avoir choisi
- JAX est une puissante bibliothèque de machine learning qui combine une API similaire à NumPy, la différenciation automatique et le compilateur XLA de Google
- Il fournit d’excellentes API pour le parallélisme de modèle, ce qui le rend idéal pour l’entraînement de grands modèles
Avantages de JAX
- Fonctions pures : JAX encourage l’écriture de fonctions pures, ce qui facilite la composition, le débogage et la lecture du code
- Parallélisme avancé : l’API JIT flexible de JAX prend en charge un parallélisme avancé des données et des modèles, essentiel pour l’entraînement à grande échelle
- Base de code propre : la philosophie de conception de JAX encourage l’écriture d’un code portable entre différentes plateformes matérielles
Pourquoi JAX se distingue sur le matériel non NVIDIA
- Approche indépendante du matériel : JAX exploite le compilateur XLA pour compiler les calculs vers une représentation intermédiaire indépendante du matériel
- Optimisation indépendante de la plateforme : le compilateur XLA effectue les optimisations indépendamment du matériel
- Portabilité simple : avec JAX, passer de NVIDIA à AMD nécessite très peu de modifications de code
Configuration de JAX sur GPU AMD
- L’image Docker a été récupérée, le conteneur lancé, puis l’installation vérifiée
- Le modèle LLaMA 405B a été entraîné à l’aide de 8 GPU AMD MI300x
Entraîner LLaMA 405B : performances et scalabilité
- Le modèle LLaMA 405B a été entraîné sur des GPU AMD avec JAX
- Avec l’ajustement fin LoRA, les poids du modèle et les paramètres LoRA ont été réglés avec une précision bfloat16
- Taille du modèle : environ 800 Go de VRAM
- Poids LoRA et état de l’optimiseur : environ 400 Go de VRAM
- Utilisation totale de VRAM : environ 1200 Go
- Vitesse d’entraînement : environ 35 tokens par seconde
- Efficacité mémoire : environ 70 % maintenus
- Scalabilité : avec JAX, la montée en charge sur 8 GPU est presque linéaire
Notre configuration d’entraînement
- LLaMA 3.1 a été converti de PyTorch vers JAX
- Le modèle a été distribué efficacement via le chargement du modèle et le sharding des paramètres
Sharding des paramètres dans JAX
- La fonctionnalité de device mesh de JAX a été utilisée pour répartir efficacement le modèle sur 8 GPU AMD
- Des règles de sharding des paramètres ont été définies afin de shard chaque dimension de tenseur selon les axes du mesh
Implémentation de l’entraînement LoRA
- LoRA réduit le nombre de paramètres entraînables en décomposant les mises à jour de poids en matrices de faible rang
- Une couche
LoRADensea été implémentée pour inclure les paramètres LoRA - Les paramètres LoRA ont été distribués efficacement afin d’optimiser l’usage mémoire et l’efficacité de calcul
Conclusion
- L’expérience d’ajustement fin du modèle LLaMA 3.1 405B avec des GPU AMD et JAX a été très positive
- Les puissantes capacités de parallélisme de JAX et son approche indépendante du matériel ont permis de distribuer efficacement le modèle
- Cela démontre que les GPU AMD constituent une alternative solide pour l’entraînement d’IA à grande échelle
- Le code complet peut être consulté et exécuté directement depuis le dépôt GitHub
Résumé de GN⁺
- Cet article explique comment entraîner efficacement de grands modèles d’IA avec des GPU AMD et JAX
- Il souligne que le matériel AMD constitue une alternative rentable à NVIDIA
- L’approche indépendante du matériel de JAX améliore la portabilité du code et facilite la maintenance
- Il fournit des informations utiles et du code pratique à celles et ceux qui s’intéressent à l’entraînement de grands modèles
- Parmi les projets aux fonctionnalités similaires figurent CUDA de NVIDIA et PyTorch
1 commentaires
Avis sur Hacker News
Nous avons récemment affiné le modèle llama3.1 405B sur 8 GPU AMD MI300x avec JAX au lieu de PyTorch
Grâce à l’API de sharding avancée de JAX, nous avons obtenu de bonnes performances, et nous avons détaillé la technique de sharding utilisée dans un billet de blog. Le code est également public : https://github.com/felafax/felafax
Nous sommes une petite startup qui construit une infrastructure IA pour le fine-tuning et le serving de LLM sur du matériel non-NVIDIA (TPU, AMD, Trainium)
Beaucoup d’entreprises essaient d’exécuter PyTorch sur des GPU AMD, mais nous pensons que PyTorch est trop profondément lié à l’écosystème NVIDIA, via
torch.cudaouscaled_dot_product_attentionpar exemple, et qu’il faut donc beaucoup de travail de « dé-NVIDIA-isation »Nous pensons que JAX est mieux adapté au matériel non-NVIDIA, car le code du modèle est compilé en graphe HLO indépendant du matériel, puis le compilateur XLA l’optimise avant d’appliquer des optimisations spécifiques à chaque cible. Le même code JAX de LLaMA3 a fonctionné sur Google TPU et sur GPU AMD sans modification
Notre stratégie consiste à porter d’abord les modèles vers JAX, puis à exploiter le framework JAX et les kernels XLA pour tirer un maximum de performances de backends non-NVIDIA. C’est pourquoi nous avons d’abord porté Llama 3.1 de PyTorch vers JAX, et le même modèle JAX fonctionne bien sur TPU comme sur GPU AMD
Personnellement, si j’utilise PyTorch en priorité, c’est parce que le modèle d’origine a été conçu avec PyTorch. Même si la logique semble identique entre différentes versions d’un modèle, à très grande échelle, de minuscules erreurs en virgule flottante peuvent s’accumuler et provoquer une dérive du modèle
Déboguer ce genre d’écarts de précision sur de gros modèles, c’est presque pire que le dixième cercle de l’enfer
hipblaslt, Composable Kernel FA, etc.Je ne connais pas très bien JAX, mais à mon avis, une grande partie des piètres performances de l’entraînement PyTorch sur MI300x vient de la lenteur des bibliothèques ROCm utilisées en interne
Et par « fonctionne », je ne parle pas d’un état où l’on passe deux semaines à faire marcher les pilotes avant de ne plus jamais oser mettre le serveur à jour
Je serais aussi curieux de connaître les problèmes techniques rencontrés
Pour être clair, ces performances sont assez mauvaises. Cela ressemble probablement à un problème de compilation mal exploitée
Sur le modèle 405B, on obtient 35 tokens/s, ce qui correspond à environ 85 téraflops. Or 8 GPU MI300x représentent environ 10,4 pétaflops, soit un MFU d’environ 0,8 %
C’est 40 à 50 fois moins qu’un entraînement correct à 30-40 % de MFU, donc AMD doit espérer que le goulot d’étranglement se situe dans la stack logicielle
La page GitHub dit qu’on peut ajuster LLaMa3.1 sur Google Cloud TPU avec un coût 30 % inférieur, mais elle ne mentionne pas les performances
Excellent travail. J’ai moi aussi un peu testé les GPU AMD et le support ROCm il y a environ un an, et il était évident qu’AMD avait encore beaucoup de chemin à parcourir pour rattraper Nvidia
L’approche consistant à choisir JAX est intéressante, mais je me demande quelles difficultés vous avez rencontrées en vous éloignant de PyTorch, qui est presque la bibliothèque standard du machine learning
Au départ, notre objectif était de faire du fine-tuning de LLaMA 3 sur TPU, mais PyTorch XLA était trop rudimentaire, donc nous avons décidé de réécrire le modèle en JAX
Comme nous l’avons dit plus haut, nous pensons que JAX est une meilleure plateforme pour les GPU non-NVIDIA, et nous voulons construire une infrastructure pour GPU non-NVIDIA sur JAX+openXLA
Beau travail. Le week-end dernier, je bricolais moi aussi la partie inférence du 405B [0]
Je ne suis pas convaincu que
torch.cudasoit si problématique. PyTorch pour AMD le remappe de toute façon. Cela ressemble davantage à un problème de nommage qu’à un problème fondamentalEn pratique, récupérer le conteneur
rocm:pytorchest aussi simple que récupérer le conteneurrocm:jaxIl n’y a pas beaucoup de chiffres publiés, donc je me demande quel MFU vous avez obtenu
[0] https://x.com/HotAisle/status/1837580046732874026
Il faut que nous calculions le MFU. Les détails sur les GPU et la VRAM sont disponibles dans le dépôt : https://dub.sh/amd-405b-res
Nous prévoyons de relancer l’entraînement le week-end prochain en compilant JIT l’ensemble de l’étape d’entraînement, et nous calculerons alors le MFU
D’après nos mesures chez ZML, le MI300X était 30 % plus rapide que le H100. Ce sont d’excellentes puces
Je me demande s’il existe des fournisseurs cloud auprès desquels on peut louer des hôtes 8xAMD MI300
Pour le travail, j’utilise beaucoup AWS, mais j’aimerais essayer les GPU AMD
Où sont les données de performance ?
Nous n’avons pas pu exécuter la version compilée en JIT du modèle 405B à cause de limitations de code et de VRAM. C’est un point qu’il faut approfondir
L’exécution complète de l’entraînement a été faite en mode eager de JAX, donc il y a encore une large marge d’amélioration des performances
Même en mode eager, l’utilisation GPU était globalement d’environ 30 à 40 %, ce qui est plutôt correct. Avec le JIT, nous pensons pouvoir monter assez facilement à 50-60 % d’utilisation GPU
Il serait intéressant, si possible, d’explorer des moyens de lever les contraintes mémoire afin d’exécuter la version compilée en JIT. Cela pourrait apporter un gain de performance supplémentaire
Il nous faut une étape d’entraînement compilée en JIT, un chargement de données et un sharding mieux optimisés, de l’accumulation de gradient, et l’activation checkpointing
Nous continuons à construire et nous publierons bientôt un nouveau billet de blog après avoir implémenté toutes ces améliorations
Je me demande si AMD se rapproche ne serait-ce qu’un peu d’une situation où elle pourrait vraiment capter de la valeur ici via de grosses commandes de GPU et une pénurie d’offre
Mon impression penche plutôt vers « non »
En face, ils ont une avance considérable, et il y a clairement encore beaucoup de travail côté logiciel. Cela prendra du temps
Pourquoi l’application de prise de notes Obsidian fait-elle ça ?