numpy中是否有内置函数在3d数组中迭代advanced

5lhxktic  于 2022-12-18  发布在  其他
关注(0)|答案(1)|浏览(88)

我想创建一个函数,该函数将一个数组作为其第一个参数,它将接受一个任意大小和形状的arr数组,并覆盖该数组在给定[a,b]区间内的所有值,使其等于c。a,b,c的数字将作为parameters.like的输入和输出提供给该函数,如下所示

arr = np.array([[[5., 2., -5.], [4., 3., 1.]]])
overwrite_interval(arr, -2., 2., 100.) -> ndarray([[[5., 100., -5.], [4., 3., 100.]]])

def overwrite_interval(arr , a , b , c):  
  for i in arr[:,:]: 
  arr[a,b] = c

arr = np.array([[[5., 2., -5.], [4., 3., 1.]]])
assert overwrite_interval(arr, -2., 2., 100.) #-> ndarray([[[5., 100., -5.], [4., 3., 100.]]])
7hiiyaii

7hiiyaii1#

我认为你提问的方式与你给出的例子不一致。首先,你给出的例子数组是3D的,不是2D的。你可以

>>> arr.shape
(1,2,3)
>>> arr.ndim
3

这可能是个错误,并且您希望数组是2D的,因此您可以

arr = np.array([[5., 2., -5.], [4., 3., 1.]])

而不是。
其次,如果ab是这样的值,如果一个元素介于两者之间,则将该元素设置为值c,而不是将ab作为索引,则np.where函数非常适合于此。

def overwrite_interval(arr , a , b , c):
    inds = np.where((arr >= a) * (arr <= b))
    arr[inds] = c
    return arr


np.where返回一个元组,所以有时直接使用布尔数组会更容易。

def overwrite_interval(arr , a , b , c):
    inds = (arr >= a) * (arr <= b)
    arr[inds] = c
    return arr

这对你有用吗?这是你想要的意思吗?注意,如果你仍然想要初始数组是一个3D数组,我提供的解决方案将照常工作。

相关问题