NumPy -检查二维数组中的“5 in a row”模式

rhfm7lfc  于 2023-11-18  发布在  其他
关注(0)|答案(2)|浏览(117)

我正在尝试写一个函数,检查在一个二维数组中,在行、列和对角线方向上,是否至少有一个数字是“一行五个”。要检查的数组可以是大小>= 5的任何方阵,但我最有可能使用的是7 x7的。
例如,下面的矩阵在一个序列中有3次出现五个1的模式(至少检测一个就足够了)。一个在第一列,一个对角地从(0,6)到(5,1),另一个对角地从(1,0)到(5,6)。

A = np.array(
[
    [0, 1, 0, 0, 0, 0],
    [1, 0, 1, 0, 1, 0],
    [1, 0, 0, 1, 0, 0],
    [1, 0, 1 ,0, 1, 0],
    [1, 1, 0, 0, 0, 1],
    [1, 0, 0, 0, 0, 0]
]
)

字符串
我在这方面的尝试如下所示

import numpy as np
import cProfile
import pstats
import time

class Board:
    def __init__(self, size):
        self.data = np.zeros((size, size), dtype=np.byte)
        self.size = size
        # make sliding window of size 5 for each row and col
        self.rowWindows = np.lib.stride_tricks.sliding_window_view(self.data, window_shape=(1,5))
        self.colWindows = np.lib.stride_tricks.sliding_window_view(np.transpose(self.data), window_shape=(1,5))

        # make sliding window for both diagonal directions
        # storing them as object array since the diagonals windows have different sizes
        self.antiDiagonalWindow = np.array(
            [
                np.lib.stride_tricks.sliding_window_view(np.fliplr(self.data).diagonal(offset=i), window_shape=(5,))
                for i in range(-self.size + 5, self.size - 5 + 1, 1)
            ],
            dtype=object,
        )
        self.diagonalWindow = np.array(
            [
                np.lib.stride_tricks.sliding_window_view(self.data.diagonal(offset=i), window_shape=(5,))
                for i in range(-self.size + 5, self.size - 5 + 1, 1)
            ],
            dtype=object,
        )

    def hasFiveInRow(self, value):
        return (
           np.any(np.all(self.rowWindows == value, -1),)
            or np.any(np.all(self.colWindows == value,-1), )
            # have to use concat to turn sliding window views into 2d array
            # since diagonals have different sizes
            or np.any(np.all(np.concatenate(self.antiDiagonalWindow) == value, -1), )
            or  np.any(np.all(np.concatenate(self.diagonalWindow) == value, -1), )
        )

def benchMark():
    b = Board(size=7)
    b.data[:]=np.random.randint(low=0, high=3, size=(7,7))

    for i in range(100_000):
        val = b.hasFiveInRow(1)

# t0 = time.time()
# benchMark()
# print(time.time() - t0)
with cProfile.Profile() as p:
    benchMark()
    res = pstats.Stats(p)
    res.sort_stats(pstats.SortKey.TIME)
    res.print_stats()


结果性能不是太差,但我想提高它,如果可能的话,因为我使用它作为一个游戏ai树搜索的一部分,将不得不调用函数非常大量的次数。我认为np.any(np.all(windows))是不理想的,因为它必须创建许多布尔数组减少到一个单一的值。
cProfile日志显示了大量对'reduce'、'dictcomp'和_wrapreduction'等的调用,这些调用需要很长时间才能完成。
有没有更好的方法来寻找这个模式呢?我只需要检查这个模式是否以布尔值的形式至少出现过一次,尽管得到确切的位置和出现的次数会很好。
任何帮助将不胜感激!

uujelgoq

uujelgoq1#

我认为这是numba闪耀的场景:

from numba import njit

@njit
def check(A, value):
    for row in range(A.shape[0]):
        for col in range(A.shape[1]):
            if A[row, col] != value:
                continue

            # check row
            if col < (A.shape[1] - 4):
                if (
                    A[row, col]
                    == A[row, col + 1]
                    == A[row, col + 2]
                    == A[row, col + 3]
                    == A[row, col + 4]
                ):
                    return True

            # check column
            if row < (A.shape[0] - 4):
                if (
                    A[row, col]
                    == A[row + 1, col]
                    == A[row + 2, col]
                    == A[row + 3, col]
                    == A[row + 4, col]
                ):
                    return True

            # check diagonal 1
            if col < (A.shape[1] - 4) and row < (A.shape[0] - 4):
                if (
                    A[row, col]
                    == A[row + 1, col + 1]
                    == A[row + 2, col + 2]
                    == A[row + 3, col + 3]
                    == A[row + 4, col + 4]
                ):
                    return True

            # check diagonal 2
            if col > 3 and row < (A.shape[0] - 4):
                if (
                    A[row, col]
                    == A[row + 1, col - 1]
                    == A[row + 2, col - 2]
                    == A[row + 3, col - 3]
                    == A[row + 4, col - 4]
                ):
                    return True

    return False

字符串
快速基准:

from statistics import median
from timeit import repeat

import numpy as np
from numba import njit

