我想在2D数组上的PyTorchTensor上做一些类似于np.clip的事情。更具体地说,我想在特定的值范围内裁剪每一列(取决于列)。例如,在numpy中,你可以这样做:
x = np.array([-1,10,3])
low = np.array([0,0,1])
high = np.array([2,5,4])
clipped_x = np.clip(x, low, high)
clipped_x == np.array([0,5,3]) # True
我找到了torch.clamp,但不幸的是它不支持多维边界(整个Tensor只有一个标量值)。有没有一种“整洁”的方法可以将该函数扩展到我的情况?
谢谢!
2条答案
按热度按时间dddzy1tm1#
不像
np.clip
那样简洁,但可以使用torch.max
和torch.min
:设置每列的下限和上限
请注意,下限
l
和上限u
是1 × 3Tensor(具有单例维度的2D)。我们需要l
和u
的这些维度可以广播到x
的形状。现在我们可以使用
min
和max
进行剪辑:导致
tzdcorbm2#
对于任何人来说,谁是有同样的问题,像我几分钟前:
在大约两年的时间里,torch.clamp中还可以有列相关的边界(请参见PR):
设置下限和上限:
现在你可以简单地使用
torch.clamp
如下:我希望这能有所帮助:)