2 points par GN⁺ 2024-09-24 | 1 commentaires | Partager sur WhatsApp

Felafax BlogTune Llama3 405B on AMD MI300x (notre parcours)

Introduction

  • À mesure que les modèles open source grossissent, le besoin d’une infrastructure puissante capable de gérer l’entraînement d’IA à grande échelle augmente
  • Felafax a ajusté finement le modèle LLaMA 3.1 405B sur des GPU AMD, démontrant l’efficacité du matériel AMD
  • L’ensemble du travail a été publié en open source sur GitHub
  • Les GPU AMD MI300X offrent de hautes performances par rapport au matériel IA de NVIDIA
  • Le projet a été rendu possible grâce au soutien de TensorWave

Qu’est-ce que JAX et pourquoi l’avoir choisi

  • JAX est une puissante bibliothèque de machine learning qui combine une API similaire à NumPy, la différenciation automatique et le compilateur XLA de Google
  • Il fournit d’excellentes API pour le parallélisme de modèle, ce qui le rend idéal pour l’entraînement de grands modèles

Avantages de JAX

  • Fonctions pures : JAX encourage l’écriture de fonctions pures, ce qui facilite la composition, le débogage et la lecture du code
  • Parallélisme avancé : l’API JIT flexible de JAX prend en charge un parallélisme avancé des données et des modèles, essentiel pour l’entraînement à grande échelle
  • Base de code propre : la philosophie de conception de JAX encourage l’écriture d’un code portable entre différentes plateformes matérielles

Pourquoi JAX se distingue sur le matériel non NVIDIA

  • Approche indépendante du matériel : JAX exploite le compilateur XLA pour compiler les calculs vers une représentation intermédiaire indépendante du matériel
  • Optimisation indépendante de la plateforme : le compilateur XLA effectue les optimisations indépendamment du matériel
  • Portabilité simple : avec JAX, passer de NVIDIA à AMD nécessite très peu de modifications de code

Configuration de JAX sur GPU AMD

  • L’image Docker a été récupérée, le conteneur lancé, puis l’installation vérifiée
  • Le modèle LLaMA 405B a été entraîné à l’aide de 8 GPU AMD MI300x

Entraîner LLaMA 405B : performances et scalabilité

  • Le modèle LLaMA 405B a été entraîné sur des GPU AMD avec JAX
  • Avec l’ajustement fin LoRA, les poids du modèle et les paramètres LoRA ont été réglés avec une précision bfloat16
  • Taille du modèle : environ 800 Go de VRAM
  • Poids LoRA et état de l’optimiseur : environ 400 Go de VRAM
  • Utilisation totale de VRAM : environ 1200 Go
  • Vitesse d’entraînement : environ 35 tokens par seconde
  • Efficacité mémoire : environ 70 % maintenus
  • Scalabilité : avec JAX, la montée en charge sur 8 GPU est presque linéaire

Notre configuration d’entraînement

  • LLaMA 3.1 a été converti de PyTorch vers JAX
  • Le modèle a été distribué efficacement via le chargement du modèle et le sharding des paramètres

Sharding des paramètres dans JAX

  • La fonctionnalité de device mesh de JAX a été utilisée pour répartir efficacement le modèle sur 8 GPU AMD
  • Des règles de sharding des paramètres ont été définies afin de shard chaque dimension de tenseur selon les axes du mesh

Implémentation de l’entraînement LoRA

  • LoRA réduit le nombre de paramètres entraînables en décomposant les mises à jour de poids en matrices de faible rang
  • Une couche LoRADense a été implémentée pour inclure les paramètres LoRA
  • Les paramètres LoRA ont été distribués efficacement afin d’optimiser l’usage mémoire et l’efficacité de calcul

Conclusion

  • L’expérience d’ajustement fin du modèle LLaMA 3.1 405B avec des GPU AMD et JAX a été très positive
  • Les puissantes capacités de parallélisme de JAX et son approche indépendante du matériel ont permis de distribuer efficacement le modèle
  • Cela démontre que les GPU AMD constituent une alternative solide pour l’entraînement d’IA à grande échelle
  • Le code complet peut être consulté et exécuté directement depuis le dépôt GitHub

Résumé de GN⁺

  • Cet article explique comment entraîner efficacement de grands modèles d’IA avec des GPU AMD et JAX
  • Il souligne que le matériel AMD constitue une alternative rentable à NVIDIA
  • L’approche indépendante du matériel de JAX améliore la portabilité du code et facilite la maintenance
  • Il fournit des informations utiles et du code pratique à celles et ceux qui s’intéressent à l’entraînement de grands modèles
  • Parmi les projets aux fonctionnalités similaires figurent CUDA de NVIDIA et PyTorch

