Comment faire passer votre modèle à l’échelle : une perspective systèmes sur les LLM sur TPU
(jax-ml.github.io)- Optimiser les performances du deep learning à grande échelle peut sembler relever de « l’alchimie », mais en pratique, des principes simples et compréhensibles permettent d’améliorer l’efficacité des modèles
- Des principes relativement simples s’appliquent partout, d’un seul accélérateur à des dizaines de milliers, et les comprendre permet notamment de :
- estimer grossièrement à quel point chaque partie du modèle se rapproche de son optimum théorique
- disposer de bases pour choisir entre différentes techniques de parallélisation selon l’échelle
- estimer le coût et le temps nécessaires à l’entraînement et à l’exécution de grands modèles Transformer
- concevoir des algorithmes qui exploitent les caractéristiques d’un matériel donné
- concevoir le matériel en comprenant clairement les limites des performances algorithmiques actuelles
- Connaissances préalables requises
- Une compréhension de base des LLM et de l’architecture Transformer est nécessaire
- Une compréhension du fonctionnement à grande échelle n’est pas indispensable
- Des bases sur l’entraînement des LLM et une expérience avec JAX sont un plus
- Il est recommandé de consulter un billet de blog sur l’architecture Transformer et des slides sur le passage à l’échelle des LLM avec JAX
- Objectifs
- Développer la capacité à estimer comment paralléliser un modèle sur le matériel disponible
- Développer la capacité à calculer approximativement le temps et le coût de l’entraînement et de l’inférence
Pourquoi s’y intéresser
- Il y a encore 3 à 4 ans, la plupart des chercheurs en ML n’avaient pas besoin de bien connaître ces optimisations à grande échelle
- Aujourd’hui, même les modèles « petits » fonctionnent près des limites du matériel, ce qui rend indispensable la compréhension d’un travail efficace à grande échelle
- L’histoire du ML peut être vue comme une évolution croisée entre innovations systèmes et améliorations logicielles
- Comme les modèles Transformer récents exploitent le matériel jusqu’à ses limites, ne pas comprendre l’efficacité des modèles augmente fortement le risque qu’une nouvelle architecture ou recherche échoue en conditions réelles
- Même si l’on obtient 20 % de gain sur un benchmark, si l’efficacité matérielle chute de 20 %, l’intérêt pratique reste faible
- L’objectif central du passage à l’échelle des modèles est d’obtenir une augmentation linéaire du débit quand on augmente le nombre de puces (accélérateurs)
- On parle alors de « strong scaling »
- Ajouter des puces réduit le temps de calcul, mais introduit un coût de communication entre puces
- Si la communication prend plus de temps que le calcul, on entre dans un régime « communication bound », et le strong scaling devient impossible
- Si l’on comprend suffisamment bien le matériel pour prévoir où ces goulets d’étranglement apparaîtront, on peut concevoir ou restructurer le modèle pour les éviter
- L’objectif de ce livre est d’expliquer comment fonctionne le matériel TPU (et GPU), et comment l’architecture Transformer a évolué pour bien fonctionner sur le matériel actuel
- Il devrait être utile autant aux chercheurs qui conçoivent de nouvelles architectures qu’aux ingénieurs qui cherchent à exécuter rapidement les LLM de la génération actuelle
Vue d’ensemble
- Cet article est organisé comme suit
- La section 1 explique, via l’analyse roofline, les facteurs qui déterminent les limites de performance d’un modèle (communication, calcul, mémoire)
- Les sections 2 et 3 traitent de la structure interne des TPU et GPU ainsi que des modes d’interconnexion entre puces
- Elles répondent notamment aux questions suivantes
- À quelle vitesse une multiplication de matrices d’une taille donnée peut-elle théoriquement être exécutée ?
- À partir de quel point un calcul devient-il limité par la bande passante mémoire ou par la bande passante de communication ?
- Comment un cluster TPU est-il connecté, et combien de temps faut-il approximativement pour déplacer des données d’une puce à une autre ?
- Comment multiplier efficacement des matrices distribuées ?
- Elles répondent notamment aux questions suivantes
- La section 4 détaille les formules de l’architecture Transformer (tailles de matrices, nombre de paramètres, FLOPs)
- Les sections 5 et 7 constituent le cœur du contenu et présentent différentes façons de paralléliser un modèle sur plusieurs puces
- Data parallel, Tensor parallel, Pipeline parallel, Expert parallel
- Des techniques d’économie mémoire comme ZeRO, Rematerialisation, Host offload, Gradient accumulation sont également abordées
- Les sections 6 et 8 prennent l’exemple du modèle LLaMA-3 entraîné et utilisé en inférence sur TPU, avec des estimations concrètes de coût, de temps et de configuration
- Enfin, les sections 9 et 10 expliquent comment profiler un modèle dans JAX, le déboguer et appliquer concrètement le traitement parallèle
Détails : résumé des principales sections du livre
-
Partie 1 : Preliminaries
-
Section 1 : Introduction rapide à l’analyse Roofline
- Les trois facteurs qui contraignent un algorithme : calcul, communication, mémoire
- Comment en déduire une borne supérieure de la vitesse de calcul
-
Section 2 : Une manière de voir les TPU
- Comment un TPU effectue les calculs
- Ce qu’est une architecture en systolic array
- Une compréhension de base de la manière dont les TPU fournissent bande passante mémoire et bande passante de communication
-
Section 3 : Matrices distribuées et multiplication distribuée
- Techniques pour répartir le stockage des paramètres du modèle sur plusieurs puces (sharding)
- Manières de traiter la communication et les goulets d’étranglement lors d’opérations sur des matrices distribuées
-
-
Partie 2 : Transformers
-
Section 4 : Récapitulatif des formules Transformer nécessaires
- La forme concrète des multiplications de matrices dans un Transformer
- Comment calculer le nombre de paramètres, les FLOPs, la taille du cache KV, etc.
- Comprendre dans quelle mesure l’opération d’attention demande plus de calcul que les blocs feed-forward
-
Section 5 : Stratégies de parallélisation pour l’entraînement des Transformer
- Présentation des techniques Data parallel, Tensor parallel, Pipeline parallel, Expert parallel
- Solutions d’économie mémoire comme ZeRO(FSDP), Rematerialisation, Gradient accumulation, Host offload
- Cadre conceptuel pour configurer la parallélisation selon la taille du modèle et le nombre de puces
-
Section 6 : Application de l’entraînement de LLaMA 3 sur TPU
- Estimation du temps et du coût en supposant l’entraînement d’un modèle LLaMA 3 dans un environnement TPU réel
- Exemples concrets de taille de batch, mode de parallélisation, usage mémoire, etc.
-
Section 7 : Tout sur l’inférence Transformer
- En inférence, la latence apparaît comme un nouveau facteur clé
- Les problèmes de mémoire et de communication liés notamment au cache KV
- Discussion sur la façon de répartir et de relier plusieurs puces pour servir le modèle
-
Section 8 : Application du serving de LLaMA 3 sur TPU
- Analyse approximative des compromis entre coût, latence et débit en supposant un serving de LLaMA 3 sur TPU v5e
-
-
Partie 3 : Practical Tutorials
-
Section 9 : Comment profiler du code TPU
- Comprendre la stack JAX+XLA
- Identifier les problèmes réels de dégradation des performances et leurs solutions
- Utiliser le profiler JAX/TensorBoard
-
Section 10 : Programmer les TPU avec JAX
- Utilisation des API de parallélisation de JAX (primitives)
- Assimiler les concepts du calcul parallèle à travers des exemples et des exercices
-
Section 11 : Conclusion et ressources supplémentaires
- Lectures complémentaires sur les TPU et les LLM
- Brève conclusion de l’ensemble, avec une mention des perspectives futures
-
1 commentaires
Commentaires sur Hacker News