DDPM代码详细解读(3):图解模型各部分结构、用ConvNextBlock代替ResNet
这里重点分析一下用最新的ConvNextBlock代替ResBlock效果
整体U-net结构
基于ResBlock的U-net版本
之前博客多次介绍了Unet结构,用一张图表示:
代码如下:
1 | class UNet(nn.Module): |
基于ConvNextBlock的U-net版本
代码如下:
1 | class Unet(nn.Module): |
原先ResNet结构
作为Unet结构基础组成部分,Resnet结构负责引入时间 t 信号,并且attention机制给多模态带来可能,结构图如下:
这里再复习一下前面的代码:
1 | class ResBlock(nn.Module): |
ConvNextBlock改进
整体看起来结构和ResNet非常像,结构图如下:
代码如下:
1 | class ConvNextBlock(nn.Module): |
网上有研究工作者给出了二者的效果对比。
ResNet训练的结果:
替换ConvBlock后的效果:
直观来看确实ConvBlock效果要好一点。
T-embedding结构
我们要对 t 做embedding操作,可以使用不同的激活函数:swish() 或 nn.GELU() ,swish在pytorch中没有现成的,我们可以用 x*sigmoid(x) 代替。
一般 embedding 的 out_dim 是 in_dim 的四倍,结构如下:
swish()版本代码:
1 | class Swish(nn.Module): |
GELU()版本代码:
1 | time_dim = dim |
论文和代码:
代码:https://github.com/lucidrains/denoising-diffusion-pytorch
论文:https://arxiv.org/abs/2006.11239
- 本文作者: 李宝璐
- 本文链接: https://libaolu312.github.io/2023/11/13/DDPM代码详细解读-3-图解模型各部分结构、用ConvNextBlock代替ResNet/
- 版权声明: 本博客所有文章除特别声明外,均采用 MIT 许可协议。转载请注明出处!