pytorch 计算'torch.tensor`中条目之间的成对距离

igsr9ssn  于 2023-02-04  发布在  其他
关注(0)|答案(1)|浏览(154)

我正在尝试实现一个流形对齐类型的损失说明here
给定Tensorembs

tensor([[ 0.0178,  0.0004, -0.0217,  ..., -0.0724,  0.0698, -0.0180],
        [ 0.0160,  0.0002, -0.0217,  ..., -0.0725,  0.0655, -0.0207],
        [ 0.0155, -0.0010, -0.0153,  ..., -0.0750,  0.0688, -0.0253],
        ...,
        [ 0.0130, -0.0113, -0.0078,  ..., -0.0805,  0.0634, -0.0241],
        [ 0.0120, -0.0047, -0.0135,  ..., -0.0846,  0.0722, -0.0230],
        [ 0.0120, -0.0048, -0.0142,  ..., -0.0843,  0.0734, -0.0246]],
       grad_fn=<AddmmBackward0>)

形状(256,64)是由网络生成的一批嵌入,我想计算行条目之间的所有成对距离。我尝试使用torch.nn.PairwiseDistance,但我不清楚它是否对我所寻找的有用。

mf98qq94

mf98qq941#

觉得奇怪的是,没有。有,它被称为torch.cdist,但它是“隐藏”在顶级。

>>> a = torch.rand((5,3))
>>> a
tensor([[0.0215, 0.0843, 0.3414],
        [0.9878, 0.5835, 0.3052],
        [0.0903, 0.7347, 0.0711],
        [0.9774, 0.8202, 0.7721],
        [0.7877, 0.9891, 0.4619]])
>>> torch.cdist(a,a)
tensor([[0.0000, 1.0883, 0.7077, 1.2809, 1.1918],
        [1.0883, 0.0000, 0.9398, 0.5236, 0.4787],
        [0.7077, 0.9398, 0.0000, 1.1339, 0.8390],
        [1.2809, 0.5236, 1.1339, 0.0000, 0.4010],
        [1.1918, 0.4787, 0.8390, 0.4010, 0.0000]])
>>> torch.nn.functional.pairwise_distance(a[0], a[2])
tensor(0.7077)

相关问题