很多框架中提供的矩阵乘法都是出于简化计算的考虑,很多情况下在进行计算时候都会牵扯到 batch size 这一个维度,这就使得很多矩阵的计算是三维的,Pytorch中的bmm()函数就可以很方便的实现三维数组的乘法,而不用拆成二维数组使用for循环解决。在查资料的时候发现有些博客写的有些小地方不太对,而且有很多提问都是关于 bmm()函数具体是如何计算的,因此记录。
文章目录
1.torch.bmm()
函数定义:
def bmm(self: Tensor,
mat2: Tensor,
*,
out: Optional[Tensor] = None) -> Tensor
函数的传入参数很简单,两个三维矩阵而已,只是要注意这两个矩阵的shape有一些要求:
res = torch.bmm(ma, mb)
ma: [a, b, c]
mb: [a, c, d]
也就是说两个tensor的第一维是相等的,然后第一个数组的第三维和第二个数组的第二维度要求一样,对于剩下的则不做要求,其实这里的意思已经很明白了,两个三维矩阵的乘法其实就是保持第一维度不变,每次相当于一个切片做二维矩阵的乘法,对于上面的矩阵来说,就是 for i in range(a)
然后 ma[i] * mb[i]
,这是一个熟悉的二维矩阵乘法,两个矩阵的shape分别是[b, c]
和[c, d]
。因此,输出的结果的shape也很明显了:[a, b, d]
。下面验证一下:
2.验证
首先创建两个tensor:
a = torch.linspace(1, 24, 24).view(2, 3, 4) # shape [2, 3, 4]
b = torch.linspace(1, 16, 16).view(2, 4, 2) # shape [2, 4, 2]
两个tensor分别是:
tensor([[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]],
[[13., 14., 15., 16.],
[17., 18., 19., 20.],
[21., 22., 23., 24.]]])
tensor([[[ 1., 2.],
[ 3., 4.],
[ 5., 6.],
[ 7., 8.]],
[[ 9., 10.],
[11., 12.],
[13., 14.],
[15., 16.]]])
接下来分别使用bmm函数和for循环方式实现乘法:
c = torch.bmm(a, b)
print(c)
d = np.array([torch.mm(a[i], b[i]).numpy() for i in range(len(a))])
print(d)
输出分别是:
tensor([[[ 50., 60.],
[ 114., 140.],
[ 178., 220.]],
[[ 706., 764.],
[ 898., 972.],
[1090., 1180.]]])
[[[ 50. 60.]
[ 114. 140.]
[ 178. 220.]]
[[ 706. 764.]
[ 898. 972.]
[1090. 1180.]]]
也可以使用函数检查一下:
print((d == c.numpy()).all())
输出:True
3.更实际一点的想法
就像刚才所说的那样,只要根据实际的情况考虑一下,这个函数的计算过程很好理解,由于 batch size的引入,所以处理数据的时候很容易出现三维数组,例如处理文本计算attention权重的时候,很容易得到的权重矩阵shape是 [batch_size, sequence_length]
,然后需要相乘的隐状态矩阵是 [batch_size, sequence_length, hidden_size]
。按照attention的计算方式,实际上就是权重矩阵中每一行的数值分别乘以隐状态矩阵中每一行的对应位置的隐状态,这个过程当然可以写循环,也可以简单的使用bmm函数计算,先将权重矩阵reshape成 [batch_size, 1, sequence_length]
然后bmm(weigths_matrix, hidden_matrix)
然后得到的结果就是attention计算的结果了。
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
文章由极客之音整理,本文链接:https://www.bmabk.com/index.php/post/116682.html