张量广播
张量广播
原本打算把常见运算类型一起写进来,但是似乎量太大,就分开了。
张量广播
广播并不只是发生在不同形状的张量之间, 相同形状的张量进行运算,张量每个元素相应发生改变,这也是广播的一种形式。 但我们主要专注于,不同形状的张量之间的广播。
不同的形状的张量之间的广播计算是怎么做到的?
实际上,并不是说直接低维度向高维度计算,而是通常会有一些低维度的 向量先隐式地转化到高维度,然后再进行相同维度的逐点计算。
广播的规则,其实也是转化的规则,什么时候发生改变,发生什么样的改变。
标量和任意形状的张量广播
import torch
a = torch.zeros(3,4)
a+1
# --
tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
具体是便利了每个元素,然后加上了 1。还是先将 1 转化为 3x4 的张量,然后再逐点相加。我偏向后者,因为这样明显更快。
不同形状或者维度的张量广播
先来一个约定,我们之后会用维度
来说明Tensor.shape
的长度,比如,Tensor.shape=(3,4,5)
,那么它就是三维了,用形状代表每个维度的具体值,比如上述的形状分别是3,4,5
。
形状 1 可以向任何形状广播
比如Tensor.shape=(1,3,4)
,实际上它可以直接用view
Resize 成为Tensor.shape=(3,4)
,并且不会丢失任何元素。 1 实际上可以被看成是标量,标量可以广播到任何形状,1 也可以广播到任何形状。
示例:
import torch
# success
a = torch.ones(1,3,4,5)
b = torch.ones(2,3,4,5)
print(a+b)
# success
a = torch.ones(2,1,4,5)
b = torch.ones(2,3,4,5)
print(a+b)
# success
a = torch.ones(1,1,4,1)
b = torch.ones(2,3,4,5)
print(a+b)
结果都相同:
tensor([[[[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.]],
[[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.]],
[[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.]]],
[[[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.]],
[[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.]],
[[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.]]]])
在这里我们也可以推导出来,广播实际上在隐式转化的过程,做的实际上就是一种复制粘贴的操作,把 1 处的值复制拓展到任意数字处,有点像复印然后堆叠起来,关键的是堆叠的每一张都是最底下那张的复印件。
1 可以被视作一张万能卡,可以被当作任意形状进行匹配。
虽然咱们这个结论说起来轻巧,但是如果要实现这样一个功能,那实际上是一个噩梦。
相同形状,不同维度进行广播
示例:
import torch
# success
a = torch.ones(4,5)
b = torch.ones(2,3,4,5)
print(a+b)
# fail
a = torch.ones(2,3)
b = torch.ones(2,3,4,5)
print(a+b)
# fail
a = torch.ones(3,4)
b = torch.ones(2,3,4,5)
print(a+b)
我以为应该都可以广播的,但是实际上只有第一个可以广播,后两个都不行。
我尝试把这个问题拿去问 gpt,结果它的回答是一塌糊涂。
不过,从右往左的匹配方式倒是对的。我们可以这么考虑,第一步,它总是在张量的左侧unsqueeze
填充形状为1
的维度到相同维度,然后再尝试进行广播。
所以填充后分别是1,1,4,5
,1,1,2,3
,1,1,2,3
,所以只有第一个可以广播。
为什么不支持自由广播呢?
这是个哲学问题,我以前不主动学广播就是因为我以为它是自由广播,一想到就头痛,特别是以为case 3
也可以广播,但是实际上这不好,因为这样会导致很多不确定性,也会让广播的规则变得非常的冗余。不过好在,广播的规则还是很简单的。
总结:
- 如果维度相同,进行一一匹配,如果碰到形状为
1
,这个1
可以作为一个通配符。 - 如果维度不同,先
unsqueeze
填充形状为1
的维度到相同维度,然后再进行匹配。 - 如果形状在维度上不匹配,不会自动用
view
进行 resize 或者交换维度,而是直接抛出错误。
没有更多了哦,大道至简嘛。
不过你别说,数学越学越上头,爽。
番外-嘛,你知道什么是 torch.cat 吗?
给你看个例子:
a = torch.ones(2,4,4,5)
b = torch.ones(2,3,4,5)
torch.cat([a,b],dim=1)
以及和 gpt 的对话:torch.cat
你觉得torch.cat
是啥?
当然是torch
里的一只猫猫啦~。