pytorch 焊炬夹具中的列相关界限

eqqqjvef  于 2023-03-23  发布在  其他
关注(0)|答案(2)|浏览(136)

我想在2D数组上的PyTorchTensor上做一些类似于np.clip的事情。更具体地说,我想在特定的值范围内裁剪每一列(取决于列)。例如,在numpy中,你可以这样做:

x = np.array([-1,10,3])
low = np.array([0,0,1])
high = np.array([2,5,4])
clipped_x = np.clip(x, low, high)

clipped_x == np.array([0,5,3]) # True

我找到了torch.clamp,但不幸的是它不支持多维边界(整个Tensor只有一个标量值)。有没有一种“整洁”的方法可以将该函数扩展到我的情况?
谢谢!

dddzy1tm

dddzy1tm1#

不像np.clip那样简洁,但可以使用torch.maxtorch.min

In [1]: x
Out[1]:
tensor([[0.9752, 0.5587, 0.0972],
        [0.9534, 0.2731, 0.6953]])

设置每列的下限和上限

l = torch.tensor([[0.2, 0.3, 0.]])
u = torch.tensor([[0.8, 1., 0.65]])

请注意,下限l和上限u是1 × 3Tensor(具有单例维度的2D)。我们需要lu的这些维度可以广播到x的形状。
现在我们可以使用minmax进行剪辑:

clipped_x = torch.max(torch.min(x, u), l)

导致

tensor([[0.8000, 0.5587, 0.0972],
        [0.8000, 0.3000, 0.6500]])
tzdcorbm

tzdcorbm2#

对于任何人来说,谁是有同样的问题,像我几分钟前:
在大约两年的时间里,torch.clamp中还可以有列相关的边界(请参见PR):

In: x = torch.randn(2, 3)
        print(x)

Out: tensor([[-0.2069, 1.4082, 0.2615],
             [0.6478, 0.0883, -0.7795]])

设置下限和上限:

lower = torch.Tensor([[-1., 0., 0.]])
upper = torch.Tensor([[0., 1., 1.]])

现在你可以简单地使用torch.clamp如下:

In: clamped_x = torch.clamp(x, min=lower, max=upper)
    print(clamped_x)

Out: tensor([[-0.2069, 1.0000, 0.2615],
             [0.0000, 0.0883, 0.0000]])

我希望这能有所帮助:)

相关问题