1 points par GN⁺ 2024-07-14 | Aucun commentaire pour le moment. | Partager sur WhatsApp
  • AlphaFold3 vise à prédire à partir de la seule séquence des complexes allant au-delà d’une protéine unique, en incluant protéines, acides nucléiques et petites molécules, ce qui rend la représentation d’entrée et la tokenisation bien plus complexes que dans AF2
  • Les entrées se divisent en représentations single/pair au niveau des tokens, représentations au niveau atomique, MSA et templates ; les acides aminés et nucléotides standards sont traités comme 1 token, tandis que les résidus non standards et les autres molécules sont traités à raison de 1 token par atome
  • Le trunk d’apprentissage des représentations améliore de façon itérative la représentation single s et la représentation pair z via le module template, le module MSA et Pairformer, avec pair-bias attention, opérations triangulaires et recycling
  • La prédiction de structure utilise un modèle de diffusion conditionnelle sur les coordonnées atomiques à la place de l’Invariant Point Attention d’AF2, et génère des mises à jour des coordonnées de tous les atomes via augmentation par rotation/déplacement et denoising
  • L’entraînement combine distogram, diffusion et confidence loss, et réapprend même les représentations unfolded dans les zones de faible confiance grâce à la cross-distillation exploitant les résultats d’AF2 et d’AF-Multimer

Périmètre d’entrée d’AlphaFold3 et pipeline global

  • L’objectif d’AlphaFold3 n’est pas seulement de prédire des séquences de protéines individuelles comme AF2 ou de traiter uniquement des complexes protéiques comme AF-Multimer, mais de prédire à partir de la seule séquence des structures où une protéine est liée, éventuellement, à d’autres protéines, à des acides nucléiques et à des petites molécules
  • Le sens de « token » varie selon le type d’entrée
    • Protéine : 1 acide aminé standard = 1 token
    • ADN/ARN : 1 nucléotide standard = 1 token
    • Acides aminés/nucléotides non standards : 1 atome = 1 token
    • Autres molécules : 1 atome = 1 token
  • Une protéine composée de 35 acides aminés standards peut en réalité contenir plus de 600 atomes, mais elle est représentée par 35 tokens, tandis qu’un ligand de 35 atomes est représenté par 35 tokens
  • Le modèle se compose globalement de trois étapes
    • Input Preparation : conversion de la séquence fournie par l’utilisateur et des séquences/structures apparentées retrouvées en tenseurs numériques
    • Representation Learning : mise à jour des représentations single et pair via plusieurs variantes d’attention
    • Structure Prediction : prédiction de la structure par diffusion conditionnelle
  • Les complexes protéiques sont principalement stockés dans deux représentations
    • single representation : représente tous les tokens du complexe eux-mêmes
    • pair representation : représente les relations entre toutes les paires de tokens, comme la distance ou les interactions latentes
  • Les principales dimensions de canal sont c_z=128, c_m=64, c_atom=128, c_atompair=16, c_token=768, c_s=384

Préparation des entrées : transformer la séquence en 6 tenseurs

  • L’entrée fournie par l’utilisateur est convertie en 6 tenseurs destinés au trunk du modèle
    • s : représentation single au niveau des tokens
    • z : représentation pair au niveau des tokens
    • q : représentation single au niveau atomique
    • p : représentation pair au niveau atomique
    • m : représentation MSA
    • t : représentation template
  • Recherche de MSA et de templates

    • AF3 recherche des séquences similaires pour les séquences protéiques et ARN, les assemble en MSA, et inclut les structures apparentées comme templates
    • Un MSA aligne des séquences protéiques similaires trouvées chez plusieurs espèces afin de fournir au modèle les motifs de conservation à certaines positions et les corrélations de variation entre positions différentes
    • Les structures connues de protéines similaires servent à estimer la structure de la protéine query, comme en homology modeling
    • La recherche n’inclut pas d’apprentissage et repose sur des méthodes basées sur les HMM
    • jackhmmer, HHBlits et nhmmer servent à interroger plusieurs bases de données de protéines et d’ARN, et hmmsearch est utilisé pour trouver des séquences similaires dans la Protein Data Bank
    • La taille du MSA est limitée à N_MSA < 2^14 en raison de la complexité de calcul
    • Pour chaque chaîne protéique, les structures de meilleure qualité sont sélectionnées, avec un échantillonnage d’au maximum 4 templates
    • Par rapport à AF-Multimer, le nouvel élément ajouté à la recherche est que les séquences d’ARN sont elles aussi incluses parmi les cibles
  • Mode de représentation des templates

    • À partir de la structure 3D d’un template, la distance euclidienne entre chaque paire de tokens est calculée
    • Pour les tokens comportant plusieurs atomes, un « center atom » représentatif est utilisé
      • Acide aminé : atome
      • Nucléotide standard : atome C1'
    • Les valeurs de distance ne sont pas conservées comme valeurs continues, mais discrétisées en distogram
      • 38 bins de 3.15Å à 50.75Å
      • 1 bin supplémentaire pour les distances supérieures
    • Le distogram inclut aussi les informations de chaîne, le fait que le token correspondant soit ou non resolved dans la structure cristalline, ainsi que des informations de distance locale à l’intérieur de chaque acide aminé
    • La matrice template est masquée pour ne voir que les distances à l’intérieur d’une même chaîne, et la sélection des templates ne cherche pas à obtenir d’informations d’interaction inter-chaînes

