tensorflow 在PPOTFPolicy中调用tf.keras.model.predict()函数时,在Ray RLlib中使用图形模式会触发错误

hlswsv35  于 2022-11-25  发布在  其他
关注(0)|答案(1)|浏览(214)

我正在使用Ray RLlib对PPO代理进行培训,并对PPOTFPolicy进行了两次修改。

  • 我在“build_tf_policy()"中的“mixins”参数中添加了一个mixin类(比如“Recal”)。这样,PPOTFPolicy将成为我的“Recal”类的子类,并且可以访问我在“Recal”中定义的成员函数。我的“Recal”类是tf.keras.Model的一个简单子类。
  • 我定义了一个“my_postprocess_fn”函数来替换“compute_gae_for_sample_batch”函数,该函数被赋予“build_tf_policy()"中的参数“postprocess_fn”。

“PPOTrainer=build_trainer(...)”函数保持不变。我使用framework=“tf”,并禁用了急切模式。
下面是伪代码。Here是colab的一个运行版本。

tf.compat.v1.disable_eager_execution()

class Recal:
    def __init__(self):
        self.recal_model = build_and_compile_keras_model()

def my_postprocess_fn(policy, sample_batch):
    with policy.model.graph.as_default():
        sample_batch = policy.recal_model.predict(sample_batch)
    return compute_gae_for_sample_batch(policy, sample_batch)

PPOTFPolicy = build_tf_policy(..., postprocess_fn=my_postprocess_fn, mixins=[..., Recal])
PPOTrainer = build_trainer(...)
ppo_trainer = PPOTrainer(config=DEFAULT_CONFIG, env="CartPole-v0")

for i in range(1):
    result = ppo_trainer.train()

这样,“Recal”类就是PPOTFPolicy的基类,当创建PPOTFPolicy的示例时,“Recal”在同一tensorflow 图中被示例化。但是当my_postprocess_fn()被调用时,它会引发一个错误(见下文)。

tensorflow.python.framework.errors_impl.FailedPreconditionError: Could not find variable default_policy_wk1/my_model/dense/kernel. This could mean that the variable has been deleted. In TF1, it can also mean the variable is uninitialized. Debug info: container=localhost, status error message=Container localhost does not exist. (Could not find resource: localhost/default_policy_wk1/my_model/dense/kernel)
     [[{{node default_policy_wk1/my_model_1/dense/MatMul/ReadVariableOp}}]]
bq3bfh9z

bq3bfh9z1#

我已经和雷一起探索了一段时间了。所以我想我可以给予你这个问题的答案。
Ray使用它自己的Model类版本。这个类没有tf.keras.Model.predict方法来获取批处理预测。但是它提供了其他选项。
我还没有弄清楚两个类的输出是否相等。在搜索这个问题的答案的过程中,我只遇到了你的问题。如果你看到这个,我很乐意继续对话。:)

相关问题