numpy 在Python中应该如何使用生成器?

tez616oj  于 2022-12-23  发布在  Python


import numpy as np

from pytest import approx

def gen_inc_dec_point(y_data: np.ndarray):
    i_start = 0
    i_stop = 0
    it = np.nditer(y_data)
    it2 = np.nditer(y_data)
    y_start = next(it2)
    y_stop = next(it)
    y = next(it2)

    # flag 1 means y values of the function is increasing
    # flag 0 indicate that there is no change in the values of y (constant)
    # flag -1 indicate the y values is decreasing
    flag = 0

    while True:
            if y_start == approx(y):
                flag = 0
                while y_stop == approx(y):
                    y_stop = next(it)
                    y = next(it2)
                    i_stop += 1
                yield [i_start, i_stop, flag]
                y_start = np.copy(y_stop)
                i_start = i_stop
            elif y_start < y:
                flag = 1
                while y_stop < y:
                    y_stop = next(it)
                    y = next(it2)
                    i_stop += 1
                yield [i_start, i_stop, flag]
                y_start = np.copy(y_stop)
                i_start = i_stop

                flag = -1
                while y_stop > y:
                    y_stop = next(it)
                    y = next(it2)
                    i_stop += 1
                yield [i_start, i_stop, flag]
                y_start = np.copy(y_stop)
                i_start = i_stop

        except StopIteration:
            yield [i_start, i_stop, flag]


import numpy as np

from pytest import approx

def gen_inc_dec_point(y_data: np.ndarray):
    i_start = 0
    i_stop = 0
    it = np.nditer(y_data)
    it2 = np.nditer(y_data)
    y_start = next(it)
    y_stop = next(it2)
    y = next(it2)

    # flag 1 means y values of the function is increasing
    # flag 0 indicate that there is no change in the values of y (constant)
    # flag -1 indicate the y values is decreasing
    flag = 0

    def advance_it(f_test):
        nonlocal i_start, i_stop, it, it2, y_start, y_stop, y, flag
        while f_test(y_stop, y):
            y_stop = next(it)
            y = next(it2)
            i_stop += 1
        yield [i_start, i_stop, flag]
        y_start = np.copy(y_stop)
        i_start = i_stop

    while True:
            if y_start == approx(y):
                flag = 0
                advance_it(lambda a, b: a == approx(b))
            elif y_start < y:
                flag = 1
                advance_it(lambda a, b: a < b)

                flag = -1
                advance_it(lambda a, b: a > b)

        except StopIteration:
            yield [i_start, i_stop, flag]

下面是一些测试数据,以备各位决定测试。(摘自Stewart,J.,Redlin,L.,&沃森,S.(2016).代数和三角学第四版. Cengage Learning.第209-211页)

def gen_data():
    x = np.hstack((
        np.linspace(-2.5, -1, endpoint=False),
        np.linspace(-1, 0, endpoint=False),
        np.linspace(0, 2, endpoint=False),
        np.linspace(2, 3.5)
    fx = 12 * x**2 + 4 * x**3 - 3 * x**4
    return x, fx


expected_results = [
    [0, 50, 1],
    [50, 100, -1],
    [100, 150, 1],
    [150, 199, -1]]





import numpy as np

from pytest import approx

def gen_inc_dec_point(y: np.ndarray):
    i_start = 0
    i_stop = 0
    it = np.nditer(y)
    it2 = np.nditer(y)
    y_start = next(it)
    y_stop = next(it2)
    y = next(it2)

    # flag 1 means y values of the function is increasing
    # flag 0 indicate a y values is constants
    # flag -1 incidcate a y values is decreasing
    flag = 0

    def advance_it(f_test):
        nonlocal i_start, i_stop, y_start, y_stop, y
        while f_test(y_stop, y):
            i_stop += 1
            y_stop = next(it)
            y = next(it2)
        result = [i_start, i_stop, flag]
        y_start = np.copy(y_stop)
        i_start = i_stop
        return result

    while True:
            if y_start == approx(y):
                flag = 0
                yield advance_it(lambda a, b: a == approx(b))
            elif y_start < y:
                flag = 1
                yield advance_it(lambda a, b: a < b)

                flag = -1
                yield advance_it(lambda a, b: a > b)

        except StopIteration:
            yield [i_start, i_stop, flag]

谢谢你的提醒,顺便说一下,it, it2, flag似乎不必声明为nonlocal,因为它们没有赋值语句。
