DeepGEMM : un kernel GEMM FP8 propre et efficace grâce au fine-grained scaling
(github.com/deepseek-ai)DeepGEMM
DeepGEMM est une bibliothèque pour la multiplication générale de matrices (GEMM) en FP8, prenant en charge le fine-grained scaling proposé dans DeepSeek-V3. Cette bibliothèque prend en charge les GEMM groupés classiques ainsi que ceux destinés au Mix-of-Experts (MoE), est écrite en CUDA et ne nécessite pas de compilation à l’installation. Elle prend en charge les Tensor Cores NVIDIA Hopper et utilise une accumulation en deux étapes sur les cœurs CUDA pour corriger les imprécisions de l’accumulation FP8 sur Tensor Core. Elle reprend certains concepts de CUTLASS et CuTe, tout en conservant sa simplicité grâce à une dépendance minimale aux templates et à l’algèbre. Avec une seule fonction kernel centrale d’environ 300 lignes de code, c’est une excellente ressource pour apprendre la multiplication de matrices FP8 sur Hopper et les techniques d’optimisation associées. Malgré sa conception légère, ses performances sont équivalentes ou supérieures à celles de bibliothèques réglées par des experts sur diverses formes de matrices.
Performances
Toutes les formes pouvant être utilisées pour l’inférence de DeepSeek-V3/R1 ont été testées sur H800 SXM5 avec NVCC 12.8. Tous les gains de performance sont calculés par comparaison avec une implémentation optimisée en interne basée sur CUTLASS 3.6. Certaines formes peuvent présenter des performances inférieures, et les PR d’optimisation sont les bienvenues.
GEMM classique (modèles denses)
- Les mesures de performance de DeepGEMM sur différentes tailles de matrices montrent jusqu’à 2,7× d’accélération pour certaines tailles.
GEMM groupé pour modèles MoE (layout contigu)
- Selon le nombre de groupes et la taille des matrices de chaque groupe, le gain peut atteindre 1,2×.
GEMM groupé pour modèles MoE (layout avec masque)
- L’utilisation d’un layout avec masque permet d’obtenir jusqu’à 1,2× d’accélération.
Démarrage rapide
Prérequis
- GPU d’architecture Hopper, prise en charge de
sm_90arequise - Python 3.8 ou version ultérieure
- CUDA 12.3 ou version ultérieure (12.8 ou plus recommandé pour les meilleures performances)
- PyTorch 2.1 ou version ultérieure
- CUTLASS 3.6 ou version ultérieure
Développement
- Description du processus de développement, incluant le clonage des sous-modules, la création de liens symboliques, la compilation JIT et le test de toutes les implémentations GEMM.
Installation
deep_gemmpeut être importé et utilisé dans un projet Python.
Interface
Points d’attention
- Cette bibliothèque ne contient que des kernels GEMM et ne prend en charge que le format NT. La transposition ou les autres opérations de cast FP8 doivent être implémentées séparément.
GEMM dense classique (non groupé)
- Fournit une fonction pour exécuter un GEMM FP8 de base non groupé.
GEMM groupé (layout contigu)
- Conçu pour les scénarios de modèles MoE dans lesquels les experts partagent la même forme.
GEMM groupé (layout avec masque)
- Pendant l’étape de décodage en inférence, un tenseur de masque est fourni pour ne calculer que les parties valides.
Utilitaires
- Fournit diverses fonctions utilitaires et variables d’environnement pour aider à l’optimisation des performances.
Optimisation
Spécialisation persistante des warps
- Suit la conception de CUTLASS en superposant les mouvements de données, les instructions MMA des Tensor Cores et la promotion sur les cœurs CUDA.
Fonctionnalité TMA de Hopper
- Utilise TMA pour accélérer les mouvements de données.
Optimisations détaillées communes
- Améliore les performances à l’aide de diverses techniques d’optimisation.
Scheduler de blocs unifié et optimisé
- Fournit un scheduler pour tous les kernels non groupés et groupés.
Conception entièrement JIT
- Améliore les performances grâce à une conception JIT ne nécessitant pas de compilation à l’installation.
Tailles de blocs non alignées
- Prend en charge des tailles de blocs non alignées afin de maximiser l’utilisation des SM pour certaines formes.
Interleaving FFMA SASS
- Modifie les instructions FFMA pour améliorer le parallélisme au niveau des warps et ainsi accroître les performances.
Remerciements
- DeepGEMM s’inspire du projet CUTLASS et exprime sa gratitude et son respect envers ses développeurs.
Licence
- Distribué sous licence MIT.
Aucun commentaire pour le moment.