Micro Diffusion - un petit modèle de diffusion pour apprendre
(github.com/Siwoo4985)Bonjour, il y a environ un mois, j’ai implémenté et publié un modèle de diffusion de texte from scratch. En parcourant récemment GeekNews, j’y ai repensé et je me décide enfin à le partager ici.
Pourquoi je l’ai créé
En voyant MicroGPT d’Andrej Karpathy, je me suis dit : « On peut donc expliquer le cœur de GPT avec un code aussi court. » Comme je voulais justement étudier aussi les modèles de diffusion, j’ai lancé ce projet pédagogique en me disant : « Ce serait intéressant de créer du code permettant de comprendre la diffusion de la même manière. »
AR vs Diffusion : quelle différence ?
À l’ère des LLM, la génération de texte se fait presque entièrement de manière autoregressive (AR). Autrement dit, on prédit les tokens un par un, de gauche à droite.
La diffusion discrète fait l’inverse. On considère toute la séquence d’un seul coup, puis on la reconstruit progressivement à partir du bruit (masquage).
Prenons le nom "emma" comme exemple :
Forward (entraînement - ajout de bruit) :
t=0 : e m m a ← original
t=25: e _ m a ← masquage partiel
t=50: _ _ m _ ← davantage masqué
t=100: _ _ _ _ ← complètement masqué
Reverse (génération - suppression du bruit) :
t=100: _ _ _ _ ← on part d’un état vide
t=75: _ m _ _ ← reconstruction à partir des positions les plus certaines
t=50: e m _ a
t=0 : e m m a ← terminé
Si l’AR revient à « écrire un mot lettre par lettre », la diffusion se rapproche davantage de « résoudre une grille de mots croisés ».
Structure de l’implémentation
Il existe trois versions, que vous pouvez choisir selon le niveau de difficulté.
train_minimal.py— MLP à 2 couches / NumPy uniquement (le plus simple)train_pure.py— MLP à 3 couches + skip connection / NumPy uniquementtrain.py— Transformer à 4 couches / PyTorch
Les trois versions partagent la même boucle de diffusion. Seule l’architecture du débruiteur change.
Les données d’entraînement sont constituées de 32 000 prénoms anglais, et le nombre de paramètres est de l’ordre de 170K à 239K.
Vous pouvez l’exécuter immédiatement :
pip install numpy # version minimale
python train_minimal.py
N’hésitez pas à faire des retours ou à proposer des PR !
Aucun commentaire pour le moment.