pytorch 不同形状的Tensor是如何组合的?

xurqigkl  于 2023-10-20  发布在  其他
关注(0)|答案(2)|浏览(138)

多维Tensor在什么逻辑中(例如,在Pytorch中)组合?假设TensorA的形状为(64,1,1,42),TensorB的形状为(1,42,42)。A & B的结果是什么?我如何提前确定结果形状(如果可能的话)?

r6vfmomb

r6vfmomb1#

PyTorchTensor根据PyTorch's broacasting semantics进行组合,这在很大程度上是从NumPy's broadcasting rules借用的。
PyTorch将这些规则总结如下:
如果以下规则成立,则两个Tensor是“可广播的”:

  • 每个Tensor至少有一个维度。
  • 从尾随维度开始循环访问维度大小时,维度大小必须相等、其中一个为1或其中一个不存在。

如果两个Tensorx,y是“可广播的”,则结果Tensor大小计算如下:

  • 如果x和y的维数不相等,则在维数较少的Tensor的维数前加上1,使它们的长度相等。
  • 然后,对于每个维度大小,结果维度大小是沿该维度的x和y沿着的大小的最大值。

在形状(64,1,1,42)和(1,42,42)的情况下,广播通过对齐它们的尾部尺寸来进行:

64,  1,   1,  42
     1,  42,  42

前置1维以使长度匹配:

64,  1,   1,  42
1    1,  42,  42

然后取其中一个维度为1的最大维度:

64,  1,  42,  42
p8h8hvxi

p8h8hvxi2#

不同形状的Tensor的组合通常是通过一种称为广播的机制来完成的。广播是NumPy和PyTorch等库中的一个强大工具,它有助于扩展Tensor的大小,而无需实际复制数据,从而在不同形状的Tensor之间执行操作。NumPy和PyTorch之间的广播规则基本一致。
广播规则:如果Tensor的维数不匹配,则在维数较少的Tensor的维数前添加1。
比较每个维度的大小:
如果大小匹配或其中一个大小为1,则该维度可以进行广播。如果两个大小都不为1,并且大小不匹配,则广播失败。在成功广播之后,每个Tensor的行为就好像它的形状是两个输入Tensor的形状的元素最大值。
在任何维度中,一个Tensor的大小为1,另一个Tensor的大小大于1,较小的Tensor表现得好像它已经扩展到匹配较大Tensor的大小。举例来说:
给出:
形状(64,1,1,42)的TensorA形状(1,42,42)的TensorB
要合并,请执行以下操作:
使维数相等:
TensorA:(64,1,1,42)TensorB:(1,1,42,42)比较尺寸:
维数是(64,1,1,42)和(1,1,42,42)每个维数要么相同,要么其中一个是1,所以广播是可能的。生成的Tensor将具有以下形状:(64,1,42,42)

import torch

# Creating example tensors
A = torch.rand((64, 1, 1, 42))
B = torch.rand((1, 42, 42))

# Broadcasting operation
result = A * B

# Outputting the resulting shape
print(result.shape)

这导致

torch.Size([64, 1, 42, 42])

果然
编辑:要解决这个答案的第一个评论:
谢谢!所以每个维度必须是一个或两个Tensor的维度必须匹配。如果前面提到的检查失败但重新排列会有帮助,是否也默认进行重新排列?
正确,在PyTorch中使用广播组合两个Tensor。两个Tensor的每个维度要么是1,要么它们匹配。PyTorch从尾部维度(最后一个)开始比较维度,并向后工作。
如果不满足广播条件,则会出现运行时错误。
但是,如果广播条件失败,您可以重新排列或重新塑造Tensor以使它们兼容:

添加单例维度:可以使用unsqueeze方法或numpy样式的None索引。例如,如果你有一个形状为(5,4)的Tensor,你希望它有一个形状(5,1,4),你可以添加一个单例维度。

tensor = tensor.unsqueeze(1) 
# or
tensor = tensor[:, None, :]

移除单例维度:您可以使用挤压方法来移除单一维度。例如,将Tensor从形状(5,1,4)更改为(5,4)。

tensor = tensor.squeeze(1)

**Reshape:**如果添加或删除单例维度还不够,您可能需要使用reshape方法完全重塑Tensor。

tensor = tensor.reshape(new_shape)

**置换维度:**如果是维度的顺序问题,可以使用置换方法重新排列。

tensor = tensor.permute(dimension_order)

在重新排列或重塑Tensor时,了解数据的语义至关重要。通过改变形状,您可能会改变Tensor的含义或其值之间的关系。始终确保这些操作在您的问题的上下文中是有意义的。

相关问题