Pytorch的“折叠”和“展开”是如何工作的?

pwuypxnk  于 2023-01-26  发布在  其他
关注(0)|答案(4)|浏览(196)

我已经看过了official doc。我很难理解这个函数是用来做什么的,它是如何工作的。有人能用外行的话解释一下吗?

5lwkijsr

5lwkijsr1#

unfold将Tensor想象为具有重复的值的列/行的较长Tensor,这些值"折叠"在彼此的顶部,然后"展开":

  • size确定折叠的大小
  • step确定折叠的频率

例如,对于2x5Tensor,用step=1展开它,并在dim=1上修补size=2

x = torch.tensor([[1,2,3,4,5],
                  [6,7,8,9,10]])
>>> x.unfold(1,2,1)
tensor([[[ 1,  2], [ 2,  3], [ 3,  4], [ 4,  5]],
        [[ 6,  7], [ 7,  8], [ 8,  9], [ 9, 10]]])

fold大致与此操作相反,但在输出中对"重叠"值求和。

s5a0g9ez

s5a0g9ez2#

unfoldfold用于简化“滑动窗口”操作(如卷积)。假设您要将函数foo应用于特征Map/图像中的每个5x5窗口:

from torch.nn import functional as f
windows = f.unfold(x, kernel_size=5)

现在windows有了批处理的size-(55x.size(1))-num_windows,您可以在windows上应用foo

processed = foo(windows)

现在您需要将processed“折叠”回x的原始大小:

out = f.fold(processed, x.shape[-2:], kernel_size=5)

您需要注意paddingkernel_size,它们可能会影响您将processed“折叠”回x大小的能力。此外,fold * 对重叠元素求和 *,因此您可能需要将fold的输出除以补丁大小。
请注意,torch.unfold执行的操作与nn.Unfold不同。有关详细信息,请参见this thread

6gpjuf90

6gpjuf903#

一维展开很容易:

x = torch.arange(1, 9).float()
print(x)
# dimension, size, step
print(x.unfold(0, 2, 1))
print(x.unfold(0, 3, 2))

输出:

tensor([1., 2., 3., 4., 5., 6., 7., 8.])
tensor([[1., 2.],
        [2., 3.],
        [3., 4.],
        [4., 5.],
        [5., 6.],
        [6., 7.],
        [7., 8.]])
tensor([[1., 2., 3.],
        [3., 4., 5.],
        [5., 6., 7.]])

二维展开(也称为 * 修补 *)

一个二个一个一个

mcdcgff0

mcdcgff04#

由于4-DTensor没有答案,并且nn.functional. unflown()只接受4-DTensor,我将对此进行解释。
假设输入Tensor的形状为(batch_size, channels, height, width),我举了一个例子,其中batch_size = 1, channels = 2, height = 3, width = 3

kernel_size = 2,它只是一个2x2内核

相关问题