FlashAttention-3 : une Attention plus rapide et plus précise grâce à l’asynchronisme et à la basse précision
(together.ai)-
Importance de l’Attention
- L’Attention est la couche centrale de l’architecture Transformer et constitue un goulot d’étranglement dans les grands modèles de langage et les applications à long contexte.
- FlashAttention et FlashAttention-2 ont ouvert la voie à une approche qui accélère l’Attention sur GPU en minimisant les lectures/écritures mémoire.
- Cela a permis d’augmenter fortement la longueur de contexte des LLM.
-
Principales techniques de FlashAttention-3
- Exploitation de l’asynchronisme : utilisation de l’asynchronisme des Tensor Cores et du TMA pour chevaucher complètement calcul et déplacement de données.
- Calcul par blocs : exécution alternée des multiplications de matrices par blocs et des opérations de softmax.
- Traitement en basse précision : amélioration des performances grâce à la prise en charge de la basse précision FP8.
-
Gains de performances de FlashAttention-3
- Efficacité d’utilisation du GPU : exploite jusqu’à 75 % des performances maximales du GPU H100, avec une vitesse 1,5 à 2 fois supérieure à celle de la version précédente.
- Performances en basse précision : l’utilisation de FP8 augmente la vitesse de traitement et réduit l’usage mémoire.
- Traitement de longs contextes : l’accélération du mécanisme d’Attention permet de traiter plus efficacement des textes plus longs.
-
Résumé de FlashAttention
- FlashAttention réorganise le calcul de l’Attention et utilise le tiling ainsi que le recalcul pour augmenter fortement la vitesse tout en réduisant l’utilisation mémoire.
- Grâce au tiling, il charge des blocs d’entrée, exécute l’Attention sur ces blocs, puis met à jour la sortie correspondante.
- En n’écrivant pas la matrice intermédiaire d’Attention en mémoire, il réduit le volume des lectures/écritures mémoire.
-
Nouvelles fonctionnalités matérielles des GPU Hopper
- WGMMA : fournit un haut débit en exploitant les nouveaux Tensor Cores.
- TMA : unité matérielle qui accélère les transferts de données entre la mémoire globale et la mémoire partagée.
- Basse précision FP8 : double le débit des Tensor Cores grâce à l’utilisation de FP8.
-
Asynchronisme : chevauchement de GEMM et Softmax
- Pourquoi ce chevauchement est nécessaire : exécuter GEMM et softmax en parallèle pour maximiser les performances.
- Ordonnancement ping-pong : deux groupes de warps exécutent alternativement GEMM et softmax afin d’améliorer les performances.
- Chevauchement au sein d’un groupe de warps : exécution parallèle de GEMM et softmax au sein du même groupe de warps pour accroître le débit.
-
Basse précision : réduction des erreurs de quantification via un traitement incohérent
- Traitement incohérent : réduction des erreurs de quantification à l’aide de la transformation de Hadamard.
- Résultats expérimentaux : ce traitement incohérent réduit les erreurs de quantification d’un facteur 2,6.
-
Benchmarks d’Attention
- FP16 : environ 1,6 à 1,8 fois plus rapide que FlashAttention-2.
- FP8 : atteint jusqu’à 1,2 PFLOPS.
Synthèse de GN⁺
- FlashAttention-3 améliore fortement les performances du mécanisme d’Attention en exploitant les nouvelles fonctionnalités matérielles des GPU.
- Sa capacité à traiter efficacement de longs contextes permet de maximiser les performances des grands modèles de langage.
- Son intégration probable dans des frameworks majeurs comme PyTorch devrait avoir un impact important sur la recherche et les applications en IA.
- Parmi les projets offrant des fonctionnalités similaires figurent Triton et cuDNN.
1 commentaires
Commentaires Hacker News
Il semble que Tri Dao ait commencé à travailler sur FA3 en avril 2022
On se demande à quel point l’algorithme Flash Attention dépend du matériel
On se demande si les compilateurs pourront trouver d’eux-mêmes des optimisations comme FlashAttention
Ceux qui souhaitent un portage vers ROCm/AMD MI300x sont invités à se manifester
TMA (Tensor Memory Accelerator) est une unité matérielle qui accélère les transferts de données entre la mémoire globale et la mémoire partagée
FlashAttention-3 est optimisé pour les GPU Hopper (par exemple H100)
Il est mentionné que les fonctions d’activation comme sigmoid sont très lentes dans les LLM modernes
On se demande pourquoi Flash Attention est 5 fois plus lent avec masquage variable que sans
On se demande si FlashAttention peut remplacer l’opération d’attention dans les LLM
llama.cppa ajouté le support de Flash Attention, on se demande s’il s’agissait simplement d’utiliser les kernels CUDA fournis par Flash AttentionDu matériel coûteux est nécessaire