diff --git a/nnabla_rl/environments/__init__.py b/nnabla_rl/environments/__init__.py index 3088e9aa..7d496356 100644 --- a/nnabla_rl/environments/__init__.py +++ b/nnabla_rl/environments/__init__.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ # limitations under the License. from gym.envs.registration import register +from gymnasium.envs.registration import register as gymnasium_register from nnabla_rl.environments.dummy import (DummyAtariEnv, DummyContinuous, DummyContinuousActionGoalEnv, DummyDiscrete, # noqa DummyDiscreteActionGoalEnv, DummyDiscreteImg, DummyContinuousImg, @@ -21,7 +22,8 @@ DummyTupleContinuous, DummyTupleDiscrete, DummyTupleMixed, DummyTupleStateContinuous, DummyTupleStateDiscrete, DummyTupleActionContinuous, DummyTupleActionDiscrete, - DummyHybridEnv) + DummyHybridEnv, + DummyGymnasiumAtariEnv, DummyGymnasiumMujocoEnv) register( id='FakeMujocoNNablaRL-v1', @@ -87,3 +89,16 @@ entry_point='nnabla_rl.environments.dummy:DummyHybridEnv', max_episode_steps=10 ) + + +gymnasium_register( + id='FakeGymnasiumMujocoNNablaRL-v1', + entry_point='nnabla_rl.environments.dummy:DummyGymnasiumMujocoEnv', + max_episode_steps=10 +) + +gymnasium_register( + id='FakeGymnasiumAtariNNablaRLNoFrameskip-v1', + entry_point='nnabla_rl.environments.dummy:DummyGymnasiumAtariEnv', + max_episode_steps=10 +) diff --git a/nnabla_rl/environments/dummy.py b/nnabla_rl/environments/dummy.py index 738a0c2f..e601fa55 100644 --- a/nnabla_rl/environments/dummy.py +++ b/nnabla_rl/environments/dummy.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,8 +16,10 @@ from typing import TYPE_CHECKING, cast import gym +import gymnasium import numpy as np from gym.envs.registration import EnvSpec +from gymnasium.envs.registration import EnvSpec as GymnasiumEnvSpec if TYPE_CHECKING: from gym.utils.seeding import RandomNumberGenerator @@ -309,3 +311,88 @@ def __init__(self, max_episode_steps=None): super(DummyHybridEnv, self).__init__(max_episode_steps=max_episode_steps) self.action_space = gym.spaces.Tuple((gym.spaces.Discrete(5), gym.spaces.Box(low=0.0, high=1.0, shape=(5, )))) self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(5, )) + + +# =========== gymnasium ========== +class AbstractDummyGymnasiumEnv(gymnasium.Env): + def __init__(self, max_episode_steps): + self.spec = GymnasiumEnvSpec('dummy-v0', max_episode_steps=max_episode_steps) + self._episode_steps = 0 + + def reset(self): + self._episode_steps = 0 + return self.observation_space.sample(), {} + + def step(self, a): + next_state = self.observation_space.sample() + reward = np.random.randn() + terminated = False + if self.spec.max_episode_steps is None: + truncated = False + else: + truncated = bool(self._episode_steps < self.spec.max_episode_steps) + info = {'rnn_states': {'dummy_scope': {'dummy_state1': 1, 'dummy_state2': 2}}} + self._episode_steps += 1 + return next_state, reward, terminated, truncated, info + + +class DummyGymnasiumAtariEnv(AbstractDummyGymnasiumEnv): + class DummyALE(object): + def __init__(self): + self._lives = 100 + + def lives(self): + self._lives -= 1 + if self._lives < 0: + self._lives = 100 + return self._lives + + # seeding.np_random outputs np_random and seed + np_random = cast("RandomNumberGenerator", nnabla_rl.random.drng) + + def __init__(self, done_at_random=True, max_episode_length=None): + super(DummyGymnasiumAtariEnv, self).__init__( + max_episode_steps=max_episode_length) + self.action_space = gymnasium.spaces.Discrete(4) + self.observation_space = gymnasium.spaces.Box( + low=0, high=255, shape=(84, 84, 3), dtype=np.uint8) + self.ale = DummyGymnasiumAtariEnv.DummyALE() + self._done_at_random = done_at_random + self._max_episode_length = max_episode_length + self._episode_length = None + + def step(self, action): + assert self._episode_length is not None + observation = self.observation_space.sample() + self._episode_length += 1 + if self._done_at_random: + done = bool(self.np_random.integers(10) == 0) + else: + done = False + if self._max_episode_length is not None: + done = (self._max_episode_length <= self._episode_length) or done + return observation, 1.0, done, {'needs_reset': False} + + def reset(self): + self._episode_length = 0 + return self.observation_space.sample() + + def get_action_meanings(self): + return ['NOOP', 'FIRE', 'LEFT', 'RIGHT'] + + +class DummyGymnasiumMujocoEnv(AbstractDummyGymnasiumEnv): + def __init__(self, max_episode_steps=None): + super(DummyGymnasiumMujocoEnv, self).__init__(max_episode_steps=max_episode_steps) + self.action_space = gymnasium.spaces.Box(low=0.0, high=1.0, shape=(5, )) + self.observation_space = gymnasium.spaces.Box(low=0.0, high=1.0, shape=(5, )) + + def get_dataset(self): + dataset = {} + datasize = 2000 + dataset['observations'] = np.stack([self.observation_space.sample() for _ in range(datasize)], axis=0) + dataset['actions'] = np.stack([self.action_space.sample() for _ in range(datasize)], axis=0) + dataset['rewards'] = np.random.randn(datasize, 1) + dataset['terminals'] = np.random.randint(2, size=(datasize, 1)) + dataset['timeouts'] = np.zeros((datasize, 1)) + return dataset diff --git a/nnabla_rl/environments/wrappers/__init__.py b/nnabla_rl/environments/wrappers/__init__.py index 238a2625..927dab3b 100644 --- a/nnabla_rl/environments/wrappers/__init__.py +++ b/nnabla_rl/environments/wrappers/__init__.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,3 +20,4 @@ from nnabla_rl.environments.wrappers.atari import make_atari, wrap_deepmind # noqa from nnabla_rl.environments.wrappers.hybrid_env import (EmbedActionWrapper, FlattenActionWrapper, # noqa RemoveStepWrapper, ScaleActionWrapper, ScaleStateWrapper) +from nnabla_rl.environments.wrappers.gymnasium import Gymnasium2GymWrapper # noqa diff --git a/nnabla_rl/environments/wrappers/atari.py b/nnabla_rl/environments/wrappers/atari.py index 0e25c7aa..f47ac09f 100644 --- a/nnabla_rl/environments/wrappers/atari.py +++ b/nnabla_rl/environments/wrappers/atari.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,10 +16,12 @@ import cv2 import gym +import gymnasium import numpy as np from gym import spaces import nnabla_rl as rl +from nnabla_rl.environments.wrappers.gymnasium import Gymnasium2GymWrapper from nnabla_rl.external.atari_wrappers import (ClipRewardEnv, EpisodicLifeEnv, FireResetEnv, MaxAndSkipEnv, NoopResetEnv, ScaledFloatFrame) @@ -97,8 +99,14 @@ def __array__(self, dtype=None): return out -def make_atari(env_id, max_frames_per_episode=None): - env = gym.make(env_id) +def make_atari(env_id, max_frames_per_episode=None, use_gymnasium=False): + if use_gymnasium: + env = gymnasium.make(env_id) + env = Gymnasium2GymWrapper(env) + # gymnasium env is not wrapped TimeLimit wrapper + env = gym.wrappers.TimeLimit(env, max_episode_steps=env.spec.kwargs["max_num_frames_per_episode"]) + else: + env = gym.make(env_id) if max_frames_per_episode is not None: env = env.unwrapped env = gym.wrappers.TimeLimit(env, max_episode_steps=max_frames_per_episode) diff --git a/nnabla_rl/environments/wrappers/gymnasium.py b/nnabla_rl/environments/wrappers/gymnasium.py new file mode 100644 index 00000000..ce4c4d52 --- /dev/null +++ b/nnabla_rl/environments/wrappers/gymnasium.py @@ -0,0 +1,88 @@ +# Copyright 2024 Sony Group Corporation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import gym +from gym import spaces as gym_spaces +from gymnasium import spaces as gymnasium_spaces +from gymnasium.utils import seeding + + +class Gymnasium2GymWrapper(gym.Wrapper): + def __init__(self, env): + if isinstance(env, gym.Env) or isinstance(env, gym.Wrapper): + raise ValueError("'env' should not be an instance of 'gym.Env' and 'gym.Wrapper'") + + super().__init__(env) + + # observation space + if isinstance(env.observation_space, gymnasium_spaces.Tuple): + self.observation_space = gym_spaces.Tuple( + [self._translate_space(observation_space) + for observation_space in env.observation_space] + ) + elif isinstance(env.observation_space, gymnasium_spaces.Dict): + self.observation_space = gym_spaces.Dict( + {key: self._translate_space(observation_space) + for key, observation_space in env.observation_space.items()} + ) + else: + self.observation_space = self._translate_space(env.observation_space) + + # action space + if isinstance(env.action_space, gymnasium_spaces.Tuple): + self.action_space = gym_spaces.Tuple( + [self._translate_space(action_space) + for action_space in env.action_space] + ) + elif isinstance(env.action_space, gymnasium_spaces.Dict): + self.action_space = gym_spaces.Dict( + {key: self._translate_space(action_space) + for key, action_space in env.action_space.items()} + ) + else: + self.action_space = self._translate_space(env.action_space) + + def reset(self): + obs, _ = self.env.reset() + return obs + + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + done = (terminated or truncated) + info.update({"TimeLimit.truncated": truncated}) + return obs, reward, done, info + + def seed(self, seed: Optional[int] = None): + np_random, seed = seeding.np_random(seed) + self.env.np_random = np_random # type: ignore + return [seed] + + @property + def unwrapped(self): + return self + + def _translate_space(self, space): + if isinstance(space, gymnasium_spaces.Box): + return gym_spaces.Box( + low=space.low, + high=space.high, + shape=space.shape, + dtype=space.dtype + ) + elif isinstance(space, gymnasium_spaces.Discrete): + return gym_spaces.Discrete(n=int(space.n)) + else: + raise NotImplementedError diff --git a/nnabla_rl/utils/reproductions.py b/nnabla_rl/utils/reproductions.py index babc2a88..15fb804a 100644 --- a/nnabla_rl/utils/reproductions.py +++ b/nnabla_rl/utils/reproductions.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,12 +16,14 @@ import random as py_random import gym +import gymnasium import numpy as np import nnabla as nn import nnabla_rl as rl from nnabla_rl.environments.environment_info import EnvironmentInfo -from nnabla_rl.environments.wrappers import NumpyFloat32Env, ScreenRenderEnv, make_atari, wrap_deepmind +from nnabla_rl.environments.wrappers import (Gymnasium2GymWrapper, NumpyFloat32Env, ScreenRenderEnv, make_atari, + wrap_deepmind) from nnabla_rl.logger import logger @@ -53,11 +55,14 @@ def build_atari_env(id_or_env, print_info=True, max_frames_per_episode=None, frame_stack=True, - flicker_probability=0.0): + flicker_probability=0.0, + use_gymnasium=False): if isinstance(id_or_env, gym.Env): env = id_or_env + elif isinstance(id_or_env, gymnasium.Env): + env = Gymnasium2GymWrapper(id_or_env) else: - env = make_atari(id_or_env, max_frames_per_episode=max_frames_per_episode) + env = make_atari(id_or_env, max_frames_per_episode=max_frames_per_episode, use_gymnasium=use_gymnasium) if print_info: print_env_info(env) @@ -75,7 +80,7 @@ def build_atari_env(id_or_env, return env -def build_mujoco_env(id_or_env, test=False, seed=None, render=False, print_info=True): +def build_mujoco_env(id_or_env, test=False, seed=None, render=False, print_info=True, use_gymnasium=False): try: # Add pybullet env import pybullet_envs # noqa @@ -91,8 +96,14 @@ def build_mujoco_env(id_or_env, test=False, seed=None, render=False, print_info= if isinstance(id_or_env, gym.Env): env = id_or_env + elif isinstance(id_or_env, gymnasium.Env): + env = Gymnasium2GymWrapper(id_or_env) else: - env = gym.make(id_or_env) + if use_gymnasium: + env = gymnasium.make(id_or_env) + env = Gymnasium2GymWrapper(env) + else: + env = gym.make(id_or_env) if print_info: print_env_info(env) diff --git a/reproductions/algorithms/atari/a2c/a2c_reproduction.py b/reproductions/algorithms/atari/a2c/a2c_reproduction.py index 0e472be4..b0aea81d 100644 --- a/reproductions/algorithms/atari/a2c/a2c_reproduction.py +++ b/reproductions/algorithms/atari/a2c/a2c_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -28,9 +28,9 @@ def run_training(args): set_global_seed(args.seed) - train_env = build_atari_env(args.env, seed=args.seed) + train_env = build_atari_env(args.env, seed=args.seed, use_gymnasium=args.use_gymnasium) eval_env = build_atari_env( - args.env, test=True, seed=args.seed + 100, render=args.render) + args.env, test=True, seed=args.seed + 100, render=args.render, use_gymnasium=args.use_gymnasium) iteration_num_hook = H.IterationNumHook(timing=100) @@ -56,7 +56,8 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: raise ValueError('Please specify the snapshot dir for showcasing') - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=False) + eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=False, + use_gymnasium=args.use_gymnasium) config = A.A2CConfig(gpu_id=args.gpu) a2c = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(a2c, A.A2C): @@ -83,6 +84,7 @@ def main(): parser.add_argument('--save_timing', type=int, default=250000) parser.add_argument('--eval_timing', type=int, default=250000) parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument('--use-gymnasium', action='store_true') args = parser.parse_args() diff --git a/reproductions/algorithms/atari/c51/c51_reproduction.py b/reproductions/algorithms/atari/c51/c51_reproduction.py index f9e4cf4d..4b95c82e 100644 --- a/reproductions/algorithms/atari/c51/c51_reproduction.py +++ b/reproductions/algorithms/atari/c51/c51_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,7 +36,8 @@ def run_training(args): outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 100, render=args.render) + eval_env = build_atari_env(args.env, test=True, seed=args.seed + 100, render=args.render, + use_gymnasium=args.use_gymnasium) evaluator = TimestepEvaluator(num_timesteps=125000) evaluation_hook = H.EvaluationHook( eval_env, evaluator, timing=args.eval_timing, writer=W.FileWriter(outdir=outdir, @@ -44,7 +45,7 @@ def run_training(args): save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) iteration_num_hook = H.IterationNumHook(timing=100) - train_env = build_atari_env(args.env, seed=args.seed, render=args.render) + train_env = build_atari_env(args.env, seed=args.seed, render=args.render, use_gymnasium=args.use_gymnasium) config = A.CategoricalDQNConfig(gpu_id=args.gpu) categorical_dqn = A.CategoricalDQN(train_env, @@ -63,7 +64,8 @@ def run_showcase(args): if args.snapshot_dir is None: raise ValueError( 'Please specify the snapshot dir for showcasing') - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=args.render) + eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=args.render, + use_gymnasium=args.use_gymnasium) config = A.CategoricalDQNConfig(gpu_id=args.gpu) categorical_dqn = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(categorical_dqn, A.CategoricalDQN): @@ -86,6 +88,7 @@ def main(): parser.add_argument('--save_timing', type=int, default=250000) parser.add_argument('--eval_timing', type=int, default=250000) parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument('--use-gymnasium', action='store_true') args = parser.parse_args() diff --git a/reproductions/algorithms/atari/ddqn/ddqn_reproduction.py b/reproductions/algorithms/atari/ddqn/ddqn_reproduction.py index ffa18635..7d30496f 100644 --- a/reproductions/algorithms/atari/ddqn/ddqn_reproduction.py +++ b/reproductions/algorithms/atari/ddqn/ddqn_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -40,14 +40,15 @@ def run_training(args): set_global_seed(args.seed) writer = FileWriter(outdir, "evaluation_result") - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 100, render=args.render) + eval_env = build_atari_env(args.env, test=True, seed=args.seed + 100, render=args.render, + use_gymnasium=args.use_gymnasium) evaluator = TimestepEvaluator(num_timesteps=125000) evaluation_hook = H.EvaluationHook(eval_env, evaluator, timing=args.eval_timing, writer=writer) iteration_num_hook = H.IterationNumHook(timing=100) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) - train_env = build_atari_env(args.env, seed=args.seed) + train_env = build_atari_env(args.env, seed=args.seed, use_gymnasium=args.use_gymnasium) config = A.DDQNConfig(gpu_id=args.gpu) ddqn = A.DDQN(train_env, @@ -65,7 +66,8 @@ def run_showcase(args): if args.snapshot_dir is None: raise ValueError( 'Please specify the snapshot dir for showcasing') - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=args.render) + eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=args.render, + use_gymnasium=args.use_gymnasium) config = A.DDQNConfig(gpu_id=args.gpu) ddqn = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(ddqn, A.DDQN): @@ -93,6 +95,7 @@ def main(): parser.add_argument('--save_timing', type=int, default=250000) parser.add_argument('--eval_timing', type=int, default=250000) parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument('--use-gymnasium', action='store_true') args = parser.parse_args() diff --git a/reproductions/algorithms/atari/decision_transformer/decision_transformer_reproduction.py b/reproductions/algorithms/atari/decision_transformer/decision_transformer_reproduction.py index 9bd0ebfd..8fd80c20 100755 --- a/reproductions/algorithms/atari/decision_transformer/decision_transformer_reproduction.py +++ b/reproductions/algorithms/atari/decision_transformer/decision_transformer_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -190,7 +190,8 @@ def run_training(args): set_global_seed(args.seed) writer = FileWriter(outdir, "evaluation_result") - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 100, render=args.render) + eval_env = build_atari_env(args.env, test=True, seed=args.seed + 100, render=args.render, + use_gymnasium=args.use_gymnasium) evaluator = EpisodicEvaluator(run_per_evaluation=10) evaluation_hook = H.EvaluationHook(eval_env, evaluator, timing=args.eval_timing, writer=writer) @@ -230,7 +231,8 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: raise ValueError('Please specify the snapshot dir for showcasing') - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=args.render) + eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=args.render, + use_gymnasium=args.use_gymnasium) config = {'gpu_id': args.gpu} decision_transformer = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(decision_transformer, A.DecisionTransformer): @@ -264,6 +266,7 @@ def main(): parser.add_argument('--eval_timing', type=int, default=1) parser.add_argument('--showcase_runs', type=int, default=10) parser.add_argument('--target-return', type=int, default=None) + parser.add_argument('--use-gymnasium', action='store_true') args = parser.parse_args() diff --git a/reproductions/algorithms/atari/dqn/dqn_reproduction.py b/reproductions/algorithms/atari/dqn/dqn_reproduction.py index 8eb89c56..973ecfd6 100644 --- a/reproductions/algorithms/atari/dqn/dqn_reproduction.py +++ b/reproductions/algorithms/atari/dqn/dqn_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -40,14 +40,15 @@ def run_training(args): set_global_seed(args.seed) writer = FileWriter(outdir, "evaluation_result") - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 100, render=args.render) + eval_env = build_atari_env(args.env, test=True, seed=args.seed + 100, render=args.render, + use_gymnasium=args.use_gymnasium) evaluator = TimestepEvaluator(num_timesteps=125000) evaluation_hook = H.EvaluationHook(eval_env, evaluator, timing=args.eval_timing, writer=writer) iteration_num_hook = H.IterationNumHook(timing=100) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) - train_env = build_atari_env(args.env, seed=args.seed) + train_env = build_atari_env(args.env, seed=args.seed, use_gymnasium=args.use_gymnasium) config = A.DQNConfig(gpu_id=args.gpu) dqn = A.DQN(train_env, @@ -65,7 +66,8 @@ def run_showcase(args): if args.snapshot_dir is None: raise ValueError( 'Please specify the snapshot dir for showcasing') - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=args.render) + eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=args.render, + use_gymnasium=args.use_gymnasium) config = A.DQNConfig(gpu_id=args.gpu) dqn = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(dqn, A.DQN): @@ -93,6 +95,7 @@ def main(): parser.add_argument('--save_timing', type=int, default=250000) parser.add_argument('--eval_timing', type=int, default=250000) parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument('--use-gymnasium', action='store_true') args = parser.parse_args() diff --git a/reproductions/algorithms/atari/drqn/drqn_reproduction.py b/reproductions/algorithms/atari/drqn/drqn_reproduction.py index 8b9eb2e2..4dd9a3c9 100644 --- a/reproductions/algorithms/atari/drqn/drqn_reproduction.py +++ b/reproductions/algorithms/atari/drqn/drqn_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -47,7 +47,8 @@ def run_training(args): seed=args.seed + 100, render=args.render, frame_stack=False, - flicker_probability=flicker_probability) + flicker_probability=flicker_probability, + use_gymnasium=args.use_gymnasium) evaluator = EpisodicEvaluator(run_per_evaluation=10) evaluation_hook = H.EvaluationHook(eval_env, evaluator, timing=args.eval_timing, writer=writer) @@ -58,7 +59,8 @@ def run_training(args): seed=args.seed, render=args.render, frame_stack=False, - flicker_probability=flicker_probability) + flicker_probability=flicker_probability, + use_gymnasium=args.use_gymnasium) config = A.DRQNConfig(gpu_id=args.gpu) drqn = A.DRQN(train_env, config=config, replay_buffer_builder=MemoryEfficientBufferBuilder()) @@ -79,7 +81,8 @@ def run_showcase(args): seed=args.seed + 200, render=args.render, frame_stack=False, - flicker_probability=flicker_probability) + flicker_probability=flicker_probability, + use_gymnasium=args.use_gymnasium) config = A.DRQNConfig(gpu_id=args.gpu) drqn = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(drqn, A.DRQN): @@ -107,6 +110,7 @@ def main(): parser.add_argument('--eval_timing', type=int, default=50000) parser.add_argument('--showcase_runs', type=int, default=10) parser.add_argument('--flicker', action='store_true') + parser.add_argument('--use-gymnasium', action='store_true') args = parser.parse_args() diff --git a/reproductions/algorithms/atari/icml2015trpo/icml2015trpo_reproduction.py b/reproductions/algorithms/atari/icml2015trpo/icml2015trpo_reproduction.py index 3c9ea6e3..d3f55da9 100644 --- a/reproductions/algorithms/atari/icml2015trpo/icml2015trpo_reproduction.py +++ b/reproductions/algorithms/atari/icml2015trpo/icml2015trpo_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,14 +30,16 @@ def run_training(args): set_global_seed(args.seed) writer = FileWriter(outdir, "evaluation_result") - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 100, render=args.render) + eval_env = build_atari_env(args.env, test=True, seed=args.seed + 100, render=args.render, + use_gymnasium=args.use_gymnasium) evaluator = EpisodicEvaluator() evaluation_hook = H.EvaluationHook(eval_env, evaluator, timing=args.eval_timing, writer=writer) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) iteration_num_hook = H.IterationNumHook(timing=int(1e5)) - train_env = build_atari_env(args.env, seed=args.seed, render=args.render) + train_env = build_atari_env(args.env, seed=args.seed, render=args.render, + use_gymnasium=args.use_gymnasium) config = A.ICML2015TRPOConfig(gpu_id=args.gpu, gpu_batch_size=args.gpu_batch_size) trpo = A.ICML2015TRPO(train_env, config=config) @@ -53,7 +55,8 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: raise ValueError('Please specify the snapshot dir for showcasing') - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=args.render) + eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=args.render, + use_gymnasium=args.use_gymnasium) config = A.ICML2015TRPOConfig(gpu_id=args.gpu) trpo = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(trpo, A.ICML2015TRPO): @@ -77,6 +80,7 @@ def main(): parser.add_argument('--save_timing', type=int, default=100000) parser.add_argument('--eval_timing', type=int, default=100000) parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument('--use-gymnasium', action='store_true') args = parser.parse_args() diff --git a/reproductions/algorithms/atari/iqn/iqn_reproduction.py b/reproductions/algorithms/atari/iqn/iqn_reproduction.py index 4f686402..0afbbbf3 100644 --- a/reproductions/algorithms/atari/iqn/iqn_reproduction.py +++ b/reproductions/algorithms/atari/iqn/iqn_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,7 +36,8 @@ def run_training(args): outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 100, render=args.render) + eval_env = build_atari_env(args.env, test=True, seed=args.seed + 100, render=args.render, + use_gymnasium=args.use_gymnasium) evaluator = TimestepEvaluator(num_timesteps=125000) evaluation_hook = H.EvaluationHook(eval_env, evaluator, timing=args.eval_timing, writer=W.FileWriter(outdir=outdir, file_prefix='evaluation_result')) @@ -44,7 +45,7 @@ def run_training(args): iteration_num_hook = H.IterationNumHook(timing=100) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) - train_env = build_atari_env(args.env, seed=args.seed, render=args.render) + train_env = build_atari_env(args.env, seed=args.seed, render=args.render, use_gymnasium=args.use_gymnasium) config = A.IQNConfig(gpu_id=args.gpu) iqn = A.IQN(train_env, config=config, replay_buffer_builder=MemoryEfficientAtariBufferBuilder()) @@ -61,7 +62,8 @@ def run_showcase(args): if args.snapshot_dir is None: raise ValueError( 'Please specify the snapshot dir for showcasing') - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=args.render) + eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=args.render, + use_gymnasium=args.use_gymnasium) config = A.IQNConfig(gpu_id=args.gpu) iqn = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(iqn, A.IQN): @@ -84,6 +86,7 @@ def main(): parser.add_argument('--save_timing', type=int, default=250000) parser.add_argument('--eval_timing', type=int, default=250000) parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument('--use-gymnasium', action='store_true') args = parser.parse_args() diff --git a/reproductions/algorithms/atari/munchausen_dqn/munchausen_dqn_reproduction.py b/reproductions/algorithms/atari/munchausen_dqn/munchausen_dqn_reproduction.py index a8288be7..8791daee 100644 --- a/reproductions/algorithms/atari/munchausen_dqn/munchausen_dqn_reproduction.py +++ b/reproductions/algorithms/atari/munchausen_dqn/munchausen_dqn_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -39,14 +39,14 @@ def run_training(args): outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 100) + eval_env = build_atari_env(args.env, test=True, seed=args.seed + 100, use_gymnasium=args.use_gymnasium) writer = FileWriter(outdir, "evaluation_result") evaluator = TimestepEvaluator(num_timesteps=125000) evaluation_hook = H.EvaluationHook(eval_env, evaluator, timing=args.eval_timing, writer=writer) iteration_num_hook = H.IterationNumHook(timing=100) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) - train_env = build_atari_env(args.env, seed=args.seed, render=args.render) + train_env = build_atari_env(args.env, seed=args.seed, render=args.render, use_gymnasium=args.use_gymnasium) config = A.MunchausenDQNConfig(gpu_id=args.gpu) m_dqn = A.MunchausenDQN(train_env, config=config, replay_buffer_builder=MemoryEfficientBufferBuilder()) @@ -60,7 +60,8 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: raise ValueError('Please specify the snapshot dir for showcasing') - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=args.render) + eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=args.render, + use_gymnasium=args.use_gymnasium) config = A.MunchausenDQNConfig(gpu_id=args.gpu) m_dqn = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(m_dqn, A.MunchausenDQN): @@ -89,6 +90,7 @@ def main(): parser.add_argument('--save_timing', type=int, default=250000) parser.add_argument('--eval_timing', type=int, default=250000) parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument('--use-gymnasium', action='store_true') args = parser.parse_args() diff --git a/reproductions/algorithms/atari/munchausen_iqn/munchausen_iqn_reproduction.py b/reproductions/algorithms/atari/munchausen_iqn/munchausen_iqn_reproduction.py index eb45e839..3e998f9f 100644 --- a/reproductions/algorithms/atari/munchausen_iqn/munchausen_iqn_reproduction.py +++ b/reproductions/algorithms/atari/munchausen_iqn/munchausen_iqn_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -39,14 +39,14 @@ def run_training(args): outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 100) + eval_env = build_atari_env(args.env, test=True, seed=args.seed + 100, use_gymnasium=args.use_gymnasium) writer = FileWriter(outdir, "evaluation_result") evaluator = TimestepEvaluator(num_timesteps=125000) evaluation_hook = H.EvaluationHook(eval_env, evaluator, timing=args.eval_timing, writer=writer) iteration_num_hook = H.IterationNumHook(timing=100) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) - train_env = build_atari_env(args.env, seed=args.seed, render=args.render) + train_env = build_atari_env(args.env, seed=args.seed, render=args.render, use_gymnasium=args.use_gymnasium) config = A.MunchausenIQNConfig(gpu_id=args.gpu) m_iqn = A.MunchausenIQN(train_env, config=config, replay_buffer_builder=MemoryEfficientBufferBuilder()) @@ -60,7 +60,8 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: raise ValueError('Please specify the snapshot dir for showcasing') - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=args.render) + eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=args.render, + use_gymnasium=args.use_gymnasium) config = A.MunchausenIQNConfig(gpu_id=args.gpu) m_iqn = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(m_iqn, A.MunchausenIQN): @@ -88,6 +89,7 @@ def main(): parser.add_argument('--save_timing', type=int, default=250000) parser.add_argument('--eval_timing', type=int, default=250000) parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument('--use-gymnasium', action='store_true') args = parser.parse_args() diff --git a/reproductions/algorithms/atari/ppo/ppo_reproduction.py b/reproductions/algorithms/atari/ppo/ppo_reproduction.py index 2dd355d3..9a043834 100644 --- a/reproductions/algorithms/atari/ppo/ppo_reproduction.py +++ b/reproductions/algorithms/atari/ppo/ppo_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,7 +29,8 @@ def run_training(args): outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 100, render=args.render) + eval_env = build_atari_env(args.env, test=True, seed=args.seed + 100, render=args.render, + use_gymnasium=args.use_gymnasium) evaluator = TimestepEvaluator(num_timesteps=125000) evaluation_hook = H.EvaluationHook( eval_env, evaluator, timing=args.evaluate_timing, writer=W.FileWriter(outdir=outdir, @@ -39,7 +40,7 @@ def run_training(args): actor_num = 8 - train_env = build_atari_env(args.env, seed=args.seed, render=args.render) + train_env = build_atari_env(args.env, seed=args.seed, render=args.render, use_gymnasium=args.use_gymnasium) config = A.PPOConfig(gpu_id=args.gpu, actor_num=actor_num, total_timesteps=args.total_iterations, @@ -61,7 +62,8 @@ def run_showcase(args): if args.snapshot_dir is None: raise ValueError( 'Please specify the snapshot dir for showcasing') - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=args.render) + eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=args.render, + use_gymnasium=args.use_gymnasium) config = A.PPOConfig(gpu_id=args.gpu, timelimit_as_terminal=True, seed=args.seed, @@ -87,6 +89,7 @@ def main(): parser.add_argument('--save_timing', type=int, default=250000) parser.add_argument('--evaluate_timing', type=int, default=250000) parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument('--use-gymnasium', action='store_true') args = parser.parse_args() diff --git a/reproductions/algorithms/atari/qrdqn/qrdqn_reproduction.py b/reproductions/algorithms/atari/qrdqn/qrdqn_reproduction.py index 6349798a..0242f4fa 100644 --- a/reproductions/algorithms/atari/qrdqn/qrdqn_reproduction.py +++ b/reproductions/algorithms/atari/qrdqn/qrdqn_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,7 +36,8 @@ def run_training(args): outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 100, render=args.render) + eval_env = build_atari_env(args.env, test=True, seed=args.seed + 100, render=args.render, + use_gymnasium=args.use_gymnasium) evaluator = TimestepEvaluator(num_timesteps=125000) evaluation_hook = H.EvaluationHook(eval_env, evaluator, timing=args.eval_timing, writer=W.FileWriter(outdir=outdir, file_prefix='evaluation_result')) @@ -44,7 +45,7 @@ def run_training(args): iteration_num_hook = H.IterationNumHook(timing=100) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) - train_env = build_atari_env(args.env, seed=args.seed, render=args.render) + train_env = build_atari_env(args.env, seed=args.seed, render=args.render, use_gymnasium=args.use_gymnasium) config = A.QRDQNConfig(gpu_id=args.gpu) qrdqn = A.QRDQN(train_env, config=config, replay_buffer_builder=MemoryEfficientBufferBuilder()) @@ -61,7 +62,8 @@ def run_showcase(args): if args.snapshot_dir is None: raise ValueError( 'Please specify the snapshot dir for showcasing') - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=args.render) + eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=args.render, + use_gymnasium=args.use_gymnasium) config = A.QRDQNConfig(gpu_id=args.gpu) qrdqn = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(qrdqn, A.QRDQN): @@ -84,6 +86,7 @@ def main(): parser.add_argument('--save_timing', type=int, default=250000) parser.add_argument('--eval_timing', type=int, default=250000) parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument('--use-gymnasium', action='store_true') args = parser.parse_args() diff --git a/reproductions/algorithms/atari/rainbow/rainbow_reproduction.py b/reproductions/algorithms/atari/rainbow/rainbow_reproduction.py index 7d135e8e..f6a1c7f2 100644 --- a/reproductions/algorithms/atari/rainbow/rainbow_reproduction.py +++ b/reproductions/algorithms/atari/rainbow/rainbow_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -191,7 +191,8 @@ def run_training(args): eval_env = build_atari_env(args.env, test=True, seed=args.seed + 100, render=args.render, - max_frames_per_episode=max_frames_per_episode) + max_frames_per_episode=max_frames_per_episode, + use_gymnasium=args.use_gymnasium) evaluator = TimestepEvaluator(num_timesteps=125000) evaluation_hook = H.EvaluationHook(eval_env, evaluator, @@ -201,7 +202,8 @@ def run_training(args): iteration_num_hook = H.IterationNumHook(timing=100) train_env = build_atari_env(args.env, seed=args.seed, render=args.render, - max_frames_per_episode=max_frames_per_episode) + max_frames_per_episode=max_frames_per_episode, + use_gymnasium=args.use_gymnasium) rainbow = setup_rainbow(train_env, args) hooks = [iteration_num_hook, save_snapshot_hook, evaluation_hook] @@ -223,7 +225,8 @@ def run_showcase(args): test=True, seed=args.seed + 200, render=args.render, - max_frames_per_episode=max_frames_per_episode) + max_frames_per_episode=max_frames_per_episode, + use_gymnasium=args.use_gymnasium) rainbow = load_rainbow(eval_env, args) if not isinstance(rainbow, A.Rainbow): raise ValueError('Loaded snapshot is not trained with Rainbow!') @@ -254,6 +257,7 @@ def main(): parser.add_argument('--save_timing', type=int, default=250000) parser.add_argument('--eval_timing', type=int, default=250000) parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument('--use-gymnasium', action='store_true') add_algorithm_options(parser) args = parser.parse_args() diff --git a/reproductions/algorithms/mujoco/atrpo/atrpo_reproduction.py b/reproductions/algorithms/mujoco/atrpo/atrpo_reproduction.py index 136d5b2e..99ea954f 100644 --- a/reproductions/algorithms/mujoco/atrpo/atrpo_reproduction.py +++ b/reproductions/algorithms/mujoco/atrpo/atrpo_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ def run_training(args): outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) - eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100) + eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100, use_gymnasium=args.use_gymnasium) evaluator = EpisodicEvaluator(run_per_evaluation=10) evaluation_hook = H.EvaluationHook(eval_env, evaluator, @@ -44,7 +44,7 @@ def run_training(args): save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) iteration_num_hook = H.IterationNumHook(timing=5000) - train_env = build_mujoco_env(args.env, seed=args.seed, render=args.render) + train_env = build_mujoco_env(args.env, seed=args.seed, render=args.render, use_gymnasium=args.use_gymnasium) train_env = EndlessEnv(train_env, reset_reward=-100) config = A.ATRPOConfig(gpu_id=args.gpu) atrpo = A.ATRPO(train_env, config=config) @@ -63,7 +63,8 @@ def run_showcase(args): raise ValueError( 'Please specify the snapshot dir for showcasing') config = A.ATRPOConfig(gpu_id=args.gpu) - eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render) + eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render, + use_gymnasium=args.use_gymnasium) atrpo = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(atrpo, A.ATRPO): raise ValueError('Loaded snapshot is not trained with ATRPO!') @@ -89,6 +90,7 @@ def main(): parser.add_argument('--save_timing', type=int, default=1000000) parser.add_argument('--eval_timing', type=int, default=50000) parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument('--use-gymnasium', action='store_true') args = parser.parse_args() diff --git a/reproductions/algorithms/mujoco/ddpg/ddpg_reproduction.py b/reproductions/algorithms/mujoco/ddpg/ddpg_reproduction.py index c8a0cf29..1f8c326f 100644 --- a/reproductions/algorithms/mujoco/ddpg/ddpg_reproduction.py +++ b/reproductions/algorithms/mujoco/ddpg/ddpg_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ def run_training(args): outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) - eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100) + eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100, use_gymnasium=args.use_gymnasium) evaluator = EpisodicEvaluator(run_per_evaluation=10) evaluation_hook = H.EvaluationHook(eval_env, evaluator, @@ -48,7 +48,7 @@ def run_training(args): iteration_num_hook = H.IterationNumHook(timing=100) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) - train_env = build_mujoco_env(args.env, seed=args.seed, render=args.render) + train_env = build_mujoco_env(args.env, seed=args.seed, render=args.render, use_gymnasium=args.use_gymnasium) timesteps = select_start_timesteps(args.env) config = A.DDPGConfig(gpu_id=args.gpu, start_timesteps=timesteps) ddpg = A.DDPG(train_env, config=config) @@ -66,7 +66,9 @@ def run_showcase(args): if args.snapshot_dir is None: raise ValueError( 'Please specify the snapshot dir for showcasing') - eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render) + + eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, + render=args.render, use_gymnasium=args.use_gymnasium) config = A.DDPGConfig(gpu_id=args.gpu) ddpg = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(ddpg, A.DDPG): @@ -89,6 +91,7 @@ def main(): parser.add_argument('--save_timing', type=int, default=5000) parser.add_argument('--eval_timing', type=int, default=5000) parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument('--use-gymnasium', action='store_true') args = parser.parse_args() diff --git a/reproductions/algorithms/mujoco/her/her_reproduction.py b/reproductions/algorithms/mujoco/her/her_reproduction.py index fdcb114d..caf61d14 100644 --- a/reproductions/algorithms/mujoco/her/her_reproduction.py +++ b/reproductions/algorithms/mujoco/her/her_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,12 +16,14 @@ from typing import List import gym +import gymnasium import nnabla_rl.algorithms as A import nnabla_rl.hooks as H import nnabla_rl.writers as W from nnabla_rl.environments.wrappers import NumpyFloat32Env, ScreenRenderEnv from nnabla_rl.environments.wrappers.goal_conditioned import GoalConditionedTupleObservationEnv, GoalEnvWrapper +from nnabla_rl.environments.wrappers.gymnasium import Gymnasium2GymWrapper from nnabla_rl.typing import Experience from nnabla_rl.utils import serializers from nnabla_rl.utils.evaluator import EpisodicSuccessEvaluator @@ -54,7 +56,7 @@ def check_success(experiences: List[Experience]) -> bool: return False -def build_mujoco_goal_conditioned_env(id_or_env, test=False, seed=None, render=False): +def build_mujoco_goal_conditioned_env(id_or_env, test=False, seed=None, render=False, use_gymnasium=False): try: # Add pybullet env import pybullet_envs # noqa @@ -70,10 +72,17 @@ def build_mujoco_goal_conditioned_env(id_or_env, test=False, seed=None, render=F if isinstance(id_or_env, gym.Env): env = id_or_env + elif isinstance(id_or_env, gymnasium.Env): + env = Gymnasium2GymWrapper(id_or_env) + use_gymnasium = True else: - # Currently, env-checker of OpenAI gym cannot handle dict observation. - # So, avoid to use env-checker. - env = gym.make(id_or_env, disable_env_checker=True) + if use_gymnasium: + env = gymnasium.make(id_or_env) + env = Gymnasium2GymWrapper(env) + else: + # Currently, env-checker of OpenAI gym cannot handle dict observation. + # So, avoid to use env-checker. + env = gym.make(id_or_env, disable_env_checker=True) env = GoalEnvWrapper(env) env = GoalConditionedTupleObservationEnv(env) print_env_info(env) @@ -83,7 +92,8 @@ def build_mujoco_goal_conditioned_env(id_or_env, test=False, seed=None, render=F if render: env = ScreenRenderEnv(env) - env.seed(seed) + if not use_gymnasium: + env.seed(seed) return env @@ -94,8 +104,10 @@ def run_training(args): set_global_seed(args.seed) n_cycles = select_n_cycles(env_name=args.env) - train_env = build_mujoco_goal_conditioned_env(args.env, seed=args.seed, render=args.render) - eval_env = build_mujoco_goal_conditioned_env(args.env, test=True, seed=args.seed + 100, render=args.render) + train_env = build_mujoco_goal_conditioned_env(args.env, seed=args.seed, render=args.render, + use_gymnasium=args.use_gymnasium) + eval_env = build_mujoco_goal_conditioned_env(args.env, test=True, seed=args.seed + 100, render=args.render, + use_gymnasium=args.use_gymnasium) max_timesteps = train_env.spec.max_episode_steps iteration_per_epoch = n_cycles * n_update * max_timesteps @@ -137,7 +149,8 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: raise ValueError('Please specify the snapshot dir for showcasing') - eval_env = build_mujoco_goal_conditioned_env(args.env, test=True, seed=args.seed + 200, render=args.render) + eval_env = build_mujoco_goal_conditioned_env(args.env, test=True, seed=args.seed + 200, render=args.render, + use_gymnasium=args.use_gymnasium) config = A.HERConfig(gpu_id=args.gpu) her = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(her, A.HER): @@ -159,6 +172,7 @@ def main(): parser.add_argument('--save-dir', type=str, default=None) parser.add_argument('--total_iterations', type=int, default=20000000) parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument('--use-gymnasium', action='store_true') args = parser.parse_args() diff --git a/reproductions/algorithms/mujoco/icml2015trpo/icml2015trpo_reproduction.py b/reproductions/algorithms/mujoco/icml2015trpo/icml2015trpo_reproduction.py index f8c040f2..1aec7e3b 100644 --- a/reproductions/algorithms/mujoco/icml2015trpo/icml2015trpo_reproduction.py +++ b/reproductions/algorithms/mujoco/icml2015trpo/icml2015trpo_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ def run_training(args): outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) - eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100) + eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100, use_gymnasium=args.use_gymnasium) evaluator = EpisodicEvaluator(run_per_evaluation=10) evaluation_hook = H.EvaluationHook(eval_env, evaluator, @@ -40,7 +40,7 @@ def run_training(args): save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) iteration_num_hook = H.IterationNumHook(timing=1000) - train_env = build_mujoco_env(args.env, seed=args.seed, render=args.render) + train_env = build_mujoco_env(args.env, seed=args.seed, render=args.render, use_gymnasium=args.use_gymnasium) config = A.ICML2015TRPOConfig(gpu_id=args.gpu, num_steps_per_iteration=1000000, batch_size=1000000, @@ -60,7 +60,8 @@ def run_showcase(args): if args.snapshot_dir is None: raise ValueError( 'Please specify the snapshot dir for showcasing') - eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render) + eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render, + use_gymnasium=args.use_gymnasium) config = A.ICML2015TRPOConfig(gpu_id=args.gpu) trpo = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(trpo, A.ICML2015TRPO): @@ -83,6 +84,7 @@ def main(): parser.add_argument('--save_timing', type=int, default=1000000) parser.add_argument('--eval_timing', type=int, default=1000000) parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument('--use-gymnasium', action='store_true') args = parser.parse_args() diff --git a/reproductions/algorithms/mujoco/icml2018sac/icml2018sac_reproduction.py b/reproductions/algorithms/mujoco/icml2018sac/icml2018sac_reproduction.py index 24805e7c..aee0350f 100644 --- a/reproductions/algorithms/mujoco/icml2018sac/icml2018sac_reproduction.py +++ b/reproductions/algorithms/mujoco/icml2018sac/icml2018sac_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -58,7 +58,7 @@ def run_training(args): outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) - eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100) + eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100, use_gymnasium=args.use_gymnasium) evaluator = EpisodicEvaluator(run_per_evaluation=10) evaluation_hook = H.EvaluationHook(eval_env, evaluator, @@ -69,7 +69,7 @@ def run_training(args): save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) iteration_num_hook = H.IterationNumHook(timing=100) - train_env = build_mujoco_env(args.env, seed=args.seed, render=args.render) + train_env = build_mujoco_env(args.env, seed=args.seed, render=args.render, use_gymnasium=args.use_gymnasium) timesteps = select_start_timesteps(args.env) reward_scalar = select_reward_scalar(args.env) config = A.ICML2018SACConfig(gpu_id=args.gpu, @@ -91,7 +91,8 @@ def run_showcase(args): if args.snapshot_dir is None: raise ValueError( 'Please specify the snapshot dir for showcasing') - eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render) + eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render, + use_gymnasium=args.use_gymnasium) config = A.ICML2018SACConfig(gpu_id=args.gpu) icml2018sac = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(icml2018sac, A.ICML2018SAC): @@ -114,6 +115,7 @@ def main(): parser.add_argument('--save_timing', type=int, default=5000) parser.add_argument('--eval_timing', type=int, default=5000) parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument('--use-gymnasium', action='store_true') args = parser.parse_args() diff --git a/reproductions/algorithms/mujoco/ppo/ppo_reproduction.py b/reproductions/algorithms/mujoco/ppo/ppo_reproduction.py index 293d617e..6637491e 100644 --- a/reproductions/algorithms/mujoco/ppo/ppo_reproduction.py +++ b/reproductions/algorithms/mujoco/ppo/ppo_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ def run_training(args): outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) - eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100) + eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100, use_gymnasium=args.use_gymnasium) evaluator = EpisodicEvaluator(run_per_evaluation=10) evaluation_hook = H.EvaluationHook(eval_env, evaluator, @@ -48,7 +48,7 @@ def run_training(args): iteration_num_hook = H.IterationNumHook(timing=1000) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) - train_env = build_mujoco_env(args.env, seed=args.seed, render=args.render) + train_env = build_mujoco_env(args.env, seed=args.seed, render=args.render, use_gymnasium=args.use_gymnasium) timelimit_as_terminal = select_timelimit_as_terminal(args.env) config = A.PPOConfig(gpu_id=args.gpu, epsilon=0.2, @@ -76,7 +76,8 @@ def run_showcase(args): if args.snapshot_dir is None: raise ValueError( 'Please specify the snapshot dir for showcasing') - eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render) + eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render, + use_gymnasium=args.use_gymnasium) config = A.PPOConfig(gpu_id=args.gpu) ppo = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(ppo, A.PPO): @@ -99,6 +100,7 @@ def main(): parser.add_argument('--save_timing', type=int, default=5000) parser.add_argument('--eval_timing', type=int, default=5000) parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument('--use-gymnasium', action='store_true') args = parser.parse_args() diff --git a/reproductions/algorithms/mujoco/qrsac/qrsac_reproduction.py b/reproductions/algorithms/mujoco/qrsac/qrsac_reproduction.py index 0b889412..b2b8ba86 100644 --- a/reproductions/algorithms/mujoco/qrsac/qrsac_reproduction.py +++ b/reproductions/algorithms/mujoco/qrsac/qrsac_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -40,7 +40,7 @@ def run_training(args): outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) - eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100) + eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100, use_gymnasium=args.use_gymnasium) evaluator = EpisodicEvaluator(run_per_evaluation=10) evaluation_hook = H.EvaluationHook(eval_env, evaluator, @@ -50,7 +50,7 @@ def run_training(args): iteration_num_hook = H.IterationNumHook(timing=100) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) - train_env = build_mujoco_env(args.env, seed=args.seed, render=args.render) + train_env = build_mujoco_env(args.env, seed=args.seed, render=args.render, use_gymnasium=args.use_gymnasium) config = A.QRSACConfig(gpu_id=args.gpu, fix_temperature=args.fix_temperature, initial_temperature=args.initial_temperature, @@ -71,7 +71,8 @@ def run_showcase(args): if args.snapshot_dir is None: raise ValueError( 'Please specify the snapshot dir for showcasing') - eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render) + eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render, + use_gymnasium=args.use_gymnasium) config = A.QRSACConfig(gpu_id=args.gpu) sac = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(sac, A.QRSAC): @@ -99,6 +100,7 @@ def main(): parser.add_argument('--fix-temperature', action='store_true') parser.add_argument('--initial-temperature', type=float, default=None) parser.add_argument('--num-steps', type=int, default=1, help='number of steps for n-step Q target') + parser.add_argument('--use-gymnasium', action='store_true') args = parser.parse_args() diff --git a/reproductions/algorithms/mujoco/redq/redq_reproduction.py b/reproductions/algorithms/mujoco/redq/redq_reproduction.py index 32d6de63..f0c7e78d 100644 --- a/reproductions/algorithms/mujoco/redq/redq_reproduction.py +++ b/reproductions/algorithms/mujoco/redq/redq_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ def run_training(args): outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) - eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100) + eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100, use_gymnasium=args.use_gymnasium) evaluator = EpisodicEvaluator(run_per_evaluation=10) evaluation_hook = H.EvaluationHook(eval_env, evaluator, @@ -48,7 +48,7 @@ def run_training(args): iteration_num_hook = H.IterationNumHook(timing=100) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) - train_env = build_mujoco_env(args.env, seed=args.seed, render=args.render) + train_env = build_mujoco_env(args.env, seed=args.seed, render=args.render, use_gymnasium=args.use_gymnasium) config = A.REDQConfig(gpu_id=args.gpu, fix_temperature=args.fix_temperature) redq = A.REDQ(train_env, config=config) @@ -66,7 +66,8 @@ def run_showcase(args): if args.snapshot_dir is None: raise ValueError( 'Please specify the snapshot dir for showcasing') - eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render) + eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render, + use_gymnasium=args.use_gymnasium) config = A.REDQConfig(gpu_id=args.gpu) redq = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(redq, A.REDQ): @@ -92,6 +93,7 @@ def main(): # REDQ algorithm config parser.add_argument('--fix-temperature', action='store_true') + parser.add_argument('--use-gymnasium', action='store_true') args = parser.parse_args() diff --git a/reproductions/algorithms/mujoco/sac/sac_reproduction.py b/reproductions/algorithms/mujoco/sac/sac_reproduction.py index 4e97ab03..f5751ddb 100644 --- a/reproductions/algorithms/mujoco/sac/sac_reproduction.py +++ b/reproductions/algorithms/mujoco/sac/sac_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -40,7 +40,7 @@ def run_training(args): outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) - eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100) + eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100, use_gymnasium=args.use_gymnasium) evaluator = EpisodicEvaluator(run_per_evaluation=10) evaluation_hook = H.EvaluationHook(eval_env, evaluator, @@ -50,7 +50,7 @@ def run_training(args): iteration_num_hook = H.IterationNumHook(timing=100) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) - train_env = build_mujoco_env(args.env, seed=args.seed, render=args.render) + train_env = build_mujoco_env(args.env, seed=args.seed, render=args.render, use_gymnasium=args.use_gymnasium) config = A.SACConfig(gpu_id=args.gpu, fix_temperature=args.fix_temperature) sac = A.SAC(train_env, config=config) @@ -68,7 +68,8 @@ def run_showcase(args): if args.snapshot_dir is None: raise ValueError( 'Please specify the snapshot dir for showcasing') - eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render) + eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render, + use_gymnasium=args.use_gymnasium) config = A.SACConfig(gpu_id=args.gpu) sac = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(sac, A.SAC): @@ -91,6 +92,7 @@ def main(): parser.add_argument('--save_timing', type=int, default=5000) parser.add_argument('--eval_timing', type=int, default=5000) parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument('--use-gymnasium', action='store_true') # SAC algorithm config parser.add_argument('--fix-temperature', action='store_true') diff --git a/reproductions/algorithms/mujoco/td3/td3_reproduction.py b/reproductions/algorithms/mujoco/td3/td3_reproduction.py index b868e665..b5d44272 100644 --- a/reproductions/algorithms/mujoco/td3/td3_reproduction.py +++ b/reproductions/algorithms/mujoco/td3/td3_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ def run_training(args): outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) - eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100) + eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100, use_gymnasium=args.use_gymnasium) evaluator = EpisodicEvaluator(run_per_evaluation=10) evaluation_hook = H.EvaluationHook(eval_env, evaluator, @@ -47,7 +47,7 @@ def run_training(args): iteration_num_hook = H.IterationNumHook(timing=100) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) - train_env = build_mujoco_env(args.env, seed=args.seed, render=args.render) + train_env = build_mujoco_env(args.env, seed=args.seed, render=args.render, use_gymnasium=args.use_gymnasium) timesteps = select_start_timesteps(args.env) config = A.TD3Config(gpu_id=args.gpu, start_timesteps=timesteps) td3 = A.TD3(train_env, config=config) @@ -65,7 +65,8 @@ def run_showcase(args): if args.snapshot_dir is None: raise ValueError( 'Please specify the snapshot dir for showcasing') - eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render) + eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render, + use_gymnasium=args.use_gymnasium) config = A.TD3Config(gpu_id=args.gpu) td3 = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(td3, A.TD3): @@ -88,6 +89,7 @@ def main(): parser.add_argument('--save_timing', type=int, default=5000) parser.add_argument('--eval_timing', type=int, default=5000) parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument('--use-gymnasium', action='store_true') args = parser.parse_args() diff --git a/reproductions/algorithms/mujoco/trpo/trpo_reproduction.py b/reproductions/algorithms/mujoco/trpo/trpo_reproduction.py index e2ab2c83..c698821f 100644 --- a/reproductions/algorithms/mujoco/trpo/trpo_reproduction.py +++ b/reproductions/algorithms/mujoco/trpo/trpo_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,7 +32,7 @@ def run_training(args): outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) - eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100) + eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100, use_gymnasium=args.use_gymnasium) evaluator = EpisodicEvaluator(run_per_evaluation=10) evaluation_hook = H.EvaluationHook(eval_env, evaluator, @@ -43,7 +43,7 @@ def run_training(args): save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) iteration_num_hook = H.IterationNumHook(timing=5000) - train_env = build_mujoco_env(args.env, seed=args.seed, render=args.render) + train_env = build_mujoco_env(args.env, seed=args.seed, render=args.render, use_gymnasium=args.use_gymnasium) config = A.TRPOConfig(gpu_id=args.gpu) trpo = A.TRPO(train_env, config=config) @@ -60,7 +60,8 @@ def run_showcase(args): if args.snapshot_dir is None: raise ValueError( 'Please specify the snapshot dir for showcasing') - eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render) + eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render, + use_gymnasium=args.use_gymnasium) config = A.TRPOConfig(gpu_id=args.gpu) trpo = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(trpo, A.TRPO): @@ -87,6 +88,7 @@ def main(): parser.add_argument('--save_timing', type=int, default=10000) parser.add_argument('--eval_timing', type=int, default=10000) parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument('--use-gymnasium', action='store_true') args = parser.parse_args() diff --git a/requirements.txt b/requirements.txt index 138b731a..7fef5fc0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ isort autopep8 packaging docformatter +gymnasium \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index cac2478c..cceef4bc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -46,6 +46,7 @@ install_requires = opencv-python packaging tqdm + gymnasium scripts = bin/check_best_iteration bin/compile_results diff --git a/tests/environments/wrappers/test_gymnasium.py b/tests/environments/wrappers/test_gymnasium.py new file mode 100644 index 00000000..c85cbad1 --- /dev/null +++ b/tests/environments/wrappers/test_gymnasium.py @@ -0,0 +1,63 @@ +# Copyright 2024 Sony Group Corporation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from nnabla_rl.environments.dummy import DummyContinuous, DummyGymnasiumMujocoEnv +from nnabla_rl.environments.wrappers.common import NumpyFloat32Env +from nnabla_rl.environments.wrappers.gymnasium import Gymnasium2GymWrapper + +max_episode_steps = 10 + + +class TestGymnasium(object): + def test_gym_env(self): + env = DummyContinuous(max_episode_steps=max_episode_steps) + with pytest.raises(ValueError): + env = Gymnasium2GymWrapper(env) + + def test_gym_wrapper(self): + env = DummyContinuous(max_episode_steps=max_episode_steps) + env = NumpyFloat32Env(env) + with pytest.raises(ValueError): + env = Gymnasium2GymWrapper(env) + + def test_reset(self): + env = DummyGymnasiumMujocoEnv(max_episode_steps=max_episode_steps) + raw_reset_outputs = env.reset() + assert isinstance(raw_reset_outputs, tuple) + assert isinstance(raw_reset_outputs[0], np.ndarray) + assert isinstance(raw_reset_outputs[1], dict) + + wrapped_env = Gymnasium2GymWrapper(env) + reset_outputs = wrapped_env.reset() + assert isinstance(reset_outputs, np.ndarray) + + def test_step(self): + env = DummyGymnasiumMujocoEnv(max_episode_steps=max_episode_steps) + action = env.action_space.sample() + raw_step_outputs = env.step(action) + assert isinstance(raw_step_outputs, tuple) + assert len(raw_step_outputs) == 5 + + wrapped_env = Gymnasium2GymWrapper(env) + action = wrapped_env.action_space.sample() + step_outputs = wrapped_env.step(action) + assert isinstance(step_outputs, tuple) + assert len(step_outputs) == 4 + + +if __name__ == "__main__": + pytest.main()