二元泊松pmf方程的Numpy矢量化

5f0d552i  于 2022-11-24  发布在  其他
关注(0)|答案(1)|浏览(160)

我正在写一个函数来计算二元泊松分布的概率密度函数。

当所有的参数(x, y, theta1, theta2, theta0)都是标量时,这很容易,但如果不使用循环来放大参数,使这些参数成为向量,这就很麻烦了。

  • theta0是标量-等式中的“相关参数”
  • 具有长度ltheta1theta2
  • xy,它们都具有长度n

输出数组的形状为(l, n, n)。例如,输出数组中的切片[j, :, :]将如下所示:


第一部分(常数,在求和之前)我想我已经算出了:

import numpy as np
from scipy.special import factorial

def constant(theta1, theta2, theta0, x, y):

    exponential_part = np.exp(-(theta1 + theta2 + theta0)).reshape(-1, 1, 1)
    x = np.tile(x, (len(x), 1)).transpose()
    y = np.tile(y, (len(y), 1))

    double_factorial = (np.power(np.array(theta1).reshape(-1, 1, 1), x)/factorial(x)) * \
                       (np.power(np.array(theta2).reshape(-1, 1, 1), y)/factorial(y))

    return exponential_part * double_factorial

但是我在求和部分遇到了困难。我如何向量化一个极限取决于变量数组的求和呢?

zed5wv10

zed5wv101#

我想我已经弄明白了,基于@w-m建议的方法:根据出现的最大x或y值,计算可能出现的每个求和项,并使用掩码去掉不需要的求和项。假设您的x和y项从0到N,按连续顺序排列,这将计算比实际需要多三倍的项,但这可以通过使用矢量化来弥补。

参考实现

我首先写了一个纯Python的参考实现,它只是用循环来实现你的问题。有了4个嵌套的循环,它不是很快,但是在测试numpy版本的时候很方便。

import numpy as np
from scipy.special import factorial, comb
import operator as op
from functools import reduce

def choose(n, r):
    # https://stackoverflow.com/a/4941932/530160
    r = min(r, n-r)
    numer = reduce(op.mul, range(n, n-r, -1), 1)
    denom = reduce(op.mul, range(1, r+1), 1)
    return numer // denom  # or / in Python 2

def reference_impl_constant(s_theta1, s_theta2, s_theta0, s_x, s_y):
    # Cast to float to prevent overflow
    s_theta1 = float(s_theta1)
    s_theta2 = float(s_theta2)
    s_theta0 = float(s_theta0)
    s_x = float(s_x)
    s_y = float(s_y)
    term1 = np.exp(-(s_theta1 + s_theta2 + s_theta0))
    term2 = (s_theta1 ** s_x / factorial(s_x))
    term3 = (s_theta2 ** s_y / factorial(s_y))
    assert term1 >= 0
    assert term2 >= 0
    assert term3 >= 0
    return term1 * term2 * term3

def reference_impl_constant_loop(theta1, theta2, theta0, x, y):
    theta_len = theta1.shape[0]
    xy_len = x.shape[0]
    constant_array = np.zeros((theta_len, xy_len, xy_len))
    for i in range(theta_len):
        for j in range(xy_len):
            for k in range(xy_len):
                s_theta1 = theta1[i]
                s_theta2 = theta2[i]
                s_theta0 = theta0
                s_x = x[j]
                s_y = y[k]
                constant_term = reference_impl_constant(s_theta1, s_theta2, s_theta0, s_x, s_y)
                assert constant_term >= 0
                constant_array[i, j, k] = constant_term
    return constant_array

def reference_impl_summation(s_theta1, s_theta2, s_theta0, s_x, s_y):
    sum_ = 0
    for i in range(min(s_x, s_y) + 1):
        sum_ += choose(s_x, i) * choose(s_y, i) * factorial(i) * ((s_theta0/s_theta1/s_theta2) ** i)
    assert sum_ >= 0
    return sum_

