0%

torch | 认识 UNet

从最简单的代码入手,认识 UNet


环境


  • python3.6
  • torch 1.3.1

简单的 UNet 网络


先看一下代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class ConvAutoEncoder(nn.Module):
def __init__(self):
super(ConvAutoEncoder, self).__init__()

# Zero padding is almost the same as average padding in this case
# Input = b, 1, 4, 300
self.encoder = nn.Sequential(
nn.Conv2d(1, 8, (4, 7), stride=1, padding=(0, 3)), # b, 8, 1, 300
nn.Tanh(),
nn.MaxPool2d((1, 2), stride=2), # b, 8, 1, 150
nn.Conv2d(8, 4, 3, stride=1, padding=1), # b, 4, 1, 150
nn.Tanh(),
nn.MaxPool2d((1, 2), stride=2) # b, 4, 1, 75
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(4, 8, 3, stride=2, padding=1, output_padding=(0, 1)), # b, 8, 1, 150
nn.Tanh(),
nn.ConvTranspose2d(8, 8, 3, stride=2, padding=1, output_padding=(0, 1)), # b, 8, 1, 300
nn.Tanh(),
nn.ConvTranspose2d(8, 1, 3, stride=1, padding=1), # b, 1, 1, 300
)

def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x

UNet 有两个部分,分别是 encoderdecoder。关于 encoder,你可以参考我下面的博文。

下面重点说说 decoder

首先输入到 decoder 的数据维度是 N * 4 * 1 * 75。在第一个逆卷积中,nn.ConvTranspose2d(4, 8, 3, stride=2, padding=1, output_padding=(0, 1)),stride = 2 进行扩充,变成 N * 4 * 1 * 149,接着,进行卷积核为 3 * 3 的逆操作,即变成 149 + 3 - 1 = 151,数据变成 N * 8 * 3 * 151,接着进行 padding = 1 的向内压缩,变成 N * 8 * 1 * 149,又因为 output_padding=(0, 1) 这个向得到结果的右侧添加了一列,变成 N * 8 * 1 * 150。接着进入了第二个逆卷积操作 nn.ConvTranspose2d(8, 8, 3, stride=2, padding=1, output_padding=(0, 1)),同理可以得到数据为 N * 8 * 1 * 300 ,接着到了第三层逆卷积,数据变成了 N * 1 * 1 * 300

关于上面的计算过程,你可以参考我下面的博文。


函数说明


nn.ConvTranspose2d

torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode=’zeros’)

关于这个的用法,我推荐你看我下面的博文。

Parameters

  • in_channels (int) – 输入信号的通道数
  • out_channels (int) – 卷积产生的通道数
  • kernel_size (int or tuple) – 卷积核的大小
  • stride (int or tuple, optional) – 卷积步长,即要将输入扩大的倍数. Default: 1
  • padding (int or tuple, optional) – 向内压缩. Default: 0
  • output_padding (int or tuple, optional) – Additional size added to one side of each dimension in the output shape. Default: 0
  • groups (int, optional) – Number of blocked connections from input channels to output channels. Default: 1
  • bias (bool, optional) – If True, adds a learnable bias to the output. Default: True
  • dilation (int or tuple, optional) – Spacing between kernel elements. Default: 1
请我喝杯咖啡吧~