如何从概率分布的tensorflow 中抽取5个指标及其概率?

dtcbnfnu  于 2023-06-24  发布在  其他
关注(0)|答案(1)|浏览(136)

我有一个概率分布(应用Softmax后),其中每行的值总和为1

probs = tf.constant([
    [0.0, 0.1, 0.2, 0.3, 0.4],
    [0.5, 0.3, 0.2, 0.0, 0.0]])

我想从它和它们各自的概率值中使用tensorflow操作来采样k索引。
3索引的预期输出:

index: [
  [4, 3, 4],
  [0, 1, 0]
       ]

probs: [
  [0.4, 0.3, 0.4],
  [0.5, 0.3, 0.5]
       ]

我如何才能做到这一点?

30byixjq

30byixjq1#

使用tf.random.uniform生成索引的随机Tensor:

k = 3 # setting up the number of element to sample
# getting the dimension of the prob tensor
row, cols = tf.unstack(tf.shape(probs))
# Generating k values for each row between 0 and cols - 1
idx = tf.random.uniform((row, k), 0, cols, dtype=tf.int32)
>>> idx
<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[0, 0, 4],
       [3, 1, 0]], dtype=int32)>

使用tf.gather_nd对概率Tensor进行索引,索引如下:

>>> tf.gather(probs, idx, batch_dims=1)
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[0. , 0. , 0.4],
       [0. , 0.3, 0.5]], dtype=float32)>

相关问题