我正在尝试理解这几行Python代码。
def _wrapped_partitioned_step(
state, prng_key, inputs, unpadded_global_batch_size=None
):
del unpadded_global_batch_size
return partitioned_step_fn(state, prng_key, inputs)
return _wrapped_partitioned_step, None
在这里,它试图返回一个 Package 函数。但我很困惑为什么它不需要传递一个值,例如state、prng_key、inputs等。
1条答案
按热度按时间lokaqttq1#
这会将函数返回给包含方法(
partition()
)的调用方,并从那里使用参数调用该函数。Partitioner
类的文档字符串显示了如何使用它:partitioned_step_function
是您发布的代码返回的函数,参数通过