Training
- (Ho et al. (2020) ) Denoising loss:
- This objective can be seen as a reweighted form of (without the terms affecting ). The authors found that optimizing this reweighted objective resulted in much better sample quality than optimizing directly, and explain this by drawing a connection to generative score matching (Song & Ermon, 2019; 2020).
āImproved denoising diffusion probabilistic modelsā (TLDR: they learn )
Learning the variance
-
We want to fit
-
The variational lower bound loss is derived from VDM
-
To fit , one can simply reparametrize as a noise prediction network , and use
-
However to train , one must use the full
-
They do
-
is only trained using , thus they use a stop-grad when computing (https://github.com/openai/improved-diffusion/blob/1bc7bbbdc414d83d4abf2ad8cc1446dc36c4e4d5/improved_diffusion/gaussian_diffusion.py#L679)
- simply use
mean_pred.detach()
as the mean when computing the VLB.
- simply use
-
For sample quality, the first few steps of the diffusion donāt really matter i.e. very small details. HOWEVER, for maximizing log-likelihood, the first few steps of the diffusion process matter the most as they contribute the most to the variational lower bound (Fig.2 of the paper)
- This is because the likelihood of a training sample with very little noise must still be well calibrated and high for this sample. This is not really taken into account when only doing noise matching loss.
-
In āImproved denoising diffusion probabilistic modelsā, they characterized the variance as:
-
where is the variance schedule and is the variance of the posterior
-
and is a vector containing one component per dimension
Better noise schedule
- cosine schedule
- they use a small offset to prevent to be too small near t=0
- They selected such that was slightly smaller than the pixel bin size 1/127.5.
- Can be smaller for other modalities.
Reducing gradient noise
- introduces a lot of gradient noise
- gradient noise = l2 norm of concatenated gradient
- Noting that different terms of Lvlb have greatly different magnitudes (Figure 2), we hypothesized that sampling t uniformly causes unnecessary noise in the Lvlb objective
- Simple importance sampling technique reduces this noise
- where and
- We found that the importance sampling technique was not helpful when optimizing the less-noisy objective directly.
Wurtschen (learned noise gating trick)
- We have
- To predict , they do
- with , and have the same dimension as the noise . The division is element-wise.
- It makes the training more stable. They hypothesize this occurs because the model parameters are initialized to predict 0 at the beginning, enlarging the difference to timesteps with a lot of noise. By reformulating to the objective, the model initially returns the input, making the loss small for very noised inputs.
- noise prediction: https://github.com/dome272/Wuerstchen/blob/main/modules.py#L307
- diffusion implementation: https://github.com/pabloppp/pytorch-tools/blob/master/torchtools/utils/diffusion.py
- Additionally, they do p2 loss weighted noise matching
- where
- making higher noise levels contribute more to the loss
Analyzing and Improving the Training Dynamics of Diffusion Models
Noisy training signal
- The training dynamics of diffusion models remain challenging due to the highly stochastic loss function.
- The final image quality is dictated by faint image details predicted throughout the sampling chain
- small mistakes at intermediate steps can have snowball effects in subsequent iterations
- The network must accurately estimate the average clean image across a vast range of noise levels, Gaussian noise realizations, and conditioning inputs.
- Learning to do so is difficult given the chaotic training signal that is randomized over all of these aspects