Représentations au niveau atomique et Atom Transformer

  • reference conformer et représentations au niveau atomique

    • Pour créer la représentation single au niveau atomique q, un reference conformer est calculé pour chaque acide aminé, nucléotide et ligand.
    • Un conformer est une disposition atomique 3D d’une molécule, générée en échantillonnant les rotations autour des liaisons simples.
    • Pour les acides aminés standards, on utilise des conformers de basse énergie obtenus par lookup, et pour les petites molécules, des conformers 3D sont générés avec RDKit’s ETKDGv3.
    • En combinant la position relative du conformer, la charge atomique, le numéro atomique, les identifiants, etc., on construit la représentation single au niveau atomique c.
    • c sert à initialiser la représentation pair au niveau atomique p, et un masque v est utilisé pour ne contenir que les distances interatomiques calculées dans le reference conformer.
    • q commence comme une copie de c, puis est mise à jour dans l’Atom Transformer.
  • Rôle de l’Atom Transformer

    • L’Atom Transformer est un module qui réalise une attention au niveau atomique et met à jour q à l’aide de p et de la représentation d’origine c.
    • c n’est pas mise à jour et sert plutôt de connexion résiduelle vers la représentation de départ.
    • La structure de base ressemble à celle d’un transformer, avec LayerNorm, attention et transition MLP, mais chaque étape est ajustée par des entrées supplémentaires c et p.
  • Adaptive LayerNorm

    • Adaptive LayerNorm, au lieu d’apprendre des gamma et beta fixes, génère gamma et beta à partir d’une entrée auxiliaire.
    • Dans l’Atom Transformer, la cible du rescaling est q, et les paramètres de rescaling sont prédits à partir de l’entrée auxiliaire c.
  • Attention with Pair Bias

    • L’attention au niveau atomique avec pair bias est une extension de la self-attention.
    • Query, key et value proviennent toutes de la représentation single q, mais après le produit scalaire query-key, on ajoute comme biais une projection linéaire de la représentation pair p.
    • L’information circule de la représentation pair vers q, mais à cette étape, p n’est pas mise à jour à partir de l’information de q.
    • Un gate supplémentaire, obtenu en faisant passer une projection dans une sigmoid, est multiplié au résultat de l’attention et contrôle quelles informations sont conservées dans le flux résiduel.
    • Comme le nombre d’atomes peut être bien plus grand que le nombre de tokens, on utilise une Sequence-local atom attention plutôt qu’une full attention.
    • Des groupes locaux de 32 atomes peuvent porter leur attention sur 128 autres atomes.
  • Conditioned Gating et Transition

    • Conditioned Gating applique aux données un gate généré à partir de la matrice single atomique d’origine c.
    • Conditioned Transition correspond au MLP du transformer et est dite conditionnée parce qu’Adaptive LayerNorm et Conditional Gating dépendent de c.
    • AF3 utilise SwiGLU au lieu de ReLU dans le bloc de transition.
    • La transition basée sur ReLU dans AF2 suit une structure up-projection ×4, ReLU, down-projection.
    • Le SwiGLU d’AF3 applique une non-linéarité swish à l’une des deux up-projections, multiplie les résultats, puis effectue une down-projection.

