DDPM代码详细解读(2):Unet结构、正向和逆向过程、IS和FID测试、EMA优化
EMA优化
使用指数移动平均对模型参数进行优化,提高测试指标增加模型鲁棒性。
EMA的公式是: (EMA[t] = α * x[t] + (1 - α) * EMA[t-1]) \[ v_t=\beta \cdot v_{t-1}+\left( 1-\beta \right) \cdot \theta _t \] 其中 β 的设置为0.999,代码如下:
1 | def ema(source, target, decay): |
在训练的过程中,每一个 step 对 net_model 和 ema_model (即sample model)做ema:
1 | ema(net_model, ema_model, FLAGS.ema_decay) |
训练目标和采样目标
正向过程
正向过程即p过程,逆向过程即q过程、采样过程。
正向过程不涉及参数分布的计算和预测,可以理解为一个单纯add noise的过程。
训练和采样的训练目标如下:

上一篇博客详细解释了\(x_t\)和\(\epsilon _{\theta}\)是怎么计算的,正向过程的code就非常容易理解了:
1 | class GaussianDiffusionTrainer(nn.Module): |
逆向过程
\(x_t\)的分布符合高斯分布,这是通过均值和方差进行计算的: \[ q\left( x_t|x_0 \right) =N\left( x_t;\sqrt{\bar{\alpha}_t}x_0,\left( 1-\bar{\alpha}_t \right) I \right) \] 计算\(\sigma _tZ\)使用:
1 | torch.exp(0.5 * log_var) * noise |
而其他的参数都已经计算过了,所以重点是计算第一项的均值:
输入\(x_t\),得到\(x_{t-1}\),最终的代码如下:
1 | class GaussianDiffusionSampler(nn.Module): |
因为我们预测的是概率分布,所以最终将所有的值缩放到[-1,1]这个区间中。
IS和FID测试
IS简介
IS基于Google的预训练网络Inception Net-V3,Inception Net-V3是精心设计的卷积网络模型,输入为图片张量,输出为1000维向量。输出向量的每个维度的值对应图片属于某类的概率,因此整个向量可以看做一个概率分布。
p(y|x) 表示 Inception 输入生成图像 x 时的输出分布,p(x) 表示生成器 G 生成图像 x 的概率,p(y_i|x)表示 Inception 预测 x 为第 i 类的概率,IS 是衡量两者之间的 KL散度: \[ IS=\exp E_{x~P_G}KL\left( p\left( y|x \right) ||p\left( y \right) \right) \] IS越大,生成图片的质量越高。
FID简介
FID衡量真实图像分布和生成器生成之间的差异,因此FID越小,代表真实图像和生成图像之间的接近性,生成质量也就越高。
计算公式:(整个公式表示两个点(m, C)和(m_w, C_w)之间的距离的度量。这个距离包括它们在特征空间的欧几里得距离以及它们在协方差矩阵空间的距离。)
使用方法:
使用 Inception V2 预训练模型提取真实图像和假图像的特征向量(由生成器生成),计算生成的特征向量的特征均值。
生成特征向量\(C,C_w\)的协方差矩阵
计算矩阵的迹 \[ Tr\left( C+C_w-2\left( CC_w \right) ^{1/2} \right) \] 计算矩阵的迹可以参考博客:https://blog.csdn.net/lyxleft/article/details/84865805
计算平均向量的平方差
IS代码
1 | import torch |
1 | def inception_score(imgs, cuda=True, batch_size=32, resize=False, splits=1): |
1 | if __name__ == '__main__': |
FID代码
1 | def calculate_fid(image1, image2): |
完整实例
1 | import matplotlib.pyplot as plt |
U-net网络结构
前面的博客已经详细说明了为什么要用U-net,以及U-net的结构
1 | import math |
- 本文作者: 李宝璐
- 本文链接: https://libaolu312.github.io/2023/11/13/DDPM代码详细解读-2-Unet结构、正向和逆向过程、IS和FID测试、EMA优化/
- 版权声明: 本博客所有文章除特别声明外,均采用 MIT 许可协议。转载请注明出处!