如何在tensorflow中创建掩码Tensor?

ryhaxcpt  于 2023-10-23  发布在  其他
关注(0)|答案(2)|浏览(122)

为给定索引创建具有一定深度的Tensor,值为1
指数Tensor:

[[1 3]
 [2 4]
 [0 4]]

输出Tensor(深度=5):

[[0. 1. 0. 1. 0.]
 [0. 0. 1. 0. 1.]
 [1. 0. 0. 0. 1.]]
gab6jxml

gab6jxml1#

您可以通过首先转换为完整索引,然后使用sparse_to_dense函数将这些索引值设置为1来实现上述目标。

#Get full indices
mesh = tf.meshgrid(tf.range(indices.shape[1]), tf.range(indices.shape[0]))[1]
full_indices = tf.reshape(tf.stack([mesh, indices], axis=2), [-1,2])

#Output
[[0 1]
 [0 3]
 [1 2]
 [1 4]
 [2 0]
 [2 4]]

#use the above indices and set the output to 1.
#depth_x = 3, depth_y = 5
dense = tf.sparse_to_dense(full_indices,tf.constant([depth_x,depth_y]), tf.ones(tf.shape(full_indices)[0]))

# Output
#[[0. 1. 0. 1. 0.]
#[0. 0. 1. 0. 1.]
#[1. 0. 0. 0. 1.]]
hl0ma9xz

hl0ma9xz2#

Silimar的解决方案@vijay m的,但使用更直接的函数tf.scatter_nd

#Get full indices
mesh = tf.meshgrid(tf.range(indices.shape[1]), tf.range(indices.shape[0]))[1]
full_indices = tf.reshape(tf.stack([mesh, indices], axis=2), [-1,2])

#Output
[[0 1]
 [0 3]
 [1 2]
 [1 4]
 [2 0]
 [2 4]]
 
updates = tf.constant(1.0, shape=(tf.shape(full_indices)[1]))
scatter = tf.scatter_nd(indices, updates, shape)
sess.run(scatter)
#
# output
#
array([[0., 1., 0., 1., 0.],
       [0., 0., 1., 0., 1.],
       [1., 0., 0., 0., 1.]], dtype=float32)

相关问题