2 points par GN⁺ 2024-09-24 | 1 commentaires | Partager sur WhatsApp

Felafax BlogTune Llama3 405B on AMD MI300x (notre parcours)

Introduction

  • À mesure que les modèles open source grossissent, le besoin d’une infrastructure puissante capable de gérer l’entraînement d’IA à grande échelle augmente
  • Felafax a ajusté finement le modèle LLaMA 3.1 405B sur des GPU AMD, démontrant l’efficacité du matériel AMD
  • L’ensemble du travail a été publié en open source sur GitHub
  • Les GPU AMD MI300X offrent de hautes performances par rapport au matériel IA de NVIDIA
  • Le projet a été rendu possible grâce au soutien de TensorWave

Qu’est-ce que JAX et pourquoi l’avoir choisi

  • JAX est une puissante bibliothèque de machine learning qui combine une API similaire à NumPy, la différenciation automatique et le compilateur XLA de Google
  • Il fournit d’excellentes API pour le parallélisme de modèle, ce qui le rend idéal pour l’entraînement de grands modèles

Avantages de JAX

  • Fonctions pures : JAX encourage l’écriture de fonctions pures, ce qui facilite la composition, le débogage et la lecture du code
  • Parallélisme avancé : l’API JIT flexible de JAX prend en charge un parallélisme avancé des données et des modèles, essentiel pour l’entraînement à grande échelle
  • Base de code propre : la philosophie de conception de JAX encourage l’écriture d’un code portable entre différentes plateformes matérielles

Pourquoi JAX se distingue sur le matériel non NVIDIA

  • Approche indépendante du matériel : JAX exploite le compilateur XLA pour compiler les calculs vers une représentation intermédiaire indépendante du matériel
  • Optimisation indépendante de la plateforme : le compilateur XLA effectue les optimisations indépendamment du matériel
  • Portabilité simple : avec JAX, passer de NVIDIA à AMD nécessite très peu de modifications de code

Configuration de JAX sur GPU AMD

  • L’image Docker a été récupérée, le conteneur lancé, puis l’installation vérifiée
  • Le modèle LLaMA 405B a été entraîné à l’aide de 8 GPU AMD MI300x

Entraîner LLaMA 405B : performances et scalabilité

  • Le modèle LLaMA 405B a été entraîné sur des GPU AMD avec JAX
  • Avec l’ajustement fin LoRA, les poids du modèle et les paramètres LoRA ont été réglés avec une précision bfloat16
  • Taille du modèle : environ 800 Go de VRAM
  • Poids LoRA et état de l’optimiseur : environ 400 Go de VRAM
  • Utilisation totale de VRAM : environ 1200 Go
  • Vitesse d’entraînement : environ 35 tokens par seconde
  • Efficacité mémoire : environ 70 % maintenus
  • Scalabilité : avec JAX, la montée en charge sur 8 GPU est presque linéaire

Notre configuration d’entraînement

  • LLaMA 3.1 a été converti de PyTorch vers JAX
  • Le modèle a été distribué efficacement via le chargement du modèle et le sharding des paramètres

Sharding des paramètres dans JAX

  • La fonctionnalité de device mesh de JAX a été utilisée pour répartir efficacement le modèle sur 8 GPU AMD
  • Des règles de sharding des paramètres ont été définies afin de shard chaque dimension de tenseur selon les axes du mesh

Implémentation de l’entraînement LoRA

  • LoRA réduit le nombre de paramètres entraînables en décomposant les mises à jour de poids en matrices de faible rang
  • Une couche LoRADense a été implémentée pour inclure les paramètres LoRA
  • Les paramètres LoRA ont été distribués efficacement afin d’optimiser l’usage mémoire et l’efficacité de calcul

Conclusion

  • L’expérience d’ajustement fin du modèle LLaMA 3.1 405B avec des GPU AMD et JAX a été très positive
  • Les puissantes capacités de parallélisme de JAX et son approche indépendante du matériel ont permis de distribuer efficacement le modèle
  • Cela démontre que les GPU AMD constituent une alternative solide pour l’entraînement d’IA à grande échelle
  • Le code complet peut être consulté et exécuté directement depuis le dépôt GitHub

