0%

torch | 查看网络的结构层以及相关参数等

如果我们要查看一个网络是否符合我们预期的时候,比如

我们想看 conv 之后尺寸变化等。

我们举一个简单的例子,一个 DNN

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class DNN(nn.Module):
def __init__(self):
super(DNN, self).__init__()
self.l1 = nn.Linear(1000, 2000)
self.l2 = nn.Linear(2000, 3000)
self.l3 = nn.Linear(3000, 2000)
self.l4 = nn.Linear(2000, 1000)

def forward(self, x):
x = x.view(-1, 1000)
x = F.relu(self.l1(x))
x = F.relu(self.l2(x))
x = F.relu(self.l3(x))

return self.l4(x).reshape(x.shape[0], 1, 1, 1000)


if __name__ == '__main__':
model = DNN()
s = torch.from_numpy(np.random.uniform(1, 5, 5000).reshape((5, 1, 1, 1000)))
test = model.forward(torch.tensor(s, dtype=torch.float))
print(test)

看上面的的代码,我们可以初始化,然后使用 forward

然后看什么 shape 就可以在 相应的语句上打断点查看。

请我喝杯咖啡吧~