对numpy数组的许多随机排列进行采样的最快方法

2nbm6dog  于 2023-01-26  发布在  其他
关注(0)|答案(2)|浏览(161)

与许多其他numpy/random函数不同,numpy.random.Generator.permutation()没有提供在一次函数调用中返回多个结果的明显方法。(1d)numpy数组x,我想采样xn排列(每个长度len(x)),并且具有形状为(n, len(x))的numpy阵列的结果。生成许多排列的一种简单方式是np.array([rng.permutation(x) for _ in range(n)])。这不是理想的,主要是因为循环是在Python中而不是在编译后的numpy函数中。

import numpy as np

rng = np.random.default_rng(1234)
# x is big enough to not want to enumerate all permutations
x = rng.standard_normal(size=20)
n = 10000
perms = np.array([rng.permutation(x) for _ in range(n)])

我的用例是使用蛮力搜索来寻找最小化特定属性的排列(构成“足够好”的搜索解决方案)。我可以使用numpy运算来计算每个排列的感兴趣属性,这些运算可以在得到的排列矩阵上很好地进行向量化/广播。结果证明,天真地生成排列矩阵是我代码中的瓶颈。有更好的方法吗?

jv4diomz

jv4diomz1#

您可以使用rng.permuted代替rng.permutation,并将其与np.tile组合,以便多次重复x,并单独打乱每个重复。

perms = rng.permuted(np.tile(x, n).reshape(n,x.size), axis=1)

这在我的机器上比你的初始代码快10倍。

balp4ylt

balp4ylt2#

注意Jérome的解决方案提供了一个n行的数组,但是它可能包含重复,不同的行可能有相同的x顺序(特别是当n大于x时)
如果您需要不重复地采样(就像我的例子一样),您可以始终执行set(list(perm))并保留唯一的组合“x”值

相关问题