与训练相比,加载Pytorch权重导致性能较差

g6ll5ycj  于 2023-05-07  发布在  其他
关注(0)|答案(1)|浏览(182)

我目前正在使用this video(Github代码位于here)进行PPO强化学习。训练进行得很顺利,模型似乎能够学习如何保持杆子,但当我加载模型时,它的表现就像它刚刚初始化为随机权重一样。文件的时间戳对应于它上次保存的时间,所以我肯定会用最新版本的网络覆盖检查点文件。
我的inference.py文件如下:

import gym
import numpy as np
from ppo_torch import Agent
import time
import sys

env = gym.make('CartPole-v1', render_mode = "human")

batch_size = 16
n_epochs = 5
alpha = 0.0003

agent = Agent(n_actions=env.action_space.n, batch_size=batch_size, 
                    alpha=alpha, n_epochs=n_epochs, 
                    input_dims=env.observation_space.shape)

agent.load_models()
agent.critic.eval()
agent.actor.eval()

while True:
    observation, _ = env.reset()
    done = False
    while not done:
        start = time.time()
        action, prob, val = agent.choose_action(observation)
        observation_, reward, done, info, _ = env.step(action)

        env.render()
        end = time.time()
        print(f'{(1/(end-start)):.2f}', end="\r", flush=True)

我所做的唯一更改是在ActorNetwork和CriticNetwork中的self.checkpoint_file中添加“.pth”。所以Actor和Critic self.checkpoint_file路径是:

self.checkpoint_file = os.path.join(chkpt_dir, 'actor_torch_ppo.pth')
#...
self.checkpoint_file = os.path.join(chkpt_dir, 'critic_torch_ppo.pth')
dxxyhpgq

dxxyhpgq1#

找到问题了。实际上我并没有在环境中迈出一步后进行观察:

observation_, reward, done, info, _ = env.step(action)

注意我是如何得到observation_的,但仍然使用observation来得到下一个动作:

action, prob, val = agent.choose_action(observation)

这是环境重置时的初始观察结果。我解决了这个问题,它现在按预期工作。

相关问题