scipy 如何在使用numpy.random.normal时指定上限和下限

koaltpgm  于 2022-11-23  发布在  其他
关注(0)|答案(8)|浏览(204)

我希望能够从正态分布中选取值,这些值只能落在0和1之间,在某些情况下,我希望能够返回一个完全随机的分布,而在另一些情况下,我希望返回的值落在高斯分布的形状中。
目前我正在使用以下函数:

def blockedgauss(mu,sigma):
    while True:
        numb = random.gauss(mu,sigma)
        if (numb > 0 and numb < 1):
            break
    return numb

它从正态分布中选择一个值,然后如果它福尔斯在0到1的范围之外就丢弃它,但我觉得一定有更好的方法来做这件事。

6l7fqoea

6l7fqoea1#

听起来你需要一个truncated normal distribution。使用scipy,你可以使用scipy.stats.truncnorm从这样的分布中生成随机变量:

import matplotlib.pyplot as plt
import scipy.stats as stats

lower, upper = 3.5, 6
mu, sigma = 5, 0.7
X = stats.truncnorm(
    (lower - mu) / sigma, (upper - mu) / sigma, loc=mu, scale=sigma)
N = stats.norm(loc=mu, scale=sigma)

fig, ax = plt.subplots(2, sharex=True)
ax[0].hist(X.rvs(10000), normed=True)
ax[1].hist(N.rvs(10000), normed=True)
plt.show()

上图显示截断正态分布,下图显示具有相同平均值mu和标准差sigma的正态分布。

pgky5nke

pgky5nke2#

我在寻找一种方法来返回一系列从0和1之间的正态分布中采样的值(即概率)时遇到了这篇文章。为了帮助其他有同样问题的人,我只想指出scipy.stats.truncnorm有一个内置的功能“.rvs”。
因此,如果需要100,000个样本,平均值为0.5,标准差为0.1:

import scipy.stats
lower = 0
upper = 1
mu = 0.5
sigma = 0.1
N = 100000

samples = scipy.stats.truncnorm.rvs(
          (lower-mu)/sigma,(upper-mu)/sigma,loc=mu,scale=sigma,size=N)

这给出了与numpy.random.normal非常相似的行为,但在所需的范围内。使用内置函数比循环收集样本要快得多,尤其是对于大的N值。

a0zr77ik

a0zr77ik3#

如果有人想要一个只使用numpy的解决方案,这里有一个使用normal函数和clip的简单实现(MacGyver的方法):

import numpy as np
    def truncated_normal(mean, stddev, minval, maxval):
        return np.clip(np.random.normal(mean, stddev), minval, maxval)

**EDIT:不要使用此方法!!这是您不应该使用方法!!**例如,

a = truncated_normal(np.zeros(10000), 1, -10, 10)
也许看起来很有效但是
b = truncated_normal(np.zeros(10000), 100, -1, 1)
绝对不会绘制截断的法线,如以下直方图所示:

很抱歉,希望没有人受伤!我想教训是,不要试图模仿MacGyver在编码...干杯,
安德烈斯

deyfvvtc

deyfvvtc4#

我已经做了一个示例脚本如下。它显示了如何使用API来实现我们想要的功能,例如用已知参数生成样本,如何计算CDF、PDF等。我还附上了一张图片来显示这一点。

#load libraries   
import scipy.stats as stats

#lower, upper, mu, and sigma are four parameters
lower, upper = 0.5, 1
mu, sigma = 0.6, 0.1

#instantiate an object X using the above four parameters,
X = stats.truncnorm((lower - mu) / sigma, (upper - mu) / sigma, loc=mu, scale=sigma)

#generate 1000 sample data
samples = X.rvs(1000)

#compute the PDF of the sample data
pdf_probs = stats.truncnorm.pdf(samples, (lower-mu)/sigma, (upper-mu)/sigma, mu, sigma)

#compute the CDF of the sample data
cdf_probs = stas.truncnorm.cdf(samples, (lower-mu)/sigma, (upper-mu)/sigma, mu, sigma)

#make a histogram for the samples
plt.hist(samples, bins= 50,normed=True,alpha=0.3,label='histogram');

#plot the PDF curves 
plt.plot(samples[samples.argsort()],pdf_probs[samples.argsort()],linewidth=2.3,label='PDF curve')

#plot CDF curve        
plt.plot(samples[samples.argsort()],cdf_probs[samples.argsort()],linewidth=2.3,label='CDF curve')

#legend
plt.legend(loc='best')

wswtfjt7

wswtfjt75#

