大家好,欢迎来到IT知识分享网。
本文将系统讲解 PyTorch 中张量的乘法运算(二维与高维)和广播机制(Broadcasting)
1 二维张量运算(矩阵乘法)
在 PyTorch 中常见的二维张量乘法相当于矩阵乘法。例如:
运算规则
X.shape = [4, 3] Y.T.shape = [3, 4](注意需要转置) 输出形状为 [4, 4],即结果矩阵的大小为 X的行数 × Y的行数
示例
计算 X @ Y.T
[1,3,5] × [6,3,2] = 1×6 + 3×3 + 5×2 = 6 + 9 + 10 = 25
2 高维张量运算(Batch 矩阵乘法)
在实际神经网络中,我们往往需要对多个矩阵进行批量乘法,例如:
原理解释
等价于:
每组进行 [3,2] × [2,4] 的矩阵乘法,最终得到 3 个 [3,4] 的矩阵,结果为 [3, 3, 4]
3 广播机制(Broadcasting)
什么是广播?
广播机制是一种自动扩展张量维度以匹配运算的技术,无需复制数据,而是通过广播规则实现隐式对齐。
广播规则:
- 从右向左对齐维度
- 两个维度相等,或其中一个为 1,才允许广播
- 否则报错
示例:
- a.shape = [3, 3, 2]
- b.shape = [2, 4] → 自动变成 [1, 2, 4] → 广播成 [3, 2, 4]
- 执行 [3,3,2] @ [3,2,4] = [3,3,4]
实战中的广播使用场景
加偏置(bias)
权重共享
自定义 loss 中对齐
Attention mask 运算
广播机制的意义
优势 |
描述 |
简洁代码 |
减少 .unsqueeze()、.expand() 的冗余操作 |
高效计算 |
无需实际复制,内部使用视图节省内存 |
批量处理 |
便于神经网络批量计算的构建 |
数学一致性 |
保证标量、向量、矩阵混合运算自然成立 |
免责声明:本站所有文章内容,图片,视频等均是来源于用户投稿和互联网及文摘转载整编而成,不代表本站观点,不承担相关法律责任。其著作权各归其原作者或其出版社所有。如发现本站有涉嫌抄袭侵权/违法违规的内容,侵犯到您的权益,请在线联系站长,一经查实,本站将立刻删除。 本文来自网络,若有侵权,请联系删除,如若转载,请注明出处:https://haidsoft.com/179962.html