Pytorch等价于Pandas groupby.apply(list)的是什么?

cgvd09ve  于 2022-11-27  发布在  其他
关注(0)|答案(1)|浏览(138)

我有下面的pytorchTensorlong_format

tensor([[ 1.,  1.],
        [ 1.,  2.],
        [ 1.,  3.],
        [ 1.,  4.],
        [ 0.,  5.],
        [ 0.,  6.],
        [ 0.,  7.],
        [ 1.,  8.],
        [ 0.,  9.],
        [ 0., 10.]])

我想将第一列归为groupby,并将第二列存储为Tensor。不保证每个分组的结果大小相同。请参见下面的示例。

[tensor([ 1., 2., 3., 4., 8.]),
 tensor([ 5.,  6., 7., 9., 10.])]

有没有什么好的方法可以使用纯Pytorch操作符来完成这个任务?我希望避免使用for循环来实现可跟踪性。
我尝试过使用for循环和空Tensor的空列表,但这会导致不正确的跟踪(不同的输入值给出相同的结果)

n_groups = 2
inverted = [torch.empty([0]) for _ in range(n_groups)]
for index, value in long_format:
   value = value.unsqueeze(dim=0)
   index = index.int()
   if type(inverted[index]) != torch.Tensor:
      inverted[index] = value
   else:
      inverted[index] = torch.cat((inverted[index], value))
trnvg8h3

trnvg8h31#

您可以使用以下代码:

import torch
x = torch.tensor([[ 1.,  1.],
        [ 1.,  2.],
        [ 1.,  3.],
        [ 1.,  4.],
        [ 0.,  5.],
        [ 0.,  6.],
        [ 0.,  7.],
        [ 1.,  8.],
        [ 0.,  9.],
        [ 0., 10.]])

result =  [x[x[:,0]==i][:,1] for i in x[:,0].unique()]

输出

[tensor([ 5.,  6.,  7.,  9., 10.]), tensor([1., 2., 3., 4., 8.])]

相关问题