Agrégation des représentations atomiques en représentations token

  • Comme l’étape d’apprentissage des représentations fonctionne ensuite au niveau token, les représentations atomiques sont agrégées en représentations token.
  • La représentation au niveau atomique est d’abord projetée dans une dimension plus grande, puis on prend la moyenne des atomes appartenant au même token.
  • Cette agrégation par moyenne s’applique quand plusieurs atomes sont reliés à un token, comme pour les acides aminés standards et les nucléotides, tandis que les entrées avec un token par atome sont conservées telles quelles.
  • Des statistiques issues de la MSA sont aussi combinées à l’entrée single au niveau token.
    • type d’acide aminé
    • distribution des acides aminés de la MSA à cette position
    • moyenne des délétions pour ce token
  • Pour les tokens sans MSA, comme les atomes de ligand, ces valeurs sont mises à 0.
  • Le s_inputs ainsi construit devient s_init après projection, puis est mis à jour dans l’étape d’apprentissage des représentations.
  • La représentation pair z_init est un tenseur tridimensionnel qui stocke les relations pour chaque paire de tokens, et chaque z_i,j est un vecteur de dimension c_z=128.
  • L’initialisation de z_i,j additionne des projections de s_i et s_j, un encodage positionnel relatif et les informations de liaison entre tokens spécifiées par l’utilisateur.

Apprentissage des représentations : Template, MSA, Pairformer

  • L’apprentissage des représentations constitue le trunk, qui représente la majeure partie du calcul du modèle, et son objectif est d’améliorer la représentation single au niveau token s et la représentation pair z.
  • La single sequence representation ne désigne pas seulement une séquence protéique unique, mais une séquence obtenue en concaténant tous les atomes ou tokens de la structure.
  • Module Template

    • Chaque template passe par une projection linéaire puis est additionné à une projection linéaire de la représentation pair z.
    • La matrice combinée passe dans un Pairformer Stack.
    • Les résultats de plusieurs templates sont moyennés, puis repassent dans une couche linéaire.
    • La dernière couche linéaire utilise ReLU, l’un des rares endroits où AF3 emploie ReLU comme non-linéarité.
  • Module MSA

    • Le module MSA est très proche de l’Evoformer d’AF2 et améliore simultanément la représentation MSA m et la représentation pair z.
    • Au lieu d’utiliser toutes les lignes de la MSA, il en prend un sous-échantillon, puis ajoute à la MSA une projection de la représentation single.
    • Outer Product Mean est l’opération qui injecte l’information de la MSA dans la représentation pair.
      • pour chaque indice de token i,j, on calcule le produit extérieur de m_s,i et m_s,j pour toutes les séquences évolutives
      • on fait ensuite la moyenne sur l’ensemble des séquences, on aplatit le résultat puis on le projette pour l’ajouter à z_i,j
      • c’est le seul point du modèle où l’information est partagée entre les séquences évolutives
    • Row-wise gated self-attention using only pair bias met à jour la MSA en utilisant la représentation pair.
      • au lieu de créer les scores d’attention avec query et key, on projette la représentation pair z en matrice pour l’utiliser comme score d’attention entre tokens
      • comme cela s’applique indépendamment à chaque ligne de la MSA, aucune information n’est partagée entre les séquences évolutives à cette étape
    • Le module MSA se termine par une triangle update et une triangle attention, qui remettent à jour la représentation pair.

