0%

torch | 广播机制

广播机制让 torch 处理不同维度的信息。

如果满足以下规则,则两个张量是“可广播的”:

  • 每个张量具有至少一个维度。
  • 从尾随尺寸开始迭代尺寸时,尺寸要么相等,要么其中之一为1,或者不存在其中之一

如果张量x和y可以符合广播的条件,那么:结果张量可以按照下面的方式计算:

  • 如果x和y的维度不相同,用1来扩张维度少的那个,使两个张量维度一致
  • 对于每个维度,结果维度是x,y对应维度的最大值
1
2
3
a = torch.ones(5, 1, 4, 1)
b = torch.ones(3, 1, 1)
print((a + b).shape)

输出

torch.Size([5, 3, 4, 1])
1
2
3
4
a = torch.ones(5, 1, 4, 1)
b = torch.from_numpy(np.array([1, 2, 5]).reshape((3, 1, 1)))
print(torch.mul(a, b).shape)
print(torch.mul(a, b))

输出

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
torch.Size([5, 3, 4, 1])
tensor([[[[1.],
[1.],
[1.],
[1.]],
[[2.],
[2.],
[2.],
[2.]],
[[5.],
[5.],
[5.],
[5.]]],

[[[1.],
[1.],
[1.],
[1.]],
[[2.],
[2.],
[2.],
[2.]],
[[5.],
[5.],
[5.],
[5.]]],

[[[1.],
[1.],
[1.],
[1.]],
[[2.],
[2.],
[2.],
[2.]],
[[5.],
[5.],
[5.],
[5.]]],

[[[1.],
[1.],
[1.],
[1.]],
[[2.],
[2.],
[2.],
[2.]],
[[5.],
[5.],
[5.],
[5.]]],

[[[1.],
[1.],
[1.],
[1.]],
[[2.],
[2.],
[2.],
[2.]],
[[5.],
[5.],
[5.],
[5.]]]])
请我喝杯咖啡吧~