numpy 在Python中应该如何使用生成器?

tez616oj  于 2022-12-23  发布在  Python
关注(0)|答案(1)|浏览(153)

我尝试创建一个简单的python函数来定位函数的增加点和减少点(在线性代数中)。

import numpy as np

from pytest import approx

def gen_inc_dec_point(y_data: np.ndarray):
    i_start = 0
    i_stop = 0
    it = np.nditer(y_data)
    it2 = np.nditer(y_data)
    y_start = next(it2)
    y_stop = next(it)
    y = next(it2)

    # flag 1 means y values of the function is increasing
    # flag 0 indicate that there is no change in the values of y (constant)
    # flag -1 indicate the y values is decreasing
    flag = 0

    while True:
        try:
            if y_start == approx(y):
                flag = 0
                while y_stop == approx(y):
                    y_stop = next(it)
                    y = next(it2)
                    i_stop += 1
                yield [i_start, i_stop, flag]
                y_start = np.copy(y_stop)
                i_start = i_stop
                
            elif y_start < y:
                flag = 1
                while y_stop < y:
                    y_stop = next(it)
                    y = next(it2)
                    i_stop += 1
                yield [i_start, i_stop, flag]
                y_start = np.copy(y_stop)
                i_start = i_stop

            else:
                flag = -1
                while y_stop > y:
                    y_stop = next(it)
                    y = next(it2)
                    i_stop += 1
                yield [i_start, i_stop, flag]
                y_start = np.copy(y_stop)
                i_start = i_stop

        except StopIteration:
            yield [i_start, i_stop, flag]
            break

我用pytest做了几个测试,结果都和我想的一样,但后来我决定重构它(代码重复很少),结果如下:

import numpy as np

from pytest import approx

def gen_inc_dec_point(y_data: np.ndarray):
    i_start = 0
    i_stop = 0
    it = np.nditer(y_data)
    it2 = np.nditer(y_data)
    y_start = next(it)
    y_stop = next(it2)
    y = next(it2)

    # flag 1 means y values of the function is increasing
    # flag 0 indicate that there is no change in the values of y (constant)
    # flag -1 indicate the y values is decreasing
    flag = 0

    def advance_it(f_test):
        nonlocal i_start, i_stop, it, it2, y_start, y_stop, y, flag
        while f_test(y_stop, y):
            y_stop = next(it)
            y = next(it2)
            i_stop += 1
        yield [i_start, i_stop, flag]
        y_start = np.copy(y_stop)
        i_start = i_stop

    while True:
        try:
            if y_start == approx(y):
                flag = 0
                advance_it(lambda a, b: a == approx(b))
                
            elif y_start < y:
                flag = 1
                advance_it(lambda a, b: a < b)

            else:
                flag = -1
                advance_it(lambda a, b: a > b)

        except StopIteration:
            yield [i_start, i_stop, flag]
            break

但是它不起作用。因为声明在nonlocal中的变量好像从来没有更新过,所以出现了一个无限循环。你们能帮我找出我的代码出了什么问题吗?
下面是一些测试数据,以备各位决定测试。(摘自Stewart,J.,Redlin,L.,&沃森,S.(2016).代数和三角学第四版. Cengage Learning.第209-211页)

def gen_data():
    x = np.hstack((
        np.linspace(-2.5, -1, endpoint=False),
        np.linspace(-1, 0, endpoint=False),
        np.linspace(0, 2, endpoint=False),
        np.linspace(2, 3.5)
    ))
    fx = 12 * x**2 + 4 * x**3 - 3 * x**4
    return x, fx

从书上看,它说函数应该以x间隔增加(-inf,-1),以x间隔减小(-1,0),增加至(0,2),最终在以下位置下降(2,inf).使用numpy,我测试x在范围内(-2.5,3.5)。从已经用pandas导出到csv的x数据和y数据中,我观察到x数据和y数据的索引为:

expected_results = [
    [0, 50, 1],
    [50, 100, -1],
    [100, 150, 1],
    [150, 199, -1]]

例如,expected_results[0]表示x从x[0]增加到x[50]

nmpmafwu

nmpmafwu1#

基于rici的评论,我决定编辑我的代码;它起作用了。它是这样的:

import numpy as np

from pytest import approx

def gen_inc_dec_point(y: np.ndarray):
    i_start = 0
    i_stop = 0
    it = np.nditer(y)
    it2 = np.nditer(y)
    y_start = next(it)
    y_stop = next(it2)
    y = next(it2)

    # flag 1 means y values of the function is increasing
    # flag 0 indicate a y values is constants
    # flag -1 incidcate a y values is decreasing
    flag = 0

    def advance_it(f_test):
        nonlocal i_start, i_stop, y_start, y_stop, y
        while f_test(y_stop, y):
            i_stop += 1
            y_stop = next(it)
            y = next(it2)
        result = [i_start, i_stop, flag]
        y_start = np.copy(y_stop)
        i_start = i_stop
        return result

    while True:
        try:
            if y_start == approx(y):
                flag = 0
                yield advance_it(lambda a, b: a == approx(b))
                
            elif y_start < y:
                flag = 1
                yield advance_it(lambda a, b: a < b)

            else:
                flag = -1
                yield advance_it(lambda a, b: a > b)

        except StopIteration:
            yield [i_start, i_stop, flag]
            break

谢谢你的提醒,顺便说一下,it, it2, flag似乎不必声明为nonlocal,因为它们没有赋值语句。

相关问题