实际上,你可以把数据归一化,然后把它转换到你需要的范围。对不起,第一次使用,我不知道如何直接显示图片the function is shown

fquxozlt

fquxozlt6#

我已经用numpy测试了一些解决方案,通过试错法,我发现± variation除以3是标准差的一个很好的猜测。
下面是一些示例:
"基础知识"

import numpy as np
import matplotlib.pyplot as plt

val_min = 1000
val_max = 2000
variation = (val_max - val_min)/2
std_dev = variation/3
mean = (val_max + val_min)/2
dist_normal = np.random.normal(mean, std_dev,  1000)
print('Normal distribution\n\tMin: {0:.2f}, Max: {1:.2f}'
      .format(dist_normal.min(), dist_normal.max()))
plt.hist(dist_normal, bins=30)
plt.show()

比较案例

import numpy as np
import matplotlib.pyplot as plt

val_min = 1400
val_max = 2800
variation = (val_max - val_min)/2
std_dev = variation/3
mean = (val_max + val_min)/2
fig, ax = plt.subplots(3, 3)
plt.suptitle("Histogram examples by Davidson Lima (github.com/davidsonlima)", 
             fontweight='bold')
i = 0
j = 0
pos = 1
while (i < 3):
    while (j < 3):
        dist_normal = np.random.normal(mean, std_dev,  1000)
        max_min = 'Min: {0:.2f}, Max: {1:.2f}'.format(dist_normal.min(), dist_normal.max())
        ax[i, j].hist(dist_normal, bins=30, label='Dist' + str(pos))
        ax[i, j].set_title('Normal distribution ' + str(pos))
        ax[i, j].legend()
        ax[i, j].text(mean, 0, max_min, horizontalalignment='center', color='white',
                      bbox={'facecolor': 'red', 'alpha': 0.5})
        print('Normal distribution {0}\n\tMin: {1:.2f}, Max: {2:.2f}'
              .format(pos, dist_normal.min(), dist_normal.max()))
        j += 1
        pos += 1
    j = 0
    i += 1
plt.show()

如果有人有一个更好的方法与numpy,请在下面评论。

xa9qqrwz

xa9qqrwz7#

我使用numpy.random.normal和一些额外的代码开发了一个简单的函数,用于创建一个范围内的值列表。

def truncnormal(meanv, sd, minv, maxv, n):
    finallist = []
    initiallist = []
    while len(finallist) < n:
        initiallist = list(np.random.normal(meanv, sd, n))
        initiallist.sort()
        indexmin = 0
        indexmax = 0
        for item in initiallist:
            if item < minv:
                indexmin = indexmin + 1
            else:
                break
        for item in initiallist[::-1]:
            if item > maxv:
                indexmax = indexmax + 1
            else:
                break
        indexmax = -indexmax
        finallist = finallist + initiallist[indexmin:indexmax]
    shuffle(finallist)
    finallist = finallist[:n] 
    print(len(finallist), min(finallist), max(finallist))

truncnormal(10, 3, 8, 11, 10000)
mlnl4t2r

mlnl4t2r8#

truncnorm参数化是复杂的,所以这里有一个函数可以将参数化转换为更直观的东西:

from scipy.stats import truncnorm

def get_truncated_normal(mean=0, sd=1, low=0, upp=10):
    return truncnorm(
        (low - mean) / sd, (upp - mean) / sd, loc=mean, scale=sd)

如何使用?

1.使用以下参数示例化生成器:* 平均值 标准差 * 和 * 截断范围 *:

>>> X = get_truncated_normal(mean=8, sd=2, low=1, upp=10)

1.然后,可以使用X生成一个值:

>>> X.rvs()
 6.0491227353928894

1.或者,具有N个生成值的numpy数组:

>>> X.rvs(10)
 array([ 7.70231607,  6.7005871 ,  7.15203887,  6.06768994,  7.25153472,
         5.41384242,  7.75200702,  5.5725888 ,  7.38512757,  7.47567455])

视觉示例

以下是三种不同的截断正态分布图:

X1 = get_truncated_normal(mean=2, sd=1, low=1, upp=10)
X2 = get_truncated_normal(mean=5.5, sd=1, low=1, upp=10)
X3 = get_truncated_normal(mean=8, sd=1, low=1, upp=10)

import matplotlib.pyplot as plt
fig, ax = plt.subplots(3, sharex=True)
ax[0].hist(X1.rvs(10000), normed=True)
ax[1].hist(X2.rvs(10000), normed=True)
ax[2].hist(X3.rvs(10000), normed=True)
plt.show()

相关问题