我尝试在cartpole环境下使用stablebaseline 3实现A2 C算法。虽然训练似乎是成功的,达到了所需的奖励,但当我尝试使用模型时,奖励似乎很低。这是我的代码。我做错了什么?
import os
import gymnasium as gym
import numpy as np
from stable_baselines3 import A2C
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3.common.env_util import make_vec_env
env_id = "CartPole-v1"
env = gym.make(env_id)
s_size = env.observation_space.shape
a_size = env.action_space
print("_____OBSERVATION SPACE_____ \n")
print("The State Space is: ", s_size)
print("Sample observation", env.observation_space.sample()) # Get a random observation
字符串
观察空间
状态空间为:(4,)样本观测值[-2.3014314e+00 4. 4097112 e +37 -4.1089469e-01 2. 7118910 e +38]
envs = make_vec_env(env_id,seed=1, n_envs=4)
envs = VecNormalize(envs, norm_obs=True, norm_reward=True, clip_obs=10.)
model = A2C(policy = "MlpPolicy",env = envs, verbose=1)
model.learn(15_000)
型
在这一步之后,我保存模型并重新加载以进行评估
model.save("a2c-"+env_id)
envs.save("vec_normalize.pkl")
型
当我加载保存的模型进行评估时,它产生的平均奖励为500
# Load the saved statistics
eval_env = DummyVecEnv([lambda: gym.make(env_id)])
eval_env = VecNormalize.load("vec_normalize.pkl", eval_env)
# We need to override the render_mode
eval_env.render_mode = "rgb_array"
# do not update them at test time
eval_env.training = False
# reward normalization is not needed at test time
eval_env.norm_reward = False
# Load the agent
model = A2C.load("a2c-"+env_id)
mean_reward, std_reward = evaluate_policy(model, eval_env)
print(f"Mean reward = {mean_reward:.2f} +/- {std_reward:.2f}")
型
平均奖励= 500.00 +/- 0.00
然而,当我尝试自己测试这个模型时,回报却很悲惨
for epi in range(10):
score = 0
state = env.reset()[0]
done = False
while not done:
a , _ = model.predict(state)
state_, r, done , _, _ = env.step(a)
score += r
state = np.copy(state_)
env.render()
print(f"Episode {epi} score {score}")
env.close()
型
第0集得分71.0
第1集得分70.0
第2集评分83.0
第3集评分62.0
第4集评分63.0
第5集评分59.0
第6集评分52.0
第7集评分54.0
第8集评分60.0
第9集评分69.0
1条答案
按热度按时间c3frrgcw1#
我终于能够解决这个问题了。看起来我需要在eval_env本身中模拟测试
字符串
494.0
297.0
359.0
402.0
406.0
500.0
500.0
500.0
371.0
371.0