pytorch torch.stack()函数和torch.cat()函数有什么区别?

ih99xse1  于 2022-11-29  发布在  其他
关注(0)|答案(4)|浏览(267)

OpenAI的REINFORCE和行动者-批评者强化学习的例子有以下代码:
REINFORCE

policy_loss = torch.cat(policy_loss).sum()

actor-critic

loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum()

一个使用torch.cat,另一个使用torch.stack,用于类似的使用情形。
据我所知,医生没有给他们任何明确的区别。

我很乐意了解这些函数之间的差异。

ghhkc1vu

ghhkc1vu1#

stack
沿新维度连接Tensor序列。
cat
在给定维度中连接给定的seqTensor序列
因此,如果AB的形状为(3,4):

  • torch.cat([A, B], dim=0)的形状为(6,4)
  • torch.stack([A, B], dim=0)的形状为(2,3,4)
a14dhokn

a14dhokn2#

t1 = torch.tensor([[1, 2],
                   [3, 4]])

t2 = torch.tensor([[5, 6],
                   [7, 8]])

| torch.stack|torch.cat|
| - -|- -|
| * *'Stacks'沿着新维度的Tensor序列:

|'Con
cat**enates'沿着现有维度的Tensor序列:

|
这些函数类似于numpy.stacknumpy.concatenate

h6my8fg2

h6my8fg23#

最初的答案缺少一个自包含的好例子,所以这里是:

import torch

# stack vs cat

# cat "extends" a list in the given dimension e.g. adds more rows or columns

x = torch.randn(2, 3)
print(f'{x.size()}')

# add more rows (thus increasing the dimensionality of the column space to 2 -> 6)
xnew_from_cat = torch.cat((x, x, x), 0)
print(f'{xnew_from_cat.size()}')

# add more columns (thus increasing the dimensionality of the row space to 3 -> 9)
xnew_from_cat = torch.cat((x, x, x), 1)
print(f'{xnew_from_cat.size()}')

print()

# stack serves the same role as append in lists. i.e. it doesn't change the original
# vector space but instead adds a new index to the new tensor, so you retain the ability
# get the original tensor you added to the list by indexing in the new dimension
xnew_from_stack = torch.stack((x, x, x, x), 0)
print(f'{xnew_from_stack.size()}')

xnew_from_stack = torch.stack((x, x, x, x), 1)
print(f'{xnew_from_stack.size()}')

xnew_from_stack = torch.stack((x, x, x, x), 2)
print(f'{xnew_from_stack.size()}')

# default appends at the from
xnew_from_stack = torch.stack((x, x, x, x))
print(f'{xnew_from_stack.size()}')

print('I like to think of xnew_from_stack as a \"tensor list\" that you can pop from the front')

输出:

torch.Size([2, 3])
torch.Size([6, 3])
torch.Size([2, 9])
torch.Size([4, 2, 3])
torch.Size([2, 4, 3])
torch.Size([2, 3, 4])
torch.Size([4, 2, 3])
I like to think of xnew_from_stack as a "tensor list"

以下定义仅供参考:
cat:连接指定维度中的指定序列Tensor序列。结果是特定维度的大小会变更,例如dim=0,则您会将元素新增至数据列,以增加数据栏空间的维度。
堆栈:沿着一个新的维度连接Tensor序列。我喜欢把它看作 Torch “附加”操作,因为你可以通过从前面“弹出”来索引/得到你的原始Tensor。它没有参数,把Tensor附加到Tensor的前面。
相关:

  • 这里是pytorch论坛的链接,上面有关于这方面的讨论:我希望tensor.torch能把一个嵌套的Tensor列表转换成一个大的Tensor,这个Tensor有很多维,并且考虑到了嵌套列表的深度。

Update:使用相同大小的嵌套列表

def tensorify(lst):
    """
    List must be nested list of tensors (with no varying lengths within a dimension).
    Nested list of nested lengths [D1, D2, ... DN] -> tensor([D1, D2, ..., DN)

    :return: nested list D
    """
    # base case, if the current list is not nested anymore, make it into tensor
    if type(lst[0]) != list:
        if type(lst) == torch.Tensor:
            return lst
        elif type(lst[0]) == torch.Tensor:
            return torch.stack(lst, dim=0)
        else:  # if the elements of lst are floats or something like that
            return torch.tensor(lst)
    current_dimension_i = len(lst)
    for d_i in range(current_dimension_i):
        tensor = tensorify(lst[d_i])
        lst[d_i] = tensor
    # end of loop lst[d_i] = tensor([D_i, ... D_0])
    tensor_lst = torch.stack(lst, dim=0)
    return tensor_lst

这里有一些单元测试(我没有写更多的测试,但它与我的真实代码一起工作,所以我相信它是好的。如果你愿意,请随时帮助我添加更多的测试):

def test_tensorify():
    t = [1, 2, 3]
    print(tensorify(t).size())
    tt = [t, t, t]
    print(tensorify(tt))
    ttt = [tt, tt, tt]
    print(tensorify(ttt))

if __name__ == '__main__':
    test_tensorify()
    print('Done\a')
4si2a6ki

4si2a6ki4#

如果有人在研究它的性能方面,我做了一个小实验,在我的例子中,我需要把一个标量Tensor列表转换成一个Tensor。

import torch
torch.__version__ # 1.10.2
x = [torch.randn(1) for _ in range(10000)]
torch.cat(x).shape, torch.stack(x).shape # torch.Size([10000]), torch.Size([10000, 1])

%timeit torch.cat(x) # 1.5 ms ± 476 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit torch.cat(x).reshape(-1,1) # 1.95 ms ± 534 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit torch.stack(x) # 5.36 ms ± 643 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

我的结论是,即使你想有torch.stack的额外维度,使用torch.cat,然后reshape是更好的。

**注:**此帖子摘自PyTorch论坛(我是original post的作者)

相关问题