@njit
def check(A, value):
    for row in range(A.shape[0]):
        for col in range(A.shape[1]):
            if A[row, col] != value:
                continue

            # check row
            if col < (A.shape[1] - 4):
                if (
                    A[row, col]
                    == A[row, col + 1]
                    == A[row, col + 2]
                    == A[row, col + 3]
                    == A[row, col + 4]
                ):
                    return True

            # check column
            if row < (A.shape[0] - 4):
                if (
                    A[row, col]
                    == A[row + 1, col]
                    == A[row + 2, col]
                    == A[row + 3, col]
                    == A[row + 4, col]
                ):
                    return True

            # check diagonal 1
            if col < (A.shape[1] - 4) and row < (A.shape[0] - 4):
                if (
                    A[row, col]
                    == A[row + 1, col + 1]
                    == A[row + 2, col + 2]
                    == A[row + 3, col + 3]
                    == A[row + 4, col + 4]
                ):
                    return True

            # check diagonal 2
            if col > 3 and row < (A.shape[0] - 4):
                if (
                    A[row, col]
                    == A[row + 1, col - 1]
                    == A[row + 2, col - 2]
                    == A[row + 3, col - 3]
                    == A[row + 4, col - 4]
                ):
                    return True

    return False

class Board:
    def __init__(self, size):
        self.data = np.zeros((size, size), dtype=np.byte)
        self.size = size
        # make sliding window of size 5 for each row and col
        self.rowWindows = np.lib.stride_tricks.sliding_window_view(
            self.data, window_shape=(1, 5)
        )
        self.colWindows = np.lib.stride_tricks.sliding_window_view(
            np.transpose(self.data), window_shape=(1, 5)
        )

        # make sliding window for both diagonal directions
        # storing them as object array since the diagonals windows have different sizes
        self.antiDiagonalWindow = np.array(
            [
                np.lib.stride_tricks.sliding_window_view(
                    np.fliplr(self.data).diagonal(offset=i), window_shape=(5,)
                )
                for i in range(-self.size + 5, self.size - 5 + 1, 1)
            ],
            dtype=object,
        )
        self.diagonalWindow = np.array(
            [
                np.lib.stride_tricks.sliding_window_view(
                    self.data.diagonal(offset=i), window_shape=(5,)
                )
                for i in range(-self.size + 5, self.size - 5 + 1, 1)
            ],
            dtype=object,
        )

    def hasFiveInRow(self, value):
        return (
            np.any(
                np.all(self.rowWindows == value, -1),
            )
            or np.any(
                np.all(self.colWindows == value, -1),
            )
            # have to use concat to turn sliding window views into 2d array
            # since diagonals have different sizes
            or np.any(
                np.all(np.concatenate(self.antiDiagonalWindow) == value, -1),
            )
            or np.any(
                np.all(np.concatenate(self.diagonalWindow) == value, -1),
            )
        )

board = Board(size=7)
# len numba compile the check() function
check(board.data, 1)

t1 = repeat(
    "check(board.data, 1)",
    setup="board.data[:] = np.random.randint(low=0, high=3, size=(7, 7))",
    number=1,
    repeat=100_000,
    globals=globals(),
)
t2 = repeat(
    "board.hasFiveInRow(1)",
    setup="board.data[:] = np.random.randint(low=0, high=3, size=(7, 7))",
    number=1,
    repeat=100_000,
    globals=globals(),
)

print(f"{median(t1) * 1_000_000} us")
print(f"{median(t2) * 1_000_000} us")


在我的计算机上打印(AMD 5700 x,Python 3.11):

0.4700850695371628 us
27.799978852272034 us


所以~ 60倍加速。

1mrurvl1

1mrurvl12#

效率是关键--循环经常击败花哨的函数。我将“五行”的逻辑修改为一个简单的嵌套迭代。
它不再使用滑动窗口,而是手动检查每个单元格与其八个相邻单元格之间的关系,没有冗余。
将模式检查 Package 在一个干净的函数中,以保持主循环整洁。
结果证明了这一点--原来的几分钟现在变成了几秒钟。小的调整可以带来大的性能提升。

import numpy as np

class Board:
    def __init__(self, size):
        self.size = size
        self.data = np.zeros((size, size), dtype=np.int8)

    def hasFiveInRow(self, value):
        # Check rows
        for i in range(self.size):
            for j in range(self.size - 4):
                if np.all(self.data[i, j:j+5] == value):
                    return True

        # Check columns
        for i in range(self.size - 4):
            for j in range(self.size):
                if np.all(self.data[i:i+5, j] == value):
                    return True

        # Check diagonals
        for i in range(self.size - 4):
            for j in range(self.size - 4):
                if np.all(np.diagonal(self.data[i:i+5, j:j+5]) == value) or np.all(np.diagonal(self.data[i:i+5, j:j+5][::-1, :]) == value):
                    return True

        return False

def benchmark():
    b = Board(size=7)
    b.data[:] = np.random.randint(low=0, high=3, size=(7, 7))

    for i in range(100_000):
        val = b.hasFiveInRow(1)

if __name__ == "__main__":
    import cProfile
    with cProfile.Profile() as p:
        benchmark()
        res = pstats.Stats(p)
        res.sort_stats(pstats.SortKey.TIME)
        res.print_stats()

字符串

相关问题