Pairformer et opérations triangulaires

  • Après avoir mis à jour z avec Template et MSA, template et MSA ne sont plus utilisés, et seuls s et z sont donnés en entrée à Pairformer
  • Pairformer génère les s_trunk et z_trunk finaux en répétant 48 blocks
  • Intuition des opérations triangulaires

    • La triangle update et la triangle attention sont des structures conçues pour intégrer au modèle l’intuition de l’inégalité triangulaire
    • Même si z_i,j du tenseur pair n’est pas lui-même une distance physique, il contient la relation entre les tokens i et j, donc les trois relations i-j, j-k et i-k sont mises à jour de façon à rester cohérentes entre elles
    • L’inégalité triangulaire n’est pas imposée directement dans le modèle ; elle est induite par une mise à jour de z_i,j qui considère tous les triplets (i,j,k)
    • z peut être vu comme une directed adjacency matrix, ce qui permet de traiter séparément les directions des outgoing edges et des incoming edges
  • Triangle Updates

    • Dans l’outgoing update, chaque z_i,j est mis à jour à l’aide d’un autre élément z_i,k de la même row et d’une troisième arête z_j,k
    • Dans l’implémentation, on crée trois projections a, b, g de z, puis on additionne sur k la multiplication élément par élément entre la row i et la row j, avant d’appliquer la gate g
    • L’incoming update inverse row et column : z_i,j est alors mis à jour via d’autres éléments z_k,j de la même column et z_k,i
  • Triangle Attention

    • La triangle attention est une forme d’axial attention, où l’on applique indépendamment l’attention aux rows et aux columns d’une matrice 2D, enrichie par le principe triangulaire
    • Dans le cas « starting node », z_i,j et z_i,k sont comparés via query-key, avec z_j,k ajouté comme bias
    • Dans le cas « ending node », l’opération se fait selon les columns, et le score d’attention entre z_i,j et z_k,i reçoit z_k,j comme bias
  • Single Attention with Pair Bias

    • Après l’étape triangulaire et le transition block, la single representation s est mise à jour par single attention with pair bias en utilisant la pair representation z mise à jour
    • Comme cette opération agit au niveau token, elle utilise une full attention plutôt que la block-wise sparse attention employée au niveau atomique

Prédiction de structure : débruiter les coordonnées atomiques par diffusion

  • Fonctionnement de base du modèle de diffusion

    • AF3 effectue la prédiction finale de structure via une diffusion au niveau atomique
    • Un diffusion model ajoute progressivement du bruit aléatoire aux données réelles, puis entraîne le modèle à prédire quel bruit a été ajouté
    • En inference, on part d’un bruit entièrement aléatoire, puis le modèle retire à chaque étape le bruit qu’il a prédit afin de produire un datapoint débruité
    • La diffusion conditionnelle prend en entrée la génération bruitée actuelle, la représentation du timestep courant et un vecteur de condition, pour produire un résultat conforme à cette condition
    • Dans AF3, l’objet à débruiter est la matrice x contenant les coordonnées x,y,z de tous les atomes
  • Augmentation par rotation et translation au lieu de l’IPA d’AF2

    • AF3 n’utilise pas l’Invariant Point Attention d’AF2 et applique à chaque timestep une rotation et une translation aléatoires à l’ensemble du complexe en cours de prédiction
    • Cette augmentation apprend au modèle que toute rotation ou translation représente la même structure valide, dans une approche plus simple que l’IPA d’AF2
    • La rotation est appliquée autour de la moyenne des coordonnées de tous les atomes de la génération courante, et la translation est échantillonnée selon une gaussienne N(0,1) sur chaque dimension
    • Un léger bruit est aussi ajouté aux coordonnées pour favoriser des générations plus variées
    • En inference, plusieurs générations peuvent être notées par le confidence head, puis celle ayant le meilleur score peut être renvoyée
  • Les quatre étapes du Diffusion Module

    • Chaque étape de débruitage utilise plusieurs représentations de conditionnement
      • sorties du trunk s_trunk, z_trunk
      • représentations initiales s_inputs, c_inputs produites par l’input embedder
    • Le processus de diffusion alterne entre l’espace des tokens et celui des atomes et se compose de quatre étapes
        1. préparation d’un tenseur de conditionnement au niveau token
        1. préparation d’un tenseur de conditionnement au niveau atomique, application de l’Atom Transformer, puis agrégation au niveau token
        1. application d’une attention au niveau token
        1. prédiction d’une mise à jour du bruit pour chaque atome via une attention au niveau atomique
    • Au niveau token, le conditionnement combine z_trunk avec le relative positional encoding puis le fait passer par un transition block
    • La single representation combine s_inputs et s_trunk, puis y ajoute un Fourier embedding dépendant du timestep de diffusion
    • À l’étape atomique, les c et p initiaux sont mis à jour avec la représentation courante au niveau token, et les coordonnées actuelles x sont mises à l’échelle par la data variance pour former la coordinate sans dimension r
    • À la dernière étape atomique, une linear layer projette q dans R^3 pour produire r_update, la mise à jour des coordonnées de tous les atomes
    • Cette mise à jour est remise à l’échelle en x_update en tenant compte de la data variance et du noise schedule, puis appliquée aux coordonnées courantes x_l

