如何在PyTorch中以特定的新维度重复Tensor

xlpyo6sf  于 2022-11-09  发布在  其他
关注(0)|答案(5)|浏览(530)

如果我有一个形状为[M, N]的TensorA,我想重复TensorK次,这样结果B的形状为[M, K, N],并且每个切片B[:, k, :]应该与A具有相同的数据。这是没有for循环的最佳实践。K可能在其他维度上。
torch.repeat_interleave()tensor.repeat()似乎不起作用。或者我用错了方法。

taor4pac

taor4pac1#

tensor.repeat应该可以满足您的需要,但是您需要先插入一个单位维。为此,我们可以使用tensor.unsqueezetensor.reshape。由于unsqueeze是专门定义为插入一个单位维的,因此我们将使用它。

B = A.unsqueeze(1).repeat(1, K, 1)
  • 代码描述 * A.unsqueeze(1)A[M, N]转换为[M, 1, N],并且.repeat(1, K, 1)沿着第二维度重复TensorK多次。
zy1mlcev

zy1mlcev2#

Einops提供重复功能

import einops
einops.repeat(x, 'm n -> m k n', k=K)

repeat可以按任意顺序添加任意数量的轴,并同时对现有轴进行重新排序。

gab6jxml

gab6jxml3#

添加到由@Alleo提供的答案。您可以使用以下Einops函数。

einops.repeat(example_tensor, 'b h w -> (repeat b) h w', repeat=b)

其中,b是Tensor重复的次数,hw是Tensor的附加维数。
示例-

example_tensor.shape -> torch.Size([1, 40, 50]) 
repeated_tensor = einops.repeat(example_tensor, 'b h w -> (repeat b) h w', repeat=8)
repeated_tensor.shape -> torch.Size([8, 40, 50])

此处提供更多示例-https://einops.rocks/api/repeat/

92vpleto

92vpleto4#

重复的值占用大量内存,在大多数情况下,最好的做法是使用广播。因此,您可以使用A[:, None, :],这将使A.shape==(M, 1, N)
我同意重复这些值的一个例子是在下面的步骤中的原地操作。由于numpy和torch在实现上的不同,我喜欢不可知的(A * torch.ones(K, 1, 1)))后面跟着一个转置。

cetgtptt

cetgtptt5#

tensor.expand可能是比tensor.repeat更好的选择,因为根据以下公式:* 扩展Tensor不会分配新内存,而只是在现有Tensor上创建新视图,其中通过将步幅设置为0,将大小为1的维度扩展为更大的大小。*
但是,请注意:* 扩展Tensor得多个元素可能引用单个内存位置.因此,就地操作(尤其是矢量化得操作)可能导致不正确得行为.如果需要写入Tensor,请先克隆它们.*

M = N = K = 3
A = torch.arange(0, M * N).reshape((M, N))
B = A.unsqueeze(1).expand(M, K, N)
B

'''
tensor([[[0, 1, 2],
         [0, 1, 2],
         [0, 1, 2]],

        [[3, 4, 5],
         [3, 4, 5],
         [3, 4, 5]],

        [[6, 7, 8],
         [6, 7, 8],
         [6, 7, 8]]])
'''

相关问题