Felafax BlogTune Llama3 405B on AMD MI300x (notre parcours)
Introduction
- À mesure que les modèles open source grossissent, le besoin d’une infrastructure puissante capable de gérer l’entraînement d’IA à grande échelle augmente
- Felafax a ajusté finement le modèle LLaMA 3.1 405B sur des GPU AMD, démontrant l’efficacité du matériel AMD
- L’ensemble du travail a été publié en open source sur GitHub
- Les GPU AMD MI300X offrent de hautes performances par rapport au matériel IA de NVIDIA
- Le projet a été rendu possible grâce au soutien de TensorWave
Qu’est-ce que JAX et pourquoi l’avoir choisi
- JAX est une puissante bibliothèque de machine learning qui combine une API similaire à NumPy, la différenciation automatique et le compilateur XLA de Google
- Il fournit d’excellentes API pour le parallélisme de modèle, ce qui le rend idéal pour l’entraînement de grands modèles
Avantages de JAX
- Fonctions pures : JAX encourage l’écriture de fonctions pures, ce qui facilite la composition, le débogage et la lecture du code
- Parallélisme avancé : l’API JIT flexible de JAX prend en charge un parallélisme avancé des données et des modèles, essentiel pour l’entraînement à grande échelle
- Base de code propre : la philosophie de conception de JAX encourage l’écriture d’un code portable entre différentes plateformes matérielles
Pourquoi JAX se distingue sur le matériel non NVIDIA
- Approche indépendante du matériel : JAX exploite le compilateur XLA pour compiler les calculs vers une représentation intermédiaire indépendante du matériel
- Optimisation indépendante de la plateforme : le compilateur XLA effectue les optimisations indépendamment du matériel
- Portabilité simple : avec JAX, passer de NVIDIA à AMD nécessite très peu de modifications de code
Configuration de JAX sur GPU AMD
- L’image Docker a été récupérée, le conteneur lancé, puis l’installation vérifiée
- Le modèle LLaMA 405B a été entraîné à l’aide de 8 GPU AMD MI300x
Entraîner LLaMA 405B : performances et scalabilité
- Le modèle LLaMA 405B a été entraîné sur des GPU AMD avec JAX
- Avec l’ajustement fin LoRA, les poids du modèle et les paramètres LoRA ont été réglés avec une précision bfloat16
- Taille du modèle : environ 800 Go de VRAM
- Poids LoRA et état de l’optimiseur : environ 400 Go de VRAM
- Utilisation totale de VRAM : environ 1200 Go
- Vitesse d’entraînement : environ 35 tokens par seconde
- Efficacité mémoire : environ 70 % maintenus
- Scalabilité : avec JAX, la montée en charge sur 8 GPU est presque linéaire
Notre configuration d’entraînement
- LLaMA 3.1 a été converti de PyTorch vers JAX
- Le modèle a été distribué efficacement via le chargement du modèle et le sharding des paramètres
Sharding des paramètres dans JAX
- La fonctionnalité de device mesh de JAX a été utilisée pour répartir efficacement le modèle sur 8 GPU AMD
- Des règles de sharding des paramètres ont été définies afin de shard chaque dimension de tenseur selon les axes du mesh
Implémentation de l’entraînement LoRA
- LoRA réduit le nombre de paramètres entraînables en décomposant les mises à jour de poids en matrices de faible rang
- Une couche
LoRADense a été implémentée pour inclure les paramètres LoRA
- Les paramètres LoRA ont été distribués efficacement afin d’optimiser l’usage mémoire et l’efficacité de calcul
Conclusion
- L’expérience d’ajustement fin du modèle LLaMA 3.1 405B avec des GPU AMD et JAX a été très positive
- Les puissantes capacités de parallélisme de JAX et son approche indépendante du matériel ont permis de distribuer efficacement le modèle
- Cela démontre que les GPU AMD constituent une alternative solide pour l’entraînement d’IA à grande échelle
- Le code complet peut être consulté et exécuté directement depuis le dépôt GitHub
Résumé de GN⁺
- Cet article explique comment entraîner efficacement de grands modèles d’IA avec des GPU AMD et JAX
- Il souligne que le matériel AMD constitue une alternative rentable à NVIDIA
- L’approche indépendante du matériel de JAX améliore la portabilité du code et facilite la maintenance
- Il fournit des informations utiles et du code pratique à celles et ceux qui s’intéressent à l’entraînement de grands modèles
- Parmi les projets aux fonctionnalités similaires figurent CUDA de NVIDIA et PyTorch
1 commentaires
Avis Hacker News
Partage des résultats du fine-tuning du modèle Llama3.1 405B sur 8 GPU AMD MI300x avec JAX
Suggestion d’explorer des moyens de surmonter les contraintes mémoire et d’exécuter une version compilée en JIT
Partage d’expérience sur les GPU AMD et le support ROCm
Partage d’expérience sur des expérimentations côté inférence avec le modèle 405B
torch.cudan’est pas si problématiquerocm:pytorchest aussi simple qu’utiliser le conteneurrocm:jaxQuestion sur l’absence de données de performance
Interrogation sur la raison pour laquelle Obsidian (application de prise de notes) ferait cela
Demande à @dang d’inclure le nom d’utilisateur dans l’URL