我正在尝试重写一个numpy函数,使其能够处理2d和3d输入。考虑以下代码:
import numpy as np
def adstock_geometric(x: np.array, theta: np.array):
x_decayed = np.zeros_like(x)
x_decayed[0] = x[0]
for xi in range(1, len(x_decayed)):
x_decayed[xi] = x[xi] + theta * x_decayed[xi - 1]
return x_decayed
def testrun():
rand3d = np.random.randint(0, 10, size=(4, 1000, 1)) / 10
rand2d = np.random.randint(0, 10, size=(1000, 1)) / 10
x = np.ones(10)
output1d = adstock_geometric(x=x, theta=0.5) # works fine
output3d = adstock_geometric(x=x, theta=rand3d)
output2d = adstock_geometric(x=x, theta=rand2d)
if __name__ == '__main__':
testrun()
字符串
正如你所看到的,它在1d的情况下工作正常,但在2d和3d的情况下不工作。你可以想象theta是在第三维中堆叠的。预期的输出形状2d:(1000,10)预期的输出形状3d:(4,1000,10)
最明显的是在所有维度中穿梭,但这真的很慢。
1条答案
按热度按时间kqqjbcuj1#
这里有一个方法来解决这个问题:
字符串
尝试使用以下方法
型