当使用scipy.stats中的multinomial.pmf时,如何处理nan值?

gijlo24d  于 2022-11-10  发布在  其他
关注(0)|答案(1)|浏览(150)

运行一个小实验,我注意到如果传递给multinomial.pmf的参数之和稍微大于1,则返回值为nan。
请参见下面的示例:

import numpy as np
from scipy.stats import multinomial as multi_s

def safe_multi(x, params):
    params_sum = params.sum()
    safe_params = params / params_sum if params_sum > 1 else params
    return multi_s.pmf(x, sum(x), safe_params)

params1 = np.array(
    [0.21310660657549002, 0.21310660657549002, 0.21310660657549002,
     2.8699968847179538e-06, 0.0023286820110742764, 2.8699968847179538e-06,
     0.0023286820110742764, 2.8699968847179538e-06, 2.8699968847179538e-06,
     0.0023258120141895593, 0.0016006555371205703, 0.0023258120141895593,
     0.0016006555371205703, 0.04333851102588555, 0.04333851102588555,
     0.04333851102588555, 0.04333851102588555, 0.04333851102588555,
     0.04333851102588555, 0.04333851102588555, 0.04333851102588555,
     0.0007251564770689873, 0.0007251564770689873, 7.377915317967555e-27,
     7.377915317967555e-27])

params2 = np.array(
    [0.3333333333333332, 0.3333333333333332, 0.3333333333333332,
     2.931077467598623e-93, 6.532951191692606e-25, 1.4080539652716124e-224,
     6.532951191692606e-25, 1.4080539652716124e-224, 1.4080539652716124e-224,
     6.532951191692606e-25, 6.532951191692606e-25, 6.532951191692606e-25,
     6.532951191692606e-25, 3.5127398105835854e-17, 3.5127398105835854e-17,
     3.5127398105835854e-17, 3.5127398105835854e-17, 3.5127398105835854e-17,
     3.5127398105835854e-17, 3.5127398105835854e-17, 3.5127398105835854e-17,
     7.860388790608641e-191, 7.860388790608641e-191, 2.931077467598623e-93,
     2.931077467598623e-93])

samples = np.array(
    [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])

result1_scipy = multi_s.pmf(samples, samples.sum(), params1)
result2_scipy = multi_s.pmf(samples, samples.sum(), params2)
print(result1_scipy, result2_scipy)
print(params1.sum(), params2.sum())

print('-----------------')

result1_scipy_sum = multi_s.pmf(samples, samples.sum(), params1 / params1.sum())
result2_scipy_sum = multi_s.pmf(samples, samples.sum(), params2 / params2.sum())
print(result1_scipy_sum, result2_scipy_sum)
print((params1 / params1.sum()).sum(), (params2 / params2.sum()).sum())

print('-----------------')

result1 = safe_multi(samples, params1)
result2 = safe_multi(samples, params2)
print(result1, result2)

输出为:

nan 0.22222222222222202
1.0000000000000002 0.9999999999999999
-----------------
0.058068684987554825 nan
0.9999999999999998 1.0000000000000002
-----------------
0.058068684987554825 0.22222222222222202

有没有更好的方法可以处理参数可能发生的数值溢出?我的safe_multi() Package 器似乎可以解决这个问题,但我对处理这个问题的最佳实践感兴趣。
编辑:我发现了一个例子,如下所示,它似乎总是返回nan,尽管参数的和为1。

c = np.array(
    [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])

params = np.array(
    [0.02702702702702703, 0.02702702702702703, 0.0, 0.0, 0.0,
     0.04054054054054054, 0.0, 0.0, 0.0, 0.0, 0.06756756756756757,
     0.0945945945945946, 0.06756756756756757, 0.06756756756756757,
     0.04054054054054054, 0.04054054054054054, 0.06756756756756757,
     0.06756756756756757, 0.013513513513513514, 0.013513513513513514, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.04054054054054054, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.013513513513513514, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.04054054054054054, 0.013513513513513514,
     0.013513513513513514, 0.013513513513513514, 0.013513513513513514, 0.0, 0.0,
     0.0, 0.0, 0.013513513513513514, 0.013513513513513514, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.02702702702702703, 0.02702702702702703, 0.013513513513513514,
     0.013513513513513514, 0.0, 0.0, 0.0, 0.0, 0.0, 0.013513513513513514, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.013513513513513514, 0.013513513513513514, 0.013513513513513514,
     0.013513513513513514, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.013513513513513514,
     0.0, 0.0, 0.0, 0.0, 0.013513513513513514, 0.013513513513513514, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])

print(multi_s.pmf(c, c.sum(), params))
print(multi_s.pmf(c, c.sum(), params / params.sum()))
print(params.sum())

输出量:

nan
nan
1.0

分析了scipy代码后,似乎问题出在_multivariate.py的第3012行:

p[..., -1] = 1. - p[..., :-1].sum(axis=-1)

此行通过将最后一个参数设置为适当的值来确保参数之和为1。在上面的示例中,这为最后一个参数添加了一个极小的负值,该负值随后将被标记为问题。为了确保满足此条件,是否可以除以参数之和,而不是执行减法?

iugsix8n

iugsix8n1#

nan值是通过pmf()方法中的有效性检查引入的。pmf()方法(以及logpmf()) Package 了_logpmf(),它计算对数概率密度函数而不进行任何有效性检查。如果您信任您的输入并希望避免nan值,则可以直接使用多项式._logpmf()代替:

def multinomial_without_checks(x, params):
    return np.exp(multi_s._logpmf(x, sum(x), params))

在您编辑的示例上运行此命令将得到预期的结果:

c = np.array(
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])

params = np.array(
    [0.02702702702702703, 0.02702702702702703, 0.0, 0.0, 0.0,
     0.04054054054054054, 0.0, 0.0, 0.0, 0.0, 0.06756756756756757,
     0.0945945945945946, 0.06756756756756757, 0.06756756756756757,
     0.04054054054054054, 0.04054054054054054, 0.06756756756756757,
     0.06756756756756757, 0.013513513513513514, 0.013513513513513514, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.04054054054054054, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.013513513513513514, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.04054054054054054, 0.013513513513513514,
     0.013513513513513514, 0.013513513513513514, 0.013513513513513514, 0.0, 0.0,
     0.0, 0.0, 0.013513513513513514, 0.013513513513513514, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.02702702702702703, 0.02702702702702703, 0.013513513513513514,
     0.013513513513513514, 0.0, 0.0, 0.0, 0.0, 0.0, 0.013513513513513514, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.013513513513513514, 0.013513513513513514, 0.013513513513513514,
     0.013513513513513514, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.013513513513513514,
     0.0, 0.0, 0.0, 0.0, 0.013513513513513514, 0.013513513513513514, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])

print(multinomial_without_checks(c, params))

输出为:

0.02702702702702703

相关问题