aim [BUG] Stable-Baselines3集成无法工作

wyyhbhjk  于 25天前  发布在  其他
关注(0)|答案(4)|浏览(20)

🐛 Bug

你好!我一直在尝试使用AIM来跟踪Stable-Baselines3项目的指标。我尝试使用AimCallback(https://aimstack.readthedocs.io/en/latest/quick_start/integrations.html)与Stable-Baselines3一起监控日志和指标,以便在模型学习过程中进行监控。然而,我遇到了一个问题,那就是AIM运行完全没有跟踪任何指标。具体来说,我感兴趣的是跟踪诸如explained_variance和loss等指标。尽管如此,当我检查指标选项卡时,没有任何指标显示出来。

重现问题

import os
import gymnasium
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy
from aim.sb3 import AimCallback

environment_name = 'CartPole-v1'
env = gymnasium.make(environment_name)

env = DummyVecEnv([lambda: env])
model = PPO('MlpPolicy', env, verbose = 1)

model.learn(total_timesteps=10_000, callback=AimCallback(repo='.', experiment_name='example_experiment'))

我做错了什么吗?我遵循了AIM文档中指示的步骤,但没有跟踪到任何指标(https://aimstack.readthedocs.io/en/latest/quick_start/integrations.html)
非常感谢!

bt1cpqcv

bt1cpqcv1#

嘿,@eltonjohnfanboy!抱歉回复晚了,感谢你打开这个问题。我们的回调中有一个问题,我将确保在即将发布的版本中包含修复。在此期间,您可以使用以下脚本作为解决方法来跟踪指标:

from typing import Any, Dict, Tuple, Union
import os
import gymnasium
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.logger import HumanOutputFormat, KVWriter, Logger

from aim import Run
import numpy as np

class AimOutputFormat(KVWriter):
    """
    Track key/value pairs into Aim run.
    """

    def __init__(
        self,
        aim_run
    ):
        self.aim_run = aim_run

    def write(
        self,
        key_values: Dict[str, Any],
        key_excluded: Dict[str, Union[str, Tuple[str, ...]]],
        step: int = 0,
    ) -> None:
        for (key, value), (_, excluded) in zip(
            sorted(key_values.items()), sorted(key_excluded.items())
        ):
            if excluded is not None and 'aim' in excluded:
                continue

            if isinstance(value, np.ScalarType):
                if not isinstance(value, str):
                    tag, key = key.split('/')
                    if tag in ['train', 'valid']:
                        context = {'subset': tag}
                    else:
                        context = {'tag': tag}

                    self.aim_run.track(value, key, step=step, context=context)

run = Run()
loggers = Logger(
    folder=None,
    output_formats=[AimOutputFormat(run)],
)

environment_name = 'CartPole-v1'
env = gymnasium.make(environment_name)

env = DummyVecEnv([lambda: env])
model = PPO('MlpPolicy', env, verbose = 1)
model.set_logger(loggers)
model.learn(total_timesteps=10_000)
ddrv8njm

ddrv8njm2#

你好,@mihran113,不用担心。非常感谢你的工作和更新!期待即将发布的版本,并在这段时间内使用脚本。:)

k4aesqcs

k4aesqcs3#

嘿,@eltonjohnfanboy!aim的新版本已经发货(3.23.0),其中包含了这个问题的修复。请让我知道一切是否如预期般工作,这样我就可以关闭这个问题了。

8xiog9wr

8xiog9wr4#

嗨,@mihran113,抱歉回复晚了。感谢更新;现在它可以正常工作了!:)

相关问题