我正在使用Ray RLlib对PPO代理进行培训,并对PPOTFPolicy进行了两次修改。
- 我在“build_tf_policy()"中的“mixins”参数中添加了一个mixin类(比如“Recal”)。这样,PPOTFPolicy将成为我的“Recal”类的子类,并且可以访问我在“Recal”中定义的成员函数。我的“Recal”类是tf.keras.Model的一个简单子类。
- 我定义了一个“my_postprocess_fn”函数来替换“compute_gae_for_sample_batch”函数,该函数被赋予“build_tf_policy()"中的参数“postprocess_fn”。
“PPOTrainer=build_trainer(...)”函数保持不变。我使用framework=“tf”,并禁用了急切模式。
下面是伪代码。Here是colab的一个运行版本。
tf.compat.v1.disable_eager_execution()
class Recal:
def __init__(self):
self.recal_model = build_and_compile_keras_model()
def my_postprocess_fn(policy, sample_batch):
with policy.model.graph.as_default():
sample_batch = policy.recal_model.predict(sample_batch)
return compute_gae_for_sample_batch(policy, sample_batch)
PPOTFPolicy = build_tf_policy(..., postprocess_fn=my_postprocess_fn, mixins=[..., Recal])
PPOTrainer = build_trainer(...)
ppo_trainer = PPOTrainer(config=DEFAULT_CONFIG, env="CartPole-v0")
for i in range(1):
result = ppo_trainer.train()
这样,“Recal”类就是PPOTFPolicy的基类,当创建PPOTFPolicy的示例时,“Recal”在同一tensorflow 图中被示例化。但是当my_postprocess_fn()被调用时,它会引发一个错误(见下文)。
tensorflow.python.framework.errors_impl.FailedPreconditionError: Could not find variable default_policy_wk1/my_model/dense/kernel. This could mean that the variable has been deleted. In TF1, it can also mean the variable is uninitialized. Debug info: container=localhost, status error message=Container localhost does not exist. (Could not find resource: localhost/default_policy_wk1/my_model/dense/kernel)
[[{{node default_policy_wk1/my_model_1/dense/MatMul/ReadVariableOp}}]]
1条答案
按热度按时间bq3bfh9z1#
我已经和雷一起探索了一段时间了。所以我想我可以给予你这个问题的答案。
Ray使用它自己的Model类版本。这个类没有tf.keras.Model.predict方法来获取批处理预测。但是它提供了其他选项。
我还没有弄清楚两个类的输出是否相等。在搜索这个问题的答案的过程中,我只遇到了你的问题。如果你看到这个,我很乐意继续对话。:)