对多维numpy数组排序,保留包含的2d块不变

yzckvree  于 2023-03-02  发布在  其他
关注(0)|答案(1)|浏览(172)

我希望有人能帮助我。
我有一个5维的numpy数组:

my_array = {ndarray: (256,256,256,4,3,3)}

我想按最后一个维度(4)排序,让3 × 3块原封不动。换句话说,我想对很多3 × 3块排序,其中4个总是建立一个组。
举个小规模的例子,假设我有一个类似的数组

my_array = {ndarray: (256,256,256,4,2,2)}

对于256 * 256 * 256个组中的每个组,其可以如下所示:

[[[2,3],[1,3]],
[[1,2],[3,2]],
[[1,4],[2,1]],
[[1,2],[3,4]]]

我希望这些块像这样排序:

[[[1,2],[3,2]],
[[1,2],[3,4]],
[[1,4],[2,1]],
[[2,3],[1,3]]]

对于二维数组的简单情况,我可以通过使用my_2darray[:,np.lexsort(my_2darray)]来实现这一点(对列排序并保持列完整
我试过使用np.sort(my_array, axis=3),它导致单个值被排序,而不是块,我试过my_array[:,np.lexsort(my_array)]风格的所有变体和类似的,我发现没有什么工作。在一个sidenote上,我发现我想用lexsort排序的轴必须是最后一个,否则它的行为会很奇怪。没有问题,np.swapaxes,但是在高维的例子中仍然不能工作。有人有什么有用的见解吗?
谢谢大家!

nzk0hqpo

nzk0hqpo1#

技术上可以使用this solution,但应用到5维空间可能会有点复杂,所以这里是实现,请在使用前自己验证。

# Create a 5-dimensional array as input.
np.random.seed(0)
a = np.random.randint(0, 10, size=(2, 2, 3, 2, 2))
print("a:", a.shape)  # (2, 2, 3, 2, 2)
print(a)
# [[[
#     [[5, 0], [3, 3]],
#     [[7, 9], [3, 5]],
#     [[2, 4], [7, 6]],
# ...

# Flatten all axes except the axis you want to sort on.
# That is, make a 3-dimensional array of (N, sort-axis, M).
b = a.reshape([-1, a.shape[-3], a.shape[-2] * a.shape[-1]])
print("b:", b.shape)  # (4, 3, 4)
print(b)
# [[
#     [5, 0, 3, 3],
#     [7, 9, 3, 5],
#     [2, 4, 7, 6],
# ...

# Then, use lexsort with the bottom axis as sort keys.
idx = np.lexsort([b[..., i] for i in range(b.shape[-1])][::-1])
idx = np.lexsort(np.rollaxis(b, -1)[::-1])  # This is the same as above, but faster.
print("idx:", idx.shape)  # (4, 3)
print(idx)
# [
#     [2, 0, 1],
# ...

# The idx above are the sort order for each block. We can use it like this:
c = np.array([b[i][idx[i]] for i in range(len(b))])
c = b[np.arange(len(b))[:, np.newaxis], idx]  # This is the same as above, but faster.
print("c:", c.shape)  # (4, 3, 4)
print(c)
# [[
#     [2, 4, 7, 6],
#     [5, 0, 3, 3],
#     [7, 9, 3, 5],
# ...

# Restore to the original shape.
d = c.reshape(a.shape)
print("d:", d.shape)  # (2, 2, 3, 2, 2)
print(d)
# [[[
#     [[2, 4], [7, 6]],
#     [[5, 0], [3, 3]],
#     [[7, 9], [3, 5]],
# ...

相关问题