DDPM代码详细解读(1):数据集准备、超参数设置、loss设计、关键参数计算
之前照猫画虎复现了DDPM,但是其中许多代码很不好理解,遂写此博客进行学习记录。
超参数设置
超参数设置使用absl中flags进行管理,num_res_blocks是U-net中每个level中resnet层数,attn是attention block,这个是后面我们加入condition的途径,非常重要。
beta_1和beta_2对用于\(\beta _1\)和\(\beta _T\),实际的\(\beta _t\)是在\(\beta _1\)、\(\beta _T\)中线性采样得到的,DDPM原文中研究了是否固定\(\beta _t\)对实验结果的影响,后面很多论文也做了对比实验探索是否线性对实验效果的影响。(个人感觉这个参数是否线性对实验的影响不大,只需要满足足够小的假设即可)
T是采样的步长,这个对采样质量和生成时间影响非常大。(在实验中,T越大,采样时间越长,4070显卡采样一个batch的数据设置需要近20小时。但是T越大并不是质量越高,呈二次函数关系。)
image_size根据数据集实际情况设置,这是影响生成时间的重要因素,size和时间呈指数倍爆炸增长。
1 | FLAGS = flags.FLAGS |
训练CIFAR10数据集的配置信息
由于不同的数据集U-net channel、T、image_size等关键参数是不一样的,因此针对不同的数据集采用不同的txt文件进行管理。
1 | --T=1000 |
加载数据集
以加载cifar10数据集为例:
1 | # dataset |
每个loop使用next方法即可:
1 | x_0 = next(datalooper).to(device) |
loss计算
原文的loss计算公式: \[ L_{simple}\left( \theta \right) :=\mathbb{E}_{t,x_0,\epsilon}\left[ ||\epsilon -\epsilon _{\theta}\left( \sqrt{\bar{\alpha}}x_0+\sqrt{1-\bar{\alpha}_t}\epsilon ,t \right) ||\begin{array}{c} 2\\ \\\ \end{array} \right] \] 计算的是纯噪声和\(\epsilon _{\theta}\left( \sqrt{\bar{\alpha}}x_0+\sqrt{1-\bar{\alpha}_t}\epsilon ,t \right)\)之间的均方差:
1 | loss = F.mse_loss(self.model(x_t, t), noise, reduction='none') |
其中noise的size和input data一样的:
1 | noise = torch.randn_like(x_0) |
计算\(x_t\)
具体的计算公式如下: \[ x_t=\sqrt{\bar{\alpha}_t}x_0+\sqrt{1-\bar{\alpha}_t}\epsilon \] 其中时刻信息 t 是通过\(\bar{\alpha}_t\)表现的,代码如下:
1 | def extract(v, t, x_shape): |
其中extract函数的作用是选取特定下标 t 的信息并转换成特定维度。
计算\(\bar{\alpha}_t\)
1、根据\(\beta _1\)和\(\beta _T\)计算所有的\(\beta _t\)
DDPM原始的论文设置的\(\beta _t\)是线性增长,后面不少文章设置了指数增长等其他方式,只要满足足够小假设即可。
1 | self.register_buffer( |
2、计算\(\alpha_t\) \[ \alpha _t=1-\beta _t \]
1 | alphas = 1. - self.betas |
3、累乘得到\(\bar{\alpha}_t\)
1 | alphas_bar = torch.cumprod(alphas, dim=0) |
最后将这些一同写入buffer即可:
1 | self.register_buffer( |
计算\(\frac{1-\alpha _t}{\sqrt{1-\bar{\alpha}_t}}\)
1 | self.register_buffer( |
- 本文作者: 李宝璐
- 本文链接: https://libaolu312.github.io/2023/11/13/DDPM代码详细解读-1-数据集准备、超参数设置、loss设计、关键参数计算/
- 版权声明: 本博客所有文章除特别声明外,均采用 MIT 许可协议。转载请注明出处!