如何让numpy将对象数组切片解释为单个数组?

8zzbczxx  于 2023-03-18  发布在  其他
关注(0)|答案(1)|浏览(173)

当你让Numpy从一个包含任意对象的集合中创建一个数组时,它会创建一个“object”类型的数组,这允许你在这些对象之间使用索引切片,但是由于对象本身对Numpy来说是未知的,你不能一次索引到对象中(即使那个特定的对象**实际上是一个Numpy数组)。
但是,如果你对object数组进行切片,选择实际上是numpy数组的object数组部分,numpy似乎不会将切片折叠成一个numpy数组,即使再次调用np.array()也是如此。

>>> aa = np.array([np.random.randn(3, 4), {'something': 'blah'}], dtype=object)
>>> aa.shape
(2,)
>>> np.array(aa[0:1])
array([array([[ 1.78237043, -0.61082005,  0.92160137,  0.58961677],
              [ 1.54183639, -0.43097464,  1.36213935, -1.2695875 ],
              [ 0.01431181, -0.62073519,  0.56267489, -0.46113538]])],
      dtype=object)
>>> np.array(aa[0:1]).shape # I want this to be (1, 3, 4)
(1,)

是否有任何方法可以做到这一点,而不需要双重拷贝(例如,不像这样:np.array(aa[0:1].tolist()))?对象数组是否允许您在没有副本的情况下执行此操作?

nhn9ugyo

nhn9ugyo1#

您可以使用np.stack将对象类型数组组合为普通的ndarray

>>> aa = np.array([np.random.randn(3, 4), {'something': 'blah'}], dtype=object)
>>> aa
array([array([[-6.36267204e-01,  8.95707498e-02,  1.09275216e+00,
               -3.70594544e-01],
              [ 8.32865823e-01, -6.53876690e-01,  1.21000457e+00,
                1.22046398e+00],
              [-5.30262118e-01,  1.17934947e-04,  4.45156002e-01,
               -6.61549444e-02]])                                ,
       {'something': 'blah'}], dtype=object)
>>> np.stack(aa[0:1])
array([[[-6.36267204e-01,  8.95707498e-02,  1.09275216e+00,
         -3.70594544e-01],
        [ 8.32865823e-01, -6.53876690e-01,  1.21000457e+00,
          1.22046398e+00],
        [-5.30262118e-01,  1.17934947e-04,  4.45156002e-01,
         -6.61549444e-02]]])
>>> np.stack(aa[0:1]).shape
(1, 3, 4)

这也适用于对象数组中的多个ndarray,只要它们具有兼容的大小。
在内部,它只是把对象数组当作一个序列并在其上迭代,我不确定它是否比np.array(aa[0:1].tolist())的解决方案有显著的性能优势。

相关问题