我正在尝试实现一个流形对齐类型的损失说明here。
给定Tensorembs
tensor([[ 0.0178, 0.0004, -0.0217, ..., -0.0724, 0.0698, -0.0180],
[ 0.0160, 0.0002, -0.0217, ..., -0.0725, 0.0655, -0.0207],
[ 0.0155, -0.0010, -0.0153, ..., -0.0750, 0.0688, -0.0253],
...,
[ 0.0130, -0.0113, -0.0078, ..., -0.0805, 0.0634, -0.0241],
[ 0.0120, -0.0047, -0.0135, ..., -0.0846, 0.0722, -0.0230],
[ 0.0120, -0.0048, -0.0142, ..., -0.0843, 0.0734, -0.0246]],
grad_fn=<AddmmBackward0>)
形状(256,64)
是由网络生成的一批嵌入,我想计算行条目之间的所有成对距离。我尝试使用torch.nn.PairwiseDistance
,但我不清楚它是否对我所寻找的有用。
1条答案
按热度按时间mf98qq941#
觉得奇怪的是,没有。有,它被称为torch.cdist,但它是“隐藏”在顶级。