如何在不传递参数的情况下理解 Package 的Python代码

n53p2ov0  于 2023-05-19  发布在  Python
关注(0)|答案(1)|浏览(202)

我正在尝试理解这几行Python代码。

def _wrapped_partitioned_step(
        state, prng_key, inputs, unpadded_global_batch_size=None
    ):
    del unpadded_global_batch_size
    return partitioned_step_fn(state, prng_key, inputs)

return _wrapped_partitioned_step, None

在这里,它试图返回一个 Package 函数。但我很困惑为什么它不需要传递一个值,例如state、prng_key、inputs等。

lokaqttq

lokaqttq1#

这会将函数返回给包含方法(partition())的调用方,并从那里使用参数调用该函数。Partitioner类的文档字符串显示了如何使用它:

# Partition the step function.
  partitioned_step_fn, input_pspec = partitioner.partition(
      step_fn, inputs_shape_dtype, is_eval)

  # Split and preprocess the prng key.
  prng_key, train_key = jax.random.split(root_prng_key)
  train_key = partitioner.preprocess_prng_key(train_key)

  # Get the inputs, preprocess and use it to run the partitioned function.
  inputs = train_input_pipeline.get_next_padded()
  inputs = partitioner.preprocess_inputs(
      train_input_pipeline, inputs, input_pspec)
  partitioned_step_fn(
      train_state, train_key, inputs, unpadded_global_batch_size)

partitioned_step_function是您发布的代码返回的函数,参数通过

partitioned_step_fn(train_state, train_key, inputs, unpadded_global_batch_size)

相关问题