pytorch 使用多掩码时, Torch Tensor分配不起作用

lb3vh1jj  于 2024-01-09  发布在  其他
关注(0)|答案(2)|浏览(173)

我有Tensor和掩码,还有第二个掩码。
现在我想改变tensor[mask][second_mask]的值,但它不起作用。
我认为这是因为tensor[mask]返回一个新的tensor,而不是原始tensor的视图,并且将值应用于tensor[mask][second_mask]不会改变原始tensor的值。我的演示如下:

import torch

x = torch.linspace(1,9,9).reshape((3,3))
mask = x>5
second_mask = 0 # In practice, it will be a booltensor
x[mask][second_mask] = 100
print(x)

# One way to solve it is use a temp tensor be like:
t = x[mask]
t[second_mask] = 100
x[mask] = t

# but it is sooooo long, hoping for any convenient method

字符串

oknwwptz

oknwwptz1#

您可以使用索引而不是掩码:

import torch

x = torch.linspace(1,9,9).reshape((3,3))
mask_idx = torch.where(x > 5) # tuple of 1D tensors for x,y coordinates
second_mask = 0 # this is index in my understanding instead of mask, but I follow the given naming convention

# make final_mask_idx using x, y coordinates and second_mask
final_mask_idx = (mask_idx[0][second_mask], mask_idx[1][second_mask])

x[final_mask_idx] = 100

字符串

2ledvvac

2ledvvac2#

选择操作返回的不是视图,这是真的。
因为你的第二个掩码只是索引0,而不是真正的掩码,你总是试图更新第一个正掩码值吗?如果是这样,对于这种情况,你可以这样做:

>>> x = torch.linspace(1,9,9).reshape((3,3))
>>> x.flatten()[ mask.to(torch.int).argmax() ] = 100
>>> x
tensor([[  1.,   2.,   3.],
        [  4.,   5., 100.],
        [  7.,   8.,   9.]])

字符串
当所有的max值都相同时,argmax返回第一个max值。不幸的是,它没有为bool实现,因此是to(torch.int)

相关问题