Résumé de GN⁺

  • Cet article explique comment entraîner efficacement de grands modèles d’IA avec des GPU AMD et JAX
  • Il souligne que le matériel AMD constitue une alternative rentable à NVIDIA
  • L’approche indépendante du matériel de JAX améliore la portabilité du code et facilite la maintenance
  • Il fournit des informations utiles et du code pratique à celles et ceux qui s’intéressent à l’entraînement de grands modèles
  • Parmi les projets aux fonctionnalités similaires figurent CUDA de NVIDIA et PyTorch

1 commentaires

 
GN⁺ 2024-09-24
Avis Hacker News
  • Partage des résultats du fine-tuning du modèle Llama3.1 405B sur 8 GPU AMD MI300x avec JAX

    • D’excellentes performances ont été obtenues grâce à l’API avancée de sharding de JAX
    • Liens vers l’article de blog et le code open source fournis : lien GitHub
    • Il s’agit d’une startup qui construit une infrastructure IA pour fine-tuner et servir des LLM sur TPU, AMD et Trainium plutôt que sur du matériel NVIDIA
    • Ils estiment que beaucoup d’entreprises essaient de faire fonctionner PyTorch sur des GPU AMD, mais que c’est une voie difficile
    • PyTorch est profondément lié à l’écosystème NVIDIA, ce qui nécessite de nombreuses adaptations pour fonctionner sur du matériel non NVIDIA
    • Ils pensent que JAX est mieux adapté au matériel non NVIDIA
    • Avec JAX, le code des modèles de ML est compilé en graphes HLO indépendants du matériel, puis le compilateur XLA applique des optimisations spécifiques au matériel
    • Le même code JAX peut s’exécuter sur les TPU de Google et sur des GPU AMD sans modification
    • La stratégie de l’entreprise consiste à porter les modèles vers JAX et à exploiter les kernels XLA pour extraire les performances maximales sur des backends non NVIDIA
    • Ils ont d’abord porté Llama 3.1 de PyTorch vers JAX, et désormais le même modèle JAX fonctionne bien sur TPU et GPU AMD
    • Ils aimeraient avoir des retours sur leur vision et sur le dépôt
  • Suggestion d’explorer des moyens de surmonter les contraintes mémoire et d’exécuter une version compilée en JIT

    • Cela pourrait apporter des gains de performance supplémentaires
  • Partage d’expérience sur les GPU AMD et le support ROCm

    • Il y a un an, quelqu’un a essayé les GPU AMD et le support ROCm, mais a eu le sentiment qu’AMD était encore loin de rattraper NVIDIA
    • Le choix de JAX est une approche intéressante, mais il se demande quelles difficultés ils ont rencontrées en s’éloignant de PyTorch
  • Partage d’expérience sur des expérimentations côté inférence avec le modèle 405B

    • Il pense que torch.cuda n’est pas si problématique
    • Il estime que ce n’est qu’une question de nom, puisque la version AMD de PyTorch traduit cela
    • Utiliser le conteneur rocm:pytorch est aussi simple qu’utiliser le conteneur rocm:jax
    • Il souligne que peu de données de performance ont été publiées
    • Il s’interroge sur les chiffres de MFU (Model FLOPs Utilization)
  • Question sur l’absence de données de performance

    • Doute sur la possibilité d’extraire de la valeur de grosses commandes de GPU AMD
    • Il en retire l’impression que la réponse est « non »
  • Interrogation sur la raison pour laquelle Obsidian (application de prise de notes) ferait cela

    • Au départ, il pensait qu’il s’agissait d’un billet publié par Obsidian
    • Il s’interroge sur le fait qu’on ne distingue toujours pas GitHub.com et GitHub.io
  • Demande à @dang d’inclure le nom d’utilisateur dans l’URL

    • Ce post concerne un blog créé par un utilisateur, et non Obsidian lui-même