pytorch中的groupby聚集平均值

5uzkadbs  于 2022-11-09  发布在  其他
关注(0)|答案(4)|浏览(337)

我有一个二维Tensor:

samples = torch.Tensor([
    [0.1, 0.1],    #-> group / class 1
    [0.2, 0.2],    #-> group / class 2
    [0.4, 0.4],    #-> group / class 2
    [0.0, 0.0]     #-> group / class 0
])

以及对应于类的每个样本的标签:

labels = torch.LongTensor([1, 2, 2, 0])

因此len(samples) == len(labels)。现在我想计算每个类/标签的平均值。因为有3个类(0、1和2),所以最终向量的维数应为[n_classes, samples.shape[1]]。因此,预期解应为:

result == torch.Tensor([
    [0.1, 0.1],
    [0.3, 0.3], # -> mean of [0.2, 0.2] and [0.4, 0.4]
    [0.0, 0.0]
])

问题:如何在没有for循环的情况下,在纯pytorch(即没有numpy,这样我就可以autograd)中完成这一操作?

neskvpey

neskvpey1#

你需要做的就是构造一个mxn矩阵(m=num个类,n=num个样本),它将选择合适的权重,并适当地缩放均值。然后你可以在新构造的矩阵和样本矩阵之间执行矩阵乘法。
给定您的标签,您的矩阵应为(每行是一个类编号,每个类是一个样本编号及其权重):

[[0.0000, 0.0000, 0.0000, 1.0000],
 [1.0000, 0.0000, 0.0000, 0.0000],
 [0.0000, 0.5000, 0.5000, 0.0000]]

可以按如下方式形成:

M = torch.zeros(labels.max()+1, len(samples))
M[labels, torch.arange(len(samples)] = 1
M = torch.nn.functional.normalize(M, p=1, dim=1)
torch.mm(M, samples)

输出量:

tensor([[0.0000, 0.0000],
        [0.1000, 0.1000],
        [0.3000, 0.3000]])

请注意,输出均值按类顺序正确排序。
为什么M[labels, torch.arange(len(samples))] = 1可以工作?
这是在标签和样本数之间执行广播操作。本质上,我们为标签中的每个元素生成一个二维索引:第一种方法指定它属于m个类中的哪一个,第二种方法简单地指定它的索引位置(从1到N)。另一种方法是显式地生成所有2D索引:

twoD_indices = []
for count, label in enumerate(labels):
  twoD_indices.append((label, count))
p8ekf7hl

p8ekf7hl2#

在此重新发布来自Pytorch forums中@ptrblck_de的答案

labels = labels.view(labels.size(0), 1).expand(-1, samples.size(1))

unique_labels, labels_count = labels.unique(dim=0, return_counts=True)

res = torch.zeros_like(unique_labels, dtype=torch.float).scatter_add_(0, labels, samples)
res = res / labels_count.float().unsqueeze(1)
0md85ypi

0md85ypi3#

由于以前的解决方案不适用于稀疏组的情况(例如,不是所有的组都在数据中),因此我做了一个:)

def groupby_mean(value:torch.Tensor, labels:torch.LongTensor) -> (torch.Tensor, torch.LongTensor):
    """Group-wise average for (sparse) grouped tensors

    Args:
        value (torch.Tensor): values to average (# samples, latent dimension)
        labels (torch.LongTensor): labels for embedding parameters (# samples,)

    Returns: 
        result (torch.Tensor): (# unique labels, latent dimension)
        new_labels (torch.LongTensor): (# unique labels,)

    Examples:
        >>> samples = torch.Tensor([
                             [0.15, 0.15, 0.15],    #-> group / class 1
                             [0.2, 0.2, 0.2],    #-> group / class 3
                             [0.4, 0.4, 0.4],    #-> group / class 3
                             [0.0, 0.0, 0.0]     #-> group / class 0
                      ])
        >>> labels = torch.LongTensor([1, 5, 5, 0])
        >>> result, new_labels = groupby_mean(samples, labels)

        >>> result
        tensor([[0.0000, 0.0000, 0.0000],
            [0.1500, 0.1500, 0.1500],
            [0.3000, 0.3000, 0.3000]])

        >>> new_labels
        tensor([0, 1, 5])
    """
    uniques = labels.unique().tolist()
    labels = labels.tolist()

    key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
    val_key = {val: key for key, val in zip(uniques, range(len(uniques)))}

    labels = torch.LongTensor(list(map(key_val.get, labels)))

    labels = labels.view(labels.size(0), 1).expand(-1, value.size(1))

    unique_labels, labels_count = labels.unique(dim=0, return_counts=True)
    result = torch.zeros_like(unique_labels, dtype=torch.float).scatter_add_(0, labels, value)
    result = result / labels_count.float().unsqueeze(1)
    new_labels = torch.LongTensor(list(map(val_key.get, unique_labels[:, 0].tolist())))
    return result, new_labels
mgdq6dx1

mgdq6dx14#

对于3DTensor:

对于那些感兴趣的人。我将@yhenon的答案扩展到了这样的情况,其中labels是一个2DTensor,samples是一个3DTensor。如果你想批量执行这个操作(就像我一样),这可能会很有用。但是它附带了一个警告(见最后)。
第一个
输出量:

>>> result
tensor([[[0.0000, 0.0000],
         [0.1000, 0.1000],
         [0.3000, 0.3000],
         [0.0000, 0.0000]],

        [[0.5000, 0.5000],
         [0.2000, 0.2000],
         [0.4000, 0.4000],
         [0.1000, 0.1000]]])

**注意:**现在,result[0]的长度为4(而不是@yhenon的答案中的3),因为labels[1]包含3。最后一行只包含0。如果你不排除结果Tensor最后一行中的0,你可以使用这段代码,并在以后处理0。

相关问题