Fonction de perte et confidence head

  • La loss totale est une somme pondérée de trois termes

L_loss = L_distogram * α_distogram + L_diffusion * α_diffusion + L_confidence * α_confidence

  • L_distogram

    • L_distogram évalue la précision du distogramme prédit au niveau des tokens
    • Lors de la création des coordonnées des tokens à partir des coordonnées atomiques, les coordonnées de l’atome central de chaque token sont utilisées
    • La distance du distogramme est traitée comme une valeur catégorielle, et le distogramme prédit est comparé au distogramme réel via une entropie croisée
  • L_diffusion

    • L_diffusion est une somme pondérée de plusieurs termes appliqués aux positions atomiques
    • L_MSE calcule la mean squared error entre positions pour tous les atomes, et non seulement pour les atomes centraux ; les atomes d’ADN, d’ARN et des ligands sont surpondérés
    • L_bond est un terme MSE supplémentaire visant à améliorer la précision des longueurs de liaison pour les paires d’atomes incluses dans les liaisons protéine-ligand
    • Au début de l’entraînement, α_bond=0, ce terme est donc introduit plus tard
    • L_smooth_LDDT est une loss qui rend la précision locale des distances lisse et différentiable
      • quatre seuils sont utilisés : 4Å, 2Å, 1Å et 0.5Å
      • les paires d’atomes de nucléotides sont ignorées au-delà de 30Å
      • les paires d’atomes de protéines ou de ligands sont ignorées au-delà de 15Å
  • L_confidence

    • L_confidence n’améliore pas directement la précision structurelle, mais entraîne le modèle à estimer la précision de ses propres prédictions
    • Elle se compose de losses correspondant à quatre métriques de confiance
      • pLDDT : précision locale des distances pour les atomes proches
      • PAE : predicted alignment error pour une paire de tokens
      • PDE : predicted distance error entre deux tokens
      • experimentally resolved prediction : prédiction du fait que chaque atome soit résolu ou non dans la structure expérimentale
    • Même si la structure prédite est imprécise et que le PAE est élevé, cette loss PAE peut rester faible si le modèle prédit aussi un PAE élevé
    • La prédiction de confiance est générée à une étape intermédiaire de la diffusion
    • Le gradient de la confidence loss met à jour uniquement la tête de prédiction de confiance, sans affecter le reste du modèle

Techniques d’apprentissage supplémentaires et optimisations

  • Recycling

    • AF3 utilise le weight recycling, comme AF2
    • Au lieu de rendre le modèle plus profond, il réutilise plusieurs fois les mêmes poids pour améliorer progressivement les représentations
    • La diffusion utilise elle aussi les informations de timestep pendant l’inférence et réemploie les mêmes poids à chaque timestep, ce qui intègre intrinsèquement du recycling
  • Cross-distillation

    • AF3 utilise non seulement les données d’entraînement synthétiques qu’il a produites lui-même, mais aussi celles générées par AF2 et AF-Multimer
    • Après le passage à une génération fondée sur la diffusion, un problème est apparu : la forme en « spaghetti », qui permettait dans AF2 de distinguer visuellement les régions peu fiables et désordonnées, a disparu
    • En incluant dans les données d’entraînement d’AF3 les générations d’AF2 et d’AF-Multimer, AF3 apprend à produire des régions dépliées dans les zones où AF2 n’était pas sûr de lui
    • Dans le jeu de données de distillation, les acides nucléiques et les petites molécules qu’AF2 et AF-Multimer ne peuvent pas traiter sont supprimés
    • Une fois la structure prédite par le modèle précédent alignée sur l’original, les molécules supprimées sont réajoutées
    • Si les molécules réajoutées créent des clashes atomiques, la structure entière est exclue, afin d’éviter que le modèle n’apprenne à tolérer ces clashes
  • Cropping et étapes d’entraînement

    • Le modèle lui-même n’impose pas de limite explicite sur la longueur des séquences en entrée, mais plusieurs opérations croissent en N_tokens^3, ce qui augmente les besoins en mémoire et en calcul
    • Pour gagner en efficacité, les protéines sont soumises à un random crop
    • Comme il faut modéliser les interactions entre plusieurs chaînes, le crop doit inclure les chaînes ensemble
    • Trois méthodes de cropping sont utilisées
      • contiguous cropping : sélection d’une séquence continue d’acides aminés dans chaque chaîne
      • spatial cropping : sélection d’acides aminés selon la distance à un atome de référence
      • spatial interface cropping : sélection selon la distance aux atomes de l’interface de liaison
    • Un modèle entraîné avec un random crop de 384 peut aussi être appliqué à des séquences plus longues, mais un fine-tuning répété avec de plus grandes longueurs de séquence est utilisé pour améliorer sa capacité à traiter les longues séquences
  • Clashing et taille de batch

    • La loss d’AF3 n’inclut pas de pénalité de clash pour les atomes qui se chevauchent
    • En théorie, le structure module fondé sur la diffusion peut prédire deux atomes à la même position, mais après entraînement ce problème reste limité
    • Une pénalité de clashing est utilisée pour le ranking des structures générées
    • Le processus de diffusion paraît complexe, mais son coût de calcul est inférieur à celui du trunk
    • Pour améliorer l’efficacité de l’entraînement, la taille de batch est augmentée après le trunk
    • Chaque structure d’entrée passe une fois par l’embedding et le trunk, puis 48 structures indépendantes augmentées par data augmentation sont entraînées en parallèle