def reference_impl_summation_loop(theta1, theta2, theta0, x, y):
    theta_len = theta1.shape[0]
    xy_len = x.shape[0]
    summation_array = np.zeros((theta_len, xy_len, xy_len))
    for i in range(theta_len):
        for j in range(xy_len):
            for k in range(xy_len):
                s_theta1 = theta1[i]
                s_theta2 = theta2[i]
                s_theta0 = theta0
                s_x = x[j]
                s_y = y[k]
                summation_term = reference_impl_summation(s_theta1, s_theta2, s_theta0, s_x, s_y)
                assert summation_term >= 0
                summation_array[i, j, k] = summation_term
    return summation_array

def reference_impl(theta1, theta2, theta0, x, y):
    # all array inputs must be 1D
    assert len(theta1.shape) == 1
    assert len(theta2.shape) == 1
    assert len(x.shape) == 1
    assert len(y.shape) == 1
    # theta vectors must have same length
    theta_len = theta1.shape[0]
    assert theta2.shape[0] == theta_len
    # x and y must have same length
    xy_len = x.shape[0]
    assert y.shape[0] == xy_len
    # theta0 is scalar
    assert isinstance(theta0, (int, float))
    constant_array = np.zeros((theta_len, xy_len, xy_len))
    output = np.zeros((theta_len, xy_len, xy_len))
    constant_array = reference_impl_constant_loop(theta1, theta2, theta0, x, y)
    summation_array = reference_impl_summation_loop(theta1, theta2, theta0, x, y)
    output = constant_array * summation_array
    return output

Numpy实作

我将其实现划分为两个函数。
函数的作用是:计算求和符号左边的所有内容。函数的作用是:计算求和符号内的所有内容。

import numpy as np
from scipy.special import factorial, comb

def fast_summation(theta1, theta2, theta0, x, y):
    x = np.tile(x, (len(x), 1)).transpose()
    y = np.tile(y, (len(y), 1))
    sum_limit = np.minimum(x, y)
    max_sum_limit = np.max(sum_limit)
    i = np.arange(max_sum_limit + 1).reshape(-1, 1, 1)
    summation_mask = (i <= sum_limit)
    theta_ratio = (theta0 / (theta1 * theta2)).reshape(-1, 1, 1, 1)
    theta_to_power = np.power(theta_ratio, i)
    terms = comb(x, i) * comb(y, i) * factorial(i) * theta_to_power
    # mask out terms which aren't part of sum
    terms *= summation_mask
    # axis 0 is theta
    # axis 1 is i
    # axis 2 & 3 are x and y
    # so sum across axis 1
    terms = terms.sum(axis=1)
    
    return terms

def fast_constant(theta1, theta2, theta0, x, y):
    theta1 = theta1.astype('float64')
    theta2 = theta2.astype('float64')
    exponential_part = np.exp(-(theta1 + theta2 + theta0)).reshape(-1, 1, 1)
    # x and y must be 1D
    assert len(x.shape) == 1
    assert len(y.shape) == 1
    # x and y must have same shape
    assert x.shape == y.shape
    x_len, y_len = x.shape[0], y.shape[0]
    x = x.reshape((x_len, 1))
    y = y.reshape((1, y_len))

    double_factorial = (np.power(np.array(theta1).reshape(-1, 1, 1), x)/factorial(x)) * \
                       (np.power(np.array(theta2).reshape(-1, 1, 1), y)/factorial(y))

    return exponential_part * double_factorial

def fast_impl(theta1, theta2, theta0, x, y):
    return fast_summation(theta1, theta2, theta0, x, y) * fast_constant(theta1, theta2, theta0, x, y)

基准测试

假设X和Y的范围从0到20,并且theta位于该范围内的某个位置,我得到的结果是numpy版本比纯python参考快大约280倍。

数值稳定性

我不确定它的数值稳定性如何。例如,当我将theta设为100时,我会得到一个浮点溢出。通常,当计算一个包含大量选择和阶乘表达式的表达式时,你会使用一些数学等价物,这会导致较小的中间和。在这种情况下,我对数学的了解非常少,我不知道你会怎么做。

相关问题