重写numpy函数以处理2D和3D输入

aij0ehis  于 2024-01-08  发布在  其他
关注(0)|答案(1)|浏览(194)

我正在尝试重写一个numpy函数,使其能够处理2d和3d输入。考虑以下代码:

import numpy as np
def adstock_geometric(x: np.array, theta: np.array):
    x_decayed = np.zeros_like(x)
    x_decayed[0] = x[0]

    for xi in range(1, len(x_decayed)):
        x_decayed[xi] = x[xi] + theta * x_decayed[xi - 1]

    return x_decayed

def testrun():
    rand3d = np.random.randint(0, 10, size=(4, 1000, 1)) / 10
    rand2d = np.random.randint(0, 10, size=(1000, 1)) / 10

    x = np.ones(10)

    output1d = adstock_geometric(x=x, theta=0.5)  # works fine
    output3d = adstock_geometric(x=x, theta=rand3d)
    output2d = adstock_geometric(x=x, theta=rand2d)
    
if __name__ == '__main__':
    testrun()

字符串
正如你所看到的,它在1d的情况下工作正常,但在2d和3d的情况下不工作。你可以想象theta是在第三维中堆叠的。预期的输出形状2d:(1000,10)预期的输出形状3d:(4,1000,10)
最明显的是在所有维度中穿梭,但这真的很慢。

kqqjbcuj

kqqjbcuj1#

这里有一个方法来解决这个问题:

def addstock_geom(x, theta):
    axis = n if (n := len(np.array(theta).shape) - 1) > 0 else None
    return ((theta**np.arange(x.size)) * x).cumsum(axis)

字符串
尝试使用以下方法

x = np.ones(10)
rand1d = 0.5
rand2d = np.random.randint(0, 10, size=(5, 1)) / 10
rand3d = np.random.randint(0, 10, size=(4, 2, 1)) / 10
addstock_geom(x, rand1d)
addstock_geom(x, rand2d)
addstock_geom(x, rand3d)

相关问题