- Les modèles de diffusion servent au-delà de la génération d’images, pour des problèmes nécessitant l’échantillonnage de distributions multimodales comme l’audio, la vidéo, la 3D, la conception de protéines ou la planification de trajectoires robotiques, et ce tutoriel relie apprentissage et échantillonnage sous l’angle de l’optimisation
- Le processus d’apprentissage construit des données bruitées via (x_\sigma=x_0+\sigma\epsilon), puis minimise l’erreur quadratique moyenne afin qu’un réseau de neurones (\epsilon_\theta(x,\sigma)) prédise la direction du bruit
- Le denoiser appris s’interprète comme une projection approchée sur l’ensemble de données (\mathcal{K}), et le denoiser idéal est relié au gradient de la fonction de distance au carré lissée par (\sigma)
- L’échantillonnage DDIM peut être vu comme une descente de gradient approchée sur (f(x)=\frac{1}{2}\mathrm{dist}_{\mathcal{K}}(x)^2), et le planning (\sigma_t) détermine le nombre d’itérations ainsi que le coût des évaluations du denoiser
- En combinant mise à jour par estimation de gradient et ajout de bruit, DDIM, DDPM et le sampler amélioré des auteurs peuvent être traités ensemble via les paramètres
gametmu, avant de passer à des exemples sur modèle jouet et latent diffusion
Modèles de diffusion vus sous l’angle de l’optimisation
- Les modèles de diffusion excellent pour générer des échantillons à partir de distributions multimodales, et s’appliquent non seulement à des outils texte-vers-image comme Stable Diffusion, mais aussi à l’audio, la vidéo, la génération 3D, la conception de protéines et la planification de trajectoires robotiques
- La base théorique du tutoriel est l’interprétation en optimisation de l’article ICML 2024 et d’un article connexe
- L’implémentation s’appuie principalement sur
smalldiffusion, et le code présenté dans le corps de l’article est simplifié à des fins pédagogiques par rapport à la bibliothèque d’origine
Apprentissage : prédire la direction du bruit
- Un modèle de diffusion apprend l’ensemble de données (\mathcal{K}) à partir d’exemples d’entraînement et vise à générer des échantillons issus de cet ensemble
- Pour des images, (\mathcal{K} \subset \mathbb{R}^{c\times h \times w}) est l’ensemble des valeurs de pixels correspondant à des images réalistes
- Le même cadre s’applique aussi à des domaines discrets comme l’audio, la vidéo, les trajectoires robotiques ou le texte
- La procédure d’apprentissage peut être vue en trois étapes
- On échantillonne (x_0 \sim \mathcal{K}), (\sigma) et (\epsilon \sim N(0,I))
- On crée des données bruitées avec (x_\sigma=x_0+\sigma\epsilon)
- On minimise une perte quadratique pour que (\epsilon_\theta(x_\sigma,\sigma)) prédise (\epsilon)
- Dans le code,
training_loopgénèresigmaetepspour chaque batchx0viagenerate_train_sample, puis optimise la MSE entre la sortie demodel(x0 + sigma * eps, sigma)eteps - (\sigma) n’est pas tiré uniformément sur un intervalle continu, mais à partir d’un planning de (\sigma) discrétisé en (N) valeurs
- La classe
Scheduleencapsule la liste dessigmaspossibles et en échantillonne pour chaque batch pendant l’entraînement - L’exemple principal utilise
ScheduleLogLinear(N, sigma_min=0.02, sigma_max=10) ScheduleDDPMest un planning pour les modèles de diffusion dans l’espace pixel, etScheduleLDMpour les modèles de latent diffusion comme Stable Diffusion
- La classe
Exemple jouet Swissroll
- Le dataset jouet est un ensemble de points en spirale utilisé dans l’un des premiers articles sur la diffusion, Sohl-Dickstein et al. 2015, avec (\mathcal{K}\subset\mathbb{R}^2)
- Sur ce dataset simple, le denoiser est implémenté sous forme de MLP
- L’entrée concatène (x\in\mathbb{R}^2) et un embedding bidimensionnel de (\sigma)
- La sortie est une prédiction du bruit (\epsilon\in\mathbb{R}^2)
- De nombreux modèles de diffusion utilisent un embedding positionnel sinusoïdal pour (\sigma), mais dans cet exemple un simple embedding 2D fonctionne aussi bien
- La configuration d’entraînement de l’exemple utilise
ScheduleLogLinear(N=200, sigma_min=0.005, sigma_max=10)etepochs=15000 - Le denoiser entraîné peut être visualisé comme un champ de vecteurs en traçant (x-\sigma\epsilon_\theta(x,\sigma))
- Quand (\sigma) est grand, le denoiser tend à prédire la moyenne des données
- Quand (\sigma) est faible et que l’entrée (x) est proche des données, il prédit de vrais points du dataset
Interpréter le denoising comme une projection
- La fonction de distance à l’ensemble de données (\mathcal{K}) est définie par (\mathrm{dist}_{\mathcal{K}}(x)=\min{|x-x_0|:x_0\in\mathcal{K}})
- La projection (\mathrm{proj}_{\mathcal{K}}(x)) de (x) est l’ensemble des points de (\mathcal{K}) qui réalisent cette distance
- Si (\mathcal{K}) est fermé, que (x\notin\mathcal{K}) et que la projection est unique, alors le gradient de la fonction de distance au carré vaut (x-\mathrm{proj}_{\mathcal{K}}(x))
- Comme la fonction de distance (\mathrm{dist}_{\mathcal{K}}) n’est pas différentiable partout, on introduit une fonction de distance au carré lissée par (\sigma), en utilisant un softmin à la place de
min - Le gradient de cette fonction de distance lissée pointe vers une moyenne pondérée des points de (\mathcal{K}), avec des poids déterminés par (x)
Denoiser idéal et modèle d’erreur relative
- Le denoiser idéal (\epsilon^*) est celui qui minimise exactement la perte d’apprentissage pour un (\sigma) donné
- Si les données suivent une distribution uniforme discrète sur un ensemble fini (\mathcal{K}), le denoiser idéal admet une expression en forme fermée
- Le poids de chaque point de données dépend de la distance entre (x_\sigma) et ce point
- Sur un petit dataset, on peut le calculer directement avec
IdealDenoiser
- Sur les données jouets, le denoiser idéal pointe vers la moyenne des données quand (\sigma) est grand, et vers le point de données le plus proche quand (\sigma) est petit
- Le théorème central établit que pour tout (\sigma>0) et tout (x\in\mathbb{R}^n), on a la relation (\frac{1}{2}\nabla_x \mathrm{dist}^2_{\mathcal{K}}(x,\sigma)=\sigma\epsilon^*(x,\sigma))
- Le modèle d’erreur relative repose sur la condition que (x-\sigma\epsilon_\theta(x,\sigma)) approxime bien (\mathrm{proj}_{\mathcal{K}}(x))
- Il s’applique lorsque (\sqrt{n}\sigma) estime bien (\mathrm{dist}_{\mathcal{K}}(x)) à un facteur constant près
- On suppose que l’erreur est bornée par (\eta\mathrm{dist}_{\mathcal{K}}(x))
- À faible bruit, sous l’hypothèse de variété, la majeure partie du bruit ajouté est orthogonale à la variété des données, donc le denoising approxime une projection
- À fort bruit, si (\sigma) dépasse le diamètre de (\mathcal{K}), même un denoiser qui prédit la moyenne pondérée des données présente une faible erreur relative
- CIFAR-10 est d’une taille qui permet encore de calculer le denoiser idéal, et les expériences montrent une faible erreur relative entre la projection exacte et la sortie du denoiser idéal le long des trajectoires d’échantillonnage
Échantillonnage : denoising itératif et DDIM
- Une fois le denoiser appris, on peut prédire (x_0) à partir de (x_t) bruité et du niveau de bruit (\sigma_t) via (\hat{x}0^t=x_t-\sigma_t\epsilon\theta(x_t,\sigma_t))
- Le point de départ consiste à choisir (\sigma_T) bien plus grand que le diamètre de (\mathcal{K}), puis à échantillonner indépendamment (x_T) depuis (N(0,\sigma_T)) pour le placer loin de (\mathcal{K})
- À fort bruit, un seul appel au denoiser peut avoir une faible erreur relative tout en gardant une erreur absolue importante, et la prédiction du denoiser idéal est alors proche de la moyenne des données
- L’échantillonnage appelle donc le denoiser de manière répétée selon un planning (\sigma_t), afin de construire une séquence (x_T,\ldots,x_0)
- La mise à jour (x_{t-1}=x_t-(\sigma_t-\sigma_{t-1})\epsilon_\theta(x_t,\sigma_t)) est équivalente à l’algorithme d’échantillonnage DDIM déterministe après un changement de coordonnées
- La preuve de l’équivalence avec DDIM figure dans l’annexe A de l’article
DDIM vu comme minimisation de distance
- DDIM s’interprète comme une descente de gradient approchée sur (f(x)=\frac{1}{2}\mathrm{dist}_{\mathcal{K}}(x)^2)
- La taille de pas est (1-\sigma_{t-1}/\sigma_t)
- (\nabla f(x_t)) est estimé par (\epsilon_\theta(x_t,\sigma_t))
- Le planning (\sigma_t) détermine le nombre et la taille des pas de gradient pendant l’échantillonnage
- S’il y a trop peu de pas, (\mathrm{dist}_{\mathcal{K}}(x_t)) peut ne pas diminuer, et la méthode peut ne pas converger
- Utiliser beaucoup de petits pas augmente le nombre d’évaluations du denoiser et donc le coût de calcul
- Un planning admissible est un planning tel que, à chaque itération, (\sqrt{n}\sigma_t) reste à un facteur constant près de (\mathrm{dist}_{\mathcal{K}}(x_t))
- Une suite log-linéaire de (\sigma_t) décroissant géométriquement est un planning admissible
- D’après le théorème, si pour les (x_t) générés par DDIM le gradient (\nabla\mathrm{dist}{\mathcal{K}}(x)) existe et que (\mathrm{dist}{\mathcal{K}}(x_T)=\sqrt{n}\sigma_T), alors (x_t) est généré par une descente de gradient sur la fonction de distance au carré et (\mathrm{dist}_{\mathcal{K}}(x_t)/\sqrt{n}\approx\sigma_t) se maintient
- Dans l’exemple jouet, un sampler DDIM en 20 étapes est implémenté en sous-échantillonnant le planning log-linéaire d’origine, et la plupart des échantillons sont proches des données d’origine, même si des améliorations restent possibles
Sampler amélioré fondé sur l’estimation du gradient
- En exploitant le fait que (\nabla\mathrm{dist}{\mathcal{K}}(x)) est invariant entre (x) et (\mathrm{proj}{\mathcal{K}}(x)), les auteurs utilisent une mise à jour qui mélange l’estimation courante et la précédente
- La mise à jour (\bar{\epsilon}t=\gamma\epsilon\theta(x_t,\sigma_t)+(1-\gamma)\epsilon_\theta(x_{t+1},\sigma_{t+1})) corrige l’erreur de l’étape précédente à l’aide de l’estimation actuelle
- Sur les échantillons du modèle jouet, cette approche converge plus vite que DDIM et produit des échantillons plus proches des données d’origine
- Par rapport à DDIM, ce sampler peut s’interpréter comme l’ajout d’un momentum ; la trajectoire peut dépasser sa cible, mais converger plus vite
- Ajouter du bruit pendant la génération améliore empiriquement la qualité d’échantillonnage
- Pour conserver le planning (\sigma_t) d’origine, on denoise jusqu’à un plus petit (\sigma_{t'}), puis on réinjecte un bruit (w_t\sim N(0,I))
- Quand (\mu=\frac{1}{2}), on retrouve exactement le sampler DDPM
- La mise à jour complète (x_{t-1}=x_t-(\sigma_t-\sigma_{t'})\bar{\epsilon}_t+\eta w_t) généralise trois samplers
- DDIM :
gam=1, mu=0 - DDPM :
gam=1, mu=0.5 - Sampler par estimation de gradient :
gam=2, mu=0
- DDIM :
Modèles plus grands et ressources utiles
- Le code d’entraînement présenté plus haut peut servir non seulement aux données jouets, mais aussi à entraîner un modèle de diffusion d’images from scratch
- L’exemple FashionMNIST est fourni comme cas d’entraînement sur le dataset FashionMNIST, et obtient le 2e meilleur score FID du leaderboard Papers with Code
- Le code d’échantillonnage peut aussi être utilisé sans modification avec des modèles latent diffusion préentraînés
- L’exemple utilise
ScheduleLDM(1000)etModelLatentDiffusion('stabilityai/stable-diffusion-2-1-base') - La condition textuelle est
An astronaut riding a horse, l’échantillonnage se fait sur 50 étapes de (\sigma), puis le latent est décodé
- L’exemple utilise
- L’effet du terme de momentum (\gamma) est illustré par des comparaisons visuelles en génération texte-vers-image haute résolution
- Autres ressources à consulter
- What are diffusion models : introduction aux modèles de diffusion dans une perspective temps discret fondée sur l’inversion d’un processus de Markov
- Generative modeling by estimating gradients of the data distribution : introduction aux modèles de diffusion dans une perspective temps continu fondée sur l’inversion d’équations différentielles stochastiques
- The annotated diffusion model : explication détaillée d’une implémentation PyTorch d’un modèle de diffusion
1 commentaires
Commentaires Hacker News
Je peux répondre aux questions s’il y en a.
J’ai particulièrement apprécié la discussion sur les trajectoires, parce qu’elle motive la compréhension d’un point qui pose problème à beaucoup de gens sur des sujets comme les ordonnanceurs. Ce n’est pas aussi complet que les articles de Song ou de Lilian, mais c’est bien plus accessible, donc je compte le recommander à d’autres.
À titre de référence, un ami avait écrit auparavant une implémentation minimale de diffusion qui, du point de vue DDPM, est un peu plus « complète » et m’a été utile : https://github.com/VSehwag/minimal-diffusion/
Ayant un peu expérimenté avec la procédure d’échantillonnage dans Stable Diffusion, j’aurais aussi aimé voir une comparaison du temps de convergence et du nombre d’étapes par rapport à DDIM. Je me demande s’il existe un lien entre momentum, convergence et erreur. Par exemple, ce serait intéressant de comparer si un sampler avec momentum en 16 étapes est presque équivalent à DDIM en 20 étapes ± un terme d’erreur.
get_sigma_embeds(batches, sigma)ne semble pas utiliser sa première entrée. Je me demande si l’intention était de broadcastersigmasous la forme(batches, 1).Il entre beaucoup plus en profondeur dans les détails mathématiques, tout en fournissant une implémentation minimale de moins de 500 lignes très facile à comprendre.
Ce serait bien que cela soit aussi étendu à la version diffusion transformer qui fait tourner Sora et d’autres modèles de génération vidéo. En combinant cet article avec https://jaykmody.com/blog/gpt-from-scratch/, on pourrait faire un article d’introduction du type « construire un diffusion transformer à partir de zéro ».
À l’inverse, si vous voulez vraiment creuser, je recommande de lire les travaux de Kingma, Gao, Ricky Tian Qi Chen, ainsi que des étudiants de Max Welling (Tomczak est postdoctorant, Hoogeboom, etc.), sans oublier le contributeur discret Aapo Hyvärinen. Voici un exemple d’un travail relativement léger de Kingma & Gao, également lié à l’article SD3 : https://arxiv.org/abs/2303.00848
Ce qui est dommage, c’est que l’accessibilité est limitée par une forte dépendance à la connaissance et à la compréhension des travaux précédents ; mais il est aussi difficile de qualifier cela de critique pertinente, puisque c’est de la recherche, pas du matériel pédagogique destiné au grand public.
n_embd, et le processus de diffusion lui-même peut rester inchangé.[1] https://yang-song.net/blog/2021/score/
[2] https://lilianweng.github.io/posts/2021-07-11-diffusion-mode...
De notre point de vue, les modèles de diffusion sont faciles à entraîner parce qu’ils utilisent un objectif d’apprentissage consistant à prédire le gradient d’une fonction de distance lissée, plutôt que le gradient de la fonction de distance exacte. L’échantillonnage d’un modèle de diffusion ressemble à l’exécution de plusieurs étapes de gradient approximatives.
Pour comprendre plus en profondeur les modèles de diffusion, je recommande de lire tous ces billets et d’apprendre les différentes interprétations.
Cela dit, l’approche de cet article semble permettre des expériences plus intéressantes, comme l’analyse des erreurs du débruiteur.
[1] https://arxiv.org/pdf/2305.03486.pdf
Par exemple, pourquoi un générateur d’images a-t-il du mal à produire des touches de piano ? Pour créer une structure où les touches noires alternent par groupes de deux et de trois, il semble qu’il faudrait mieux représenter des contraintes de distance intermédiaires.