1 commentaires

 
GN⁺ 2024-09-24
Avis sur Hacker News
  • Nous avons récemment affiné le modèle llama3.1 405B sur 8 GPU AMD MI300x avec JAX au lieu de PyTorch
    Grâce à l’API de sharding avancée de JAX, nous avons obtenu de bonnes performances, et nous avons détaillé la technique de sharding utilisée dans un billet de blog. Le code est également public : https://github.com/felafax/felafax
    Nous sommes une petite startup qui construit une infrastructure IA pour le fine-tuning et le serving de LLM sur du matériel non-NVIDIA (TPU, AMD, Trainium)
    Beaucoup d’entreprises essaient d’exécuter PyTorch sur des GPU AMD, mais nous pensons que PyTorch est trop profondément lié à l’écosystème NVIDIA, via torch.cuda ou scaled_dot_product_attention par exemple, et qu’il faut donc beaucoup de travail de « dé-NVIDIA-isation »
    Nous pensons que JAX est mieux adapté au matériel non-NVIDIA, car le code du modèle est compilé en graphe HLO indépendant du matériel, puis le compilateur XLA l’optimise avant d’appliquer des optimisations spécifiques à chaque cible. Le même code JAX de LLaMA3 a fonctionné sur Google TPU et sur GPU AMD sans modification
    Notre stratégie consiste à porter d’abord les modèles vers JAX, puis à exploiter le framework JAX et les kernels XLA pour tirer un maximum de performances de backends non-NVIDIA. C’est pourquoi nous avons d’abord porté Llama 3.1 de PyTorch vers JAX, et le même modèle JAX fonctionne bien sur TPU comme sur GPU AMD

    • Il n’y a pas eu de vrai problème pour exécuter PyTorch sur des GPU AMD sans modifier le code CUDA. Le billet de blog de MosaicML mérite aussi le détour : https://www.databricks.com/blog/training-llms-scale-amd-mi25...
    • Je me demande comment vous validez la fidélité du portage JAX de Llama 3.1
      Personnellement, si j’utilise PyTorch en priorité, c’est parce que le modèle d’origine a été conçu avec PyTorch. Même si la logique semble identique entre différentes versions d’un modèle, à très grande échelle, de minuscules erreurs en virgule flottante peuvent s’accumuler et provoquer une dérive du modèle
      Déboguer ce genre d’écarts de précision sur de gros modèles, c’est presque pire que le dixième cercle de l’enfer
    • Je me demande si JAX a ses propres implémentations de multiplication de matrices ou de FlashAttention, ou s’il utilise les implémentations ROCm comme PyTorch. Par exemple hipblaslt, Composable Kernel FA, etc.
      Je ne connais pas très bien JAX, mais à mon avis, une grande partie des piètres performances de l’entraînement PyTorch sur MI300x vient de la lenteur des bibliothèques ROCm utilisées en interne
    • Je me demande si cela fonctionne aussi sur des cartes grand public comme la 7900 XTX
      Et par « fonctionne », je ne parle pas d’un état où l’on passe deux semaines à faire marcher les pilotes avant de ne plus jamais oser mettre le serveur à jour
    • Puisqu’il s’agit d’une migration, je me demande s’il existe des chiffres réels comparés à la version PyTorch du même modèle. Le tableau comparatif de l’article semble porter davantage sur des aspects techniques
      Je serais aussi curieux de connaître les problèmes techniques rencontrés
  • Pour être clair, ces performances sont assez mauvaises. Cela ressemble probablement à un problème de compilation mal exploitée
    Sur le modèle 405B, on obtient 35 tokens/s, ce qui correspond à environ 85 téraflops. Or 8 GPU MI300x représentent environ 10,4 pétaflops, soit un MFU d’environ 0,8 %
    C’est 40 à 50 fois moins qu’un entraînement correct à 30-40 % de MFU, donc AMD doit espérer que le goulot d’étranglement se situe dans la stack logicielle

    • C’est exactement ce que je voulais demander moi aussi
      La page GitHub dit qu’on peut ajuster LLaMa3.1 sur Google Cloud TPU avec un coût 30 % inférieur, mais elle ne mentionne pas les performances
  • Excellent travail. J’ai moi aussi un peu testé les GPU AMD et le support ROCm il y a environ un an, et il était évident qu’AMD avait encore beaucoup de chemin à parcourir pour rattraper Nvidia
    L’approche consistant à choisir JAX est intéressante, mais je me demande quelles difficultés vous avez rencontrées en vous éloignant de PyTorch, qui est presque la bibliothèque standard du machine learning

    • Nous avons publié il y a quelques semaines un Show HN qui raconte notre parcours : https://news.ycombinator.com/item?id=41512142
      Au départ, notre objectif était de faire du fine-tuning de LLaMA 3 sur TPU, mais PyTorch XLA était trop rudimentaire, donc nous avons décidé de réécrire le modèle en JAX
      Comme nous l’avons dit plus haut, nous pensons que JAX est une meilleure plateforme pour les GPU non-NVIDIA, et nous voulons construire une infrastructure pour GPU non-NVIDIA sur JAX+openXLA
    • Sur mon système Debian 12, je n’arrive pas à faire fonctionner AMD ROCm, donc Ollama utilise probablement le CPU au lieu du GPU. Il reste clairement beaucoup de chemin à parcourir
  • Beau travail. Le week-end dernier, je bricolais moi aussi la partie inférence du 405B [0]
    Je ne suis pas convaincu que torch.cuda soit si problématique. PyTorch pour AMD le remappe de toute façon. Cela ressemble davantage à un problème de nommage qu’à un problème fondamental
    En pratique, récupérer le conteneur rocm:pytorch est aussi simple que récupérer le conteneur rocm:jax
    Il n’y a pas beaucoup de chiffres publiés, donc je me demande quel MFU vous avez obtenu
    [0] https://x.com/HotAisle/status/1837580046732874026

    • Bien
      Il faut que nous calculions le MFU. Les détails sur les GPU et la VRAM sont disponibles dans le dépôt : https://dub.sh/amd-405b-res
      Nous prévoyons de relancer l’entraînement le week-end prochain en compilant JIT l’ensemble de l’étape d’entraînement, et nous calculerons alors le MFU
  • D’après nos mesures chez ZML, le MI300X était 30 % plus rapide que le H100. Ce sont d’excellentes puces

  • Je me demande s’il existe des fournisseurs cloud auprès desquels on peut louer des hôtes 8xAMD MI300
    Pour le travail, j’utilise beaucoup AWS, mais j’aimerais essayer les GPU AMD

    • Pour information, notre entreprise loue des 8xMI300x, donc n’hésitez pas à nous contacter
    • Oracle en propose. D’autres suivront probablement, mais à mon avis il est plus raisonnable de passer par de petits acteurs
  • Où sont les données de performance ?

    • Nous avons ajouté au dépôt GitHub des données sur l’utilisation des GPU et de la VRAM : https://github.com/felafax/felafax?tab=readme-ov-file#amd-40...
      Nous n’avons pas pu exécuter la version compilée en JIT du modèle 405B à cause de limitations de code et de VRAM. C’est un point qu’il faut approfondir
      L’exécution complète de l’entraînement a été faite en mode eager de JAX, donc il y a encore une large marge d’amélioration des performances
      Même en mode eager, l’utilisation GPU était globalement d’environ 30 à 40 %, ce qui est plutôt correct. Avec le JIT, nous pensons pouvoir monter assez facilement à 50-60 % d’utilisation GPU
  • Il serait intéressant, si possible, d’explorer des moyens de lever les contraintes mémoire afin d’exécuter la version compilée en JIT. Cela pourrait apporter un gain de performance supplémentaire

    • D’accord. Il reste encore beaucoup de performances à aller chercher
      Il nous faut une étape d’entraînement compilée en JIT, un chargement de données et un sharding mieux optimisés, de l’accumulation de gradient, et l’activation checkpointing
      Nous continuons à construire et nous publierons bientôt un nouveau billet de blog après avoir implémenté toutes ces améliorations
  • Je me demande si AMD se rapproche ne serait-ce qu’un peu d’une situation où elle pourrait vraiment capter de la valeur ici via de grosses commandes de GPU et une pénurie d’offre
    Mon impression penche plutôt vers « non »

    • Je vois bien l’ironie. Mais à ce stade, si vous ne voulez pas confier tout le matériel et tout le logiciel de l’IA à une source unique, il faut commencer à avancer vers des alternatives
      En face, ils ont une avance considérable, et il y a clairement encore beaucoup de travail côté logiciel. Cela prendra du temps
  • Pourquoi l’application de prise de notes Obsidian fait-elle ça ?

    • Ce n’est pas le cas. Cette entreprise utilise simplement Obsidian Publish pour publier sa documentation