python 保存函数中的所有中间变量,以防函数失败

yks3o0rb  于 2023-10-14  发布在  Python
关注(0)|答案(3)|浏览(67)

bounty还有5天到期。此问题的答案有资格获得+50声望奖励。Alex Lenail正在寻找一个规范答案

我发现自己经常遇到这类问题。我有一个函数,

def compute(input):
    result = two_hour_computation(input)
    result = post_processing(result)
    return result

post_processing(result)失败。很明显要做的是将函数改为

import pickle

def compute(input):
    result = two_hour_computation(input)
    pickle.dump(result, open('intermediate_result.pickle', 'wb'))
    result = post_processing(result)
    return result

但我通常不会记得把所有函数都写成那样。我希望我有一个像这样的室内设计师:

@return_intermediate_results_if_something_goes_wrong
def compute(input):
    result = two_hour_computation(input)
    result = post_processing(result)
    return result

这样的东西存在吗?我在google上找不到。

p8h8hvxi

p8h8hvxi1#

函数的“外部”在运行时无法访问函数内部的局部变量的状态。所以这个问题不能用装饰器来解决。
在任何情况下,我都认为捕捉错误和保存有价值的中间结果的责任应该由程序员明确地完成。如果你“忘记”了做这件事,那对你来说一定不那么重要。
话虽如此,像 “在A、B或C引发异常的情况下执行X” 这样的情况是上下文管理器的典型用例。您可以编写自己的上下文管理器,它充当中间结果的存储桶(代替变量),并在异常退出时执行一些save操作。
大概是这样的:

from __future__ import annotations
from types import TracebackType
from typing import Generic, Optional, TypeVar

T = TypeVar("T")

class Saver(Generic[T]):
    def __init__(self, initial_value: Optional[T] = None) -> None:
        self._value = initial_value

    def __enter__(self) -> Saver[T]:
        return self

    def __exit__(
        self,
        exc_type: Optional[type[BaseException]],
        exc_val: Optional[BaseException],
        exc_tb: Optional[TracebackType],
    ) -> None:
        if exc_type is not None:
            self.save()

    def save(self) -> None:
        print(f"saved {self.value}!")

    @property
    def value(self) -> T:
        if self._value is None:
            raise RuntimeError
        return self._value

    @value.setter
    def value(self, value: T) -> None:
        self._value = value

显然,你可以这样做,而不是在save中使用print(f"saved {self.value}!")

with open('intermediate_result.pickle', 'wb') as f:
            pickle.dump(self.value, f)

现在,您需要记住的是将这些操作 Package 在with-语句中,并将中间结果分配给上下文管理器的value属性。演示:

def x_times_2(x: float) -> float:
    return x * 2

def one_over_x_minus_2(x: float) -> float:
    return 1 / (x - 2)

def main() -> None:
    with Saver(1.) as s:
        s.value = x_times_2(s.value)
        s.value = one_over_x_minus_2(s.value)
    print(s.value)

if __name__ == "__main__":
    main()

输出:

saved 2.0!
Traceback (most recent call last):
  [...]
    return 1 / (x - 2)
           ~~^~~~~~~~~
ZeroDivisionError: float division by zero

正如您所看到的,中间计算值2.0被“保存”了,即使下一个函数引发了异常。
值得注意的是,在本例中,上下文管理器仅在遇到异常时才调用save,而不是在上下文“和平”退出时。如果你愿意的话,你当然可以无条件地这么做。
这可能不像只是在函数上添加装饰器那么方便,但它可以完成工作。在我看来,你必须有意识地在这种情况下 Package 你的重要行动,这是一件好事,因为它教会你特别注意这些事情。
这是在Python中实现数据库事务之类的东西的典型方法(例如,在SQLAlchemy中)。

PS

