SciPy Bootstrap Wrapper函数的要求

4uqofj5v  于 2023-06-29  发布在  Bootstrap
关注(0)|答案(1)|浏览(152)

我尝试使用SciPy bootstrap函数来处理中位数的简单差异。下面的例子,来自SciPy文档,工作正常。

from scipy.stats import mood, norm

def my_statistic(sample1, sample2, axis):
    statistic, _ = mood(sample1, sample2, axis=-1)
    return statistic

sample1 = norm.rvs(scale=1, size=100)
sample2 = norm.rvs(scale=2, size=100)
data = (sample1, sample2)
res = bootstrap(data, my_statistic, method='basic')

然而,当我尝试使用一个函数来计算两个数据集的中位数差时,我得到了一个“ValueError:零维数组不能串联”错误。下面是我们讨论的函数。

import numpy as np

def median_diff(group1, group2, axis=-1):
    diff = np.median(group1) - np.median(group2)
    return diff

我试过axis=0,1和-1,这没有任何区别。这个论点之所以存在,是因为SciPy文档说它是必需的。
带有traceback的完整错误消息是:

ValueError                                Traceback (most recent call last)
/tmp/ipykernel_259/1837026830.py in <cell line: 1>()
----> 1 boot = bootstrap(lab4data, median_diff, method="basic")

/usr/local/lib/python3.10/dist-packages/scipy/stats/_resampling.py in bootstrap(data, statistic, n_resamples, batch, vectorized, paired, axis, confidence_level, method, bootstrap_result, random_state)
    589         # Compute bootstrap distribution of statistic
    590         theta_hat_b.append(statistic(*resampled_data, axis=-1))
--> 591     theta_hat_b = np.concatenate(theta_hat_b, axis=-1)
    592 
    593     # Calculate percentile interval
/usr/local/lib/python3.10/dist-packages/numpy/core/overrides.py in concatenate(*args, **kwargs)
ValueError: zero-dimensional arrays cannot be concatenated

我的功能需要如何不同?

6vl6ewon

6vl6ewon1#

将诊断打印添加到统计功能:

In [57]: def median_diff(group1, group2, axis=-1):
    ...:     diff = np.median(group1) - np.median(group2)
    ...:     print(group1.shape, group2.shape, diff)
    ...:     return diff
    ...: 
    ...:

使用输入data列表运行:

In [58]: median_diff(*data)
(100,) (100,) -0.5152173131401271
Out[58]: -0.5152173131401271

但是当在bootstrap中运行时,data是2d:

In [59]: res = bootstrap(data, median_diff, method='basic')
(9999, 100) (9999, 100) -0.46106790632314354
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-59-073accca9196> in <module>
----> 1 res = bootstrap(data, median_diff, method='basic')

/usr/lib/python3/dist-packages/scipy/stats/_bootstrap.py in bootstrap(data, statistic, vectorized, paired, axis, confidence_level, n_resamples, batch, method, random_state)
    464         # Compute bootstrap distribution of statistic
    465         theta_hat_b.append(statistic(*resampled_data, axis=-1))
--> 466     theta_hat_b = np.concatenate(theta_hat_b, axis=-1)
    467 
    468     # Calculate percentile interval

~/.local/lib/python3.10/site-packages/numpy/core/overrides.py in concatenate(*args, **kwargs)

ValueError: zero-dimensional arrays cannot be concatenated

使用轴参数:

In [61]: def median_diff(group1, group2, axis=-1):
    ...:     diff = np.median(group1, axis) - np.median(group2, axis)
    ...:     print(group1.shape, group2.shape, diff)
    ...:     return diff
    ...:

该函数现在起作用:

In [62]: res = bootstrap(data, median_diff, method='basic')
(9999, 100) (9999, 100) [-0.58497618 -0.73443767 -0.75436132 ... -0.6955514  -0.62782726
 -0.56389115]
(100,) (100,) -0.5152173131401271

我们得到了concatenate错误,因为函数返回了一个标量,而bootstrap期望的是一个数组(形状为(9999,))。

相关问题