大家好,欢迎来到IT知识分享网。
PyTorch中的两个张量的乘法可以分为两种:
- 两个张量对应的元素相乘(element-wise),在PyTorch中可以通过
torch.mul函数(或者 ∗ * ∗运算符)实现 - 两个张量矩阵相乘(Matrix product),在PyTorch中可以通过
torch.matmul函数实现
本文主要介绍两个张量的矩阵相乘。
语法为:
torch.matmul(input, other, out = None)
函数对input和other两个张量进行矩阵相乘。为了方便后续的讲解,将input记为a,将other记为b。
- 若a为1D张量,b为1D张量,则返回两个张量的点积,则返回两个张量的点积(此时的torch.matmul不支持out参数)
举例如下:
import torch a = torch.tensor([1, 2]) b = torch.tensor([3, 4]) result = torch.matmul(a, b) print(result)
结果为:
(PyTorch) D:\Code Project>D:/Anaconda/envs/PyTorch/python.exe "d:/Code Project/demo.py" tensor(11)
- 若a为2D张量,b为2D张量,则返回两个张量的矩阵乘积。
举例为:
import torch a = torch.tensor([[1, 2],[3,4]]) b = torch.tensor([[5,6,7],[8,9,10]]) result = torch.matmul(a, b) print(result)
结果展示为:
(PyTorch) D:\Code Project>D:/Anaconda/envs/PyTorch/python.exe "d:/Code Project/demo.py" tensor([[21, 24, 27], [47, 54, 61]])
- 若a为1D张量,b为2D张量,torch.matmul函数:
首先,在1D张量a的前面插入一个长度为1的新维度变成2D张量;
然后,在满足第一个2D张量(矩阵)的列数(column)和第二个2D张量(矩阵)的行数(row)相同的条件下,两个2D张量矩阵乘积,否则会抛出错误;
最后,将矩阵乘积结果中长度为1的维度(前面插入的长度为1的新维度)删除作为最终torch.matmul函数返回的结果。
import torch a = torch.tensor([1, 2]) b = torch.tensor([[5, 6, 7],[8, 9, 10]]) result = torch.matmul(a, b) print(result, result.shape)
结果为:
(PyTorch) D:\Code Project>D:/Anaconda/envs/PyTorch/python.exe "d:/Code Project/demo.py" tensor([21, 24, 27]) torch.Size([3])
- 若a为2D张量,b为1D张量,torch.matmul函数:
首先,在1D张量b的后面插入一个长度为1的新维度变成2D张量;
然后,在满足第一个2D张量(矩阵)的列数(column)和第二个2D张量(矩阵)的行数(row)相同的条件下,两个2D张量矩阵乘积,否则会抛出错误;
最后,将矩阵乘积结果中长度为1的维度(后面插入的长度为1的新维度)删除作为最终torch.matmul函数返回的结果;
import torch b = torch.tensor([1, 2, 3]) a = torch.tensor([[5, 6, 7],[8, 9, 10]]) result = torch.matmul(a, b) print(result, result.shape)
结果展示为:
(PyTorch) D:\Code Project>D:/Anaconda/envs/PyTorch/python.exe "d:/Code Project/demo.py" tensor([38, 56]) torch.Size([2])
其中:
38 = 15+26+3*7
56 = 18+29+3*10
免责声明:本站所有文章内容,图片,视频等均是来源于用户投稿和互联网及文摘转载整编而成,不代表本站观点,不承担相关法律责任。其著作权各归其原作者或其出版社所有。如发现本站有涉嫌抄袭侵权/违法违规的内容,侵犯到您的权益,请在线联系站长,一经查实,本站将立刻删除。 本文来自网络,若有侵权,请联系删除,如若转载,请注明出处:https://haidsoft.com/111314.html