我有一个概率分布(应用Softmax后),其中每行的值总和为1
probs = tf.constant([
[0.0, 0.1, 0.2, 0.3, 0.4],
[0.5, 0.3, 0.2, 0.0, 0.0]])
我想从它和它们各自的概率值中使用tensorflow操作来采样k索引。
3索引的预期输出:
index: [
[4, 3, 4],
[0, 1, 0]
]
probs: [
[0.4, 0.3, 0.4],
[0.5, 0.3, 0.5]
]
我如何才能做到这一点?
1条答案
按热度按时间30byixjq1#
使用
tf.random.uniform
生成索引的随机Tensor:使用
tf.gather_nd
对概率Tensor进行索引,索引如下: