Tensorflow:沿着Tensor轴的最大值的二进制掩码

iqxoj9l9  于 2023-01-13  发布在  其他
关注(0)|答案(1)|浏览(137)

如果我有一个N维Tensor,我想创建另一个值为0和1的Tensor(具有相同的形状),其中1在某个维度上与原始Tensor中的最大元素处于相同的位置。
我的一个约束是,我只想得到第一个最大元素沿该轴,以防有重复。
为了简化,我将使用较少的尺寸。

>>> x = tf.constant([[7, 2, 3], 
                     [5, 0, 1], 
                     [3, 8, 2]], dtype=tf.float32)

>>> tf.reduce_max(x, axis=-1)
tf.Tensor([7. 5. 8.], shape=(3,), dtype=float32)

我想要的是:

tf.Tensor([1. 0. 0.], 
          [1. 0. 0.],
          [0. 1. 0.], shape=(3,3), dtype=float32)

我尝试过(并意识到是错误的):

>>> tf.cast(tf.equal(x, tf.reduce_max(x, axis=-1, keepdims=True)), dtype=tf.float32)

# works fine when there are no duplicates
tf.Tensor([[1. 0. 0.]
           [1. 0. 0.]
           [0. 1. 0.]], shape=(3, 3), dtype=float32)

>>> y = tf.zeros([3,3])
>>> tf.cast(tf.equal(y, tf.reduce_max(y, axis=-1, keepdims=True)), dtype=tf.float32)

# fails when there are multiple identical values across dimension
tf.Tensor([[1. 1. 1.]
           [1. 1. 1.]
           [1. 1. 1.]], shape=(3, 3), dtype=float32)
    • 编辑:已解决**
tf.cast(tf.equal(tf.argsort(tf.argsort(x, 1, direction='DESCENDING'), 1), 0), tf.float32)
myss37ts

myss37ts1#

你可以使用double tf.argsort()来获得元素在轴1上的排名顺序,并得到最大排名,这将把最大值的last示例作为最高排名。

x = tf.constant([[7, 2, 3],  #max is 7
                 [5, 0, 5],  #max is 5 but duplicate in same row
                 [7, 8, 7]]) #max is 8 but shares 7 with first row too

tf.cast(tf.equal(tf.argsort(tf.argsort(x, 1), 1), x.shape[0]-1), tf.int64)
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[1, 0, 0],
       [0, 0, 1],
       [0, 1, 0]], dtype=int32)>

相关问题