numpy 使用元组的多索引

lg40wkob  于 12个月前  发布在  其他
关注(0)|答案(4)|浏览(113)

我有一个多维np.array。我知道前N维和后M维的形状。例如,

>>> n = (3,4,5)
>>> m = (6,)
>>> a = np.ones(n + m)
>>> a.shape
(3, 4, 5, 6)

字符串
使用元组作为索引可以快速索引前N个维度,如

>>> i = (1,1,2)
>>> a[i].shape
(6,)


使用列表并不能给予我所需要的结果

>>> i = [1,1,2]
>>> a[i].shape
(3, 4, 5, 6)


但是我在做多索引(既检索又赋值)时遇到了麻烦。例如,

>>> i = (1,1,2)
>>> j = (2,2,2)


我需要通过一些东西,

>>> a[[i, j]]


并得到(2, 6)的输出形状。
但我却得到了

>>> a[[i, j]].shape
(2, 3, 4, 5, 6)


>>> a[(i, j)].shape
(3, 5, 6)


我总是可以循环或改变我索引的方式(比如使用np.reshapenp.unravel_index),但是有没有更pythonic的方法来实现我所需要的?

EDIT我需要任何数量的索引,例如,

>>> i = (1,1,2)
>>> j = (2,2,2)
>>> k = (0,0,0)
...

ruarlubt

ruarlubt1#

考虑一个索引列表:

idx = [
    (1, 1, 2),  # Your i
    (2, 2, 2),  # Your j
    (0, 0, 0),  # Your k
    (1, 2, 1),  # ... 
    (2, 0, 1),  # extend as necessary
]

字符串
和形状为(3, 4, 5, 6)的数组a
当你写out = a[idx]时,numpy会这样解释:

out = np.array([
    [a[1], a[1], a[2]],
    [a[2], a[2], a[2]],
    [a[0], a[0], a[0]],
    [a[1], a[2], a[1]],
    [a[2], a[0], a[1]],
])


其中,例如a[0],只是a的第一个子阵列,因此具有形状(4, 5, 6)
结果,你得到的是一个形状为(5, 3)的数组(索引的形状),包含a!的(4, 5, 6)子数组(......最终结果是(5, 3, 4, 5, 6)np.shape(idx) + a.shape[1:])。
相反,您想要的是以下内容:

out = np.array([
    a[1, 1, 2],
    a[2, 2, 2],
    a[0, 0, 0],
    a[1, 2, 1],
    a[2, 0, 1],
])


在numpy中实现“矢量化”的方法如下:

out = a[
    [1, 2, 0, 1, 2],  # [idx[0][0], idx[1][0], idx[2][0], ...]
    [1, 2, 0, 2, 0],  # [idx[0][1], idx[1][1], idx[2][1], ...]
    [2, 2, 0, 1, 1]   # [idx[0][2], idx[1][2], idx[2][2], ...]
]


该行为记录在索引指南中:
高级索引始终作为一个索引广播和迭代:

result[i_1, ..., i_M] == x[ind_1[i_1, ..., i_M], ind_2[i_1, ..., i_M],
                           ..., ind_N[i_1, ..., i_M]]


要将原始idx转换为这样的索引器,可以使用tuple(zip(*idx))技巧。
Numpy的索引系统非常灵活,但这种灵活性的代价是这些“简单”的任务变得不直观......至少在我看来;)

djmepvbi

djmepvbi2#

提取每个选择,然后将它们重新组合成一个新的数组?

>>> np.array([a[i], a[j]]).shape
(2, 6)

字符串

zlwx9yxi

zlwx9yxi3#

我测试了ShadowRanger和Adulphylaxs的解决方案,并将它们与我最初的解进行了比较。我对它们进行了计时,Adulphylaxs的是最快的。

import time
import numpy as np

n = (3, 4, 5)
m = (2, 4, 6)
a = np.random.rand(*(n + m))

indices = [tuple(np.random.randint(n)) for _ in range(100)]

test = 100000

t1 = time.time()

for _ in range(test):
  x = a[tuple(zip(*indices))]  

t2 = time.time()

for _ in range(test):
  y = np.array([a[i] for i in indices])

t3 = time.time()

for _ in range(test):
  b = a.reshape(np.prod(n), np.prod(m))
  j = [np.ravel_multi_index(i, n) for i in indices]
  z = b[j].reshape(-1, *m)

t4 = time.time()

个字符

7kjnsjlb

7kjnsjlb4#

让我们学究气,以澄清发生了什么,在每一个案件。

In [19]: >>> n = (3,4,5)
    ...: >>> m = (6,)

字符串
加上元组(以及列表和字符串)加入它们:

In [20]: n+m
Out[20]: (3, 4, 5, 6)

In [21]: a=np.ones(n+m)


使用元组索引与单独输入每个标量是一样的。是逗号组成了元组,而不是()。实际上是解释器将元组传递给对象;它是对象自己的getitem方法来解释它。(如果你给列表给予一个元组,列表会抱怨,但数组喜欢它:))。

In [22]: a[1,1,2].shape
Out[22]: (6,)


有一个自动跟踪切片。在下面我将包括这一点,以明确(呃):

In [23]: a[1,1,2,:].shape
Out[23]: (6,)

In [24]: >>> i = (1,1,2)
    ...: >>> j = (2,2,2)


[i,j]被转换为数组,(2,3)形状:

In [25]: np.array([i,j])
Out[25]: 
array([[1, 1, 2],
       [2, 2, 2]])


所以这个数组只用于索引第一维。其余的都是跟踪。

In [26]: a[np.array([i,j]),:,:,:].shape
Out[26]: (2, 3, 4, 5, 6)


元组的元组:

In [27]: (i,j)
Out[27]: ((1, 1, 2), (2, 2, 2))


内部元组被转换为列表,或者更确切地说,转换为数组。因此,两个索引一起广播,选择一个(3,)形状(将其视为“对角线”)

In [29]: a[[1,1,2],[2,2,2],:,:].shape
Out[29]: (3, 5, 6)


加上ka[i,j,k]将得到给予(3,6)形状。
我不知道你的i,j应该如何产生(2,6)形状
等等,也许这相当于

In [32]: np.stack([a[i],a[j]]).shape
Out[32]: (2, 6)


或者等效地用np.array连接两个选择。
就像experiment一样,这里有一个等价物:

In [45]: b=np.arange(3*4*5*6).reshape(a.shape)

In [46]: c=np.stack([b[i],b[j]])
In [47]: d=b[[1,2],[1,2],[2,2]]
In [48]: np.allclose(c,d)
Out[48]: True


所以我们需要将i,j元组转换为这个pairs元组。

In [55]: tuple([[i1,j1] for i1,j1 in zip(i,j)])
Out[55]: ([1, 2], [1, 2], [2, 2])
In [56]: tuple(np.array([i,j]).T.tolist())
Out[56]: ([1, 2], [1, 2], [2, 2])
In [57]: tuple(np.stack([i,j],1).tolist())
Out[57]: ([1, 2], [1, 2], [2, 2])


用第三个元组

In [58]: k=(0,0,0)    
In [59]: tuple(np.stack([i,j,k],1).tolist())
Out[59]: ([1, 2, 0], [1, 2, 0], [2, 2, 0])


我们不需要tolist,尽管它相当快:

In [61]: b[tuple(np.stack([i,j,k],1))]
Out[61]: 
array([[162, 163, 164, 165, 166, 167],
       [312, 313, 314, 315, 316, 317],
       [  0,   1,   2,   3,   4,   5]])

相关问题