为了公平起见,我可能应该对我最初的陈述做一点修改。当然,你可以在函数中使用non-localstate,尽管这通常是不被鼓励的。用超级简单的术语来说,如果在你的例子中result是一个全局变量(你在函数中声明了global result),这实际上可以通过装饰器来解决。但我不推荐这种方法,因为全局状态是一种反模式。(它仍然需要你记住每次使用你为该作业指定的任何全局变量。

tag5nh1u

tag5nh1u2#

我并不是说这是一个好主意,但在函数引发异常后,读取函数的局部变量是可能的:

try:
    my_func_that_raises(...)
except Exception as e:
    traceback = e.__traceback__
    function_frame = traceback.tb_next.tb_frame
    all_local_variables_until_crash = function_frame.f_locals

你当然可以把它 Package 在一个装饰器里

from collections.abc import Callable
from typing import Any, ParamSpec, TypeVar

P = ParamSpec('P')
R = TypeVar('R')

def save_by_inspect(f: Callable[P, R]) -> Callable[P, R]:
    def save(*args: P.args, **kwargs: P.kwargs) -> Any:
        try:
            return f(*args, **kwargs)  # call the function
        except Exception as e:  # oh no, it crashed
            tb = e.__traceback__
            assert tb is not None  # you can implement a different strategy when the traceback is None
            function_tb = tb.tb_next
            assert function_tb is not None
            print(function_tb.tb_frame.f_locals)  # save the local variables
            raise

    return save

这个实现只打印局部变量(Mapping[str, Any])。
用法是这样的:

@save_by_inspect
def foo(crash: bool) -> int:
    expensive = 1 + 2 + 3
    if crash:
        raise RuntimeError("i crashed")

    return expensive + 123

assert foo(crash=False) == 129

foo(crash=True)

assert成功,第二个调用打印局部变量({'crash': True, 'expensive': 6}),然后再次引发RuntimeError。
但还有一个选择。如果你记得在你的函数中加入一个装饰器,你还可以添加一个最小干扰的安全特性。这个想法是yield每个结果,这是重要的,然后return最终结果。所以你的函数看起来像这样:

R = TypeVar('R')  # part of the library
Saved = Generator[Any, None, R]  # part of the library

@save_by_yield  # part of the library
def foo2(crash: bool) -> Saved[int]:  # user code
    x = 1
    yield x  # this is saved on crash
    if crash:
        raise RuntimeError("i crashed")
    return 123  # the final result

然后装饰器可以运行函数(不是像OP建议的那样逐行运行,而是逐yield运行)并收集产生的结果:

from collections.abc import Callable, Generator
from typing import Any, ParamSpec, TypeVar, cast

P = ParamSpec('P')
R = TypeVar('R')
Saved = Generator[Any, None, R]

def save_by_yield(f: Callable[P, Saved[R]]) -> Callable[P, R]:
    def save(*args: P.args, **kwargs: P.kwargs) -> R:
        generator = f(*args, **kwargs)
        save_this = []

        while True:
            try:
                save_this.append(next(generator))  # get an expensive result
            except StopIteration as stop:  # function is done
                return cast(R, stop.value)  # StopIteration is not generic :(
            except Exception:  # function crashed
                print(save_this)
                raise
    return save

每当生成器产生时,我们存储结果。当它引发Exception时,存储的结果被保存(print),当它完成时(StopIteration),结果只是返回。
注意:decorator的类型是正确的。即reveal_type(foo(...))int,当你忘记产生一个结果,mypy会抱怨这个
不兼容的返回值类型(得到“int”,应为“Generator[Any,None,int]”)
不是很漂亮,但也是个好东西。
注2:我省略了@functools.wraps以缩短代码

yftpprvb

yftpprvb3#

这可以用装饰器来完成,但装饰器应该在底层函数上,主要是因为它要简单得多。假设您只想重用计算,那么缓存系统应该是理想的。
我实现该高速缓存的方式非常简单,装饰器@cache.result获取调用签名(函数名和参数的md5),无论函数返回什么,如果函数完成,结果都作为文件写入磁盘; @cache.with_key('key')它是相同的,但全局的,任何函数,得到装饰与它和相同的关键字将返回相同的缓存值;在这两种情况下,装饰器都没有向所使用的函数添加额外的代码或复杂性。

import os
import io
import pickle
import hashlib

class Cache:

    cache_dir = os.path.join(
        os.path.dirname(__file__), 'cache'
    )

    def __init__(self, log=False):
        self.log_enabled = log

        if not os.path.exists(Cache.cache_dir):
            os.mkdir(Cache.cache_dir)

    def __log(self, text):
        if self.log_enabled:
            print(text)

    def _call_to_key(self, func, args, kwargs):
        key = str(func.__name__)
        buff = io.BytesIO()

        for e in [args, kwargs]:
            pickle.dump(e, buff)

            md5_str = hashlib.md5(
                buff.getvalue(), 
                usedforsecurity=False
            ).hexdigest()
            key += f'-{md5_str}'

            buff.seek(0)
            buff.truncate(0)

        buff.close()

        return key

    def __cache_or_invoke(self, key, func, args, kwargs):
        cache_key = key or self._call_to_key(func, args, kwargs)
        cache_file = os.path.join(cache.cache_dir, cache_key)

        # Every key is a file name in the folder
        if cache_key in os.listdir(Cache.cache_dir):
            self.__log(f'Cache hit  {cache_key}')

            with open(cache_file, 'rb') as file:
                return pickle.load(file)

        else:
            self.__log(f'Cache miss {cache_key}')

            ret = func(*args, **kwargs)

            with open(cache_file, 'wb') as file:
                pickle.dump(ret, file)

            return ret

    def with_key(self, key):
        def decorator(func):
            def wrapper(*args, **kwargs):
                return self.__cache_or_invoke(key, func, args, kwargs);
            return wrapper
        return decorator

    def result(self, func):
        def wrapper(*args, **kwargs):
            return self.__cache_or_invoke(None, func, args, kwargs);
        return wrapper

    def clear(self, key=None):
        if os.path.exists(Cache.cache_dir):
            for file in os.listdir(Cache.cache_dir):

                if key is not None:
                    if file != key:
                        continue            
                    if file.split('-', 1)[0] != key:
                        continue

                os.remove(
                    os.path.join(Cache.cache_dir, file)
                )
                self.__log(f'Cache cleared {file}')
cache = Cache(log=False)

class Rational:
    def __init__(self, den, num):
        self.den = den
        self.num = num

    def __str__(self):
        return f'{self.den}/{self.num}'

@cache.result
def half(rat):
    rat.num *= 2
    return rat

@cache.with_key('float')
def to_float(rat):
    return rat.den / rat.num

if __name__ == '__main__':
    f1 = Rational(1, 3)
    f2 = Rational(5, 2)

    # If not cleared the next [to_float]
    # will return the cache
    cache.clear(key='float')

    half1 = half(f1)
    print(half1)
    print(to_float(half1))

    half2 = half(f2)
    print(half2)
    # Wrong value, cache is set 
    # from the previous call
    print(to_float(half2))

    # cache.clear()

顺便说一下,没有必要使用cache.<func>作为装饰器,它也可以用作常规函数。

cache.result(half)(half2)
cache.with_key('float')(to_float)(half2)
cache.with_key('float')(lambda _: None)(half2)

相关问题