Conception d’AF3 du point de vue du ML

  • Une structure similaire à la Retrieval-Augmented Generation

    • La recherche de MSA et de templates dans AF3 a une nature similaire à la RAG des modèles de langage
    • Dans le domaine d’AlphaFold, l’usage de templates structurels existait depuis bien avant le terme RAG, sous le nom de homology modeling
    • AF3 a réduit par rapport à AF2 l’importance du traitement des MSA, mais conserve malgré tout les MSA et les templates
    • Certains modèles de prédiction de protéines comme ESMFold suppriment le retrieval et utilisent une inférence entièrement paramétrique
  • Pair-Bias Attention

    • Pair-Bias Attention, un composant majeur d’AF2, est utilisé plus largement dans AF3
    • Les query, key et value proviennent de la même source, mais un terme de biais issu d’une autre source est ajouté à la carte d’attention
    • Il s’agit d’un mode de partage d’information plus léger qu’une cross-attention complète
    • Comme la représentation par paires ressemble naturellement à une carte d’attention, cette structure peut être particulièrement adaptée à la modélisation des protéines
  • Réduction du self-supervised training

    • Les modèles de la famille ESM ont montré leur force en remplaçant les embeddings MSA par du pre-training self-supervised
    • AF2 comportait une tâche supplémentaire consistant à prédire les masked tokens des MSA, mais elle a été supprimée dans AF3
    • AF3 a réduit le compute consacré au traitement des MSA et n’utilise pas de pre-training de language modeling self-supervised sur les MSA
    • Parmi les explications possibles : un pre-training massif était inefficace du point de vue du compute, un petit module MSA était meilleur que des embeddings préentraînés, ou bien la combinaison entre des embeddings préentraînés et une structure hybride atom-token mêlant acides aminés, ADN/ARN et ligands ne fonctionnait pas bien
  • Mélange de classification et de régression

    • AF3 utilise, comme AF2, à la fois une loss MSE et une loss de classification par bins
    • Une caractéristique de la loss de classification est que se tromper d’un seul bin du distogram ne donne pas plus de crédit que se tromper de très loin
    • La justification de ce choix de conception n’est pas claire, mais il est possible que les gradients aient été plus stables qu’avec plusieurs losses MSE
  • Des éléments qui rappellent une architecture récurrente

    • AF3 comporte de nombreux éléments qui font davantage penser à un réseau récurrent qu’à un transformer classique
    • Le gating contrôle le flux d’information dans le residual stream, d’une manière proche des gates d’un LSTM ou d’un GRU
    • Le recycling et la diffusion appliquent de manière répétée les mêmes poids afin d’améliorer progressivement la prédiction
    • À l’image de l’adaptive compute time, ces mises à jour itératives sont liées à une structure capable d’appliquer davantage de traitement aux entrées difficiles
    • Les ablations d’AF2 ont montré l’importance du recycling, mais l’importance du gating a été beaucoup moins discutée

Aucun commentaire pour le moment.

Aucun commentaire pour le moment.