PyTorch:如何从Tensor中采样,其中Tensor中的每个值都有不同的被选中可能性?

gr8qqesn  于 12个月前  发布在  其他
关注(0)|答案(4)|浏览(152)

给定TensorA = torch.tensor([0.0316, 0.2338, 0.2338, 0.2338, 0.0316, 0.0316, 0.0860, 0.0316, 0.0860])包含概率之和为1(我删除了一些小数,但可以安全地假设它总是和为1),我想从A中采样一个值,其中该值本身是被采样的可能性。例如,从A采样0.0316的可能性是0.0316。采样值的输出仍然应该是Tensor。
我尝试使用WeightedRandomSampler,但它不允许选择的值是一个Tensor了,而是分离。
需要注意的是,我还想知道采样值在Tensor中的索引,也就是说,我采样0.2338,我想知道它是TensorA的索引12还是3

6bc51xsx

6bc51xsx1#

通过累加权重并选择随机浮点数[0,1)的插入索引,可以实现以期望概率进行选择。示例数组 A 稍微调整为总和为1。

import torch

A = torch.tensor([0.0316, 0.2338, 0.2338, 0.2338, 0.0316, 0.0316, 0.0860, 0.0316, 0.0862], requires_grad=True)

p = A.cumsum(0)
#tensor([0.0316, 0.2654, 0.4992, 0.7330, 0.7646, 0.7962, 0.8822, 0.9138, 1.0000], grad_fn=<CumsumBackward0>))

idx = torch.searchsorted(p, torch.rand(1))
A[idx], idx

字符串
输出

(tensor([0.2338], grad_fn=<IndexBackward0>), tensor([3]))


这比A.multinomial(1)更常见的方法更快。
对一个元素进行10000次采样,以检查分布是否符合概率

from collections import Counter

Counter(int(A.multinomial(1)) for _ in range(10000))
#1 loop, best of 5: 233 ms per loop

# vs @HatemAli's solution
dist=torch.distributions.categorical.Categorical(probs=A)
Counter(int(dist.sample()) for _ in range(10000))
# 10 loops, best of 5: 107 ms per loop

Counter(int(torch.searchsorted(p, torch.rand(1))) for _ in range(10000))
# 10 loops, best of 5: 53.2 ms per loop


输出

Counter({0: 319,
         1: 2360,
         2: 2321,
         3: 2319,
         4: 330,
         5: 299,
         6: 903,
         7: 298,
         8: 851})

vatpfxk5

vatpfxk52#

这个怎么样?

probs = torch.tensor([0.0316, 0.2338, 0.2338, 0.2338, 0.0316, 0.0316, 0.0860, 0.0316, 0.0860],requires_grad=True)

dist=torch.distributions.categorical.Categorical(probs=probs)
probs[dist.sample()]

字符串

hgncfbus

hgncfbus3#

从已接受的答案(Michael Szczesny)中得到的解可以扩展到覆盖具有概率的2dTensor,就像模型输出的softmax一样。只需相应地调整随机Tensor的维度。

def multisample_from_softmax(softmax_values):
    """
    quick weighted sampling using pytorch
    softmax_values : torch.tensor shaped (n_tokens, embedding_vocab_size)
    returns: torch.tensor shaped(n_tokens) with indices of sampled tokens
    """
    size = softmax_values.shape[0]
    rand_values = torch.rand((size, 1), device=softmax_values.device)
    cumprobs = softmax_values.cumsum(dim=1)
    selection = torch.searchsorted(cumprobs, rand_values).squeeze(1)
    selection_probs = (softmax_values[:, selection] * torch.eye(size, device=softmax_values.device)).diagonal()
    return selection, selection_probs

字符串

7ivaypg9

7ivaypg94#

你可以通过这样做来欺骗一点:

A = A*10000
temp = [[i]*A[i] for i in range(len(A))]
value = np.random.choice(temp)/10000

字符串

相关问题