diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 71b77975..84b33216 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -19,9 +19,11 @@ jobs: run: | python -m pip install --upgrade pip pip install -r requirements.txt - - name: check python code format with autopep8 - run: | - autopep8 --diff . + - name: check python code format with black + uses: psf/black@stable + with: + options: "--check --verbose" + src: "." typing: runs-on: ubuntu-latest timeout-minutes: 3 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 72675223..d6a95e4d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -69,13 +69,13 @@ $ pip install -e . ### Code format guidelines -We use [autopep8](https://github.com/hhatto/autopep8) and [isort](https://github.com/PyCQA/isort) to keep consistent coding style. After finishing developing the code, run autopep8 and isort to ensure that your code is correctly formatted. +We use [black](https://github.com/psf/black) and [isort](https://github.com/PyCQA/isort) to keep consistent coding style. After finishing developing the code, run black and isort to ensure that your code is correctly formatted. -You can run autopep8 and isort as follows. +You can run black and isort as follows. ```sh cd -autopep8 . +black . ``` ```sh diff --git a/README.md b/README.md index 860ee523..4158acaa 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ [![Build status](https://github.com/sony/nnabla-rl/workflows/Build%20nnabla-rl/badge.svg)](https://github.com/sony/nnabla-rl/actions) [![Documentation Status](https://readthedocs.org/projects/nnabla-rl/badge/?version=latest)](https://nnabla-rl.readthedocs.io/en/latest/?badge=latest) [![Doc style](https://img.shields.io/badge/%20style-google-3666d6.svg)](https://google.github.io/styleguide/pyguide.html#s3.8-comments-and-docstrings) +[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) # Deep Reinforcement Learning Library built on top of Neural Network Libraries diff --git a/docs/source/conf.py b/docs/source/conf.py index e0effe58..48978841 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,14 +18,14 @@ import sphinx_rtd_theme -sys.path.insert(0, os.path.abspath('../../')) +sys.path.insert(0, os.path.abspath("../../")) import nnabla_rl # noqa # -- Project information ----------------------------------------------------- -project = 'nnablaRL' -copyright = '2021, Sony Group Corporation' -author = 'Sony Group Corporation' +project = "nnablaRL" +copyright = "2021, Sony Group Corporation" +author = "Sony Group Corporation" release = nnabla_rl.__version__ # -- General configuration --------------------------------------------------- @@ -34,22 +34,22 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.doctest', - 'sphinx.ext.autodoc', - 'sphinx.ext.napoleon', - 'sphinx.ext.intersphinx', - 'sphinx.ext.todo', - 'sphinx.ext.mathjax', - 'sphinx.ext.autosummary', - 'sphinx.ext.viewcode' + "sphinx.ext.doctest", + "sphinx.ext.autodoc", + "sphinx.ext.napoleon", + "sphinx.ext.intersphinx", + "sphinx.ext.todo", + "sphinx.ext.mathjax", + "sphinx.ext.autosummary", + "sphinx.ext.viewcode", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['ntemplates'] +templates_path = ["ntemplates"] -source_suffix = '.rst' +source_suffix = ".rst" # The master toctree document. -master_doc = 'index' +master_doc = "index" # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. diff --git a/examples/evaluate_trained_model.py b/examples/evaluate_trained_model.py index a16c64f3..0e2a8779 100644 --- a/examples/evaluate_trained_model.py +++ b/examples/evaluate_trained_model.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,16 +22,16 @@ def build_env(): try: - env = gym.make('Pendulum-v0') + env = gym.make("Pendulum-v0") except gym.error.DeprecatedEnv: - env = gym.make('Pendulum-v1') + env = gym.make("Pendulum-v1") env = NumpyFloat32Env(env) env = ScreenRenderEnv(env) return env def main(): - snapshot_dir = './pendulum_v0_snapshot/iteration-10000' + snapshot_dir = "./pendulum_v0_snapshot/iteration-10000" env = build_env() algorithm = serializers.load_snapshot(snapshot_dir, env) diff --git a/examples/hook_example.py b/examples/hook_example.py index 86984d19..9c4220a8 100644 --- a/examples/hook_example.py +++ b/examples/hook_example.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ def __init__(self): super().__init__(timing=1) def on_hook_called(self, algorithm): - print('hello!!') + print("hello!!") class PrintOnlyEvenIteraion(Hook): @@ -32,7 +32,7 @@ def __init__(self): super().__init__(timing=2) def on_hook_called(self, algorithm): - print('even iteration -> {}'.format(algorithm.iteration_num)) + print("even iteration -> {}".format(algorithm.iteration_num)) def main(): diff --git a/examples/recurrent_model.py b/examples/recurrent_model.py index 01f74f5d..8e5e1542 100644 --- a/examples/recurrent_model.py +++ b/examples/recurrent_model.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -65,8 +65,8 @@ def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: shapes: Dict[str, nn.Variable] = {} # You can use arbitral (but distinguishable) key as name # Use same key for same state - shapes['my_lstm_h'] = (self._lstm_state_size, ) - shapes['my_lstm_c'] = (self._lstm_state_size, ) + shapes["my_lstm_h"] = (self._lstm_state_size,) + shapes["my_lstm_c"] = (self._lstm_state_size,) return shapes def get_internal_states(self) -> Dict[str, nn.Variable]: @@ -74,8 +74,8 @@ def get_internal_states(self) -> Dict[str, nn.Variable]: states: Dict[str, nn.Variable] = {} # You can use arbitral (but distinguishable) key as name # Use same key for same state. - states['my_lstm_h'] = self._h - states['my_lstm_c'] = self._c + states["my_lstm_h"] = self._h + states["my_lstm_c"] = self._c return states def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): @@ -88,8 +88,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): else: # Otherwise, set given states # Use the key defined in internal_state_shapes() for getting the states - self._h = states['my_lstm_h'] - self._c = states['my_lstm_c'] + self._h = states["my_lstm_h"] + self._c = states["my_lstm_c"] def _create_internal_states(self, batch_size): self._h = nn.Variable((batch_size, self._lstm_state_size)) @@ -103,11 +103,9 @@ def _is_internal_state_created(self) -> bool: class QFunctionWithRNNBuilder(ModelBuilder[QFunction]): - def build_model(self, - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: AlgorithmConfig, - **kwargs) -> QFunction: + def build_model( + self, scope_name: str, env_info: EnvironmentInfo, algorithm_config: AlgorithmConfig, **kwargs + ) -> QFunction: action_num = env_info.action_dim return QFunctionWithRNN(scope_name, action_num) @@ -118,7 +116,7 @@ def build_solver(self, env_info: EnvironmentInfo, algorithm_config: AlgorithmCon def build_env(seed=None): - env = gym.make('MountainCar-v0') + env = gym.make("MountainCar-v0") env = NumpyFloat32Env(env) env = ScreenRenderEnv(env) env.seed(seed) @@ -133,30 +131,32 @@ def main(): # Here, we use DRQN algorithm # For the list of algorithms that support RNN layers see # https://github.com/sony/nnabla-rl/tree/master/nnabla_rl/algorithms - config = A.DRQNConfig(gpu_id=0, - learning_rate=1e-2, - gamma=0.9, - learner_update_frequency=1, - target_update_frequency=200, - start_timesteps=200, - replay_buffer_size=10000, - max_explore_steps=10000, - initial_epsilon=1.0, - final_epsilon=0.001, - test_epsilon=0.001, - grad_clip=None, - unroll_steps=2) # Unroll only for 2 timesteps for fast iteration. Because this is an example - drqn = A.DRQN(train_env, - config=config, - q_func_builder=QFunctionWithRNNBuilder(), - q_solver_builder=AdamSolverBuilder()) + config = A.DRQNConfig( + gpu_id=0, + learning_rate=1e-2, + gamma=0.9, + learner_update_frequency=1, + target_update_frequency=200, + start_timesteps=200, + replay_buffer_size=10000, + max_explore_steps=10000, + initial_epsilon=1.0, + final_epsilon=0.001, + test_epsilon=0.001, + grad_clip=None, + unroll_steps=2, + ) # Unroll only for 2 timesteps for fast iteration. Because this is an example + drqn = A.DRQN( + train_env, config=config, q_func_builder=QFunctionWithRNNBuilder(), q_solver_builder=AdamSolverBuilder() + ) # Optional: Add hooks to check the training progress eval_env = build_env(seed=100) evaluation_hook = H.EvaluationHook( eval_env, timing=1000, - writer=W.FileWriter(outdir='./mountain_car_v0_drqn_results', file_prefix='evaluation_result')) + writer=W.FileWriter(outdir="./mountain_car_v0_drqn_results", file_prefix="evaluation_result"), + ) iteration_num_hook = H.IterationNumHook(timing=100) drqn.set_hooks(hooks=[iteration_num_hook, evaluation_hook]) diff --git a/examples/rl_project_template/environment.py b/examples/rl_project_template/environment.py index da45a0ac..6594e15b 100644 --- a/examples/rl_project_template/environment.py +++ b/examples/rl_project_template/environment.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,15 +20,15 @@ def __init__(self, max_episode_steps=100): # max_episode_steps is the maximum possible steps that the rl agent interacts with this environment. # You can set this value to None if the is no limits. # The first argument is the name of this environment used when registering this environment to gym. - self.spec = EnvSpec('template-v0', max_episode_steps=max_episode_steps) + self.spec = EnvSpec("template-v0", max_episode_steps=max_episode_steps) self._episode_steps = 0 # Use gym's spaces to define the shapes and ranges of states and actions. # observation_space: definition of states's shape and its ranges # action_space: definition of actions's shape and its ranges - observation_shape = (10, ) # Example 10 dimensional state with range of [0.0, 1.0] each. + observation_shape = (10,) # Example 10 dimensional state with range of [0.0, 1.0] each. self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=observation_shape) - action_shape = (1, ) # 1 dimensional continuous action with range of [0.0, 1.0]. + action_shape = (1,) # 1 dimensional continuous action with range of [0.0, 1.0]. self.action_space = gym.spaces.Box(low=0.0, high=1.0, shape=action_shape) def reset(self): diff --git a/examples/rl_project_template/models.py b/examples/rl_project_template/models.py index 3fcf72ea..31dae47c 100644 --- a/examples/rl_project_template/models.py +++ b/examples/rl_project_template/models.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -58,7 +58,7 @@ def pi(self, state: nn.Variable) -> Distribution: h = NF.relu(x=h) h = NPF.affine(h, n_outmaps=256, name="linear2") h = NF.relu(x=h) - h = NPF.affine(h, n_outmaps=self._action_dim*2, name="linear3") + h = NPF.affine(h, n_outmaps=self._action_dim * 2, name="linear3") reshaped = NF.reshape(h, shape=(-1, 2, self._action_dim)) # Split the output into mean and variance of the Gaussian distribution. diff --git a/examples/rl_project_template/training.py b/examples/rl_project_template/training.py index 6c8e02bf..01c10155 100644 --- a/examples/rl_project_template/training.py +++ b/examples/rl_project_template/training.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -75,7 +75,7 @@ def run_training(args): # Save trained parameters every "timing" steps. # Without this, the parameters will not be saved. # We recommend saving parameters at every evaluation timing. - outdir = pathlib.Path(args.save_dir) / 'snapshots' + outdir = pathlib.Path(args.save_dir) / "snapshots" save_snapshot_hook = H.SaveSnapshotHook(outdir=outdir, timing=1000) # All instantiated hooks should be set at once. @@ -87,13 +87,13 @@ def run_training(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--gpu', type=int, default=-1) - parser.add_argument('--save-dir', type=str, default=str(pathlib.Path(__file__).parent)) - parser.add_argument('--seed', type=int, default=0) + parser.add_argument("--gpu", type=int, default=-1) + parser.add_argument("--save-dir", type=str, default=str(pathlib.Path(__file__).parent)) + parser.add_argument("--seed", type=int, default=0) args = parser.parse_args() run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/save_load_snapshot.py b/examples/save_load_snapshot.py index 5ca6980c..a3c8566e 100644 --- a/examples/save_load_snapshot.py +++ b/examples/save_load_snapshot.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ def main(): env = E.DummyContinuous() ddpg = A.DDPG(env, config=config) - outdir = './save_load_snapshot' + outdir = "./save_load_snapshot" # This actually saves the model and solver state right after the algorithm construction snapshot_dir = serializers.save_snapshot(outdir, ddpg) diff --git a/examples/training_example.py b/examples/training_example.py index ee25939d..a6c0a3a0 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,9 +23,9 @@ def build_env(seed=None): try: - env = gym.make('Pendulum-v0') + env = gym.make("Pendulum-v0") except gym.error.DeprecatedEnv: - env = gym.make('Pendulum-v1') + env = gym.make("Pendulum-v1") env = NumpyFloat32Env(env) env = ScreenRenderEnv(env) env.seed(seed) @@ -36,15 +36,14 @@ def main(): # Evaluate the trained network (Optional) eval_env = build_env(seed=100) evaluation_hook = H.EvaluationHook( - eval_env, - timing=1000, - writer=W.FileWriter(outdir='./pendulum_v0_ddpg_results', file_prefix='evaluation_result')) + eval_env, timing=1000, writer=W.FileWriter(outdir="./pendulum_v0_ddpg_results", file_prefix="evaluation_result") + ) # Pring iteration number every 100 iterations. iteration_num_hook = H.IterationNumHook(timing=100) # Save trained algorithm snapshot (Optional) - save_snapshot_hook = H.SaveSnapshotHook('./pendulum_v0_ddpg_results', timing=1000) + save_snapshot_hook = H.SaveSnapshotHook("./pendulum_v0_ddpg_results", timing=1000) train_env = build_env() config = A.DDPGConfig(start_timesteps=200, gpu_id=0) diff --git a/examples/writer_example.py b/examples/writer_example.py index c699aba7..ab39ba83 100644 --- a/examples/writer_example.py +++ b/examples/writer_example.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,7 +27,7 @@ class MyScalarWriter(Writer): def __init__(self, outdir): - self._outdir = os.path.join(outdir, 'writer') + self._outdir = os.path.join(outdir, "writer") create_dir_if_not_exist(outdir=self._outdir) self._monitor = Monitor(self._outdir) self._monitor_series = None @@ -43,22 +43,21 @@ def write_scalar(self, iteration_num, scalar): def _create_monitor_series(self, names): self._monitor_series = [] for name in names: - self._monitor_series.append(MonitorSeries( - name, self._monitor, interval=1, verbose=False)) + self._monitor_series.append(MonitorSeries(name, self._monitor, interval=1, verbose=False)) def build_env(seed=None): try: - env = gym.make('Pendulum-v0') + env = gym.make("Pendulum-v0") except gym.error.DeprecatedEnv: - env = gym.make('Pendulum-v1') + env = gym.make("Pendulum-v1") env = NumpyFloat32Env(env) env.seed(seed) return env def main(): - writer = MyScalarWriter('./pendulum_v0_ddpg_results') + writer = MyScalarWriter("./pendulum_v0_ddpg_results") training_state_hook = H.IterationStateHook(writer=writer, timing=100) train_env = build_env() diff --git a/interactive-demos/colab_utils.py b/interactive-demos/colab_utils.py index d1407189..13c74afd 100644 --- a/interactive-demos/colab_utils.py +++ b/interactive-demos/colab_utils.py @@ -1,5 +1,5 @@ # Copyright 2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,12 +33,12 @@ def __init__(self, env): def on_hook_called(self, algorithm): display.clear_output(wait=True) if self._image is None: - self._image = plt.imshow(self._env.render('rgb_array')) + self._image = plt.imshow(self._env.render("rgb_array")) else: - self._image.set_data(self._env.render('rgb_array')) + self._image.set_data(self._env.render("rgb_array")) plt.suptitle(f"iteration num : {self._iteration_num}") self._iteration_num += 1 - plt.axis('off') + plt.axis("off") display.display(plt.gcf()) def reset(self): diff --git a/nnabla_rl/__init__.py b/nnabla_rl/__init__.py index fe908954..9766285e 100644 --- a/nnabla_rl/__init__.py +++ b/nnabla_rl/__init__.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '0.16.0.dev1' +__version__ = "0.16.0.dev1" from nnabla_rl.logger import enable_logging, disable_logging # noqa from nnabla_rl.scopes import eval_scope, is_eval_scope # noqa diff --git a/nnabla_rl/algorithm.py b/nnabla_rl/algorithm.py index 63134827..5a0aea49 100644 --- a/nnabla_rl/algorithm.py +++ b/nnabla_rl/algorithm.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,13 +31,14 @@ from nnabla_rl.model_trainers.model_trainer import ModelTrainer from nnabla_rl.replay_buffer import ReplayBuffer -F = TypeVar('F', bound=Callable[..., Any]) +F = TypeVar("F", bound=Callable[..., Any]) def eval_api(f: F) -> F: def wrapped_with_eval_scope(*args, **kwargs): with rl.eval_scope(): return f(*args, **kwargs) + return cast(F, wrapped_with_eval_scope) @@ -48,6 +49,7 @@ class AlgorithmConfig(Configuration): Args: gpu_id (int): id of the gpu to use. If negative, the training will run on cpu. Defaults to -1. """ + gpu_id: int = -1 @@ -83,14 +85,18 @@ def __init__(self, env_info, config=AlgorithmConfig()): self._hooks = [] if not self.is_supported_env(env_info): - raise UnsupportedEnvironmentException("{} does not support the enviroment. \ + raise UnsupportedEnvironmentException( + "{} does not support the enviroment. \ See the algorithm catalog (https://github.com/sony/nnabla-rl/tree/master/nnabla_rl/algorithms) \ - and confirm what kinds of enviroments are supported".format(self.__name__)) + and confirm what kinds of enviroments are supported".format( + self.__name__ + ) + ) if self._config.gpu_id < 0: - logger.info('algorithm will run on cpu.') + logger.info("algorithm will run on cpu.") else: - logger.info('algorithm will run on gpu: {}'.format(self._config.gpu_id)) + logger.info("algorithm will run on gpu: {}".format(self._config.gpu_id)) @property def __name__(self): @@ -107,9 +113,9 @@ def latest_iteration_state(self) -> Dict[str, Any]: Dict[str, Any]: Dictionary with items of training process state. """ latest_iteration_state: Dict[str, Any] = {} - latest_iteration_state['scalar'] = {} - latest_iteration_state['histogram'] = {} - latest_iteration_state['image'] = {} + latest_iteration_state["scalar"] = {} + latest_iteration_state["histogram"] = {} + latest_iteration_state["image"] = {} return latest_iteration_state @property @@ -223,9 +229,9 @@ def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}) - """ raise NotImplementedError - def compute_trajectory(self, - initial_trajectory: Sequence[Tuple[np.ndarray, Optional[np.ndarray]]]) \ - -> Tuple[Sequence[Tuple[np.ndarray, Optional[np.ndarray]]], Sequence[Dict[str, Any]]]: + def compute_trajectory( + self, initial_trajectory: Sequence[Tuple[np.ndarray, Optional[np.ndarray]]] + ) -> Tuple[Sequence[Tuple[np.ndarray, Optional[np.ndarray]]], Sequence[Dict[str, Any]]]: """Compute trajectory (sequence of state and action tuples) from given initial trajectory using current policy. Most of the reinforcement learning algorithms does not implement this method. Only the optimal @@ -289,7 +295,7 @@ def _has_rnn_models(self): def _assert_rnn_is_supported(self): if not self.is_rnn_supported(): - raise RuntimeError(f'{self.__name__} does not support rnn models but rnn models where given!') + raise RuntimeError(f"{self.__name__} does not support rnn models but rnn models where given!") @classmethod @abstractmethod diff --git a/nnabla_rl/algorithms/__init__.py b/nnabla_rl/algorithms/__init__.py index 2cf9c5d7..ebedd457 100644 --- a/nnabla_rl/algorithms/__init__.py +++ b/nnabla_rl/algorithms/__init__.py @@ -63,11 +63,9 @@ def register_algorithm(algorithm_class, config_class): global _ALGORITHMS if not issubclass(algorithm_class, Algorithm): - raise ValueError( - "{} is not subclass of Algorithm".format(algorithm_class)) + raise ValueError("{} is not subclass of Algorithm".format(algorithm_class)) if not issubclass(config_class, AlgorithmConfig): - raise ValueError( - "{} is not subclass of AlgorithmConfig".format(config_class)) + raise ValueError("{} is not subclass of AlgorithmConfig".format(config_class)) _ALGORITHMS[algorithm_class.__name__] = (algorithm_class, config_class) diff --git a/nnabla_rl/algorithms/a2c.py b/nnabla_rl/algorithms/a2c.py index cfe329b2..9087a83f 100644 --- a/nnabla_rl/algorithms/a2c.py +++ b/nnabla_rl/algorithms/a2c.py @@ -1,5 +1,5 @@ # Copyright 2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -34,8 +34,14 @@ from nnabla_rl.model_trainers.model_trainer import ModelTrainer, TrainingBatch from nnabla_rl.models import A3CPolicy, A3CSharedFunctionHead, A3CVFunction, StochasticPolicy, VFunction from nnabla_rl.utils.data import marshal_experiences, unzip -from nnabla_rl.utils.multiprocess import (copy_mp_arrays_to_params, copy_params_to_mp_arrays, mp_array_from_np_array, - mp_to_np_array, new_mp_arrays_from_params, np_to_mp_array) +from nnabla_rl.utils.multiprocess import ( + copy_mp_arrays_to_params, + copy_params_to_mp_arrays, + mp_array_from_np_array, + mp_to_np_array, + new_mp_arrays_from_params, + np_to_mp_array, +) from nnabla_rl.utils.reproductions import set_global_seed from nnabla_rl.utils.solver_wrappers import AutoClipGradByGlobalNorm @@ -67,6 +73,7 @@ class A2CConfig(AlgorithmConfig): learning_rate_decay_iterations (int): learning rate will be decreased lineary to 0 till this iteration number. If 0 or negative, learning rate will be kept fixed. Defaults to 50000000. """ + gamma: float = 0.99 n_steps: int = 5 learning_rate: float = 7e-4 @@ -86,48 +93,49 @@ def __post_init__(self): Check the set values are in valid range. """ - self._assert_between(self.gamma, 0.0, 1.0, 'gamma') - self._assert_between(self.decay, 0.0, 1.0, 'decay') - self._assert_positive(self.n_steps, 'n_steps') - self._assert_positive(self.actor_num, 'actor num') - self._assert_positive(self.learning_rate, 'learning_rate') + self._assert_between(self.gamma, 0.0, 1.0, "gamma") + self._assert_between(self.decay, 0.0, 1.0, "decay") + self._assert_positive(self.n_steps, "n_steps") + self._assert_positive(self.actor_num, "actor num") + self._assert_positive(self.learning_rate, "learning_rate") class DefaultPolicyBuilder(ModelBuilder[StochasticPolicy]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: A2CConfig, - **kwargs) -> StochasticPolicy: - _shared_function_head = A3CSharedFunctionHead(scope_name="shared", - state_shape=env_info.state_shape) - return A3CPolicy(head=_shared_function_head, - scope_name="shared", - state_shape=env_info.state_shape, - action_dim=env_info.action_dim) + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: A2CConfig, + **kwargs, + ) -> StochasticPolicy: + _shared_function_head = A3CSharedFunctionHead(scope_name="shared", state_shape=env_info.state_shape) + return A3CPolicy( + head=_shared_function_head, + scope_name="shared", + state_shape=env_info.state_shape, + action_dim=env_info.action_dim, + ) class DefaultVFunctionBuilder(ModelBuilder[VFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: A2CConfig, - **kwargs) -> VFunction: - _shared_function_head = A3CSharedFunctionHead(scope_name="shared", - state_shape=env_info.state_shape) - return A3CVFunction(head=_shared_function_head, - scope_name="shared", - state_shape=env_info.state_shape) + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: A2CConfig, + **kwargs, + ) -> VFunction: + _shared_function_head = A3CSharedFunctionHead(scope_name="shared", state_shape=env_info.state_shape) + return A3CVFunction(head=_shared_function_head, scope_name="shared", state_shape=env_info.state_shape) class DefaultSolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: A2CConfig, - **kwargs) -> nn.solver.Solver: - solver = NS.RMSprop(lr=algorithm_config.learning_rate, - decay=algorithm_config.decay, - eps=algorithm_config.epsilon) + def build_solver( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: A2CConfig, **kwargs + ) -> nn.solver.Solver: + solver = NS.RMSprop( + lr=algorithm_config.learning_rate, decay=algorithm_config.decay, eps=algorithm_config.epsilon + ) if algorithm_config.max_grad_norm is None: return solver else: @@ -165,7 +173,7 @@ class A2C(Algorithm): _v_function_solver: nn.solver.Solver _policy: StochasticPolicy _policy_solver: nn.solver.Solver - _actors: List['_A2CActor'] + _actors: List["_A2CActor"] _actor_processes: List[mp.Process] _s_current_var: nn.Variable _a_current_var: nn.Variable @@ -182,18 +190,21 @@ class A2C(Algorithm): _evaluation_actor: _StochasticPolicyActionSelector - def __init__(self, env_or_env_info, - v_function_builder: ModelBuilder[VFunction] = DefaultVFunctionBuilder(), - v_solver_builder: SolverBuilder = DefaultSolverBuilder(), - policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), - policy_solver_builder: SolverBuilder = DefaultSolverBuilder(), - config=A2CConfig()): + def __init__( + self, + env_or_env_info, + v_function_builder: ModelBuilder[VFunction] = DefaultVFunctionBuilder(), + v_solver_builder: SolverBuilder = DefaultSolverBuilder(), + policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), + policy_solver_builder: SolverBuilder = DefaultSolverBuilder(), + config=A2CConfig(), + ): super(A2C, self).__init__(env_or_env_info, config=config) # Initialize on cpu and change the context later with nn.context_scope(context.get_nnabla_context(-1)): - self._policy = policy_builder('pi', self._env_info, self._config) - self._v_function = v_function_builder('v', self._env_info, self._config) + self._policy = policy_builder("pi", self._env_info, self._config) + self._v_function = v_function_builder("v", self._env_info, self._config) self._policy_solver = policy_solver_builder(self._env_info, self._config) self._policy_solver_builder = policy_solver_builder # keep for later use @@ -201,7 +212,8 @@ def __init__(self, env_or_env_info, self._v_solver_builder = v_solver_builder # keep for later use self._evaluation_actor = _StochasticPolicyActionSelector( - self._env_info, self._policy.shallowcopy(), deterministic=False) + self._env_info, self._policy.shallowcopy(), deterministic=False + ) @eval_api def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}): @@ -211,7 +223,7 @@ def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}): def _before_training_start(self, env_or_buffer): if not self._is_env(env_or_buffer): - raise ValueError('A2C only supports online training') + raise ValueError("A2C only supports online training") env = env_or_buffer # FIXME: This setup is a workaround for creating underlying model parameters @@ -246,20 +258,20 @@ def _setup_policy_training(self, env_or_buffer): models=self._policy, solvers={self._policy.scope_name: self._policy_solver}, env_info=self._env_info, - config=policy_trainer_config) + config=policy_trainer_config, + ) return policy_trainer def _setup_v_function_training(self, env_or_buffer): # training input/loss variables v_function_trainer_config = MT.v_value.MonteCarloVTrainerConfig( - reduction_method='mean', - v_loss_scalar=self._config.value_coefficient + reduction_method="mean", v_loss_scalar=self._config.value_coefficient ) v_function_trainer = MT.v_value.MonteCarloVTrainer( train_functions=self._v_function, solvers={self._v_function.scope_name: self._v_function_solver}, env_info=self._env_info, - config=v_function_trainer_config + config=v_function_trainer_config, ) return v_function_trainer @@ -275,12 +287,9 @@ def _launch_actor_processes(self, env): def _build_a2c_actors(self, env, v_function, policy): actors = [] for i in range(self._config.actor_num): - actor = _A2CActor(actor_num=i, - env=env, - env_info=self._env_info, - v_function=v_function, - policy=policy, - config=self._config) + actor = _A2CActor( + actor_num=i, env=env, env_info=self._env_info, v_function=v_function, policy=policy, config=self._config + ) actors.append(actor) return actors @@ -318,12 +327,9 @@ def _a2c_training(self, experiences): s, a, returns = experiences advantage = self._compute_advantage(s, returns) extra = {} - extra['advantage'] = advantage - extra['v_target'] = returns - batch = TrainingBatch(batch_size=len(a), - s_current=s, - a_current=a, - extra=extra) + extra["advantage"] = advantage + extra["v_target"] = returns + batch = TrainingBatch(batch_size=len(a), s_current=s, a_current=a, extra=extra) # lr decay alpha = self._config.learning_rate @@ -338,7 +344,7 @@ def _a2c_training(self, experiences): self._v_function_trainer_state = self._v_function_trainer.train(batch) def _compute_advantage(self, s, returns): - if not hasattr(self, '_state_var_for_advantage'): + if not hasattr(self, "_state_var_for_advantage"): self._state_var_for_advantage = nn.Variable(s.shape) self._returns_var_for_advantage = nn.Variable(returns.shape) v_for_advantage = self._v_function.v(self._state_var_for_advantage) @@ -364,17 +370,18 @@ def _solvers(self): @classmethod def is_supported_env(cls, env_or_env_info): - env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \ - else env_or_env_info + env_info = ( + EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info + ) return not env_info.is_continuous_action_env() and not env_info.is_tuple_action_env() @property def latest_iteration_state(self): latest_iteration_state = super(A2C, self).latest_iteration_state - if hasattr(self, '_policy_trainer_state'): - latest_iteration_state['scalar'].update({'pi_loss': float(self._policy_trainer_state['pi_loss'])}) - if hasattr(self, '_v_function_trainer_state'): - latest_iteration_state['scalar'].update({'v_loss': float(self._v_function_trainer_state['v_loss'])}) + if hasattr(self, "_policy_trainer_state"): + latest_iteration_state["scalar"].update({"pi_loss": float(self._policy_trainer_state["pi_loss"])}) + if hasattr(self, "_v_function_trainer_state"): + latest_iteration_state["scalar"].update({"v_loss": float(self._v_function_trainer_state["v_loss"])}) return latest_iteration_state @property @@ -394,44 +401,41 @@ def __init__(self, actor_num, env, env_info, policy, v_function, config): self._config = config # IPC communication variables - self._disposed = mp.Value('i', False) + self._disposed = mp.Value("i", False) self._task_start_event = mp.Event() self._task_finish_event = mp.Event() self._policy_mp_arrays = new_mp_arrays_from_params(policy.get_parameters()) self._v_function_mp_arrays = new_mp_arrays_from_params(v_function.get_parameters()) - explorer_config = EE.RawPolicyExplorerConfig(initial_step_num=0, - timelimit_as_terminal=self._config.timelimit_as_terminal) - self._environment_explorer = EE.RawPolicyExplorer(policy_action_selector=self._compute_action, - env_info=self._env_info, - config=explorer_config) + explorer_config = EE.RawPolicyExplorerConfig( + initial_step_num=0, timelimit_as_terminal=self._config.timelimit_as_terminal + ) + self._environment_explorer = EE.RawPolicyExplorer( + policy_action_selector=self._compute_action, env_info=self._env_info, config=explorer_config + ) obs_space = self._env.observation_space action_space = self._env.action_space - MultiProcessingArrays = namedtuple('MultiProcessingArrays', ['state', 'action', 'returns']) + MultiProcessingArrays = namedtuple("MultiProcessingArrays", ["state", "action", "returns"]) state_mp_array_shape = (self._n_steps, *obs_space.shape) - state_mp_array = mp_array_from_np_array( - np.empty(shape=state_mp_array_shape, dtype=obs_space.dtype)) + state_mp_array = mp_array_from_np_array(np.empty(shape=state_mp_array_shape, dtype=obs_space.dtype)) if env_info.is_discrete_action_env(): action_mp_array_shape = (self._n_steps, 1) - action_mp_array = mp_array_from_np_array( - np.empty(shape=action_mp_array_shape, dtype=action_space.dtype)) + action_mp_array = mp_array_from_np_array(np.empty(shape=action_mp_array_shape, dtype=action_space.dtype)) else: action_mp_array_shape = (self._n_steps, action_space.shape[0]) - action_mp_array = mp_array_from_np_array( - np.empty(shape=action_mp_array_shape, dtype=action_space.dtype)) + action_mp_array = mp_array_from_np_array(np.empty(shape=action_mp_array_shape, dtype=action_space.dtype)) scalar_mp_array_shape = (self._n_steps, 1) - returns_mp_array = mp_array_from_np_array( - np.empty(shape=scalar_mp_array_shape, dtype=np.float32)) + returns_mp_array = mp_array_from_np_array(np.empty(shape=scalar_mp_array_shape, dtype=np.float32)) self._mp_arrays = MultiProcessingArrays( (state_mp_array, state_mp_array_shape, obs_space.dtype), (action_mp_array, action_mp_array_shape, action_space.dtype), - (returns_mp_array, scalar_mp_array_shape, np.float32) + (returns_mp_array, scalar_mp_array_shape, np.float32), ) self._exploration_actor = _StochasticPolicyActionSelector(env_info, policy, deterministic=False) @@ -465,7 +469,7 @@ def _run_actor_loop(self): seed = os.getpid() set_global_seed(seed) self._env.seed(seed) - while (True): + while True: self._task_start_event.wait() if self._disposed.get_obj(): break @@ -481,8 +485,7 @@ def _run_actor_loop(self): def _run_data_collection(self): experiences = self._environment_explorer.step(self._env, n=self._n_steps, break_if_done=False) s_last = experiences[-1][4] - experiences = [(s, a, r, non_terminal) - for (s, a, r, non_terminal, *_) in experiences] + experiences = [(s, a, r, non_terminal) for (s, a, r, non_terminal, *_) in experiences] processed_experiences = self._process_experiences(experiences, s_last) return processed_experiences @@ -502,7 +505,7 @@ def _compute_returns(self, rewards, non_terminals, value_last): def _compute_v(self, s): s = np.expand_dims(s, axis=0) - if not hasattr(self, '_state_var'): + if not hasattr(self, "_state_var"): self._state_var = nn.Variable(s.shape) self._v_var = self._v_function.v(self._state_var) self._v_var.need_grad = False @@ -514,6 +517,7 @@ def _compute_v(self, s): def _fill_result(self, experiences): def array_and_dtype(mp_arrays_item): return mp_arrays_item[0], mp_arrays_item[2] + (s, a, returns) = experiences np_to_mp_array(s, *array_and_dtype(self._mp_arrays.state)) np_to_mp_array(a, *array_and_dtype(self._mp_arrays.action)) diff --git a/nnabla_rl/algorithms/amp.py b/nnabla_rl/algorithms/amp.py index 1ac616b0..b7977899 100644 --- a/nnabla_rl/algorithms/amp.py +++ b/nnabla_rl/algorithms/amp.py @@ -29,28 +29,53 @@ import nnabla_rl.environment_explorers as EE import nnabla_rl.model_trainers as MT from nnabla_rl.algorithm import Algorithm, AlgorithmConfig, eval_api -from nnabla_rl.algorithms.common_utils import (_get_shape, _StatePreprocessedRewardFunction, - _StatePreprocessedStochasticPolicy, _StatePreprocessedVFunction, - _StochasticPolicyActionSelector) +from nnabla_rl.algorithms.common_utils import ( + _get_shape, + _StatePreprocessedRewardFunction, + _StatePreprocessedStochasticPolicy, + _StatePreprocessedVFunction, + _StochasticPolicyActionSelector, +) from nnabla_rl.builders import ExplorerBuilder, ModelBuilder, PreprocessorBuilder, ReplayBufferBuilder, SolverBuilder from nnabla_rl.environment_explorer import EnvironmentExplorer from nnabla_rl.environments.amp_env import AMPEnv, AMPGoalEnv, TaskResult from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.functions import compute_std, unnormalize from nnabla_rl.model_trainers.model_trainer import ModelTrainer, TrainingBatch -from nnabla_rl.models import (AMPDiscriminator, AMPGatedPolicy, AMPGatedVFunction, AMPPolicy, AMPVFunction, Model, - RewardFunction, StochasticPolicy, VFunction) +from nnabla_rl.models import ( + AMPDiscriminator, + AMPGatedPolicy, + AMPGatedVFunction, + AMPPolicy, + AMPVFunction, + Model, + RewardFunction, + StochasticPolicy, + VFunction, +) from nnabla_rl.preprocessors import Preprocessor from nnabla_rl.random import drng from nnabla_rl.replay_buffer import ReplayBuffer from nnabla_rl.replay_buffers.buffer_iterator import BufferIterator from nnabla_rl.typing import Experience from nnabla_rl.utils import context -from nnabla_rl.utils.data import (add_batch_dimension, compute_std_ndarray, marshal_experiences, normalize_ndarray, - set_data_to_variable, unnormalize_ndarray) +from nnabla_rl.utils.data import ( + add_batch_dimension, + compute_std_ndarray, + marshal_experiences, + normalize_ndarray, + set_data_to_variable, + unnormalize_ndarray, +) from nnabla_rl.utils.misc import create_variable -from nnabla_rl.utils.multiprocess import (copy_mp_arrays_to_params, copy_params_to_mp_arrays, mp_array_from_np_array, - mp_to_np_array, new_mp_arrays_from_params, np_to_mp_array) +from nnabla_rl.utils.multiprocess import ( + copy_mp_arrays_to_params, + copy_params_to_mp_arrays, + mp_array_from_np_array, + mp_to_np_array, + new_mp_arrays_from_params, + np_to_mp_array, +) from nnabla_rl.utils.reproductions import set_global_seed @@ -228,10 +253,12 @@ def __post_init__(self): self._assert_positive(self.discriminator_learning_rate, "discriminator_learning_rate") self._assert_positive(self.discriminator_momentum, "discriminator_momentum") self._assert_positive(self.discriminator_weight_decay, "discriminator_weight_decay") - self._assert_positive(self.discriminator_extra_regularization_coefficient, - "discriminator_extra_regularization_coefficient") - self._assert_positive(self.discriminator_gradient_penelty_coefficient, - "discriminator_gradient_penelty_coefficient") + self._assert_positive( + self.discriminator_extra_regularization_coefficient, "discriminator_extra_regularization_coefficient" + ) + self._assert_positive( + self.discriminator_gradient_penelty_coefficient, "discriminator_gradient_penelty_coefficient" + ) self._assert_positive(self.discriminator_batch_size, "discriminator_batch_size") self._assert_positive(self.discriminator_epochs, "discriminator_epochs") self._assert_positive(self.discriminator_reward_scale, "discriminator_reward_scale") @@ -240,11 +267,13 @@ def __post_init__(self): class DefaultPolicyBuilder(ModelBuilder[StochasticPolicy]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: AMPConfig, - **kwargs) -> StochasticPolicy: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: AMPConfig, + **kwargs, + ) -> StochasticPolicy: if env_info.is_goal_conditioned_env(): return AMPGatedPolicy(scope_name, env_info.action_dim, 0.01) else: @@ -252,11 +281,13 @@ def build_model(self, # type: ignore[override] class DefaultVFunctionBuilder(ModelBuilder[VFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: AMPConfig, - **kwargs) -> VFunction: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: AMPConfig, + **kwargs, + ) -> VFunction: if env_info.is_goal_conditioned_env(): return AMPGatedVFunction(scope_name) else: @@ -264,68 +295,71 @@ def build_model(self, # type: ignore[override] class DefaultRewardFunctionBuilder(ModelBuilder[RewardFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: AMPConfig, - **kwargs) -> RewardFunction: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: AMPConfig, + **kwargs, + ) -> RewardFunction: return AMPDiscriminator(scope_name, 1.0) class DefaultVFunctionSolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: AMPConfig, - **kwargs) -> nn.solver.Solver: - return NS.Momentum(lr=algorithm_config.v_function_learning_rate, - momentum=algorithm_config.v_function_momentum) + def build_solver( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: AMPConfig, **kwargs + ) -> nn.solver.Solver: + return NS.Momentum(lr=algorithm_config.v_function_learning_rate, momentum=algorithm_config.v_function_momentum) class DefaultPolicySolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: AMPConfig, - **kwargs) -> nn.solver.Solver: - return NS.Momentum(lr=algorithm_config.policy_learning_rate, - momentum=algorithm_config.policy_momentum) + def build_solver( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: AMPConfig, **kwargs + ) -> nn.solver.Solver: + return NS.Momentum(lr=algorithm_config.policy_learning_rate, momentum=algorithm_config.policy_momentum) class DefaultRewardFunctionSolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: AMPConfig, - **kwargs) -> nn.solver.Solver: - return NS.Momentum(lr=algorithm_config.discriminator_learning_rate, - momentum=algorithm_config.discriminator_momentum) + def build_solver( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: AMPConfig, **kwargs + ) -> nn.solver.Solver: + return NS.Momentum( + lr=algorithm_config.discriminator_learning_rate, momentum=algorithm_config.discriminator_momentum + ) class DefaultExplorerBuilder(ExplorerBuilder): - def build_explorer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: AMPConfig, - algorithm: "AMP", - **kwargs) -> EnvironmentExplorer: - explorer_config = \ - EE.LinearDecayEpsilonGreedyExplorerConfig(initial_step_num=0, - timelimit_as_terminal=algorithm_config.timelimit_as_terminal, - initial_epsilon=1.0, - final_epsilon=algorithm_config.final_explore_rate, - max_explore_steps=algorithm_config.max_explore_steps, - append_explorer_info=True) - explorer = EE.LinearDecayEpsilonGreedyExplorer(greedy_action_selector=kwargs["greedy_action_selector"], - random_action_selector=kwargs["random_action_selector"], - env_info=env_info, - config=explorer_config) + def build_explorer( # type: ignore[override] + self, + env_info: EnvironmentInfo, + algorithm_config: AMPConfig, + algorithm: "AMP", + **kwargs, + ) -> EnvironmentExplorer: + explorer_config = EE.LinearDecayEpsilonGreedyExplorerConfig( + initial_step_num=0, + timelimit_as_terminal=algorithm_config.timelimit_as_terminal, + initial_epsilon=1.0, + final_epsilon=algorithm_config.final_explore_rate, + max_explore_steps=algorithm_config.max_explore_steps, + append_explorer_info=True, + ) + explorer = EE.LinearDecayEpsilonGreedyExplorer( + greedy_action_selector=kwargs["greedy_action_selector"], + random_action_selector=kwargs["random_action_selector"], + env_info=env_info, + config=explorer_config, + ) return explorer class DefaultReplayBufferBuilder(ReplayBufferBuilder): - def build_replay_buffer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: AMPConfig, - **kwargs) -> ReplayBuffer: + def build_replay_buffer( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: AMPConfig, **kwargs + ) -> ReplayBuffer: return ReplayBuffer( - capacity=int(np.ceil(algorithm_config.discriminator_agent_replay_buffer_size / algorithm_config.actor_num))) + capacity=int(np.ceil(algorithm_config.discriminator_agent_replay_buffer_size / algorithm_config.actor_num)) + ) class AMP(Algorithm): @@ -361,6 +395,7 @@ class AMP(Algorithm): (:py:class:`ReplayBufferBuilder `): builder of replay_buffer of \ discriminator. """ + # type declarations to type check with mypy # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar # See https://mypy.readthedocs.io/en/stable/class_basics.html for details @@ -391,18 +426,20 @@ class AMP(Algorithm): _evaluation_actor: _StochasticPolicyActionSelector - def __init__(self, - env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: AMPConfig = AMPConfig(), - v_function_builder: ModelBuilder[VFunction] = DefaultVFunctionBuilder(), - v_solver_builder: SolverBuilder = DefaultVFunctionSolverBuilder(), - policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), - policy_solver_builder: SolverBuilder = DefaultPolicySolverBuilder(), - reward_function_builder: ModelBuilder[RewardFunction] = DefaultRewardFunctionBuilder(), - reward_solver_builder: SolverBuilder = DefaultRewardFunctionSolverBuilder(), - state_preprocessor_builder: Optional[PreprocessorBuilder] = None, - env_explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), - discriminator_replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder()): + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: AMPConfig = AMPConfig(), + v_function_builder: ModelBuilder[VFunction] = DefaultVFunctionBuilder(), + v_solver_builder: SolverBuilder = DefaultVFunctionSolverBuilder(), + policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), + policy_solver_builder: SolverBuilder = DefaultPolicySolverBuilder(), + reward_function_builder: ModelBuilder[RewardFunction] = DefaultRewardFunctionBuilder(), + reward_solver_builder: SolverBuilder = DefaultRewardFunctionSolverBuilder(), + state_preprocessor_builder: Optional[PreprocessorBuilder] = None, + env_explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), + discriminator_replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), + ): super(AMP, self).__init__(env_or_env_info, config=config) # Initialize on cpu and change the context later @@ -415,18 +452,20 @@ def __init__(self, if state_preprocessor_builder is None: raise ValueError("State preprocessing is enabled but no preprocessor builder is given") - self._pi_v_state_preprocessor = state_preprocessor_builder("pi_v_preprocessor", - self._env_info, - self._config) - v_function = _StatePreprocessedVFunction(v_function=v_function, - preprocessor=self._pi_v_state_preprocessor) + self._pi_v_state_preprocessor = state_preprocessor_builder( + "pi_v_preprocessor", self._env_info, self._config + ) + v_function = _StatePreprocessedVFunction( + v_function=v_function, preprocessor=self._pi_v_state_preprocessor + ) policy = _StatePreprocessedStochasticPolicy(policy=policy, preprocessor=self._pi_v_state_preprocessor) - self._discriminator_state_preprocessor = state_preprocessor_builder("r_preprocessor", - self._env_info, - self._config) - discriminator = _StatePreprocessedRewardFunction(reward_function=discriminator, - preprocessor=self._discriminator_state_preprocessor) + self._discriminator_state_preprocessor = state_preprocessor_builder( + "r_preprocessor", self._env_info, self._config + ) + discriminator = _StatePreprocessedRewardFunction( + reward_function=discriminator, preprocessor=self._discriminator_state_preprocessor + ) self._v_function = v_function self._policy = policy @@ -440,22 +479,25 @@ def __init__(self, self._discriminator_solver_builder = reward_solver_builder # keep for later use self._env_explorer_builder = env_explorer_builder # keep for later use - self._evaluation_actor = _StochasticPolicyActionSelector(self._env_info, - self._policy.shallowcopy(), - deterministic=self._config.act_deterministic_in_eval) - self._discriminator_agent_replay_buffers = \ - [discriminator_replay_buffer_builder(env_info=self._env_info, algorithm_config=self._config) - for _ in range(self._config.actor_num)] - self._discriminator_expert_replay_buffers = \ - [discriminator_replay_buffer_builder(env_info=self._env_info, algorithm_config=self._config) - for _ in range(self._config.actor_num)] + self._evaluation_actor = _StochasticPolicyActionSelector( + self._env_info, self._policy.shallowcopy(), deterministic=self._config.act_deterministic_in_eval + ) + self._discriminator_agent_replay_buffers = [ + discriminator_replay_buffer_builder(env_info=self._env_info, algorithm_config=self._config) + for _ in range(self._config.actor_num) + ] + self._discriminator_expert_replay_buffers = [ + discriminator_replay_buffer_builder(env_info=self._env_info, algorithm_config=self._config) + for _ in range(self._config.actor_num) + ] if self._config.normalize_action: action_mean = add_batch_dimension(np.array(self._config.action_mean, dtype=np.float32)) self._action_mean = nn.Variable.from_numpy_array(action_mean) action_var = add_batch_dimension(np.array(self._config.action_var, dtype=np.float32)) - self._action_std = compute_std(nn.Variable.from_numpy_array(action_var), - epsilon=0.0, mode_for_floating_point_error="max") + self._action_std = compute_std( + nn.Variable.from_numpy_array(action_var), epsilon=0.0, mode_for_floating_point_error="max" + ) else: self._action_mean = None self._action_std = None @@ -466,8 +508,11 @@ def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}): action, _ = self._evaluation_action_selector(state, begin_of_episode=begin_of_episode) if self._config.normalize_action: - std = compute_std_ndarray(np.array(self._config.action_var, dtype=np.float32), epsilon=0.0, - mode_for_floating_point_error="max") + std = compute_std_ndarray( + np.array(self._config.action_var, dtype=np.float32), + epsilon=0.0, + mode_for_floating_point_error="max", + ) action = unnormalize_ndarray(action, np.array(self._config.action_mean, dtype=np.float32), std) return action @@ -516,23 +561,27 @@ def _setup_policy_training(self, env_or_buffer): action_bound_loss_coefficient=self._config.action_bound_loss_coefficient, action_mean=self._config.action_mean, action_var=self._config.action_var, - regularization_coefficient=self._config.policy_weight_decay) + regularization_coefficient=self._config.policy_weight_decay, + ) policy_trainer = MT.policy_trainers.AMPPolicyTrainer( models=self._policy, solvers={self._policy.scope_name: self._policy_solver}, env_info=self._env_info, - config=policy_trainer_config) + config=policy_trainer_config, + ) return policy_trainer def _setup_v_function_training(self, env_or_buffer): # training input/loss variables - v_function_trainer_config = MT.v_value_trainers.MonteCarloVTrainerConfig(reduction_method="mean", - v_loss_scalar=0.5) + v_function_trainer_config = MT.v_value_trainers.MonteCarloVTrainerConfig( + reduction_method="mean", v_loss_scalar=0.5 + ) v_function_trainer = MT.v_value_trainers.MonteCarloVTrainer( train_functions=self._v_function, solvers={self._v_function.scope_name: self._v_function_solver}, env_info=self._env_info, - config=v_function_trainer_config) + config=v_function_trainer_config, + ) return v_function_trainer def _setup_reward_function_training(self, env_or_buffer): @@ -542,16 +591,21 @@ def _setup_reward_function_training(self, env_or_buffer): extra_regularization_coefficient=self._config.discriminator_extra_regularization_coefficient, extra_regularization_variable_names=self._config.discriminator_extra_regularization_variable_names, gradient_penelty_coefficient=self._config.discriminator_gradient_penelty_coefficient, - gradient_penalty_indexes=self._config.discriminator_gradient_penalty_indexes) - model = self._discriminator._reward_function if isinstance( - self._discriminator, _StatePreprocessedRewardFunction) else self._discriminator + gradient_penalty_indexes=self._config.discriminator_gradient_penalty_indexes, + ) + model = ( + self._discriminator._reward_function + if isinstance(self._discriminator, _StatePreprocessedRewardFunction) + else self._discriminator + ) preprocessor = self._discriminator_state_preprocessor if self._config.preprocess_state else None reward_function_trainer = MT.reward_trainiers.AMPRewardFunctionTrainer( models=model, solvers={self._discriminator.scope_name: self._discriminator_solver}, env_info=self._env_info, state_preprocessor=preprocessor, - config=reward_function_trainer_config) + config=reward_function_trainer_config, + ) return reward_function_trainer def _after_training_finish(self, env_or_buffer): @@ -561,20 +615,23 @@ def _after_training_finish(self, env_or_buffer): self._kill_actor_processes(process) def _launch_actor_processes(self, env): - actors = self._build_amp_actors(env=env, - policy=self._policy, - v_function=self._v_function, - state_preprocessor=(self._pi_v_state_preprocessor if - self._config.preprocess_state else None), - reward_function=self._discriminator, - reward_state_preprocessor=(self._discriminator_state_preprocessor if - self._config.preprocess_state else None), - env_explorer=self._env_explorer_builder( - self._env_info, - self._config, - self, - greedy_action_selector=self._compute_greedy_action, - random_action_selector=self._compute_explore_action)) + actors = self._build_amp_actors( + env=env, + policy=self._policy, + v_function=self._v_function, + state_preprocessor=(self._pi_v_state_preprocessor if self._config.preprocess_state else None), + reward_function=self._discriminator, + reward_state_preprocessor=( + self._discriminator_state_preprocessor if self._config.preprocess_state else None + ), + env_explorer=self._env_explorer_builder( + self._env_info, + self._config, + self, + greedy_action_selector=self._compute_greedy_action, + random_action_selector=self._compute_explore_action, + ), + ) processes = [] for actor in actors: if self._config.actor_num == 1: @@ -586,20 +643,23 @@ def _launch_actor_processes(self, env): processes.append(p) return actors, processes - def _build_amp_actors(self, env, policy, v_function, state_preprocessor, reward_function, - reward_state_preprocessor, env_explorer): + def _build_amp_actors( + self, env, policy, v_function, state_preprocessor, reward_function, reward_state_preprocessor, env_explorer + ): actors = [] for i in range(self._config.actor_num): - actor = _AMPActor(actor_num=i, - env=env, - env_info=self._env_info, - policy=policy, - v_function=v_function, - state_preprocessor=state_preprocessor, - reward_function=reward_function, - reward_state_preprocessor=reward_state_preprocessor, - config=self._config, - env_explorer=env_explorer) + actor = _AMPActor( + actor_num=i, + env=env, + env_info=self._env_info, + policy=policy, + v_function=v_function, + state_preprocessor=state_preprocessor, + reward_function=reward_function, + reward_state_preprocessor=reward_state_preprocessor, + config=self._config, + env_explorer=env_explorer, + ) actors.append(actor) return actors @@ -619,8 +679,12 @@ def _run_online_training_iteration(self, env): if update_interval < self.iteration_num: # NOTE: The first update (when update_interval == self.iteration_num) will be skipped policy_buffers, v_function_buffers = self._create_policy_and_v_function_buffers(experiences_per_agent) - self._amp_training(self._discriminator_agent_replay_buffers, self._discriminator_expert_replay_buffers, - policy_buffers, v_function_buffers) + self._amp_training( + self._discriminator_agent_replay_buffers, + self._discriminator_expert_replay_buffers, + policy_buffers, + v_function_buffers, + ) if self._config.preprocess_state and self.iteration_num < self._config.num_processor_samples: self._pi_v_state_preprocessor.update(s) @@ -641,7 +705,8 @@ def split_result(tuple_val): actor.update_state_preprocessor_params(casted_pi_v_state_preprocessor.get_parameters()) casted_discriminator_state_preprocessor = cast(Model, self._discriminator_state_preprocessor) actor.update_reward_state_preprocessor_params( - casted_discriminator_state_preprocessor.get_parameters()) + casted_discriminator_state_preprocessor.get_parameters() + ) else: # Its running on same process. No need to synchronize parameters with multiprocessing arrays. pass @@ -664,8 +729,10 @@ def split_result(tuple_val): experience = [ (s, a, r, non_terminal, n_s, log_prob, non_greedy, e_s, e_a, e_s_next, v_target, advantage) - for (s, a, r, non_terminal, n_s, log_prob, non_greedy, e_s, e_a, e_s_next, v_target, advantage) - in zip(*splitted_result)] + for (s, a, r, non_terminal, n_s, log_prob, non_greedy, e_s, e_a, e_s_next, v_target, advantage) in zip( + *splitted_result + ) + ] assert len(experience) == self._config.actor_timesteps experiences_per_agent.append(experience) @@ -675,9 +742,9 @@ def split_result(tuple_val): def _add_experience_to_reward_buffers(self, experience_per_agent): assert len(self._discriminator_agent_replay_buffers) == len(experience_per_agent) assert len(self._discriminator_expert_replay_buffers) == len(experience_per_agent) - for agent_buffer, expert_buffer, experience in zip(self._discriminator_agent_replay_buffers, - self._discriminator_expert_replay_buffers, - experience_per_agent): + for agent_buffer, expert_buffer, experience in zip( + self._discriminator_agent_replay_buffers, self._discriminator_expert_replay_buffers, experience_per_agent + ): agent_buffer.append_all(experience) expert_buffer.append_all(experience) @@ -710,8 +777,13 @@ def _kill_actor_processes(self, process): def _run_offline_training_iteration(self, buffer): raise NotImplementedError - def _amp_training(self, discriminator_agent_replay_buffers, discriminator_expert_replay_buffers, - policy_replay_buffers, v_function_replay_buffers): + def _amp_training( + self, + discriminator_agent_replay_buffers, + discriminator_expert_replay_buffers, + policy_replay_buffers, + v_function_replay_buffers, + ): self._reward_function_training(discriminator_agent_replay_buffers, discriminator_expert_replay_buffers) total_updates = (self._config.actor_num * self._config.actor_timesteps) // self._config.batch_size @@ -723,13 +795,14 @@ def _reward_function_training(self, agent_buffers: List[ReplayBuffer], expert_bu for _ in range(self._config.discriminator_epochs): for _ in range(num_updates): agent_experiences = _sample_experiences_from_buffers( - agent_buffers, self._config.discriminator_batch_size) + agent_buffers, self._config.discriminator_batch_size + ) (s_agent, a_agent, _, _, s_next_agent, *_) = marshal_experiences(agent_experiences) expert_experiences = _sample_experiences_from_buffers( - expert_buffers, self._config.discriminator_batch_size) - (_, _, _, _, _, _, _, s_expert, a_expert, s_next_expert, _, _) = marshal_experiences( - expert_experiences) + expert_buffers, self._config.discriminator_batch_size + ) + (_, _, _, _, _, _, _, s_expert, a_expert, s_next_expert, _, _) = marshal_experiences(expert_experiences) extra = {} extra["s_current_agent"] = s_agent @@ -744,7 +817,8 @@ def _reward_function_training(self, agent_buffers: List[ReplayBuffer], expert_bu def _v_function_training(self, total_updates, v_function_replay_buffers): v_function_buffer_iterator = _EquallySampleBufferIterator( - total_updates, v_function_replay_buffers, self._config.batch_size) + total_updates, v_function_replay_buffers, self._config.batch_size + ) for _ in range(self._config.epochs): for experiences in v_function_buffer_iterator: (s, a, _, _, _, _, v_target, _, _) = marshal_experiences(experiences) @@ -755,7 +829,8 @@ def _v_function_training(self, total_updates, v_function_replay_buffers): def _policy_training(self, total_updates, policy_replay_buffers): policy_buffer_iterator = _EquallySampleBufferIterator( - total_updates, policy_replay_buffers, self._config.batch_size) + total_updates, policy_replay_buffers, self._config.batch_size + ) for _ in range(self._config.epochs): for experiences in policy_buffer_iterator: (s, a, _, _, _, log_prob, _, advantage, _) = marshal_experiences(experiences) @@ -788,8 +863,9 @@ def _solvers(self): @classmethod def is_supported_env(cls, env_or_env_info): - env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance( - env_or_env_info, gym.Env) else env_or_env_info + env_info = ( + EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info + ) return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env() @eval_api @@ -854,9 +930,11 @@ def latest_iteration_state(self): @property def trainers(self): - return {"discriminator": self._discriminator_trainer, - "v_function": self._v_function_trainer, - "policy": self._policy_trainer} + return { + "discriminator": self._discriminator_trainer, + "v_function": self._v_function_trainer, + "policy": self._policy_trainer, + } def _sample_experiences_from_buffers(buffers: List[ReplayBuffer], batch_size: int) -> List[Experience]: @@ -880,17 +958,19 @@ def _concatenate_state(experiences_per_agent) -> Tuple[np.ndarray, np.ndarray]: class _AMPActor: - def __init__(self, - actor_num: int, - env: gym.Env, - env_info: EnvironmentInfo, - policy: StochasticPolicy, - v_function: VFunction, - state_preprocessor: Optional[Preprocessor], - reward_function: RewardFunction, - reward_state_preprocessor: Optional[Preprocessor], - env_explorer: EnvironmentExplorer, - config: AMPConfig): + def __init__( + self, + actor_num: int, + env: gym.Env, + env_info: EnvironmentInfo, + policy: StochasticPolicy, + v_function: VFunction, + state_preprocessor: Optional[Preprocessor], + reward_function: RewardFunction, + reward_state_preprocessor: Optional[Preprocessor], + env_explorer: EnvironmentExplorer, + config: AMPConfig, + ): # These variables will be copied when process is created self._actor_num = actor_num self._env = env @@ -925,60 +1005,78 @@ def __init__(self, casted_reward_state_preprocessor.get_parameters() ) - MultiProcessingArrays = namedtuple("MultiProcessingArrays", - ["state", - "action", - "reward", - "non_terminal", - "next_state", - "log_prob", - "non_greedy_action", - "expert_state", - "expert_action", - "expert_next_state", - "v_target", - "advantage"]) + MultiProcessingArrays = namedtuple( + "MultiProcessingArrays", + [ + "state", + "action", + "reward", + "non_terminal", + "next_state", + "log_prob", + "non_greedy_action", + "expert_state", + "expert_action", + "expert_next_state", + "v_target", + "advantage", + ], + ) state_mp_array = self._prepare_state_mp_array(env_info.observation_space, env_info) action_mp_array = self._prepare_action_mp_array(env_info.action_space, env_info) scalar_mp_array_shape = (self._timesteps, 1) - reward_mp_array = (mp_array_from_np_array(np.empty(shape=scalar_mp_array_shape, dtype=np.float32)), - scalar_mp_array_shape, - np.float32) - non_terminal_mp_array = (mp_array_from_np_array(np.empty(shape=scalar_mp_array_shape, dtype=np.float32)), - scalar_mp_array_shape, - np.float32) + reward_mp_array = ( + mp_array_from_np_array(np.empty(shape=scalar_mp_array_shape, dtype=np.float32)), + scalar_mp_array_shape, + np.float32, + ) + non_terminal_mp_array = ( + mp_array_from_np_array(np.empty(shape=scalar_mp_array_shape, dtype=np.float32)), + scalar_mp_array_shape, + np.float32, + ) next_state_mp_array = self._prepare_state_mp_array(env_info.observation_space, env_info) - log_prob_mp_array = (mp_array_from_np_array(np.empty(shape=scalar_mp_array_shape, dtype=np.float32)), - scalar_mp_array_shape, - np.float32) - non_greedy_action_mp_array = (mp_array_from_np_array(np.empty(shape=scalar_mp_array_shape, dtype=np.float32)), - scalar_mp_array_shape, - np.float32) - v_target_mp_array = (mp_array_from_np_array(np.empty(shape=scalar_mp_array_shape, dtype=np.float32)), - scalar_mp_array_shape, - np.float32) - advantage_mp_array = (mp_array_from_np_array(np.empty(shape=scalar_mp_array_shape, dtype=np.float32)), - scalar_mp_array_shape, - np.float32) + log_prob_mp_array = ( + mp_array_from_np_array(np.empty(shape=scalar_mp_array_shape, dtype=np.float32)), + scalar_mp_array_shape, + np.float32, + ) + non_greedy_action_mp_array = ( + mp_array_from_np_array(np.empty(shape=scalar_mp_array_shape, dtype=np.float32)), + scalar_mp_array_shape, + np.float32, + ) + v_target_mp_array = ( + mp_array_from_np_array(np.empty(shape=scalar_mp_array_shape, dtype=np.float32)), + scalar_mp_array_shape, + np.float32, + ) + advantage_mp_array = ( + mp_array_from_np_array(np.empty(shape=scalar_mp_array_shape, dtype=np.float32)), + scalar_mp_array_shape, + np.float32, + ) expert_state_mp_array = self._prepare_state_mp_array(env_info.observation_space, env_info) expert_action_mp_array = self._prepare_action_mp_array(env_info.action_space, env_info) expert_next_state_mp_array = self._prepare_state_mp_array(env_info.observation_space, env_info) - self._mp_arrays = MultiProcessingArrays(state_mp_array, - action_mp_array, - reward_mp_array, - non_terminal_mp_array, - next_state_mp_array, - log_prob_mp_array, - non_greedy_action_mp_array, - expert_state_mp_array, - expert_action_mp_array, - expert_next_state_mp_array, - v_target_mp_array, - advantage_mp_array) + self._mp_arrays = MultiProcessingArrays( + state_mp_array, + action_mp_array, + reward_mp_array, + non_terminal_mp_array, + next_state_mp_array, + log_prob_mp_array, + non_greedy_action_mp_array, + expert_state_mp_array, + expert_action_mp_array, + expert_next_state_mp_array, + v_target_mp_array, + advantage_mp_array, + ) self._reward_min = np.inf self._reward_max = -np.inf @@ -1057,8 +1155,9 @@ def _fill_result(self, experiences): indexes = np.arange(len(experiences)) drng.shuffle(indexes) experiences = [experiences[i] for i in indexes[: self._config.actor_timesteps]] - (s, a, r, non_terminal, s_next, log_prob, non_greedy_action, e_s, e_a, - e_s_next, v_target, advantage) = marshal_experiences(experiences) + (s, a, r, non_terminal, s_next, log_prob, non_greedy_action, e_s, e_a, e_s_next, v_target, advantage) = ( + marshal_experiences(experiences) + ) _copy_np_array_to_mp_array(s, self._mp_arrays.state) _copy_np_array_to_mp_array(a, self._mp_arrays.action) @@ -1086,13 +1185,15 @@ def _run_data_collection(self): lmb=self._config.lmb, value_clip=(self._reward_min / (1.0 - self._config.gamma), self._reward_max / (1.0 - self._config.gamma)), value_at_task_fail=self._config.value_at_task_fail, - value_at_task_success=self._config.value_at_task_success) + value_at_task_success=self._config.value_at_task_success, + ) assert self._config.target_value_clip[0] < self._config.target_value_clip[1] v_targets = np.clip(v_targets, a_min=self._config.target_value_clip[0], a_max=self._config.target_value_clip[1]) advantage_std = compute_std_ndarray(np.var(advantages), epsilon=1e-5, mode_for_floating_point_error="add") - advantages = normalize_ndarray(advantages, mean=np.mean( - advantages), std=advantage_std, value_clip=self._config.normalized_advantage_clip) + advantages = normalize_ndarray( + advantages, mean=np.mean(advantages), std=advantage_std, value_clip=self._config.normalized_advantage_clip + ) assert len(experiences) == len(v_targets) assert len(experiences) == len(advantages) @@ -1101,18 +1202,22 @@ def _run_data_collection(self): expert_s, expert_a, _, _, expert_n_s, _ = info["expert_experience"] assert "greedy_action" in info - organized_experiences.append((s, - a, - r, - non_terminal, - s_next, - info["log_prob"], - 0.0 if info["greedy_action"] else 1.0, - expert_s, - expert_a, - expert_n_s, - v_target, - advantage)) + organized_experiences.append( + ( + s, + a, + r, + non_terminal, + s_next, + info["log_prob"], + 0.0 if info["greedy_action"] else 1.0, + expert_s, + expert_a, + expert_n_s, + v_target, + advantage, + ) + ) return organized_experiences @@ -1135,8 +1240,9 @@ def _compute_rewards(self, experiences: List[Experience]) -> List[float]: self._reward_var.forward() if self._config.use_reward_from_env: - reward = (1.0 - self._config.lerp_reward_coefficient) * \ - float(self._reward_var.d) + self._config.lerp_reward_coefficient * float(env_r) + reward = (1.0 - self._config.lerp_reward_coefficient) * float( + self._reward_var.d + ) + self._config.lerp_reward_coefficient * float(env_r) else: reward = float(self._reward_var.d) rewards.append(reward) @@ -1168,29 +1274,27 @@ def _prepare_state_mp_array(self, obs_space, env_info): state_mp_array_dtypes = [] for space in obs_space: state_mp_array_shape = (self._timesteps, *space.shape) - state_mp_array = mp_array_from_np_array( - np.empty(shape=state_mp_array_shape, dtype=space.dtype)) + state_mp_array = mp_array_from_np_array(np.empty(shape=state_mp_array_shape, dtype=space.dtype)) state_mp_array_shapes.append(state_mp_array_shape) state_mp_array_dtypes.append(space.dtype) state_mp_arrays.append(state_mp_array) return tuple(x for x in zip(state_mp_arrays, state_mp_array_shapes, state_mp_array_dtypes)) else: state_mp_array_shape = (self._timesteps, *obs_space.shape) - state_mp_array = mp_array_from_np_array( - np.empty(shape=state_mp_array_shape, dtype=obs_space.dtype)) + state_mp_array = mp_array_from_np_array(np.empty(shape=state_mp_array_shape, dtype=obs_space.dtype)) return (state_mp_array, state_mp_array_shape, obs_space.dtype) def _prepare_action_mp_array(self, action_space, env_info): action_mp_array_shape = (self._timesteps, action_space.shape[0]) - action_mp_array = mp_array_from_np_array( - np.empty(shape=action_mp_array_shape, dtype=action_space.dtype)) + action_mp_array = mp_array_from_np_array(np.empty(shape=action_mp_array_shape, dtype=action_space.dtype)) return (action_mp_array, action_mp_array_shape, action_space.dtype) def _copy_np_array_to_mp_array( np_array: Union[np.ndarray, Tuple[np.ndarray]], - mp_array_shape_type: Union[Tuple[np.ndarray, Tuple[int, ...], np.dtype], - Tuple[Tuple[np.ndarray, Tuple[int, ...], np.dtype]]], + mp_array_shape_type: Union[ + Tuple[np.ndarray, Tuple[int, ...], np.dtype], Tuple[Tuple[np.ndarray, Tuple[int, ...], np.dtype]] + ], ): """Copy numpy array to multiprocessing array. @@ -1210,15 +1314,16 @@ def _copy_np_array_to_mp_array( raise ValueError("Invalid pair of np_array and mp_array!") -def _compute_v_target_and_advantage_with_clipping_and_overwriting(v_function: VFunction, - experiences: List[Experience], - rewards: List[float], - gamma: float, - lmb: float, - value_at_task_fail: float, - value_at_task_success: float, - value_clip: Optional[Tuple[float, float]] = None - ) -> Tuple[np.ndarray, np.ndarray]: +def _compute_v_target_and_advantage_with_clipping_and_overwriting( + v_function: VFunction, + experiences: List[Experience], + rewards: List[float], + gamma: float, + lmb: float, + value_at_task_fail: float, + value_at_task_success: float, + value_clip: Optional[Tuple[float, float]] = None, +) -> Tuple[np.ndarray, np.ndarray]: assert isinstance(v_function, VFunction), "Invalid v_function" if value_clip is not None: assert value_clip[0] < value_clip[1] @@ -1293,11 +1398,11 @@ def __init__(self, buffer, batch_size, shuffle=True): super().__init__(buffer, batch_size, shuffle, repeat=True) def next(self): - indices = self._indices[self._index:self._index + self._batch_size] - if (len(indices) < self._batch_size): + indices = self._indices[self._index : self._index + self._batch_size] + if len(indices) < self._batch_size: rest = self._batch_size - len(indices) self.reset() - indices = np.append(indices, self._indices[self._index:self._index + rest]) + indices = np.append(indices, self._indices[self._index : self._index + rest]) self._index += rest else: self._index += self._batch_size @@ -1306,13 +1411,13 @@ def next(self): __next__ = next -class _EquallySampleBufferIterator(): +class _EquallySampleBufferIterator: def __init__(self, total_num_iterations: int, replay_buffers: List[ReplayBuffer], batch_size: int): buffer_batch_size = int(np.ceil(batch_size / len(replay_buffers))) - self._internal_iterators = [_EndlessBufferIterator(buffer=buffer, - shuffle=False, - batch_size=buffer_batch_size) - for buffer in replay_buffers] + self._internal_iterators = [ + _EndlessBufferIterator(buffer=buffer, shuffle=False, batch_size=buffer_batch_size) + for buffer in replay_buffers + ] self._total_num_iterations = total_num_iterations self._replay_buffers = replay_buffers self._batch_size = batch_size @@ -1344,6 +1449,6 @@ def _sample(self): if len(experiences) > self._batch_size: drng.shuffle(experiences) - experiences = experiences[:self._batch_size] + experiences = experiences[: self._batch_size] return experiences diff --git a/nnabla_rl/algorithms/atrpo.py b/nnabla_rl/algorithms/atrpo.py index 713f1106..3ca68963 100644 --- a/nnabla_rl/algorithms/atrpo.py +++ b/nnabla_rl/algorithms/atrpo.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,8 +24,12 @@ import nnabla_rl.model_trainers as MT import nnabla_rl.preprocessors as RP from nnabla_rl.algorithm import Algorithm, AlgorithmConfig, eval_api -from nnabla_rl.algorithms.common_utils import (_StatePreprocessedStochasticPolicy, _StatePreprocessedVFunction, - _StochasticPolicyActionSelector, compute_average_v_target_and_advantage) +from nnabla_rl.algorithms.common_utils import ( + _StatePreprocessedStochasticPolicy, + _StatePreprocessedVFunction, + _StochasticPolicyActionSelector, + compute_average_v_target_and_advantage, +) from nnabla_rl.builders import ExplorerBuilder, ModelBuilder, PreprocessorBuilder, SolverBuilder from nnabla_rl.environment_explorer import EnvironmentExplorer from nnabla_rl.environments.environment_info import EnvironmentInfo @@ -74,6 +78,7 @@ class ATRPOConfig(AlgorithmConfig): learning_rate_decay_iterations (int): learning rate will be decreased lineary to 0 till this iteration number. If 0 or negative, learning rate will be kept fixed. Defaults to 10000000. """ + lmb: float = 0.95 num_steps_per_iteration: int = 5000 pi_batch_size: int = 5000 @@ -84,8 +89,8 @@ class ATRPOConfig(AlgorithmConfig): conjugate_gradient_iterations: int = 10 vf_epochs: int = 5 vf_batch_size: int = 64 - vf_learning_rate: float = 3. * 1e-4 - vf_l2_reg_coefficient: float = 3. * 1e-3 + vf_learning_rate: float = 3.0 * 1e-4 + vf_l2_reg_coefficient: float = 3.0 * 1e-3 preprocess_state: bool = True gpu_batch_size: Optional[int] = None learning_rate_decay_iterations: int = 10000000 @@ -95,69 +100,75 @@ def __post_init__(self): Check the values are in valid range. """ - self._assert_between(self.pi_batch_size, 0, self.num_steps_per_iteration, 'pi_batch_size') - self._assert_between(self.lmb, 0.0, 1.0, 'lmb') - self._assert_positive(self.num_steps_per_iteration, 'num_steps_per_iteration') - self._assert_between(self.pi_batch_size, 0, self.num_steps_per_iteration, 'pi_batch_size') - self._assert_positive(self.sigma_kl_divergence_constraint, 'sigma_kl_divergence_constraint') - self._assert_positive(self.maximum_backtrack_numbers, 'maximum_backtrack_numbers') - self._assert_positive(self.backtrack_coefficient, 'backtrack_coefficient') - self._assert_positive(self.conjugate_gradient_damping, 'conjugate_gradient_damping') - self._assert_positive(self.conjugate_gradient_iterations, 'conjugate_gradient_iterations') - self._assert_positive(self.vf_epochs, 'vf_epochs') - self._assert_positive(self.vf_batch_size, 'vf_batch_size') - self._assert_positive(self.vf_learning_rate, 'vf_learning_rate') - self._assert_positive(self.vf_l2_reg_coefficient, 'vf_l2_reg_coefficient') + self._assert_between(self.pi_batch_size, 0, self.num_steps_per_iteration, "pi_batch_size") + self._assert_between(self.lmb, 0.0, 1.0, "lmb") + self._assert_positive(self.num_steps_per_iteration, "num_steps_per_iteration") + self._assert_between(self.pi_batch_size, 0, self.num_steps_per_iteration, "pi_batch_size") + self._assert_positive(self.sigma_kl_divergence_constraint, "sigma_kl_divergence_constraint") + self._assert_positive(self.maximum_backtrack_numbers, "maximum_backtrack_numbers") + self._assert_positive(self.backtrack_coefficient, "backtrack_coefficient") + self._assert_positive(self.conjugate_gradient_damping, "conjugate_gradient_damping") + self._assert_positive(self.conjugate_gradient_iterations, "conjugate_gradient_iterations") + self._assert_positive(self.vf_epochs, "vf_epochs") + self._assert_positive(self.vf_batch_size, "vf_batch_size") + self._assert_positive(self.vf_learning_rate, "vf_learning_rate") + self._assert_positive(self.vf_l2_reg_coefficient, "vf_l2_reg_coefficient") class DefaultPolicyBuilder(ModelBuilder[StochasticPolicy]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: ATRPOConfig, - **kwargs) -> StochasticPolicy: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: ATRPOConfig, + **kwargs, + ) -> StochasticPolicy: return ATRPOPolicy(scope_name, env_info.action_dim) class DefaultVFunctionBuilder(ModelBuilder[VFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: ATRPOConfig, - **kwargs) -> VFunction: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: ATRPOConfig, + **kwargs, + ) -> VFunction: return ATRPOVFunction(scope_name) class DefaultSolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: ATRPOConfig, - **kwargs) -> nn.solver.Solver: + def build_solver( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: ATRPOConfig, **kwargs + ) -> nn.solver.Solver: return NS.Adam(alpha=algorithm_config.vf_learning_rate) class DefaultPreprocessorBuilder(PreprocessorBuilder): - def build_preprocessor(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: ATRPOConfig, - **kwargs) -> Preprocessor: + def build_preprocessor( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: ATRPOConfig, + **kwargs, + ) -> Preprocessor: return RP.RunningMeanNormalizer(scope_name, env_info.state_shape, value_clip=(-5.0, 5.0)) class DefaultExplorerBuilder(ExplorerBuilder): - def build_explorer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: ATRPOConfig, - algorithm: "ATRPO", - **kwargs) -> EnvironmentExplorer: + def build_explorer( # type: ignore[override] + self, + env_info: EnvironmentInfo, + algorithm_config: ATRPOConfig, + algorithm: "ATRPO", + **kwargs, + ) -> EnvironmentExplorer: explorer_config = EE.RawPolicyExplorerConfig( - initial_step_num=algorithm.iteration_num, - timelimit_as_terminal=True + initial_step_num=algorithm.iteration_num, timelimit_as_terminal=True + ) + explorer = EE.RawPolicyExplorer( + policy_action_selector=algorithm._exploration_action_selector, env_info=env_info, config=explorer_config ) - explorer = EE.RawPolicyExplorer(policy_action_selector=algorithm._exploration_action_selector, - env_info=env_info, - config=explorer_config) return explorer @@ -207,25 +218,27 @@ class ATRPO(Algorithm): _policy_trainer_state: Dict[str, Any] _v_function_trainer_state: Dict[str, Any] - def __init__(self, - env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: ATRPOConfig = ATRPOConfig(), - v_function_builder: ModelBuilder[VFunction] = DefaultVFunctionBuilder(), - v_solver_builder: SolverBuilder = DefaultSolverBuilder(), - policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), - state_preprocessor_builder: Optional[PreprocessorBuilder] = DefaultPreprocessorBuilder(), - explorer_builder: ExplorerBuilder = DefaultExplorerBuilder()): + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: ATRPOConfig = ATRPOConfig(), + v_function_builder: ModelBuilder[VFunction] = DefaultVFunctionBuilder(), + v_solver_builder: SolverBuilder = DefaultSolverBuilder(), + policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), + state_preprocessor_builder: Optional[PreprocessorBuilder] = DefaultPreprocessorBuilder(), + explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), + ): super(ATRPO, self).__init__(env_or_env_info, config=config) self._explorer_builder = explorer_builder with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)): - self._v_function = v_function_builder('v', self._env_info, self._config) - self._policy = policy_builder('pi', self._env_info, self._config) + self._v_function = v_function_builder("v", self._env_info, self._config) + self._policy = policy_builder("pi", self._env_info, self._config) self._preprocessor: Optional[Preprocessor] = None if self._config.preprocess_state and state_preprocessor_builder is not None: - preprocessor = state_preprocessor_builder('preprocessor', self._env_info, self._config) + preprocessor = state_preprocessor_builder("preprocessor", self._env_info, self._config) assert preprocessor is not None self._v_function = _StatePreprocessedVFunction(v_function=self._v_function, preprocessor=preprocessor) self._policy = _StatePreprocessedStochasticPolicy(policy=self._policy, preprocessor=preprocessor) @@ -235,14 +248,16 @@ def __init__(self, self._v_function_solver = solver self._evaluation_actor = _StochasticPolicyActionSelector( - self._env_info, self._policy.shallowcopy(), deterministic=True) + self._env_info, self._policy.shallowcopy(), deterministic=True + ) self._exploration_actor = _StochasticPolicyActionSelector( - self._env_info, self._policy.shallowcopy(), deterministic=False) + self._env_info, self._policy.shallowcopy(), deterministic=False + ) @eval_api def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}): with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)): - action, _ = self._evaluation_action_selector(state, begin_of_episode=begin_of_episode) + action, _ = self._evaluation_action_selector(state, begin_of_episode=begin_of_episode) return action def _before_training_start(self, env_or_buffer): @@ -256,15 +271,13 @@ def _setup_environment_explorer(self, env_or_buffer): return None if self._is_buffer(env_or_buffer) else self._explorer_builder(self._env_info, self._config, self) def _setup_v_function_training(self, env_or_buffer): - v_function_trainer_config = MT.v_value.MonteCarloVTrainerConfig( - reduction_method='mean', - v_loss_scalar=1.0 - ) + v_function_trainer_config = MT.v_value.MonteCarloVTrainerConfig(reduction_method="mean", v_loss_scalar=1.0) v_function_trainer = MT.v_value.MonteCarloVTrainer( train_functions=self._v_function, solvers={self._v_function.scope_name: self._v_function_solver}, env_info=self._env_info, - config=v_function_trainer_config) + config=v_function_trainer_config, + ) return v_function_trainer def _setup_policy_training(self, env_or_buffer): @@ -274,11 +287,11 @@ def _setup_policy_training(self, env_or_buffer): maximum_backtrack_numbers=self._config.maximum_backtrack_numbers, conjugate_gradient_damping=self._config.conjugate_gradient_damping, conjugate_gradient_iterations=self._config.conjugate_gradient_iterations, - backtrack_coefficient=self._config.backtrack_coefficient) + backtrack_coefficient=self._config.backtrack_coefficient, + ) policy_trainer = MT.policy_trainers.TRPOPolicyTrainer( - model=self._policy, - env_info=self._env_info, - config=policy_trainer_config) + model=self._policy, env_info=self._env_info, config=policy_trainer_config + ) return policy_trainer def _run_online_training_iteration(self, env): @@ -316,10 +329,12 @@ def _align_experiences(self, buffer_iterator): s_batch, a_batch = self._align_state_and_action(buffer_iterator) - return s_batch[:self._config.num_steps_per_iteration], \ - a_batch[:self._config.num_steps_per_iteration], \ - v_target_batch[:self._config.num_steps_per_iteration], \ - adv_batch[:self._config.num_steps_per_iteration] + return ( + s_batch[: self._config.num_steps_per_iteration], + a_batch[: self._config.num_steps_per_iteration], + v_target_batch[: self._config.num_steps_per_iteration], + adv_batch[: self._config.num_steps_per_iteration], + ) def _compute_v_target_and_advantage(self, buffer_iterator): v_target_batch = [] @@ -328,7 +343,8 @@ def _compute_v_target_and_advantage(self, buffer_iterator): for experiences, _ in buffer_iterator: # length of experiences is 1 v_target, adv = compute_average_v_target_and_advantage( - self._v_function, experiences[0], lmb=self._config.lmb) + self._v_function, experiences[0], lmb=self._config.lmb + ) v_target_batch.append(v_target.reshape(-1, 1)) adv_batch.append(adv.reshape(-1, 1)) @@ -366,19 +382,21 @@ def _v_function_training(self, s, v_target): for _ in range(self._config.vf_epochs * num_iterations_per_epoch): indices = np.random.randint(0, self._config.num_steps_per_iteration, size=self._config.vf_batch_size) - batch = TrainingBatch(batch_size=self._config.vf_batch_size, - s_current=s[indices], - extra={'v_target': v_target[indices]}) + batch = TrainingBatch( + batch_size=self._config.vf_batch_size, s_current=s[indices], extra={"v_target": v_target[indices]} + ) self._v_function_trainer_state = self._v_function_trainer.train(batch) def _policy_training(self, s, a, v_target, advantage): extra = {} - extra['v_target'] = v_target[:self._config.pi_batch_size] - extra['advantage'] = advantage[:self._config.pi_batch_size] - batch = TrainingBatch(batch_size=self._config.pi_batch_size, - s_current=s[:self._config.pi_batch_size], - a_current=a[:self._config.pi_batch_size], - extra=extra) + extra["v_target"] = v_target[: self._config.pi_batch_size] + extra["advantage"] = advantage[: self._config.pi_batch_size] + batch = TrainingBatch( + batch_size=self._config.pi_batch_size, + s_current=s[: self._config.pi_batch_size], + a_current=a[: self._config.pi_batch_size], + extra=extra, + ) self._policy_trainer_state = self._policy_trainer.train(batch) @@ -403,15 +421,16 @@ def _solvers(self): @classmethod def is_supported_env(cls, env_or_env_info): - env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \ - else env_or_env_info + env_info = ( + EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info + ) return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env() @property def latest_iteration_state(self): latest_iteration_state = super(ATRPO, self).latest_iteration_state - if hasattr(self, '_v_function_trainer_state'): - latest_iteration_state['scalar'].update({'v_loss': self._v_function_trainer_state['v_loss']}) + if hasattr(self, "_v_function_trainer_state"): + latest_iteration_state["scalar"].update({"v_loss": self._v_function_trainer_state["v_loss"]}) return latest_iteration_state @property diff --git a/nnabla_rl/algorithms/bcq.py b/nnabla_rl/algorithms/bcq.py index 2648aadd..af9e7d03 100644 --- a/nnabla_rl/algorithms/bcq.py +++ b/nnabla_rl/algorithms/bcq.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,8 +27,15 @@ from nnabla_rl.builders import ModelBuilder, SolverBuilder from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.model_trainers.model_trainer import ModelTrainer, TrainingBatch -from nnabla_rl.models import (BCQPerturbator, BCQVariationalAutoEncoder, DeterministicPolicy, Perturbator, QFunction, - TD3QFunction, VariationalAutoEncoder) +from nnabla_rl.models import ( + BCQPerturbator, + BCQVariationalAutoEncoder, + DeterministicPolicy, + Perturbator, + QFunction, + TD3QFunction, + VariationalAutoEncoder, +) from nnabla_rl.utils import context from nnabla_rl.utils.data import add_batch_dimension, marshal_experiences, set_data_to_variable from nnabla_rl.utils.misc import create_variable, sync_model @@ -53,8 +60,9 @@ class BCQConfig(AlgorithmConfig): num_q_ensembles (int): number of q function ensembles . Defaults to 2. num_action_samples (int): number of actions to sample for computing target q values. Defaults to 10. """ + gamma: float = 0.99 - learning_rate: float = 1.0*1e-3 + learning_rate: float = 1.0 * 1e-3 batch_size: int = 100 tau: float = 0.005 lmb: float = 0.75 @@ -67,56 +75,54 @@ def __post_init__(self): Check set values are in valid range. """ - self._assert_between(self.tau, 0.0, 1.0, 'tau') - self._assert_between(self.gamma, 0.0, 1.0, 'gamma') - self._assert_positive(self.lmb, 'lmb') - self._assert_positive(self.phi, 'phi') - self._assert_positive(self.num_q_ensembles, 'num_q_ensembles') - self._assert_positive(self.num_action_samples, 'num_action_samples') - self._assert_positive(self.batch_size, 'batch_size') + self._assert_between(self.tau, 0.0, 1.0, "tau") + self._assert_between(self.gamma, 0.0, 1.0, "gamma") + self._assert_positive(self.lmb, "lmb") + self._assert_positive(self.phi, "phi") + self._assert_positive(self.num_q_ensembles, "num_q_ensembles") + self._assert_positive(self.num_action_samples, "num_action_samples") + self._assert_positive(self.batch_size, "batch_size") class DefaultQFunctionBuilder(ModelBuilder[QFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: BCQConfig, - **kwargs) -> QFunction: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: BCQConfig, + **kwargs, + ) -> QFunction: return TD3QFunction(scope_name) class DefaultVAEBuilder(ModelBuilder[VariationalAutoEncoder]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: BCQConfig, - **kwargs) -> VariationalAutoEncoder: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: BCQConfig, + **kwargs, + ) -> VariationalAutoEncoder: max_action_value = float(env_info.action_high[0]) - return BCQVariationalAutoEncoder(scope_name, - env_info.state_dim, - env_info.action_dim, - env_info.action_dim*2, - max_action_value) + return BCQVariationalAutoEncoder( + scope_name, env_info.state_dim, env_info.action_dim, env_info.action_dim * 2, max_action_value + ) class DefaultPerturbatorBuilder(ModelBuilder[Perturbator]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: BCQConfig, - **kwargs) -> Perturbator: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: BCQConfig, + **kwargs, + ) -> Perturbator: max_action_value = float(env_info.action_high[0]) - return BCQPerturbator(scope_name, - env_info.state_dim, - env_info.action_dim, - max_action_value) + return BCQPerturbator(scope_name, env_info.state_dim, env_info.action_dim, max_action_value) class DefaultSolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: BCQConfig, - **kwargs): + def build_solver(self, env_info: EnvironmentInfo, algorithm_config: BCQConfig, **kwargs): # type: ignore[override] return NS.Adam(alpha=algorithm_config.learning_rate) @@ -173,15 +179,17 @@ class BCQ(Algorithm): _q_function_trainer_state: Dict[str, Any] _perturbator_trainer_state: Dict[str, Any] - def __init__(self, - env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: BCQConfig = BCQConfig(), - q_function_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), - q_solver_builder: SolverBuilder = DefaultSolverBuilder(), - vae_builder: ModelBuilder[VariationalAutoEncoder] = DefaultVAEBuilder(), - vae_solver_builder: SolverBuilder = DefaultSolverBuilder(), - perturbator_builder: ModelBuilder[Perturbator] = DefaultPerturbatorBuilder(), - perturbator_solver_builder: SolverBuilder = DefaultSolverBuilder()): + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: BCQConfig = BCQConfig(), + q_function_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), + q_solver_builder: SolverBuilder = DefaultSolverBuilder(), + vae_builder: ModelBuilder[VariationalAutoEncoder] = DefaultVAEBuilder(), + vae_solver_builder: SolverBuilder = DefaultSolverBuilder(), + perturbator_builder: ModelBuilder[Perturbator] = DefaultPerturbatorBuilder(), + perturbator_solver_builder: SolverBuilder = DefaultSolverBuilder(), + ): super(BCQ, self).__init__(env_or_env_info, config=config) with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)): @@ -190,10 +198,8 @@ def __init__(self, self._target_q_ensembles = [] for i in range(self._config.num_q_ensembles): - q = q_function_builder(scope_name=f"q{i}", - env_info=self._env_info, - algorithm_config=self._config) - target_q = q.deepcopy(f'target_q{i}') + q = q_function_builder(scope_name=f"q{i}", env_info=self._env_info, algorithm_config=self._config) + target_q = q.deepcopy(f"target_q{i}") assert isinstance(target_q, QFunction) self._q_ensembles.append(q) self._q_solvers[q.scope_name] = q_solver_builder(env_info=self._env_info, algorithm_config=self._config) @@ -205,20 +211,20 @@ def __init__(self, self._xi = perturbator_builder(scope_name="xi", env_info=self._env_info, algorithm_config=self._config) self._xi_solver = perturbator_solver_builder(env_info=self._env_info, algorithm_config=self._config) self._target_xi = perturbator_builder( - scope_name="target_xi", env_info=self._env_info, algorithm_config=self._config) + scope_name="target_xi", env_info=self._env_info, algorithm_config=self._config + ) @eval_api def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}): if has_batch_dimension(state, self._env_info): - raise RuntimeError(f'{self.__name__} does not support batched state!') + raise RuntimeError(f"{self.__name__} does not support batched state!") with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)): state = add_batch_dimension(state) - if not hasattr(self, '_eval_state_var'): + if not hasattr(self, "_eval_state_var"): repeat_num = 100 self._eval_state_var = create_variable(1, self._env_info.state_shape) if isinstance(self._eval_state_var, tuple): - state_var = tuple(RF.repeat(x=s_var, repeats=repeat_num, axis=0) - for s_var in self._eval_state_var) + state_var = tuple(RF.repeat(x=s_var, repeats=repeat_num, axis=0) for s_var in self._eval_state_var) else: state_var = RF.repeat(x=self._eval_state_var, repeats=repeat_num, axis=0) assert state_var.shape == (repeat_num, self._eval_state_var.shape[1]) @@ -247,13 +253,14 @@ def _setup_encoder_training(self, env_or_buffer): models=self._vae, solvers={self._vae.scope_name: self._vae_solver}, env_info=self._env_info, - config=trainer_config) + config=trainer_config, + ) return encoder_trainer def _setup_q_function_training(self, env_or_buffer): - trainer_config = MT.q_value.BCQQTrainerConfig(reduction_method='mean', - num_action_samples=self._config.num_action_samples, - lmb=self._config.lmb) + trainer_config = MT.q_value.BCQQTrainerConfig( + reduction_method="mean", num_action_samples=self._config.num_action_samples, lmb=self._config.lmb + ) # This is a wrapper class which outputs the target action for next state in q function training class PerturbedPolicy(DeterministicPolicy): @@ -266,6 +273,7 @@ def pi(self, s): a = self._vae.decode(z=None, state=s) noise = self._perturbator.generate_noise(s, a, phi=self._phi) return a + noise + target_policy = PerturbedPolicy(self._vae, self._target_xi, self._config.phi) q_function_trainer = MT.q_value.BCQQTrainer( train_functions=self._q_ensembles, @@ -273,15 +281,14 @@ def pi(self, s): target_functions=self._target_q_ensembles, target_policy=target_policy, env_info=self._env_info, - config=trainer_config) + config=trainer_config, + ) for q, target_q in zip(self._q_ensembles, self._target_q_ensembles): sync_model(q, target_q, 1.0) return q_function_trainer def _setup_perturbator_training(self, env_or_buffer): - trainer_config = MT.perturbator_trainers.BCQPerturbatorTrainerConfig( - phi=self._config.phi - ) + trainer_config = MT.perturbator_trainers.BCQPerturbatorTrainerConfig(phi=self._config.phi) perturbator_trainer = MT.perturbator.BCQPerturbatorTrainer( models=self._xi, @@ -289,12 +296,13 @@ def _setup_perturbator_training(self, env_or_buffer): q_function=self._q_ensembles[0], vae=self._vae, env_info=self._env_info, - config=trainer_config) + config=trainer_config, + ) sync_model(self._xi, self._target_xi, 1.0) return perturbator_trainer def _run_online_training_iteration(self, env): - raise NotImplementedError('BCQ does not support online training') + raise NotImplementedError("BCQ does not support online training") def _run_offline_training_iteration(self, buffer): self._bcq_training(buffer) @@ -302,14 +310,16 @@ def _run_offline_training_iteration(self, buffer): def _bcq_training(self, replay_buffer): experiences, info = replay_buffer.sample(self._config.batch_size) (s, a, r, non_terminal, s_next, *_) = marshal_experiences(experiences) - batch = TrainingBatch(batch_size=self._config.batch_size, - s_current=s, - a_current=a, - gamma=self._config.gamma, - reward=r, - non_terminal=non_terminal, - s_next=s_next, - weight=info['weights']) + batch = TrainingBatch( + batch_size=self._config.batch_size, + s_current=s, + a_current=a, + gamma=self._config.gamma, + reward=r, + non_terminal=non_terminal, + s_next=s_next, + weight=info["weights"], + ) # Train vae self._encoder_trainer_state = self._encoder_trainer.train(batch) @@ -317,7 +327,7 @@ def _bcq_training(self, replay_buffer): self._q_function_trainer_state = self._q_function_trainer.train(batch) for q, target_q in zip(self._q_ensembles, self._target_q_ensembles): sync_model(q, target_q, tau=self._config.tau) - td_errors = self._q_function_trainer_state['td_errors'] + td_errors = self._q_function_trainer_state["td_errors"] replay_buffer.update_priorities(td_errors) self._perturbator_trainer.train(batch) @@ -326,8 +336,7 @@ def _bcq_training(self, replay_buffer): self._perturbator_trainer_state = self._perturbator_trainer.train(batch) def _models(self): - models = [*self._q_ensembles, *self._target_q_ensembles, - self._vae, self._xi, self._target_xi] + models = [*self._q_ensembles, *self._target_q_ensembles, self._vae, self._xi, self._target_xi] return {model.scope_name: model for model in models} def _solvers(self): @@ -339,23 +348,27 @@ def _solvers(self): @classmethod def is_supported_env(cls, env_or_env_info): - env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \ - else env_or_env_info + env_info = ( + EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info + ) return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env() @property def latest_iteration_state(self): latest_iteration_state = super(BCQ, self).latest_iteration_state - if hasattr(self, '_encoder_trainer_state'): - latest_iteration_state['scalar'].update( - {'encoder_loss': float(self._encoder_trainer_state['encoder_loss'])}) - if hasattr(self, '_perturbator_trainer_state'): - latest_iteration_state['scalar'].update( - {'perturbator_loss': float(self._perturbator_trainer_state['perturbator_loss'])}) - if hasattr(self, '_q_function_trainer_state'): - latest_iteration_state['scalar'].update({'q_loss': float(self._q_function_trainer_state['q_loss'])}) - latest_iteration_state['histogram'].update( - {'td_errors': self._q_function_trainer_state['td_errors'].flatten()}) + if hasattr(self, "_encoder_trainer_state"): + latest_iteration_state["scalar"].update( + {"encoder_loss": float(self._encoder_trainer_state["encoder_loss"])} + ) + if hasattr(self, "_perturbator_trainer_state"): + latest_iteration_state["scalar"].update( + {"perturbator_loss": float(self._perturbator_trainer_state["perturbator_loss"])} + ) + if hasattr(self, "_q_function_trainer_state"): + latest_iteration_state["scalar"].update({"q_loss": float(self._q_function_trainer_state["q_loss"])}) + latest_iteration_state["histogram"].update( + {"td_errors": self._q_function_trainer_state["td_errors"].flatten()} + ) return latest_iteration_state @property @@ -369,5 +382,6 @@ def trainers(self): if __name__ == "__main__": import nnabla_rl.environments as E + env = E.DummyContinuous() bcq = BCQ(env) diff --git a/nnabla_rl/algorithms/bear.py b/nnabla_rl/algorithms/bear.py index d58d9b33..8a466bfe 100644 --- a/nnabla_rl/algorithms/bear.py +++ b/nnabla_rl/algorithms/bear.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,8 +29,15 @@ from nnabla_rl.builders import ModelBuilder, SolverBuilder from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.model_trainers.model_trainer import ModelTrainer, TrainingBatch -from nnabla_rl.models import (BEARPolicy, DeterministicPolicy, QFunction, StochasticPolicy, TD3QFunction, - UnsquashedVariationalAutoEncoder, VariationalAutoEncoder) +from nnabla_rl.models import ( + BEARPolicy, + DeterministicPolicy, + QFunction, + StochasticPolicy, + TD3QFunction, + UnsquashedVariationalAutoEncoder, + VariationalAutoEncoder, +) from nnabla_rl.utils import context from nnabla_rl.utils.data import add_batch_dimension, marshal_experiences, set_data_to_variable from nnabla_rl.utils.misc import create_variable, sync_model @@ -64,6 +71,7 @@ class BEARConfig(AlgorithmConfig): use_mean_for_eval (bool): Use mean value instead of best action among the samples for evaluation.\ Defaults to False. """ + gamma: float = 0.99 learning_rate: float = 1e-3 batch_size: int = 100 @@ -73,7 +81,7 @@ class BEARConfig(AlgorithmConfig): num_q_ensembles: int = 2 num_mmd_actions: int = 5 num_action_samples: int = 10 - mmd_type: str = 'gaussian' + mmd_type: str = "gaussian" mmd_sigma: float = 20.0 initial_lagrange_multiplier: Optional[float] = None fix_lagrange_multiplier: bool = False @@ -86,56 +94,60 @@ def __post_init__(self): Check set values are in valid range. """ if not ((0.0 <= self.tau) & (self.tau <= 1.0)): - raise ValueError('tau must lie between [0.0, 1.0]') + raise ValueError("tau must lie between [0.0, 1.0]") if not ((0.0 <= self.gamma) & (self.gamma <= 1.0)): - raise ValueError('gamma must lie between [0.0, 1.0]') + raise ValueError("gamma must lie between [0.0, 1.0]") if not (0 <= self.num_q_ensembles): - raise ValueError('num q ensembles must not be negative') + raise ValueError("num q ensembles must not be negative") if not (0 <= self.num_mmd_actions): - raise ValueError('num mmd actions must not be negative') + raise ValueError("num mmd actions must not be negative") if not (0 <= self.num_action_samples): - raise ValueError('num action samples must not be negative') + raise ValueError("num action samples must not be negative") if not (0 <= self.warmup_iterations): - raise ValueError('warmup iterations must not be negative') + raise ValueError("warmup iterations must not be negative") if not (0 <= self.batch_size): - raise ValueError('batch size must not be negative') + raise ValueError("batch size must not be negative") class DefaultQFunctionBuilder(ModelBuilder[QFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: BEARConfig, - **kwargs) -> QFunction: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: BEARConfig, + **kwargs, + ) -> QFunction: return TD3QFunction(scope_name) class DefaultPolicyBuilder(ModelBuilder[StochasticPolicy]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: BEARConfig, - **kwargs) -> StochasticPolicy: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: BEARConfig, + **kwargs, + ) -> StochasticPolicy: return BEARPolicy(scope_name, env_info.action_dim) class DefaultVAEBuilder(ModelBuilder[VariationalAutoEncoder]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: BEARConfig, - **kwargs) -> VariationalAutoEncoder: - return UnsquashedVariationalAutoEncoder(scope_name, - env_info.state_dim, - env_info.action_dim, - env_info.action_dim*2) + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: BEARConfig, + **kwargs, + ) -> VariationalAutoEncoder: + return UnsquashedVariationalAutoEncoder( + scope_name, env_info.state_dim, env_info.action_dim, env_info.action_dim * 2 + ) class DefaultSolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: BEARConfig, - **kwargs) -> nn.solver.Solver: + def build_solver( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: BEARConfig, **kwargs + ) -> nn.solver.Solver: return NS.Adam(alpha=algorithm_config.learning_rate) @@ -195,15 +207,18 @@ class BEAR(Algorithm): _policy_trainer_state: Dict[str, Any] _q_function_trainer_state: Dict[str, Any] - def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: BEARConfig = BEARConfig(), - q_function_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), - q_solver_builder: SolverBuilder = DefaultSolverBuilder(), - pi_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), - pi_solver_builder: SolverBuilder = DefaultSolverBuilder(), - vae_builder: ModelBuilder[VariationalAutoEncoder] = DefaultVAEBuilder(), - vae_solver_builder: SolverBuilder = DefaultSolverBuilder(), - lagrange_solver_builder: SolverBuilder = DefaultSolverBuilder()): + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: BEARConfig = BEARConfig(), + q_function_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), + q_solver_builder: SolverBuilder = DefaultSolverBuilder(), + pi_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), + pi_solver_builder: SolverBuilder = DefaultSolverBuilder(), + vae_builder: ModelBuilder[VariationalAutoEncoder] = DefaultVAEBuilder(), + vae_solver_builder: SolverBuilder = DefaultSolverBuilder(), + lagrange_solver_builder: SolverBuilder = DefaultSolverBuilder(), + ): super(BEAR, self).__init__(env_or_env_info, config=config) with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)): @@ -211,10 +226,12 @@ def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], self._q_solvers = {} self._target_q_ensembles = [] for i in range(self._config.num_q_ensembles): - q = q_function_builder(scope_name="q{}".format( - i), env_info=self._env_info, algorithm_config=self._config) + q = q_function_builder( + scope_name="q{}".format(i), env_info=self._env_info, algorithm_config=self._config + ) target_q = q_function_builder( - scope_name="target_q{}".format(i), env_info=self._env_info, algorithm_config=self._config) + scope_name="target_q{}".format(i), env_info=self._env_info, algorithm_config=self._config + ) self._q_ensembles.append(q) self._q_solvers[q.scope_name] = q_solver_builder(env_info=self._env_info, algorithm_config=self._config) self._target_q_ensembles.append(target_q) @@ -227,17 +244,17 @@ def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], self._vae_solver = vae_solver_builder(env_info=self._env_info, algorithm_config=self._config) self._lagrange = MT.policy_trainers.bear_policy_trainer.AdjustableLagrangeMultiplier( - scope_name="alpha", - initial_value=self._config.initial_lagrange_multiplier) + scope_name="alpha", initial_value=self._config.initial_lagrange_multiplier + ) self._lagrange_solver = lagrange_solver_builder(env_info=self._env_info, algorithm_config=self._config) @eval_api def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}): if has_batch_dimension(state, self._env_info): - raise RuntimeError(f'{self.__name__} does not support batched state!') + raise RuntimeError(f"{self.__name__} does not support batched state!") with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)): state = add_batch_dimension(state) - if not hasattr(self, '_eval_state_var'): + if not hasattr(self, "_eval_state_var"): self._eval_state_var = create_variable(1, self._env_info.state_shape) if self._config.use_mean_for_eval: eval_distribution = self._pi.pi(self._eval_state_var) @@ -245,8 +262,9 @@ def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}): else: repeat_num = 100 if isinstance(self._eval_state_var, tuple): - state_var = tuple(RF.repeat(x=s_var, repeats=repeat_num, axis=0) - for s_var in self._eval_state_var) + state_var = tuple( + RF.repeat(x=s_var, repeats=repeat_num, axis=0) for s_var in self._eval_state_var + ) else: state_var = RF.repeat(x=self._eval_state_var, repeats=repeat_num, axis=0) assert state_var.shape == (repeat_num, self._eval_state_var.shape[1]) @@ -284,17 +302,25 @@ def encode_and_decode(self, s, **kwargs): latent_distribution, reconstructed = self._original_vae.encode_and_decode(s, **kwargs) return latent_distribution, NF.tanh(reconstructed) - def encode(self, *args): raise NotImplementedError - def decode(self, *args): raise NotImplementedError - def decode_multiple(self, decode_num, *args): raise NotImplementedError - def latent_distribution(self, *args): raise NotImplementedError + def encode(self, *args): + raise NotImplementedError + + def decode(self, *args): + raise NotImplementedError + + def decode_multiple(self, decode_num, *args): + raise NotImplementedError + + def latent_distribution(self, *args): + raise NotImplementedError squashed_action_vae = SquashedActionVAE(self._vae) encoder_trainer = MT.encoder_trainers.KLDVariationalAutoEncoderTrainer( models=squashed_action_vae, solvers={self._vae.scope_name: self._vae_solver}, env_info=self._env_info, - config=trainer_config) + config=trainer_config, + ) return encoder_trainer def _setup_q_function_training(self, env_or_buffer): @@ -307,18 +333,20 @@ def __init__(self, target_pi): def pi(self, s): policy_distribution = self._target_pi.pi(s) return NF.tanh(policy_distribution.sample()) + target_policy = PerturbedPolicy(self._target_pi) - trainer_config = MT.q_value.BCQQTrainerConfig(reduction_method='mean', - num_action_samples=self._config.num_action_samples, - lmb=self._config.lmb) + trainer_config = MT.q_value.BCQQTrainerConfig( + reduction_method="mean", num_action_samples=self._config.num_action_samples, lmb=self._config.lmb + ) q_function_trainer = MT.q_value.BCQQTrainer( train_functions=self._q_ensembles, solvers=self._q_solvers, target_functions=self._target_q_ensembles, target_policy=target_policy, env_info=self._env_info, - config=trainer_config) + config=trainer_config, + ) for q, target_q in zip(self._q_ensembles, self._target_q_ensembles): sync_model(q, target_q, 1.0) return q_function_trainer @@ -329,7 +357,8 @@ def _setup_policy_training(self, env_or_buffer): mmd_type=self._config.mmd_type, epsilon=self._config.epsilon, fix_lagrange_multiplier=self._config.fix_lagrange_multiplier, - warmup_iterations=self._config.warmup_iterations-self._iteration_num) + warmup_iterations=self._config.warmup_iterations - self._iteration_num, + ) class SquashedActionQ(QFunction): def __init__(self, original_q): @@ -349,7 +378,8 @@ def q(self, s, a): lagrange_multiplier=self._lagrange, lagrange_solver=self._lagrange_solver, env_info=self._env_info, - config=trainer_config) + config=trainer_config, + ) sync_model(self._pi, self._target_pi, 1.0) return policy_trainer @@ -363,19 +393,21 @@ def _run_offline_training_iteration(self, buffer): def _bear_training(self, replay_buffer): experiences, info = replay_buffer.sample(self._config.batch_size) (s, a, r, non_terminal, s_next, *_) = marshal_experiences(experiences) - batch = TrainingBatch(batch_size=self._config.batch_size, - s_current=s, - a_current=a, - gamma=self._config.gamma, - reward=r, - non_terminal=non_terminal, - s_next=s_next, - weight=info['weights']) + batch = TrainingBatch( + batch_size=self._config.batch_size, + s_current=s, + a_current=a, + gamma=self._config.gamma, + reward=r, + non_terminal=non_terminal, + s_next=s_next, + weight=info["weights"], + ) self._q_function_trainer_state = self._q_function_trainer.train(batch) for q, target_q in zip(self._q_ensembles, self._target_q_ensembles): sync_model(q, target_q, tau=self._config.tau) - td_errors = self._q_function_trainer_state['td_errors'] + td_errors = self._q_function_trainer_state["td_errors"] replay_buffer.update_priorities(td_errors) self._encoder_trainer_state = self._encoder_trainer.train(batch) @@ -383,9 +415,7 @@ def _bear_training(self, replay_buffer): sync_model(self._pi, self._target_pi, tau=self._config.tau) def _models(self): - models = [*self._q_ensembles, *self._target_q_ensembles, - self._pi, self._target_pi, self._vae, - self._lagrange] + models = [*self._q_ensembles, *self._target_q_ensembles, self._pi, self._target_pi, self._vae, self._lagrange] return {model.scope_name: model for model in models} def _solvers(self): @@ -399,21 +429,23 @@ def _solvers(self): @classmethod def is_supported_env(cls, env_or_env_info): - env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \ - else env_or_env_info + env_info = ( + EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info + ) return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env() @property def latest_iteration_state(self): latest_iteration_state = super(BEAR, self).latest_iteration_state - if hasattr(self, '_encoder_trainer_state'): - latest_iteration_state['scalar'].update( - {'encoder_loss': float(self._encoder_trainer_state['encoder_loss'])}) - if hasattr(self, '_policy_trainer_state'): - latest_iteration_state['scalar'].update({'pi_loss': float(self._policy_trainer_state['pi_loss'])}) - if hasattr(self, '_q_function_trainer_state'): - latest_iteration_state['scalar'].update({'q_loss': float(self._q_function_trainer_state['q_loss'])}) - latest_iteration_state['histogram'].update({'td_errors': self._q_function_trainer_state['td_errors']}) + if hasattr(self, "_encoder_trainer_state"): + latest_iteration_state["scalar"].update( + {"encoder_loss": float(self._encoder_trainer_state["encoder_loss"])} + ) + if hasattr(self, "_policy_trainer_state"): + latest_iteration_state["scalar"].update({"pi_loss": float(self._policy_trainer_state["pi_loss"])}) + if hasattr(self, "_q_function_trainer_state"): + latest_iteration_state["scalar"].update({"q_loss": float(self._q_function_trainer_state["q_loss"])}) + latest_iteration_state["histogram"].update({"td_errors": self._q_function_trainer_state["td_errors"]}) return latest_iteration_state @property diff --git a/nnabla_rl/algorithms/categorical_ddqn.py b/nnabla_rl/algorithms/categorical_ddqn.py index 25e4f89e..11c5d441 100644 --- a/nnabla_rl/algorithms/categorical_ddqn.py +++ b/nnabla_rl/algorithms/categorical_ddqn.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,9 +18,14 @@ import gym import nnabla_rl.model_trainers as MT -from nnabla_rl.algorithms.categorical_dqn import (CategoricalDQN, CategoricalDQNConfig, DefaultExplorerBuilder, - DefaultReplayBufferBuilder, DefaultSolverBuilder, - DefaultValueDistFunctionBuilder) +from nnabla_rl.algorithms.categorical_dqn import ( + CategoricalDQN, + CategoricalDQNConfig, + DefaultExplorerBuilder, + DefaultReplayBufferBuilder, + DefaultSolverBuilder, + DefaultValueDistFunctionBuilder, +) from nnabla_rl.builders import ExplorerBuilder, ModelBuilder, ReplayBufferBuilder, SolverBuilder from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.models import ValueDistributionFunction @@ -58,19 +63,23 @@ class CategoricalDDQN(CategoricalDQN): builder of environment explorer """ - def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: CategoricalDQNConfig = CategoricalDDQNConfig(), - value_distribution_builder: ModelBuilder[ValueDistributionFunction] - = DefaultValueDistFunctionBuilder(), - value_distribution_solver_builder: SolverBuilder = DefaultSolverBuilder(), - replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), - explorer_builder: ExplorerBuilder = DefaultExplorerBuilder()): - super(CategoricalDDQN, self).__init__(env_or_env_info, - config=config, - value_distribution_builder=value_distribution_builder, - value_distribution_solver_builder=value_distribution_solver_builder, - replay_buffer_builder=replay_buffer_builder, - explorer_builder=explorer_builder) + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: CategoricalDQNConfig = CategoricalDDQNConfig(), + value_distribution_builder: ModelBuilder[ValueDistributionFunction] = DefaultValueDistFunctionBuilder(), + value_distribution_solver_builder: SolverBuilder = DefaultSolverBuilder(), + replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), + explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), + ): + super(CategoricalDDQN, self).__init__( + env_or_env_info, + config=config, + value_distribution_builder=value_distribution_builder, + value_distribution_solver_builder=value_distribution_solver_builder, + replay_buffer_builder=replay_buffer_builder, + explorer_builder=explorer_builder, + ) def _setup_value_distribution_function_training(self, env_or_buffer): trainer_config = MT.q_value_trainers.CategoricalDDQNQTrainerConfig( @@ -81,14 +90,16 @@ def _setup_value_distribution_function_training(self, env_or_buffer): reduction_method=self._config.loss_reduction_method, unroll_steps=self._config.unroll_steps, burn_in_steps=self._config.burn_in_steps, - reset_on_terminal=self._config.reset_rnn_on_terminal) + reset_on_terminal=self._config.reset_rnn_on_terminal, + ) model_trainer = MT.q_value_trainers.CategoricalDDQNQTrainer( train_function=self._atom_p, solvers={self._atom_p.scope_name: self._atom_p_solver}, target_function=self._target_atom_p, env_info=self._env_info, - config=trainer_config) + config=trainer_config, + ) # NOTE: Copy initial parameters after setting up the training # Because the parameter is created after training graph construction diff --git a/nnabla_rl/algorithms/categorical_dqn.py b/nnabla_rl/algorithms/categorical_dqn.py index 69675f94..c55746c4 100644 --- a/nnabla_rl/algorithms/categorical_dqn.py +++ b/nnabla_rl/algorithms/categorical_dqn.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -106,71 +106,72 @@ def __post_init__(self): Check set values are in valid range. """ - self._assert_between(self.gamma, 0.0, 1.0, 'gamma') - self._assert_positive(self.learning_rate, 'learning_rate') - self._assert_positive(self.batch_size, 'batch_size') - self._assert_positive(self.num_steps, 'num_steps') - self._assert_positive(self.learner_update_frequency, 'learner_update_frequency') - self._assert_positive(self.target_update_frequency, 'target_update_frequency') - self._assert_positive(self.start_timesteps, 'start_timesteps') - self._assert_positive(self.replay_buffer_size, 'replay_buffer_size') - self._assert_smaller_than(self.start_timesteps, self.replay_buffer_size, 'start_timesteps') - self._assert_positive(self.max_explore_steps, 'max_explore_steps') - self._assert_between(self.initial_epsilon, 0.0, 1.0, 'initial_epsilon') - self._assert_between(self.final_epsilon, 0.0, 1.0, 'final_epsilon') - self._assert_between(self.test_epsilon, 0.0, 1.0, 'test_epsilon') - self._assert_positive(self.num_atoms, 'num_atoms') - self._assert_positive(self.unroll_steps, 'unroll_steps') - self._assert_positive_or_zero(self.burn_in_steps, 'burn_in_steps') + self._assert_between(self.gamma, 0.0, 1.0, "gamma") + self._assert_positive(self.learning_rate, "learning_rate") + self._assert_positive(self.batch_size, "batch_size") + self._assert_positive(self.num_steps, "num_steps") + self._assert_positive(self.learner_update_frequency, "learner_update_frequency") + self._assert_positive(self.target_update_frequency, "target_update_frequency") + self._assert_positive(self.start_timesteps, "start_timesteps") + self._assert_positive(self.replay_buffer_size, "replay_buffer_size") + self._assert_smaller_than(self.start_timesteps, self.replay_buffer_size, "start_timesteps") + self._assert_positive(self.max_explore_steps, "max_explore_steps") + self._assert_between(self.initial_epsilon, 0.0, 1.0, "initial_epsilon") + self._assert_between(self.final_epsilon, 0.0, 1.0, "final_epsilon") + self._assert_between(self.test_epsilon, 0.0, 1.0, "test_epsilon") + self._assert_positive(self.num_atoms, "num_atoms") + self._assert_positive(self.unroll_steps, "unroll_steps") + self._assert_positive_or_zero(self.burn_in_steps, "burn_in_steps") class DefaultValueDistFunctionBuilder(ModelBuilder[ValueDistributionFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: CategoricalDQNConfig, - **kwargs) -> ValueDistributionFunction: - return C51ValueDistributionFunction(scope_name, - env_info.action_dim, - algorithm_config.num_atoms, - algorithm_config.v_min, - algorithm_config.v_max) + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: CategoricalDQNConfig, + **kwargs, + ) -> ValueDistributionFunction: + return C51ValueDistributionFunction( + scope_name, env_info.action_dim, algorithm_config.num_atoms, algorithm_config.v_min, algorithm_config.v_max + ) class DefaultReplayBufferBuilder(ReplayBufferBuilder): - def build_replay_buffer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: CategoricalDQNConfig, - **kwargs) -> ReplayBuffer: + def build_replay_buffer( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: CategoricalDQNConfig, **kwargs + ) -> ReplayBuffer: return ReplayBuffer(capacity=algorithm_config.replay_buffer_size) class DefaultSolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: CategoricalDQNConfig, - **kwargs) -> nn.solver.Solver: + def build_solver( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: CategoricalDQNConfig, **kwargs + ) -> nn.solver.Solver: return NS.Adam(alpha=algorithm_config.learning_rate, eps=1e-2 / algorithm_config.batch_size) class DefaultExplorerBuilder(ExplorerBuilder): - def build_explorer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: CategoricalDQNConfig, - algorithm: "CategoricalDQN", - **kwargs) -> EnvironmentExplorer: + def build_explorer( # type: ignore[override] + self, + env_info: EnvironmentInfo, + algorithm_config: CategoricalDQNConfig, + algorithm: "CategoricalDQN", + **kwargs, + ) -> EnvironmentExplorer: explorer_config = EE.LinearDecayEpsilonGreedyExplorerConfig( warmup_random_steps=algorithm_config.start_timesteps, initial_step_num=algorithm.iteration_num, initial_epsilon=algorithm_config.initial_epsilon, final_epsilon=algorithm_config.final_epsilon, - max_explore_steps=algorithm_config.max_explore_steps + max_explore_steps=algorithm_config.max_explore_steps, ) explorer = EE.LinearDecayEpsilonGreedyExplorer( greedy_action_selector=algorithm._exploration_action_selector, random_action_selector=algorithm._random_action_selector, env_info=env_info, - config=explorer_config) + config=explorer_config, + ) return explorer @@ -214,21 +215,23 @@ class CategoricalDQN(Algorithm): _model_trainer_state: Dict[str, Any] - def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: CategoricalDQNConfig = CategoricalDQNConfig(), - value_distribution_builder: ModelBuilder[ValueDistributionFunction] - = DefaultValueDistFunctionBuilder(), - value_distribution_solver_builder: SolverBuilder = DefaultSolverBuilder(), - replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), - explorer_builder: ExplorerBuilder = DefaultExplorerBuilder()): + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: CategoricalDQNConfig = CategoricalDQNConfig(), + value_distribution_builder: ModelBuilder[ValueDistributionFunction] = DefaultValueDistFunctionBuilder(), + value_distribution_solver_builder: SolverBuilder = DefaultSolverBuilder(), + replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), + explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), + ): super(CategoricalDQN, self).__init__(env_or_env_info, config=config) self._explorer_builder = explorer_builder with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)): - self._atom_p = value_distribution_builder('atom_p_train', self._env_info, self._config) + self._atom_p = value_distribution_builder("atom_p_train", self._env_info, self._config) self._atom_p_solver = value_distribution_solver_builder(self._env_info, self._config) - self._target_atom_p = self._atom_p.deepcopy('target_atom_p_train') + self._target_atom_p = self._atom_p.deepcopy("target_atom_p_train") self._replay_buffer = replay_buffer_builder(self._env_info, self._config) @@ -238,11 +241,13 @@ def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], @eval_api def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}): with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)): - (action, _), _ = epsilon_greedy_action_selection(state, - self._evaluation_action_selector, - self._random_action_selector, - epsilon=self._config.test_epsilon, - begin_of_episode=begin_of_episode) + (action, _), _ = epsilon_greedy_action_selection( + state, + self._evaluation_action_selector, + self._random_action_selector, + epsilon=self._config.test_epsilon, + begin_of_episode=begin_of_episode, + ) return action def _before_training_start(self, env_or_buffer): @@ -263,14 +268,16 @@ def _setup_value_distribution_function_training(self, env_or_buffer): reduction_method=self._config.loss_reduction_method, unroll_steps=self._config.unroll_steps, burn_in_steps=self._config.burn_in_steps, - reset_on_terminal=self._config.reset_rnn_on_terminal) + reset_on_terminal=self._config.reset_rnn_on_terminal, + ) model_trainer = MT.q_value_trainers.CategoricalDQNQTrainer( train_functions=self._atom_p, solvers={self._atom_p.scope_name: self._atom_p_solver}, target_function=self._target_atom_p, env_info=self._env_info, - config=trainer_config) + config=trainer_config, + ) # NOTE: Copy initial parameters after setting up the training # Because the parameter is created after training graph construction @@ -291,28 +298,30 @@ def _categorical_dqn_training(self, replay_buffer): num_steps = self._config.num_steps + self._config.burn_in_steps + self._config.unroll_steps - 1 experiences_tuple, info = replay_buffer.sample(self._config.batch_size, num_steps=num_steps) if num_steps == 1: - experiences_tuple = (experiences_tuple, ) + experiences_tuple = (experiences_tuple,) assert len(experiences_tuple) == num_steps batch = None for experiences in reversed(experiences_tuple): (s, a, r, non_terminal, s_next, rnn_states_dict, *_) = marshal_experiences(experiences) - rnn_states = rnn_states_dict['rnn_states'] if 'rnn_states' in rnn_states_dict else {} - batch = TrainingBatch(batch_size=self._config.batch_size, - s_current=s, - a_current=a, - gamma=self._config.gamma, - reward=r, - non_terminal=non_terminal, - s_next=s_next, - weight=info['weights'], - next_step_batch=batch, - rnn_states=rnn_states) + rnn_states = rnn_states_dict["rnn_states"] if "rnn_states" in rnn_states_dict else {} + batch = TrainingBatch( + batch_size=self._config.batch_size, + s_current=s, + a_current=a, + gamma=self._config.gamma, + reward=r, + non_terminal=non_terminal, + s_next=s_next, + weight=info["weights"], + next_step_batch=batch, + rnn_states=rnn_states, + ) self._model_trainer_state = self._model_trainer.train(batch) if self.iteration_num % self._config.target_update_frequency == 0: sync_model(self._atom_p, self._target_atom_p) - td_errors = self._model_trainer_state['td_errors'] + td_errors = self._model_trainer_state["td_errors"] replay_buffer.update_priorities(td_errors) def _evaluation_action_selector(self, s, *, begin_of_episode=False): @@ -323,7 +332,7 @@ def _exploration_action_selector(self, s, *, begin_of_episode=False): def _random_action_selector(self, s, *, begin_of_episode=False): action = self._env_info.action_space.sample() - return np.asarray(action).reshape((1, )), {} + return np.asarray(action).reshape((1,)), {} def _models(self): models = {} @@ -337,8 +346,9 @@ def _solvers(self): @classmethod def is_supported_env(cls, env_or_env_info): - env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \ - else env_or_env_info + env_info = ( + EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info + ) return not env_info.is_continuous_action_env() and not env_info.is_tuple_action_env() @classmethod @@ -348,10 +358,11 @@ def is_rnn_supported(self): @property def latest_iteration_state(self): latest_iteration_state = super(CategoricalDQN, self).latest_iteration_state - if hasattr(self, '_model_trainer_state'): - latest_iteration_state['scalar'].update( - {'cross_entropy_loss': float(self._model_trainer_state['cross_entropy_loss'])}) - latest_iteration_state['histogram'].update({'td_errors': self._model_trainer_state['td_errors'].flatten()}) + if hasattr(self, "_model_trainer_state"): + latest_iteration_state["scalar"].update( + {"cross_entropy_loss": float(self._model_trainer_state["cross_entropy_loss"])} + ) + latest_iteration_state["histogram"].update({"td_errors": self._model_trainer_state["td_errors"].flatten()}) return latest_iteration_state @property diff --git a/nnabla_rl/algorithms/common_utils.py b/nnabla_rl/algorithms/common_utils.py index 7e2289cc..1318b7fa 100644 --- a/nnabla_rl/algorithms/common_utils.py +++ b/nnabla_rl/algorithms/common_utils.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,9 +24,18 @@ from nnabla_rl.algorithm import eval_api from nnabla_rl.distributions.distribution import Distribution from nnabla_rl.environments.environment_info import EnvironmentInfo -from nnabla_rl.models import (DeterministicDecisionTransformer, DeterministicDynamics, DeterministicPolicy, - FactoredContinuousQFunction, Model, QFunction, RewardFunction, - StochasticDecisionTransformer, StochasticPolicy, VFunction) +from nnabla_rl.models import ( + DeterministicDecisionTransformer, + DeterministicDynamics, + DeterministicPolicy, + FactoredContinuousQFunction, + Model, + QFunction, + RewardFunction, + StochasticDecisionTransformer, + StochasticPolicy, + VFunction, +) from nnabla_rl.preprocessors import Preprocessor from nnabla_rl.typing import Experience, State from nnabla_rl.utils.data import add_batch_dimension, marshal_experiences, set_data_to_variable @@ -51,10 +60,9 @@ def has_batch_dimension(state: State, env_info: EnvironmentInfo): return not (fed_state_shape == env_state_shape) -def compute_v_target_and_advantage(v_function: VFunction, - experiences: Sequence[Experience], - gamma: float = 0.99, - lmb: float = 0.97) -> Tuple[np.ndarray, np.ndarray]: +def compute_v_target_and_advantage( + v_function: VFunction, experiences: Sequence[Experience], gamma: float = 0.99, lmb: float = 0.97 +) -> Tuple[np.ndarray, np.ndarray]: """Compute value target and advantage by using Generalized Advantage Estimation (GAE) @@ -75,7 +83,7 @@ def compute_v_target_and_advantage(v_function: VFunction, T = len(experiences) v_targets: np.ndarray = np.empty(shape=(T, 1), dtype=np.float32) advantages: np.ndarray = np.empty(shape=(T, 1), dtype=np.float32) - advantage: np.float32 = np.float32(0.) + advantage: np.float32 = np.float32(0.0) v_current = None v_next = None @@ -108,10 +116,8 @@ def compute_v_target_and_advantage(v_function: VFunction, return np.array(v_targets, dtype=np.float32), np.array(advantages, dtype=np.float32) -def compute_average_v_target_and_advantage(v_function: VFunction, - experiences: Sequence[Experience], - lmb=0.95): - ''' Compute value target and advantage by using Average Reward Criterion +def compute_average_v_target_and_advantage(v_function: VFunction, experiences: Sequence[Experience], lmb=0.95): + """Compute value target and advantage by using Average Reward Criterion See: https://arxiv.org/pdf/2106.07329.pdf Args: @@ -121,12 +127,12 @@ def compute_average_v_target_and_advantage(v_function: VFunction, lmb (float): lambda Returns: Tuple[np.ndarray, np.ndarray]: target of value and advantage - ''' + """ assert isinstance(v_function, VFunction), "Invalid v_function" T = len(experiences) v_targets: np.ndarray = np.empty(shape=(T, 1), dtype=np.float32) advantages: np.ndarray = np.empty(shape=(T, 1), dtype=np.float32) - advantage: np.float32 = np.float32(0.) + advantage: np.float32 = np.float32(0.0) v_current = None v_next = None @@ -175,9 +181,9 @@ def v(self, s: nn.Variable): preprocessed_state = self._preprocessor.process(s) return self._v_function.v(preprocessed_state) - def deepcopy(self, new_scope_name: str) -> '_StatePreprocessedVFunction': + def deepcopy(self, new_scope_name: str) -> "_StatePreprocessedVFunction": copied = super().deepcopy(new_scope_name=new_scope_name) - assert isinstance(copied, _StatePreprocessedVFunction) + assert isinstance(copied, _StatePreprocessedVFunction) copied._v_function._scope_name = new_scope_name return copied @@ -207,9 +213,9 @@ def pi(self, s: nn.Variable) -> nn.Variable: preprocessed_state = self._preprocessor.process(s) return self._policy.pi(preprocessed_state) - def deepcopy(self, new_scope_name: str) -> '_StatePreprocessedDeterministicPolicy': + def deepcopy(self, new_scope_name: str) -> "_StatePreprocessedDeterministicPolicy": copied = super().deepcopy(new_scope_name=new_scope_name) - assert isinstance(copied, _StatePreprocessedDeterministicPolicy) + assert isinstance(copied, _StatePreprocessedDeterministicPolicy) copied._policy._scope_name = new_scope_name return copied @@ -239,9 +245,9 @@ def pi(self, s: nn.Variable) -> Distribution: preprocessed_state = self._preprocessor.process(s) return self._policy.pi(preprocessed_state) - def deepcopy(self, new_scope_name: str) -> '_StatePreprocessedStochasticPolicy': + def deepcopy(self, new_scope_name: str) -> "_StatePreprocessedStochasticPolicy": copied = super().deepcopy(new_scope_name=new_scope_name) - assert isinstance(copied, _StatePreprocessedStochasticPolicy) + assert isinstance(copied, _StatePreprocessedStochasticPolicy) copied._policy._scope_name = new_scope_name return copied @@ -272,9 +278,9 @@ def r(self, s_current: nn.Variable, a_current: nn.Variable, s_next: nn.Variable) preprocessed_state_next = self._preprocessor.process(s_next) return self._reward_function.r(preprocessed_state_current, a_current, preprocessed_state_next) - def deepcopy(self, new_scope_name: str) -> '_StatePreprocessedRewardFunction': + def deepcopy(self, new_scope_name: str) -> "_StatePreprocessedRewardFunction": copied = super().deepcopy(new_scope_name=new_scope_name) - assert isinstance(copied, _StatePreprocessedRewardFunction) + assert isinstance(copied, _StatePreprocessedRewardFunction) copied._reward_function._scope_name = new_scope_name return copied @@ -316,9 +322,9 @@ def argmax_q(self, s: nn.Variable) -> nn.Variable: preprocessed_state = self._preprocessor.process(s) return self._q_function.argmax_q(preprocessed_state) - def deepcopy(self, new_scope_name: str) -> '_StatePreprocessedQFunction': + def deepcopy(self, new_scope_name: str) -> "_StatePreprocessedQFunction": copied = super().deepcopy(new_scope_name=new_scope_name) - assert isinstance(copied, _StatePreprocessedQFunction) + assert isinstance(copied, _StatePreprocessedQFunction) copied._q_function._scope_name = new_scope_name return copied @@ -335,7 +341,7 @@ def get_internal_states(self) -> Dict[str, nn.Variable]: return self._q_function.get_internal_states() -M = TypeVar('M', bound=Model) +M = TypeVar("M", bound=Model) class _ActionSelector(Generic[M], metaclass=ABCMeta): @@ -355,7 +361,7 @@ def __call__(self, s: Union[np.ndarray, Tuple[np.ndarray, ...]], *, begin_of_epi if not has_batch_dimension(s, self._env_info): s = add_batch_dimension(s) batch_size = len(s[0]) if self._env_info.is_tuple_state_env() else len(s) - if not hasattr(self, '_eval_state_var') or batch_size != self._batch_size: + if not hasattr(self, "_eval_state_var") or batch_size != self._batch_size: # Variable creation self._eval_state_var = create_variable(batch_size, self._env_info.state_shape) if self._model.is_recurrent(): @@ -389,8 +395,15 @@ def _compute_action(self, state_var: nn.Variable) -> nn.Variable: class _DecisionTransformerActionSelector(_ActionSelector[DecisionTransformerModel]): - def __init__(self, env_info: EnvironmentInfo, decision_transformer: DecisionTransformerModel, - max_timesteps: int, context_length: int, target_return: float, reward_scale: float): + def __init__( + self, + env_info: EnvironmentInfo, + decision_transformer: DecisionTransformerModel, + max_timesteps: int, + context_length: int, + target_return: float, + reward_scale: float, + ): super().__init__(env_info, decision_transformer) self._max_timesteps = max_timesteps self._context_length = context_length @@ -399,13 +412,13 @@ def __init__(self, env_info: EnvironmentInfo, decision_transformer: DecisionTran def __call__(self, s: Union[np.ndarray, Tuple[np.ndarray, ...]], *, begin_of_episode: bool = False, extra_info={}): if self._env_info.is_tuple_state_env(): - raise NotImplementedError('Tuple env not supported') + raise NotImplementedError("Tuple env not supported") if not has_batch_dimension(s, self._env_info): s = add_batch_dimension(s) batch_size = len(s) - if not hasattr(self, '_eval_states') or batch_size != self._batch_size: + if not hasattr(self, "_eval_states") or batch_size != self._batch_size: self._eval_states = np.empty(shape=(batch_size, self._context_length, *self._env_info.state_shape)) self._eval_actions = np.empty(shape=(batch_size, self._context_length, *self._env_info.action_shape)) self._eval_rtgs = np.empty(shape=(batch_size, self._context_length, 1)) @@ -420,19 +433,19 @@ def __call__(self, s: Union[np.ndarray, Tuple[np.ndarray, ...]], *, begin_of_epi if t == 0: self._eval_rtgs[:, T, ...] = self._target_return * self._reward_scale else: - reward = extra_info['reward'] * self._reward_scale - self._eval_rtgs[:, T, ...] = self._eval_rtgs[:, T-1, ...] - reward + reward = extra_info["reward"] * self._reward_scale + self._eval_rtgs[:, T, ...] = self._eval_rtgs[:, T - 1, ...] - reward with nn.auto_forward(): - states_var = nn.Variable.from_numpy_array(self._eval_states[:, 0:T+1, ...]) + states_var = nn.Variable.from_numpy_array(self._eval_states[:, 0 : T + 1, ...]) if begin_of_episode: actions_var = None else: if self._context_length <= t: - actions_var = nn.Variable.from_numpy_array(self._eval_actions[:, 0:T+1, ...]) + actions_var = nn.Variable.from_numpy_array(self._eval_actions[:, 0 : T + 1, ...]) else: actions_var = nn.Variable.from_numpy_array(self._eval_actions[:, 0:T, ...]) - rtgs_var = nn.Variable.from_numpy_array(self._eval_rtgs[:, 0:T+1, ...]) + rtgs_var = nn.Variable.from_numpy_array(self._eval_rtgs[:, 0 : T + 1, ...]) timesteps_var = nn.Variable.from_numpy_array(self._eval_timesteps) if isinstance(self._model, DeterministicDecisionTransformer): @@ -503,17 +516,13 @@ def __init__(self, env_info: EnvironmentInfo, model: M): self._batch_size = 1 @eval_api - def __call__(self, - s: Union[np.ndarray, Tuple[np.ndarray, ...]], - a: np.ndarray, - *, - begin_of_episode: bool = False): + def __call__(self, s: Union[np.ndarray, Tuple[np.ndarray, ...]], a: np.ndarray, *, begin_of_episode: bool = False): if not has_batch_dimension(s, self._env_info): s = add_batch_dimension(s) if not has_batch_dimension(a, self._env_info): a = cast(np.ndarray, add_batch_dimension(a)) batch_size = len(s[0]) if self._env_info.is_tuple_state_env() else len(s) - if not hasattr(self, '_eval_state_var') or batch_size != self._batch_size: + if not hasattr(self, "_eval_state_var") or batch_size != self._batch_size: # Variable creation self._eval_state_var = create_variable(batch_size, self._env_info.state_shape) self._eval_action_var = create_variable(batch_size, self._env_info.action_shape) @@ -562,6 +571,7 @@ class _InfluenceMetricsEvaluator: env_info (EnvironmentInfo): Environment infomation. q_function (FactoredContinuousQFunction): Factored Q-function for continuous action. """ + # type declarations to type check with mypy # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar # See https://mypy.readthedocs.io/en/stable/class_basics.html for details @@ -574,13 +584,14 @@ def __init__(self, env_info: EnvironmentInfo, q_function: FactoredContinuousQFun self._batch_size = 1 @eval_api - def __call__(self, s: Union[np.ndarray, Tuple[np.ndarray, ...]], a: np.ndarray, *, begin_of_episode: bool = False) \ - -> Tuple[np.ndarray, Dict]: + def __call__( + self, s: Union[np.ndarray, Tuple[np.ndarray, ...]], a: np.ndarray, *, begin_of_episode: bool = False + ) -> Tuple[np.ndarray, Dict]: if not has_batch_dimension(s, self._env_info): s = add_batch_dimension(s) a = cast(np.ndarray, add_batch_dimension(a)) batch_size = len(s[0]) if self._env_info.is_tuple_state_env() else len(s) - if not hasattr(self, '_eval_state_var') or batch_size != self._batch_size: + if not hasattr(self, "_eval_state_var") or batch_size != self._batch_size: # Variable creation self._batch_size = batch_size self._eval_state_var = create_variable(batch_size, self._env_info.state_shape) diff --git a/nnabla_rl/algorithms/ddp.py b/nnabla_rl/algorithms/ddp.py index b934c8e5..4bb02e69 100644 --- a/nnabla_rl/algorithms/ddp.py +++ b/nnabla_rl/algorithms/ddp.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,6 +36,7 @@ class DDPConfig(AlgorithmConfig): modification_factor (float): Modification factor for the regularizer. Defaults to 2.0. accept_improvement_ratio (float): Threshold value for deciding to accept the update or not. Defaults to 0.0 """ + T_max: int = 50 num_iterations: int = 10 mu_min: float = 1e-6 @@ -45,8 +46,8 @@ class DDPConfig(AlgorithmConfig): def __post_init__(self): super().__post_init__() - self._assert_positive(self.T_max, 'T_max') - self._assert_positive(self.num_iterations, 'num_iterations') + self._assert_positive(self.T_max, "T_max") + self._assert_positive(self.num_iterations, "num_iterations") class DDP(Algorithm): @@ -69,13 +70,10 @@ class DDP(Algorithm): config (:py:class:`DDPConfig `): the parameter for DDP controller """ + _config: DDPConfig - def __init__(self, - env_or_env_info, - dynamics: Dynamics, - cost_function: CostFunction, - config=DDPConfig()): + def __init__(self, env_or_env_info, dynamics: Dynamics, cost_function: CostFunction, config=DDPConfig()): super(DDP, self).__init__(env_or_env_info, config=config) self._dynamics = dynamics self._cost_function = cost_function @@ -92,9 +90,9 @@ def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}): return improved_trajectory[0][1] @eval_api - def compute_trajectory(self, - initial_trajectory: Sequence[Tuple[np.ndarray, Optional[np.ndarray]]]) \ - -> Tuple[Sequence[Tuple[np.ndarray, Optional[np.ndarray]]], Sequence[Dict[str, Any]]]: + def compute_trajectory( + self, initial_trajectory: Sequence[Tuple[np.ndarray, Optional[np.ndarray]]] + ) -> Tuple[Sequence[Tuple[np.ndarray, Optional[np.ndarray]]], Sequence[Dict[str, Any]]]: assert len(initial_trajectory) == self._config.T_max dynamics = self._dynamics cost_function = self._cost_function @@ -102,25 +100,28 @@ def compute_trajectory(self, delta = 0.0 trajectory = initial_trajectory for _ in range(self._config.num_iterations): - trajectory, trajectory_info, mu, delta = \ - self._improve_trajectory(trajectory, dynamics, cost_function, mu, delta) + trajectory, trajectory_info, mu, delta = self._improve_trajectory( + trajectory, dynamics, cost_function, mu, delta + ) return trajectory, trajectory_info - def _optimize(self, - initial_state: Union[np.ndarray, Sequence[Tuple[np.ndarray, Optional[np.ndarray]]]], - dynamics: Dynamics, - cost_function: CostFunction, - **kwargs) \ - -> Tuple[Sequence[Tuple[np.ndarray, Optional[np.ndarray]]], Sequence[Dict[str, Any]]]: + def _optimize( + self, + initial_state: Union[np.ndarray, Sequence[Tuple[np.ndarray, Optional[np.ndarray]]]], + dynamics: Dynamics, + cost_function: CostFunction, + **kwargs, + ) -> Tuple[Sequence[Tuple[np.ndarray, Optional[np.ndarray]]], Sequence[Dict[str, Any]]]: assert len(initial_state) == self._config.T_max initial_state = cast(Sequence[Tuple[np.ndarray, Optional[np.ndarray]]], initial_state) mu = 0.0 delta = 0.0 trajectory = initial_state for _ in range(self._config.num_iterations): - trajectory, trajectory_info, mu, delta = \ - self._improve_trajectory(trajectory, dynamics, cost_function, mu, delta) + trajectory, trajectory_info, mu, delta = self._improve_trajectory( + trajectory, dynamics, cost_function, mu, delta + ) return trajectory, trajectory_info @@ -133,15 +134,14 @@ def _compute_initial_trajectory(self, x0, dynamics, T, u): trajectory.append((x, None)) return trajectory - def _improve_trajectory(self, - trajectory: Sequence[Tuple[np.ndarray, Optional[np.ndarray]]], - dynamics: Dynamics, - cost_function: CostFunction, - mu: float, - delta: float) -> Tuple[Sequence[Tuple[np.ndarray, Optional[np.ndarray]]], - Sequence[Dict[str, Any]], - float, - float]: + def _improve_trajectory( + self, + trajectory: Sequence[Tuple[np.ndarray, Optional[np.ndarray]]], + dynamics: Dynamics, + cost_function: CostFunction, + mu: float, + delta: float, + ) -> Tuple[Sequence[Tuple[np.ndarray, Optional[np.ndarray]]], Sequence[Dict[str, Any]], float, float]: while True: ks, Ks, Qus, Quus, Quu_invs, success = self._backward_pass(trajectory, dynamics, cost_function, mu) mu, delta = self._update_regularizer(mu, delta, increase=not success) @@ -149,7 +149,7 @@ def _improve_trajectory(self, break # Backtracking linear search - alphas = 0.9**(np.arange(10) ** 2) + alphas = 0.9 ** (np.arange(10) ** 2) improved_trajectory = trajectory improved_trajectory_info: Sequence[Dict[str, Any]] = [] J_current = self._compute_cost(trajectory, cost_function) @@ -163,7 +163,7 @@ def _improve_trajectory(self, improved_trajectory = new_trajectory # append Quu for info, k, K, Quu, Quu_inv in zip(new_trajectory_info, ks, Ks, Quus, Quu_invs): - info.update({'k': k, 'K': K, 'Quu': Quu, 'Quu_inv': Quu_inv}) + info.update({"k": k, "K": K, "Quu": Quu, "Quu_inv": Quu_inv}) improved_trajectory_info = new_trajectory_info break return improved_trajectory, improved_trajectory_info, mu, delta @@ -204,7 +204,7 @@ def _backward_pass(self, trajectory, dynamics, cost_function, mu): Fx, Fu = dynamics.gradient(x, u, self._config.T_max - t - 1) # Hessians should be a 3d tensor - Fxx, Fxu, Fux, Fuu = dynamics.hessian(x, u, self._config.T_max - t - 1) + Fxx, Fxu, Fux, Fuu = dynamics.hessian(x, u, self._config.T_max - t - 1) Quu = Cuu + Fu.T.dot(Vxx + mu * E).dot(Fu) + np.tensordot(Vx, Fuu, axes=(0, 0)).reshape((u_dim, u_dim)) @@ -247,7 +247,7 @@ def _forward_pass( dynamics: Dynamics, ks: List[np.ndarray], Ks: List[np.ndarray], - alpha: float + alpha: float, ) -> Tuple[Sequence[Tuple[np.ndarray, Optional[np.ndarray]]], Sequence[Dict[str, Any]]]: x_hat = trajectory[0][0] new_trajectory = [] @@ -277,9 +277,9 @@ def _compute_cost( def _expected_cost_reduction(self, ks, Qus, Quus, alpha) -> float: delta_J = 0.0 - for (k, Qu, Quu) in zip(ks, Qus, Quus): + for k, Qu, Quu in zip(ks, Qus, Quus): linear_part = alpha * k.T.dot(Qu) - squared_part = 0.5 * (alpha ** 2.0) * k.T.dot(Quu).dot(k) + squared_part = 0.5 * (alpha**2.0) * k.T.dot(Quu).dot(k) delta_J += float(linear_part) + float(squared_part) return delta_J @@ -287,16 +287,16 @@ def _is_positive_definite(self, symmetric_matrix: np.ndarray): return np.all(np.linalg.eigvals(symmetric_matrix) > 0.0) def _before_training_start(self, env_or_buffer): - raise NotImplementedError('You do not need training to use this algorithm.') + raise NotImplementedError("You do not need training to use this algorithm.") def _run_online_training_iteration(self, env): - raise NotImplementedError('You do not need training to use this algorithm.') + raise NotImplementedError("You do not need training to use this algorithm.") def _run_offline_training_iteration(self, buffer): - raise NotImplementedError('You do not need training to use this algorithm.') + raise NotImplementedError("You do not need training to use this algorithm.") def _after_training_finish(self, env_or_buffer): - raise NotImplementedError('You do not need training to use this algorithm.') + raise NotImplementedError("You do not need training to use this algorithm.") def _models(self): return {} @@ -306,8 +306,9 @@ def _solvers(self): @classmethod def is_supported_env(cls, env_or_env_info): - env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \ - else env_or_env_info + env_info = ( + EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info + ) return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env() @property diff --git a/nnabla_rl/algorithms/ddpg.py b/nnabla_rl/algorithms/ddpg.py index 544615a9..6f504ab9 100644 --- a/nnabla_rl/algorithms/ddpg.py +++ b/nnabla_rl/algorithms/ddpg.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -74,7 +74,7 @@ class DDPGConfig(AlgorithmConfig): """ gamma: float = 0.99 - learning_rate: float = 1.0*1e-3 + learning_rate: float = 1.0 * 1e-3 batch_size: int = 100 tau: float = 0.005 start_timesteps: int = 10000 @@ -96,72 +96,76 @@ def __post_init__(self): Check set values are in valid range. """ - self._assert_between(self.gamma, 0.0, 1.0, 'gamma') - self._assert_positive(self.learning_rate, 'learning_rate') - self._assert_positive(self.batch_size, 'batch_size') - self._assert_positive(self.start_timesteps, 'start_timesteps') - self._assert_positive(self.replay_buffer_size, 'replay_buffer_size') - self._assert_positive(self.exploration_noise_sigma, 'exploration_noise_sigma') + self._assert_between(self.gamma, 0.0, 1.0, "gamma") + self._assert_positive(self.learning_rate, "learning_rate") + self._assert_positive(self.batch_size, "batch_size") + self._assert_positive(self.start_timesteps, "start_timesteps") + self._assert_positive(self.replay_buffer_size, "replay_buffer_size") + self._assert_positive(self.exploration_noise_sigma, "exploration_noise_sigma") - self._assert_positive(self.critic_unroll_steps, 'critic_unroll_steps') - self._assert_positive_or_zero(self.critic_burn_in_steps, 'critic_burn_in_steps') - self._assert_positive(self.actor_unroll_steps, 'actor_unroll_steps') - self._assert_positive_or_zero(self.actor_burn_in_steps, 'actor_burn_in_steps') + self._assert_positive(self.critic_unroll_steps, "critic_unroll_steps") + self._assert_positive_or_zero(self.critic_burn_in_steps, "critic_burn_in_steps") + self._assert_positive(self.actor_unroll_steps, "actor_unroll_steps") + self._assert_positive_or_zero(self.actor_burn_in_steps, "actor_burn_in_steps") class DefaultCriticBuilder(ModelBuilder[QFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: DDPGConfig, - **kwargs) -> QFunction: - target_policy = kwargs.get('target_policy') + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: DDPGConfig, + **kwargs, + ) -> QFunction: + target_policy = kwargs.get("target_policy") return TD3QFunction(scope_name, optimal_policy=target_policy) class DefaultActorBuilder(ModelBuilder[DeterministicPolicy]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: DDPGConfig, - **kwargs) -> DeterministicPolicy: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: DDPGConfig, + **kwargs, + ) -> DeterministicPolicy: max_action_value = float(env_info.action_high[0]) return TD3Policy(scope_name, env_info.action_dim, max_action_value=max_action_value) class DefaultSolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: DDPGConfig, - **kwargs) -> nn.solver.Solver: + def build_solver( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: DDPGConfig, **kwargs + ) -> nn.solver.Solver: return NS.Adam(alpha=algorithm_config.learning_rate) class DefaultReplayBufferBuilder(ReplayBufferBuilder): - def build_replay_buffer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: DDPGConfig, - **kwargs) -> ReplayBuffer: + def build_replay_buffer( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: DDPGConfig, **kwargs + ) -> ReplayBuffer: return ReplayBuffer(capacity=algorithm_config.replay_buffer_size) class DefaultExplorerBuilder(ExplorerBuilder): - def build_explorer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: DDPGConfig, - algorithm: "DDPG", - **kwargs) -> EnvironmentExplorer: + def build_explorer( # type: ignore[override] + self, + env_info: EnvironmentInfo, + algorithm_config: DDPGConfig, + algorithm: "DDPG", + **kwargs, + ) -> EnvironmentExplorer: explorer_config = EE.GaussianExplorerConfig( warmup_random_steps=algorithm_config.start_timesteps, initial_step_num=algorithm.iteration_num, timelimit_as_terminal=False, action_clip_low=env_info.action_low, action_clip_high=env_info.action_high, - sigma=algorithm_config.exploration_noise_sigma + sigma=algorithm_config.exploration_noise_sigma, + ) + explorer = EE.GaussianExplorer( + policy_action_selector=algorithm._exploration_action_selector, env_info=env_info, config=explorer_config ) - explorer = EE.GaussianExplorer(policy_action_selector=algorithm._exploration_action_selector, - env_info=env_info, - config=explorer_config) return explorer @@ -214,14 +218,17 @@ class DDPG(Algorithm): _policy_trainer_state: Dict[str, Any] _q_function_trainer_state: Dict[str, Any] - def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: DDPGConfig = DDPGConfig(), - critic_builder: ModelBuilder[QFunction] = DefaultCriticBuilder(), - critic_solver_builder: SolverBuilder = DefaultSolverBuilder(), - actor_builder: ModelBuilder[DeterministicPolicy] = DefaultActorBuilder(), - actor_solver_builder: SolverBuilder = DefaultSolverBuilder(), - replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), - explorer_builder: ExplorerBuilder = DefaultExplorerBuilder()): + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: DDPGConfig = DDPGConfig(), + critic_builder: ModelBuilder[QFunction] = DefaultCriticBuilder(), + critic_solver_builder: SolverBuilder = DefaultSolverBuilder(), + actor_builder: ModelBuilder[DeterministicPolicy] = DefaultActorBuilder(), + actor_solver_builder: SolverBuilder = DefaultSolverBuilder(), + replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), + explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), + ): super(DDPG, self).__init__(env_or_env_info, config=config) self._explorer_builder = explorer_builder @@ -229,7 +236,7 @@ def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)): self._q = critic_builder(scope_name="q", env_info=self._env_info, algorithm_config=self._config) self._q_solver = critic_solver_builder(env_info=self._env_info, algorithm_config=self._config) - self._target_q = self._q.deepcopy('target_' + self._q.scope_name) + self._target_q = self._q.deepcopy("target_" + self._q.scope_name) self._pi = actor_builder(scope_name="pi", env_info=self._env_info, algorithm_config=self._config) self._pi_solver = actor_solver_builder(env_info=self._env_info, algorithm_config=self._config) @@ -258,12 +265,13 @@ def _setup_environment_explorer(self, env_or_buffer): def _setup_q_function_training(self, env_or_buffer): q_function_trainer_config = MT.q_value_trainers.DDPGQTrainerConfig( - reduction_method='mean', + reduction_method="mean", grad_clip=None, num_steps=self._config.num_steps, unroll_steps=self._config.critic_unroll_steps, burn_in_steps=self._config.critic_burn_in_steps, - reset_on_terminal=self._config.critic_reset_rnn_on_terminal) + reset_on_terminal=self._config.critic_reset_rnn_on_terminal, + ) q_function_trainer = MT.q_value_trainers.DDPGQTrainer( train_functions=self._q, @@ -271,7 +279,8 @@ def _setup_q_function_training(self, env_or_buffer): target_functions=self._target_q, target_policy=self._target_pi, env_info=self._env_info, - config=q_function_trainer_config) + config=q_function_trainer_config, + ) sync_model(self._q, self._target_q) return q_function_trainer @@ -279,14 +288,16 @@ def _setup_policy_training(self, env_or_buffer): policy_trainer_config = MT.policy_trainers.DPGPolicyTrainerConfig( unroll_steps=self._config.actor_unroll_steps, burn_in_steps=self._config.actor_burn_in_steps, - reset_on_terminal=self._config.actor_reset_rnn_on_terminal) + reset_on_terminal=self._config.actor_reset_rnn_on_terminal, + ) policy_trainer = MT.policy_trainers.DPGPolicyTrainer( models=self._pi, solvers={self._pi.scope_name: self._pi_solver}, q_function=self._q, env_info=self._env_info, - config=policy_trainer_config) + config=policy_trainer_config, + ) sync_model(self._pi, self._target_pi, tau=1.0) return policy_trainer @@ -305,23 +316,25 @@ def _ddpg_training(self, replay_buffer): num_steps = max(actor_steps, critic_steps) experiences_tuple, info = replay_buffer.sample(self._config.batch_size, num_steps=num_steps) if num_steps == 1: - experiences_tuple = (experiences_tuple, ) + experiences_tuple = (experiences_tuple,) assert len(experiences_tuple) == num_steps batch = None for experiences in reversed(experiences_tuple): (s, a, r, non_terminal, s_next, rnn_states_dict, *_) = marshal_experiences(experiences) - rnn_states = rnn_states_dict['rnn_states'] if 'rnn_states' in rnn_states_dict else {} - batch = TrainingBatch(batch_size=self._config.batch_size, - s_current=s, - a_current=a, - gamma=self._config.gamma, - reward=r, - non_terminal=non_terminal, - s_next=s_next, - weight=info['weights'], - next_step_batch=batch, - rnn_states=rnn_states) + rnn_states = rnn_states_dict["rnn_states"] if "rnn_states" in rnn_states_dict else {} + batch = TrainingBatch( + batch_size=self._config.batch_size, + s_current=s, + a_current=a, + gamma=self._config.gamma, + reward=r, + non_terminal=non_terminal, + s_next=s_next, + weight=info["weights"], + next_step_batch=batch, + rnn_states=rnn_states, + ) self._q_function_trainer_state = self._q_function_trainer.train(batch) sync_model(self._q, self._target_q, tau=self._config.tau) @@ -329,7 +342,7 @@ def _ddpg_training(self, replay_buffer): self._policy_trainer_state = self._policy_trainer.train(batch) sync_model(self._pi, self._target_pi, tau=self._config.tau) - td_errors = self._q_function_trainer_state['td_errors'] + td_errors = self._q_function_trainer_state["td_errors"] replay_buffer.update_priorities(td_errors) def _evaluation_action_selector(self, s, *, begin_of_episode=False): @@ -357,19 +370,21 @@ def is_rnn_supported(self): @classmethod def is_supported_env(cls, env_or_env_info): - env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \ - else env_or_env_info + env_info = ( + EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info + ) return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env() @property def latest_iteration_state(self): latest_iteration_state = super(DDPG, self).latest_iteration_state - if hasattr(self, '_policy_trainer_state'): - latest_iteration_state['scalar'].update({'pi_loss': float(self._policy_trainer_state['pi_loss'])}) - if hasattr(self, '_q_function_trainer_state'): - latest_iteration_state['scalar'].update({'q_loss': float(self._q_function_trainer_state['q_loss'])}) - latest_iteration_state['histogram'].update( - {'td_errors': self._q_function_trainer_state['td_errors'].flatten()}) + if hasattr(self, "_policy_trainer_state"): + latest_iteration_state["scalar"].update({"pi_loss": float(self._policy_trainer_state["pi_loss"])}) + if hasattr(self, "_q_function_trainer_state"): + latest_iteration_state["scalar"].update({"q_loss": float(self._q_function_trainer_state["q_loss"])}) + latest_iteration_state["histogram"].update( + {"td_errors": self._q_function_trainer_state["td_errors"].flatten()} + ) return latest_iteration_state @property diff --git a/nnabla_rl/algorithms/ddqn.py b/nnabla_rl/algorithms/ddqn.py index ddf93981..4b678777 100644 --- a/nnabla_rl/algorithms/ddqn.py +++ b/nnabla_rl/algorithms/ddqn.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,8 +18,14 @@ import gym import nnabla_rl.model_trainers as MT -from nnabla_rl.algorithms.dqn import (DQN, DefaultExplorerBuilder, DefaultQFunctionBuilder, DefaultReplayBufferBuilder, - DefaultSolverBuilder, DQNConfig) +from nnabla_rl.algorithms.dqn import ( + DQN, + DefaultExplorerBuilder, + DefaultQFunctionBuilder, + DefaultReplayBufferBuilder, + DefaultSolverBuilder, + DQNConfig, +) from nnabla_rl.builders import ModelBuilder, ReplayBufferBuilder, SolverBuilder from nnabla_rl.builders.explorer_builder import ExplorerBuilder from nnabla_rl.environments.environment_info import EnvironmentInfo @@ -55,6 +61,7 @@ class DDQNConfig(DQNConfig): test_epsilon (float): the epsilon value on testing. Defaults to 0.05. grad_clip (Optional[Tuple[float, float]]): Clip the gradient of final layer. Defaults to (-1.0, 1.0). """ + pass @@ -87,31 +94,40 @@ class DDQN(DQN): # See https://mypy.readthedocs.io/en/stable/class_basics.html for details _config: DDQNConfig - def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: DDQNConfig = DDQNConfig(), - q_func_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), - q_solver_builder: SolverBuilder = DefaultSolverBuilder(), - replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), - explorer_builder: ExplorerBuilder = DefaultExplorerBuilder()): - super(DDQN, self).__init__(env_or_env_info=env_or_env_info, - config=config, - q_func_builder=q_func_builder, - q_solver_builder=q_solver_builder, - replay_buffer_builder=replay_buffer_builder, - explorer_builder=explorer_builder) + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: DDQNConfig = DDQNConfig(), + q_func_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), + q_solver_builder: SolverBuilder = DefaultSolverBuilder(), + replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), + explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), + ): + super(DDQN, self).__init__( + env_or_env_info=env_or_env_info, + config=config, + q_func_builder=q_func_builder, + q_solver_builder=q_solver_builder, + replay_buffer_builder=replay_buffer_builder, + explorer_builder=explorer_builder, + ) def _setup_q_function_training(self, env_or_buffer): - trainer_config = MT.q_value_trainers.DDQNQTrainerConfig(num_steps=self._config.num_steps, - reduction_method='sum', - grad_clip=self._config.grad_clip, - unroll_steps=self._config.unroll_steps, - burn_in_steps=self._config.burn_in_steps, - reset_on_terminal=self._config.reset_rnn_on_terminal) - - q_function_trainer = MT.q_value_trainers.DDQNQTrainer(train_function=self._q, - solvers={self._q.scope_name: self._q_solver}, - target_function=self._target_q, - env_info=self._env_info, - config=trainer_config) + trainer_config = MT.q_value_trainers.DDQNQTrainerConfig( + num_steps=self._config.num_steps, + reduction_method="sum", + grad_clip=self._config.grad_clip, + unroll_steps=self._config.unroll_steps, + burn_in_steps=self._config.burn_in_steps, + reset_on_terminal=self._config.reset_rnn_on_terminal, + ) + + q_function_trainer = MT.q_value_trainers.DDQNQTrainer( + train_function=self._q, + solvers={self._q.scope_name: self._q_solver}, + target_function=self._target_q, + env_info=self._env_info, + config=trainer_config, + ) sync_model(self._q, self._target_q) return q_function_trainer diff --git a/nnabla_rl/algorithms/decision_transformer.py b/nnabla_rl/algorithms/decision_transformer.py index 01687d09..7da3e968 100644 --- a/nnabla_rl/algorithms/decision_transformer.py +++ b/nnabla_rl/algorithms/decision_transformer.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -54,6 +54,7 @@ class DecisionTransformerConfig(AlgorithmConfig): target_return (int): Initial target return used to compute the evaluation action. Defaults to 90. reward_scale (float): Reward scaler. Reward received during evaluation will be multiplied by this value. """ + learning_rate: float = 6.0e-4 batch_size: int = 128 context_length: int = 30 @@ -69,34 +70,34 @@ def __post_init__(self): Check set values are in valid range. """ - self._assert_positive(self.learning_rate, 'learning_rate') - self._assert_positive(self.batch_size, 'batch_size') - self._assert_positive(self.context_length, 'context_length') - self._assert_positive(self.grad_clip_norm, 'grad_clip_norm') + self._assert_positive(self.learning_rate, "learning_rate") + self._assert_positive(self.batch_size, "batch_size") + self._assert_positive(self.context_length, "context_length") + self._assert_positive(self.grad_clip_norm, "grad_clip_norm") if self.max_timesteps is not None: - self._assert_positive(self.max_timesteps, 'max_timesteps') - self._assert_positive(self.weight_decay, 'weight_decay') - self._assert_positive(self.target_return, 'target_return') + self._assert_positive(self.max_timesteps, "max_timesteps") + self._assert_positive(self.weight_decay, "weight_decay") + self._assert_positive(self.target_return, "target_return") class DefaultTransformerBuilder(ModelBuilder[DecisionTransformerModel]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: DecisionTransformerConfig, - **kwargs) -> DecisionTransformerModel: - max_timesteps = cast(int, kwargs['max_timesteps']) - return AtariDecisionTransformer(scope_name, - env_info.action_dim, - max_timestep=max_timesteps, - context_length=algorithm_config.context_length) + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: DecisionTransformerConfig, + **kwargs, + ) -> DecisionTransformerModel: + max_timesteps = cast(int, kwargs["max_timesteps"]) + return AtariDecisionTransformer( + scope_name, env_info.action_dim, max_timestep=max_timesteps, context_length=algorithm_config.context_length + ) class DefaultSolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: DecisionTransformerConfig, - **kwargs) -> nn.solver.Solver: + def build_solver( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: DecisionTransformerConfig, **kwargs + ) -> nn.solver.Solver: solver = NS.Adam(alpha=algorithm_config.learning_rate, beta1=0.9, beta2=0.95) return AutoClipGradByNorm(solver, algorithm_config.grad_clip_norm) @@ -128,12 +129,15 @@ class DecisionTransformer(Algorithm): _decision_transformer_trainer_state: Dict[str, Any] _action_selector: _DecisionTransformerActionSelector - def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: DecisionTransformerConfig = DecisionTransformerConfig(), - transformer_builder: ModelBuilder[DecisionTransformerModel] = DefaultTransformerBuilder(), - transformer_solver_builder: SolverBuilder = DefaultSolverBuilder(), - transformer_wd_solver_builder: Optional[SolverBuilder] = None, - lr_scheduler_builder: Optional[LearningRateSchedulerBuilder] = None): + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: DecisionTransformerConfig = DecisionTransformerConfig(), + transformer_builder: ModelBuilder[DecisionTransformerModel] = DefaultTransformerBuilder(), + transformer_solver_builder: SolverBuilder = DefaultSolverBuilder(), + transformer_wd_solver_builder: Optional[SolverBuilder] = None, + lr_scheduler_builder: Optional[LearningRateSchedulerBuilder] = None, + ): super(DecisionTransformer, self).__init__(env_or_env_info, config=config) if config.max_timesteps is None: assert not np.isposinf(self._env_info.max_episode_steps) @@ -141,29 +145,42 @@ def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], else: self._max_timesteps = config.max_timesteps with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)): - self._decision_transformer = transformer_builder(scope_name='decision_transformer', - env_info=self._env_info, - algorithm_config=self._config, - max_timesteps=self._max_timesteps) + self._decision_transformer = transformer_builder( + scope_name="decision_transformer", + env_info=self._env_info, + algorithm_config=self._config, + max_timesteps=self._max_timesteps, + ) self._decision_transformer_solver = transformer_solver_builder( - env_info=self._env_info, algorithm_config=self._config) - self._decision_transformer_wd_solver = None if transformer_wd_solver_builder is None else \ - transformer_wd_solver_builder(env_info=self._env_info, algorithm_config=self._config) - self._lr_scheduler = None if lr_scheduler_builder is None else lr_scheduler_builder( - env_info=self._env_info, algorithm_config=self._config) - - self._action_selector = _DecisionTransformerActionSelector(self._env_info, - self._decision_transformer.shallowcopy(), - self._max_timesteps, - self._config.context_length, - self._config.target_return, - self._config.reward_scale) + env_info=self._env_info, algorithm_config=self._config + ) + self._decision_transformer_wd_solver = ( + None + if transformer_wd_solver_builder is None + else transformer_wd_solver_builder(env_info=self._env_info, algorithm_config=self._config) + ) + self._lr_scheduler = ( + None + if lr_scheduler_builder is None + else lr_scheduler_builder(env_info=self._env_info, algorithm_config=self._config) + ) + + self._action_selector = _DecisionTransformerActionSelector( + self._env_info, + self._decision_transformer.shallowcopy(), + self._max_timesteps, + self._config.context_length, + self._config.target_return, + self._config.reward_scale, + ) @eval_api def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}): - if 'reward' not in extra_info: - raise ValueError(f'{self.__name__} requires previous reward info in addition to state to compute action.' - 'use extra_info["reward"]=reward') + if "reward" not in extra_info: + raise ValueError( + f"{self.__name__} requires previous reward info in addition to state to compute action." + 'use extra_info["reward"]=reward' + ) return self._action_selector(state, begin_of_episode=begin_of_episode, extra_info=extra_info) def _before_training_start(self, env_or_buffer): @@ -174,7 +191,8 @@ def _before_training_start(self, env_or_buffer): def _setup_decision_transformer_training(self, env_or_buffer): if isinstance(self._decision_transformer, DeterministicDecisionTransformer): trainer_config = MT.dt_trainers.DeterministicDecisionTransformerTrainerConfig( - context_length=self._config.context_length) + context_length=self._config.context_length + ) solvers = {self._decision_transformer.scope_name: self._decision_transformer_solver} wd_solver = self._decision_transformer_wd_solver wd_solvers = None if wd_solver is None else {self._decision_transformer.scope_name: wd_solver} @@ -183,11 +201,13 @@ def _setup_decision_transformer_training(self, env_or_buffer): solvers=solvers, wd_solvers=wd_solvers, env_info=self._env_info, - config=trainer_config) + config=trainer_config, + ) return decision_transformer_trainer if isinstance(self._decision_transformer, StochasticDecisionTransformer): trainer_config = MT.dt_trainers.StochasticDecisionTransformerTrainerConfig( - context_length=self._config.context_length) + context_length=self._config.context_length + ) solvers = {self._decision_transformer.scope_name: self._decision_transformer_solver} wd_solver = self._decision_transformer_wd_solver wd_solvers = None if wd_solver is None else {self._decision_transformer.scope_name: wd_solver} @@ -196,13 +216,15 @@ def _setup_decision_transformer_training(self, env_or_buffer): solvers=solvers, wd_solvers=wd_solvers, env_info=self._env_info, - config=trainer_config) + config=trainer_config, + ) return decision_transformer_trainer raise NotImplementedError( - 'Unknown model type. Model should be either Deterministic/StochasticDecisionTransformer') + "Unknown model type. Model should be either Deterministic/StochasticDecisionTransformer" + ) def _run_online_training_iteration(self, env): - raise NotImplementedError(f'Online training is not supported for {self.__name__}') + raise NotImplementedError(f"Online training is not supported for {self.__name__}") def _run_offline_training_iteration(self, buffer): assert isinstance(buffer, TrajectoryReplayBuffer) @@ -211,13 +233,15 @@ def _run_offline_training_iteration(self, buffer): def _decision_transformer_training(self, replay_buffer, run_epoch): if run_epoch: buffer_iterator = _TrajectoryBufferIterator( - replay_buffer, self._config.batch_size, self._config.context_length) + replay_buffer, self._config.batch_size, self._config.context_length + ) # Run 1 epoch for trajectories, info in buffer_iterator: self._decision_transformer_iteration(trajectories, info) else: - trajectories, info = replay_buffer.sample_trajectories_portion(self._config.batch_size, - self._config.context_length) + trajectories, info = replay_buffer.sample_trajectories_portion( + self._config.batch_size, self._config.context_length + ) self._decision_transformer_iteration(trajectories, info) def _decision_transformer_iteration(self, trajectories, info): @@ -229,15 +253,17 @@ def _decision_transformer_iteration(self, trajectories, info): marshaled = marshal_experiences(trajectory) experiences.append(marshaled) (s, a, _, _, _, extra, *_) = marshal_experiences(experiences) - extra['target'] = a - extra['rtg'] = extra['rtg'] # NOTE: insure that 'rtg' exists - extra['timesteps'] = extra['timesteps'][:, 0:1, :] - batch = TrainingBatch(batch_size=self._config.batch_size, - s_current=s, - a_current=a, - weight=info['weights'], - next_step_batch=None, - extra=extra) + extra["target"] = a + extra["rtg"] = extra["rtg"] # NOTE: insure that 'rtg' exists + extra["timesteps"] = extra["timesteps"][:, 0:1, :] + batch = TrainingBatch( + batch_size=self._config.batch_size, + s_current=s, + a_current=a, + weight=info["weights"], + next_step_batch=None, + extra=extra, + ) self._decision_transformer_trainer_state = self._decision_transformer_trainer.train(batch) if self._lr_scheduler is not None: @@ -249,22 +275,24 @@ def _models(self): models = {} models[self._decision_transformer.scope_name] = self._decision_transformer if self._decision_transformer_wd_solver is not None: - models[f'{self._decision_transformer.scope_name}_wd'] = self._decision_transformer + models[f"{self._decision_transformer.scope_name}_wd"] = self._decision_transformer return models def _solvers(self): solvers = {} solvers[self._decision_transformer.scope_name] = self._decision_transformer_solver if self._decision_transformer_wd_solver is not None: - solvers[f'{self._decision_transformer.scope_name}_wd'] = self._decision_transformer_wd_solver + solvers[f"{self._decision_transformer.scope_name}_wd"] = self._decision_transformer_wd_solver return solvers @classmethod def is_supported_env(cls, env_or_env_info): - env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance( - env_or_env_info, gym.Env) else env_or_env_info - return ((env_info.is_continuous_action_env() or env_info.is_discrete_action_env()) - and not env_info.is_tuple_action_env()) + env_info = ( + EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info + ) + return ( + env_info.is_continuous_action_env() or env_info.is_discrete_action_env() + ) and not env_info.is_tuple_action_env() @classmethod def is_rnn_supported(self): @@ -273,9 +301,8 @@ def is_rnn_supported(self): @property def latest_iteration_state(self): latest_iteration_state = super(DecisionTransformer, self).latest_iteration_state - if hasattr(self, '_decision_transformer_trainer_state'): - latest_iteration_state['scalar'].update( - {'loss': float(self._decision_transformer_trainer_state['loss'])}) + if hasattr(self, "_decision_transformer_trainer_state"): + latest_iteration_state["scalar"].update({"loss": float(self._decision_transformer_trainer_state["loss"])}) return latest_iteration_state @property diff --git a/nnabla_rl/algorithms/demme_sac.py b/nnabla_rl/algorithms/demme_sac.py index 9a1791ba..24194a17 100644 --- a/nnabla_rl/algorithms/demme_sac.py +++ b/nnabla_rl/algorithms/demme_sac.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -118,7 +118,7 @@ class DEMMESACConfig(AlgorithmConfig): """ gamma: float = 0.99 - learning_rate: float = 3.0*1e-4 + learning_rate: float = 3.0 * 1e-4 batch_size: int = 256 tau: float = 0.005 environment_steps: int = 1 @@ -164,27 +164,27 @@ def __post_init__(self): Check the values are in valid range. """ - self._assert_between(self.tau, 0.0, 1.0, 'tau') - self._assert_between(self.gamma, 0.0, 1.0, 'gamma') - self._assert_positive(self.gradient_steps, 'gradient_steps') - self._assert_positive(self.environment_steps, 'environment_steps') - self._assert_positive(self.start_timesteps, 'start_timesteps') - self._assert_positive(self.target_update_interval, 'target_update_interval') - self._assert_positive(self.num_rr_steps, 'num_rr_steps') - self._assert_positive(self.num_re_steps, 'num_re_steps') - - self._assert_positive(self.pi_t_unroll_steps, 'pi_t_unroll_steps') - self._assert_positive_or_zero(self.pi_t_burn_in_steps, 'pi_t_burn_in_steps') - self._assert_positive(self.pi_e_unroll_steps, 'pi_e_unroll_steps') - self._assert_positive_or_zero(self.pi_e_burn_in_steps, 'pi_e_burn_in_steps') - self._assert_positive(self.q_rr_unroll_steps, 'q_rr_unroll_steps') - self._assert_positive_or_zero(self.q_rr_burn_in_steps, 'q_rr_burn_in_steps') - self._assert_positive(self.q_re_unroll_steps, 'q_re_unroll_steps') - self._assert_positive_or_zero(self.q_re_burn_in_steps, 'q_re_burn_in_steps') - self._assert_positive(self.v_rr_unroll_steps, 'v_rr_unroll_steps') - self._assert_positive_or_zero(self.v_rr_burn_in_steps, 'v_rr_burn_in_steps') - self._assert_positive(self.v_re_unroll_steps, 'v_re_unroll_steps') - self._assert_positive_or_zero(self.v_re_burn_in_steps, 'v_re_burn_in_steps') + self._assert_between(self.tau, 0.0, 1.0, "tau") + self._assert_between(self.gamma, 0.0, 1.0, "gamma") + self._assert_positive(self.gradient_steps, "gradient_steps") + self._assert_positive(self.environment_steps, "environment_steps") + self._assert_positive(self.start_timesteps, "start_timesteps") + self._assert_positive(self.target_update_interval, "target_update_interval") + self._assert_positive(self.num_rr_steps, "num_rr_steps") + self._assert_positive(self.num_re_steps, "num_re_steps") + + self._assert_positive(self.pi_t_unroll_steps, "pi_t_unroll_steps") + self._assert_positive_or_zero(self.pi_t_burn_in_steps, "pi_t_burn_in_steps") + self._assert_positive(self.pi_e_unroll_steps, "pi_e_unroll_steps") + self._assert_positive_or_zero(self.pi_e_burn_in_steps, "pi_e_burn_in_steps") + self._assert_positive(self.q_rr_unroll_steps, "q_rr_unroll_steps") + self._assert_positive_or_zero(self.q_rr_burn_in_steps, "q_rr_burn_in_steps") + self._assert_positive(self.q_re_unroll_steps, "q_re_unroll_steps") + self._assert_positive_or_zero(self.q_re_burn_in_steps, "q_re_burn_in_steps") + self._assert_positive(self.v_rr_unroll_steps, "v_rr_unroll_steps") + self._assert_positive_or_zero(self.v_rr_burn_in_steps, "v_rr_burn_in_steps") + self._assert_positive(self.v_re_unroll_steps, "v_re_unroll_steps") + self._assert_positive_or_zero(self.v_re_burn_in_steps, "v_re_burn_in_steps") if self.alpha_pi is not None: # Recompute with alpha_pi @@ -192,65 +192,71 @@ def __post_init__(self): class DefaultVFunctionBuilder(ModelBuilder[VFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: DEMMESACConfig, - **kwargs) -> VFunction: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: DEMMESACConfig, + **kwargs, + ) -> VFunction: return SACVFunction(scope_name) class DefaultQFunctionBuilder(ModelBuilder[QFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: DEMMESACConfig, - **kwargs) -> QFunction: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: DEMMESACConfig, + **kwargs, + ) -> QFunction: return SACQFunction(scope_name) class DefaultPolicyBuilder(ModelBuilder[StochasticPolicy]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: DEMMESACConfig, - **kwargs) -> StochasticPolicy: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: DEMMESACConfig, + **kwargs, + ) -> StochasticPolicy: return SACPolicy(scope_name, env_info.action_dim) class DefaultSolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: DEMMESACConfig, - **kwargs) -> nn.solver.Solver: + def build_solver( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: DEMMESACConfig, **kwargs + ) -> nn.solver.Solver: assert isinstance(algorithm_config, DEMMESACConfig) return NS.Adam(alpha=algorithm_config.learning_rate) class DefaultReplayBufferBuilder(ReplayBufferBuilder): - def build_replay_buffer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: DEMMESACConfig, - **kwargs) -> ReplayBuffer: + def build_replay_buffer( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: DEMMESACConfig, **kwargs + ) -> ReplayBuffer: assert isinstance(algorithm_config, DEMMESACConfig) return ReplayBuffer(capacity=algorithm_config.replay_buffer_size) class DefaultExplorerBuilder(ExplorerBuilder): - def build_explorer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: DEMMESACConfig, - algorithm: "DEMMESAC", - **kwargs) -> EnvironmentExplorer: + def build_explorer( # type: ignore[override] + self, + env_info: EnvironmentInfo, + algorithm_config: DEMMESACConfig, + algorithm: "DEMMESAC", + **kwargs, + ) -> EnvironmentExplorer: explorer_config = EE.RawPolicyExplorerConfig( warmup_random_steps=algorithm_config.start_timesteps, reward_scalar=algorithm_config.reward_scalar, initial_step_num=algorithm.iteration_num, - timelimit_as_terminal=False + timelimit_as_terminal=False, + ) + explorer = EE.RawPolicyExplorer( + policy_action_selector=algorithm._exploration_action_selector, env_info=env_info, config=explorer_config ) - explorer = EE.RawPolicyExplorer(policy_action_selector=algorithm._exploration_action_selector, - env_info=env_info, - config=explorer_config) return explorer @@ -338,58 +344,69 @@ class DEMMESAC(Algorithm): _v_rr_trainer_state: Dict[str, Any] _v_re_trainer_state: Dict[str, Any] - def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: DEMMESACConfig = DEMMESACConfig(), - v_rr_function_builder: ModelBuilder[VFunction] = DefaultVFunctionBuilder(), - v_rr_solver_builder: SolverBuilder = DefaultSolverBuilder(), - v_re_function_builder: ModelBuilder[VFunction] = DefaultVFunctionBuilder(), - v_re_solver_builder: SolverBuilder = DefaultSolverBuilder(), - q_rr_function_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), - q_rr_solver_builder: SolverBuilder = DefaultSolverBuilder(), - q_re_function_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), - q_re_solver_builder: SolverBuilder = DefaultSolverBuilder(), - pi_t_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), - pi_t_solver_builder: SolverBuilder = DefaultSolverBuilder(), - pi_e_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), - pi_e_solver_builder: SolverBuilder = DefaultSolverBuilder(), - replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), - explorer_builder: ExplorerBuilder = DefaultExplorerBuilder()): + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: DEMMESACConfig = DEMMESACConfig(), + v_rr_function_builder: ModelBuilder[VFunction] = DefaultVFunctionBuilder(), + v_rr_solver_builder: SolverBuilder = DefaultSolverBuilder(), + v_re_function_builder: ModelBuilder[VFunction] = DefaultVFunctionBuilder(), + v_re_solver_builder: SolverBuilder = DefaultSolverBuilder(), + q_rr_function_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), + q_rr_solver_builder: SolverBuilder = DefaultSolverBuilder(), + q_re_function_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), + q_re_solver_builder: SolverBuilder = DefaultSolverBuilder(), + pi_t_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), + pi_t_solver_builder: SolverBuilder = DefaultSolverBuilder(), + pi_e_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), + pi_e_solver_builder: SolverBuilder = DefaultSolverBuilder(), + replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), + explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), + ): super(DEMMESAC, self).__init__(env_or_env_info, config=config) self._explorer_builder = explorer_builder with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)): self._v_rr = v_rr_function_builder( - scope_name="v_rr", env_info=self._env_info, algorithm_config=self._config) + scope_name="v_rr", env_info=self._env_info, algorithm_config=self._config + ) self._v_rr_solver = v_rr_solver_builder(env_info=self._env_info, algorithm_config=self._config) self._v_re = v_re_function_builder( - scope_name="v_re", env_info=self._env_info, algorithm_config=self._config) + scope_name="v_re", env_info=self._env_info, algorithm_config=self._config + ) self._v_re_solver = v_re_solver_builder(env_info=self._env_info, algorithm_config=self._config) - self._target_v_rr = self._v_rr.deepcopy('target_' + self._v_rr.scope_name) - self._target_v_re = self._v_re.deepcopy('target_' + self._v_re.scope_name) + self._target_v_rr = self._v_rr.deepcopy("target_" + self._v_rr.scope_name) + self._target_v_re = self._v_re.deepcopy("target_" + self._v_re.scope_name) self._q_rr1 = q_rr_function_builder( - scope_name="q_rr1", env_info=self._env_info, algorithm_config=self._config) + scope_name="q_rr1", env_info=self._env_info, algorithm_config=self._config + ) self._q_rr2 = q_rr_function_builder( - scope_name="q_rr2", env_info=self._env_info, algorithm_config=self._config) + scope_name="q_rr2", env_info=self._env_info, algorithm_config=self._config + ) self._train_q_rr_functions = [self._q_rr1, self._q_rr2] self._train_q_rr_solvers = {} for q in self._train_q_rr_functions: self._train_q_rr_solvers[q.scope_name] = q_rr_solver_builder( - env_info=self._env_info, algorithm_config=self._config) + env_info=self._env_info, algorithm_config=self._config + ) self._q_re1 = q_re_function_builder( - scope_name="q_re1", env_info=self._env_info, algorithm_config=self._config) + scope_name="q_re1", env_info=self._env_info, algorithm_config=self._config + ) self._q_re2 = q_re_function_builder( - scope_name="q_re2", env_info=self._env_info, algorithm_config=self._config) + scope_name="q_re2", env_info=self._env_info, algorithm_config=self._config + ) self._train_q_re_functions = [self._q_re1, self._q_re2] self._train_q_re_solvers = {} for q in self._train_q_re_functions: self._train_q_re_solvers[q.scope_name] = q_re_solver_builder( - env_info=self._env_info, algorithm_config=self._config) + env_info=self._env_info, algorithm_config=self._config + ) self._pi_t = pi_t_builder(scope_name="pi_t", env_info=self._env_info, algorithm_config=self._config) self._pi_t_solver = pi_t_solver_builder(env_info=self._env_info, algorithm_config=self._config) @@ -400,9 +417,11 @@ def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], self._replay_buffer = replay_buffer_builder(env_info=self._env_info, algorithm_config=self._config) self._evaluation_actor = _StochasticPolicyActionSelector( - self._env_info, self._pi_t.shallowcopy(), deterministic=True) + self._env_info, self._pi_t.shallowcopy(), deterministic=True + ) self._exploration_actor = _StochasticPolicyActionSelector( - self._env_info, self._pi_t.shallowcopy(), deterministic=False) + self._env_info, self._pi_t.shallowcopy(), deterministic=False + ) @eval_api def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}): @@ -428,14 +447,16 @@ def _setup_pi_t_training(self, env_or_buffer): policy_trainer_config = MT.policy_trainers.DEMMEPolicyTrainerConfig( unroll_steps=self._config.pi_t_unroll_steps, burn_in_steps=self._config.pi_t_burn_in_steps, - reset_on_terminal=self._config.pi_t_reset_rnn_on_terminal) + reset_on_terminal=self._config.pi_t_reset_rnn_on_terminal, + ) policy_trainer = MT.policy_trainers.DEMMEPolicyTrainer( models=self._pi_t, solvers={self._pi_t.scope_name: self._pi_t_solver}, q_rr_functions=self._train_q_rr_functions, q_re_functions=self._train_q_re_functions, env_info=self._env_info, - config=policy_trainer_config) + config=policy_trainer_config, + ) return policy_trainer def _setup_pi_e_training(self, env_or_buffer): @@ -444,10 +465,11 @@ def _setup_pi_e_training(self, env_or_buffer): fixed_temperature=True, unroll_steps=self._config.pi_e_unroll_steps, burn_in_steps=self._config.pi_e_burn_in_steps, - reset_on_terminal=self._config.pi_e_reset_rnn_on_terminal) + reset_on_terminal=self._config.pi_e_reset_rnn_on_terminal, + ) temperature = MT.policy_trainers.soft_policy_trainer.AdjustableTemperature( - scope_name='temperature', - initial_value=1.0) + scope_name="temperature", initial_value=1.0 + ) policy_trainer = MT.policy_trainers.SoftPolicyTrainer( models=self._pi_e, solvers={self._pi_e.scope_name: self._pi_e_solver}, @@ -455,70 +477,78 @@ def _setup_pi_e_training(self, env_or_buffer): temperature=temperature, temperature_solver=None, env_info=self._env_info, - config=policy_trainer_config) + config=policy_trainer_config, + ) return policy_trainer def _setup_q_rr_training(self, env_or_buffer): q_rr_trainer_param = MT.q_value_trainers.VTargetedQTrainerConfig( - reduction_method='mean', + reduction_method="mean", q_loss_scalar=0.5, num_steps=self._config.num_rr_steps, unroll_steps=self._config.q_rr_unroll_steps, burn_in_steps=self._config.q_rr_burn_in_steps, - reset_on_terminal=self._config.q_rr_reset_rnn_on_terminal) + reset_on_terminal=self._config.q_rr_reset_rnn_on_terminal, + ) q_rr_trainer = MT.q_value_trainers.VTargetedQTrainer( train_functions=self._train_q_rr_functions, solvers=self._train_q_rr_solvers, target_functions=self._target_v_rr, env_info=self._env_info, - config=q_rr_trainer_param) + config=q_rr_trainer_param, + ) return q_rr_trainer def _setup_q_re_training(self, env_or_buffer): q_re_trainer_param = MT.q_value_trainers.VTargetedQTrainerConfig( - reduction_method='mean', + reduction_method="mean", q_loss_scalar=0.5, num_steps=self._config.num_re_steps, unroll_steps=self._config.q_re_unroll_steps, burn_in_steps=self._config.q_re_burn_in_steps, reset_on_terminal=self._config.q_re_reset_rnn_on_terminal, - pure_exploration=True) + pure_exploration=True, + ) q_re_trainer = MT.q_value_trainers.VTargetedQTrainer( train_functions=self._train_q_re_functions, solvers=self._train_q_re_solvers, target_functions=self._target_v_re, env_info=self._env_info, - config=q_re_trainer_param) + config=q_re_trainer_param, + ) return q_re_trainer def _setup_v_rr_training(self, env_or_buffer): v_rr_trainer_config = MT.v_value_trainers.DEMMEVTrainerConfig( - reduction_method='mean', + reduction_method="mean", v_loss_scalar=0.5, unroll_steps=self._config.v_rr_unroll_steps, burn_in_steps=self._config.v_rr_burn_in_steps, - reset_on_terminal=self._config.v_rr_reset_rnn_on_terminal) + reset_on_terminal=self._config.v_rr_reset_rnn_on_terminal, + ) v_rr_trainer = MT.v_value_trainers.DEMMEVTrainer( train_functions=self._v_rr, solvers={self._v_rr.scope_name: self._v_rr_solver}, target_functions=self._train_q_rr_functions, # Set training q_rr as target target_policy=self._pi_t, env_info=self._env_info, - config=v_rr_trainer_config) + config=v_rr_trainer_config, + ) sync_model(self._v_rr, self._target_v_rr, 1.0) return v_rr_trainer def _setup_v_re_training(self, env_or_buffer): alpha_q = MT.policy_trainers.soft_policy_trainer.AdjustableTemperature( - scope_name='alpha_q', - initial_value=self._config.alpha_q) + scope_name="alpha_q", initial_value=self._config.alpha_q + ) v_re_trainer_config = MT.v_value_trainers.MMEVTrainerConfig( - reduction_method='mean', + reduction_method="mean", v_loss_scalar=0.5, unroll_steps=self._config.v_re_unroll_steps, burn_in_steps=self._config.v_re_burn_in_steps, - reset_on_terminal=self._config.v_re_reset_rnn_on_terminal) + reset_on_terminal=self._config.v_re_reset_rnn_on_terminal, + ) v_re_trainer = MT.v_value_trainers.MMEVTrainer( train_functions=self._v_re, temperature=alpha_q, @@ -526,7 +556,8 @@ def _setup_v_re_training(self, env_or_buffer): target_functions=self._train_q_re_functions, # Set training q_re as target target_policy=self._pi_e, env_info=self._env_info, - config=v_re_trainer_config) + config=v_re_trainer_config, + ) sync_model(self._v_re, self._target_v_re, 1.0) return v_re_trainer @@ -561,23 +592,25 @@ def _demme_training(self, replay_buffer): num_steps = max(pi_steps, max(q_steps, v_steps)) experiences_tuple, info = replay_buffer.sample(self._config.batch_size, num_steps=num_steps) if num_steps == 1: - experiences_tuple = (experiences_tuple, ) + experiences_tuple = (experiences_tuple,) assert len(experiences_tuple) == num_steps batch = None for experiences in reversed(experiences_tuple): (s, a, r, non_terminal, s_next, rnn_states_dict, *_) = marshal_experiences(experiences) - rnn_states = rnn_states_dict['rnn_states'] if 'rnn_states' in rnn_states_dict else {} - batch = TrainingBatch(batch_size=self._config.batch_size, - s_current=s, - a_current=a, - gamma=self._config.gamma, - reward=r, - non_terminal=non_terminal, - s_next=s_next, - weight=info['weights'], - next_step_batch=batch, - rnn_states=rnn_states) + rnn_states = rnn_states_dict["rnn_states"] if "rnn_states" in rnn_states_dict else {} + batch = TrainingBatch( + batch_size=self._config.batch_size, + s_current=s, + a_current=a, + gamma=self._config.gamma, + reward=r, + non_terminal=non_terminal, + s_next=s_next, + weight=info["weights"], + next_step_batch=batch, + rnn_states=rnn_states, + ) # Train in the order of v -> q -> policy self._v_rr_trainer_state = self._v_rr_trainer.train(batch) @@ -591,7 +624,7 @@ def _demme_training(self, replay_buffer): self._pi_e_trainer_state = self._pi_e_trainer.train(batch) # Use q_rr's td error - td_errors = self._q_rr_trainer_state['td_errors'] + td_errors = self._q_rr_trainer_state["td_errors"] replay_buffer.update_priorities(td_errors) def _evaluation_action_selector(self, s, *, begin_of_episode=False): @@ -601,9 +634,18 @@ def _exploration_action_selector(self, s, *, begin_of_episode=False): return self._exploration_actor(s, begin_of_episode=begin_of_episode) def _models(self): - models = [self._v_rr, self._v_re, self._target_v_rr, self._target_v_re, - self._q_rr1, self._q_rr2, self._q_re1, self._q_re2, - self._pi_t, self._pi_e] + models = [ + self._v_rr, + self._v_re, + self._target_v_rr, + self._target_v_re, + self._q_rr1, + self._q_rr2, + self._q_re1, + self._q_re2, + self._pi_t, + self._pi_e, + ] return {model.scope_name: model for model in models} def _solvers(self): @@ -622,29 +664,32 @@ def is_rnn_supported(cls): @classmethod def is_supported_env(cls, env_or_env_info): - env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \ - else env_or_env_info + env_info = ( + EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info + ) return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env() @property def latest_iteration_state(self): latest_iteration_state = super(DEMMESAC, self).latest_iteration_state - if hasattr(self, '_pi_t_trainer_state'): - latest_iteration_state['scalar'].update({'pi_t_loss': float(self._pi_t_trainer_state['pi_loss'])}) - if hasattr(self, '_pi_e_trainer_state'): - latest_iteration_state['scalar'].update({'pi_e_loss': float(self._pi_e_trainer_state['pi_loss'])}) - if hasattr(self, '_v_rr_trainer_state'): - latest_iteration_state['scalar'].update({'v_re_loss': float(self._v_re_trainer_state['v_loss'])}) - if hasattr(self, '_v_re_trainer_state'): - latest_iteration_state['scalar'].update({'v_rr_loss': float(self._v_rr_trainer_state['v_loss'])}) - if hasattr(self, '_q_rr_trainer_state'): - latest_iteration_state['scalar'].update({'q_rr_loss': float(self._q_rr_trainer_state['q_loss'])}) - latest_iteration_state['histogram'].update( - {'q_rr_td_errors': self._q_rr_trainer_state['td_errors'].flatten()}) - if hasattr(self, '_q_re_trainer_state'): - latest_iteration_state['scalar'].update({'q_re_loss': float(self._q_re_trainer_state['q_loss'])}) - latest_iteration_state['histogram'].update( - {'q_re_td_errors': self._q_re_trainer_state['td_errors'].flatten()}) + if hasattr(self, "_pi_t_trainer_state"): + latest_iteration_state["scalar"].update({"pi_t_loss": float(self._pi_t_trainer_state["pi_loss"])}) + if hasattr(self, "_pi_e_trainer_state"): + latest_iteration_state["scalar"].update({"pi_e_loss": float(self._pi_e_trainer_state["pi_loss"])}) + if hasattr(self, "_v_rr_trainer_state"): + latest_iteration_state["scalar"].update({"v_re_loss": float(self._v_re_trainer_state["v_loss"])}) + if hasattr(self, "_v_re_trainer_state"): + latest_iteration_state["scalar"].update({"v_rr_loss": float(self._v_rr_trainer_state["v_loss"])}) + if hasattr(self, "_q_rr_trainer_state"): + latest_iteration_state["scalar"].update({"q_rr_loss": float(self._q_rr_trainer_state["q_loss"])}) + latest_iteration_state["histogram"].update( + {"q_rr_td_errors": self._q_rr_trainer_state["td_errors"].flatten()} + ) + if hasattr(self, "_q_re_trainer_state"): + latest_iteration_state["scalar"].update({"q_re_loss": float(self._q_re_trainer_state["q_loss"])}) + latest_iteration_state["histogram"].update( + {"q_re_td_errors": self._q_re_trainer_state["td_errors"].flatten()} + ) return latest_iteration_state @property diff --git a/nnabla_rl/algorithms/dqn.py b/nnabla_rl/algorithms/dqn.py index 413db909..04a02cfa 100644 --- a/nnabla_rl/algorithms/dqn.py +++ b/nnabla_rl/algorithms/dqn.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -74,6 +74,7 @@ class DQNConfig(AlgorithmConfig): This flag does not take effect if given model is not an RNN model. Defaults to False. """ + gamma: float = 0.99 learning_rate: float = 2.5e-4 batch_size: int = 32 @@ -100,72 +101,76 @@ def __post_init__(self): Check set values are in valid range. """ - self._assert_between(self.gamma, 0.0, 1.0, 'gamma') - self._assert_positive(self.learning_rate, 'learning_rate') - self._assert_positive(self.batch_size, 'batch_size') - self._assert_positive(self.num_steps, 'num_steps') - self._assert_positive(self.learner_update_frequency, 'learner_update_frequency') - self._assert_positive(self.target_update_frequency, 'target_update_frequency') - self._assert_positive(self.start_timesteps, 'start_timesteps') - self._assert_positive(self.replay_buffer_size, 'replay_buffer_size') - self._assert_smaller_than(self.start_timesteps, self.replay_buffer_size, 'start_timesteps') - self._assert_between(self.initial_epsilon, 0.0, 1.0, 'initial_epsilon') - self._assert_between(self.final_epsilon, 0.0, 1.0, 'final_epsilon') - self._assert_between(self.test_epsilon, 0.0, 1.0, 'test_epsilon') - self._assert_positive(self.max_explore_steps, 'max_explore_steps') - self._assert_positive(self.unroll_steps, 'unroll_steps') - self._assert_positive_or_zero(self.burn_in_steps, 'burn_in_steps') + self._assert_between(self.gamma, 0.0, 1.0, "gamma") + self._assert_positive(self.learning_rate, "learning_rate") + self._assert_positive(self.batch_size, "batch_size") + self._assert_positive(self.num_steps, "num_steps") + self._assert_positive(self.learner_update_frequency, "learner_update_frequency") + self._assert_positive(self.target_update_frequency, "target_update_frequency") + self._assert_positive(self.start_timesteps, "start_timesteps") + self._assert_positive(self.replay_buffer_size, "replay_buffer_size") + self._assert_smaller_than(self.start_timesteps, self.replay_buffer_size, "start_timesteps") + self._assert_between(self.initial_epsilon, 0.0, 1.0, "initial_epsilon") + self._assert_between(self.final_epsilon, 0.0, 1.0, "final_epsilon") + self._assert_between(self.test_epsilon, 0.0, 1.0, "test_epsilon") + self._assert_positive(self.max_explore_steps, "max_explore_steps") + self._assert_positive(self.unroll_steps, "unroll_steps") + self._assert_positive_or_zero(self.burn_in_steps, "burn_in_steps") class DefaultQFunctionBuilder(ModelBuilder[QFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: DQNConfig, - **kwargs) -> QFunction: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: DQNConfig, + **kwargs, + ) -> QFunction: return DQNQFunction(scope_name, env_info.action_dim) class DefaultSolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: DQNConfig, - **kwargs) -> nn.solver.Solver: + def build_solver( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: DQNConfig, **kwargs + ) -> nn.solver.Solver: # this decay is equivalent to 'gradient momentum' and 'squared gradient momentum' of the nature paper decay: float = 0.95 momentum: float = 0.0 min_squared_gradient: float = 0.01 - solver = NS.RMSpropGraves(lr=algorithm_config.learning_rate, decay=decay, - momentum=momentum, eps=min_squared_gradient) + solver = NS.RMSpropGraves( + lr=algorithm_config.learning_rate, decay=decay, momentum=momentum, eps=min_squared_gradient + ) return solver class DefaultReplayBufferBuilder(ReplayBufferBuilder): - def build_replay_buffer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: DQNConfig, - **kwargs) -> ReplayBuffer: + def build_replay_buffer( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: DQNConfig, **kwargs + ) -> ReplayBuffer: return ReplayBuffer(capacity=algorithm_config.replay_buffer_size) class DefaultExplorerBuilder(ExplorerBuilder): - def build_explorer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: DQNConfig, - algorithm: "DQN", - **kwargs) -> EnvironmentExplorer: + def build_explorer( # type: ignore[override] + self, + env_info: EnvironmentInfo, + algorithm_config: DQNConfig, + algorithm: "DQN", + **kwargs, + ) -> EnvironmentExplorer: explorer_config = EE.LinearDecayEpsilonGreedyExplorerConfig( warmup_random_steps=algorithm_config.start_timesteps, initial_step_num=algorithm.iteration_num, initial_epsilon=algorithm_config.initial_epsilon, final_epsilon=algorithm_config.final_epsilon, - max_explore_steps=algorithm_config.max_explore_steps + max_explore_steps=algorithm_config.max_explore_steps, ) explorer = EE.LinearDecayEpsilonGreedyExplorer( greedy_action_selector=algorithm._exploration_action_selector, random_action_selector=algorithm._random_action_selector, env_info=env_info, - config=explorer_config) + config=explorer_config, + ) return explorer @@ -212,25 +217,28 @@ class DQN(Algorithm): _evaluation_actor: _GreedyActionSelector _exploration_actor: _GreedyActionSelector - def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: DQNConfig = DQNConfig(), - q_func_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), - q_solver_builder: SolverBuilder = DefaultSolverBuilder(), - replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), - explorer_builder: ExplorerBuilder = DefaultExplorerBuilder()): + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: DQNConfig = DQNConfig(), + q_func_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), + q_solver_builder: SolverBuilder = DefaultSolverBuilder(), + replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), + explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), + ): super(DQN, self).__init__(env_or_env_info, config=config) self._explorer_builder = explorer_builder with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)): - self._q = q_func_builder(scope_name='q', env_info=self._env_info, algorithm_config=self._config) + self._q = q_func_builder(scope_name="q", env_info=self._env_info, algorithm_config=self._config) self._q_solver = q_solver_builder(env_info=self._env_info, algorithm_config=self._config) - self._target_q = self._q.deepcopy('target_' + self._q.scope_name) + self._target_q = self._q.deepcopy("target_" + self._q.scope_name) self._replay_buffer = replay_buffer_builder(env_info=self._env_info, algorithm_config=self._config) - self._environment_explorer = explorer_builder(env_info=self._env_info, - algorithm_config=self._config, - algorithm=self) + self._environment_explorer = explorer_builder( + env_info=self._env_info, algorithm_config=self._config, algorithm=self + ) self._evaluation_actor = _GreedyActionSelector(self._env_info, self._q.shallowcopy()) self._exploration_actor = _GreedyActionSelector(self._env_info, self._q.shallowcopy()) @@ -238,11 +246,13 @@ def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], @eval_api def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}): with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)): - (action, _), _ = epsilon_greedy_action_selection(state, - self._evaluation_action_selector, - self._random_action_selector, - epsilon=self._config.test_epsilon, - begin_of_episode=begin_of_episode) + (action, _), _ = epsilon_greedy_action_selection( + state, + self._evaluation_action_selector, + self._random_action_selector, + epsilon=self._config.test_epsilon, + begin_of_episode=begin_of_episode, + ) return action def _before_training_start(self, env_or_buffer): @@ -257,18 +267,20 @@ def _setup_environment_explorer(self, env_or_buffer): def _setup_q_function_training(self, env_or_buffer): trainer_config = MT.q_value_trainers.DQNQTrainerConfig( num_steps=self._config.num_steps, - reduction_method='sum', + reduction_method="sum", grad_clip=self._config.grad_clip, unroll_steps=self._config.unroll_steps, burn_in_steps=self._config.burn_in_steps, - reset_on_terminal=self._config.reset_rnn_on_terminal) + reset_on_terminal=self._config.reset_rnn_on_terminal, + ) q_function_trainer = MT.q_value_trainers.DQNQTrainer( train_functions=self._q, solvers={self._q.scope_name: self._q_solver}, target_function=self._target_q, env_info=self._env_info, - config=trainer_config) + config=trainer_config, + ) sync_model(self._q, self._target_q) return q_function_trainer @@ -290,35 +302,37 @@ def _exploration_action_selector(self, s, *, begin_of_episode=False): def _random_action_selector(self, s, *, begin_of_episode=False): action = self._env_info.action_space.sample() - return np.asarray(action).reshape((1, )), {} + return np.asarray(action).reshape((1,)), {} def _dqn_training(self, replay_buffer): num_steps = self._config.num_steps + self._config.burn_in_steps + self._config.unroll_steps - 1 experiences_tuple, info = replay_buffer.sample(self._config.batch_size, num_steps=num_steps) if num_steps == 1: - experiences_tuple = (experiences_tuple, ) + experiences_tuple = (experiences_tuple,) assert len(experiences_tuple) == num_steps batch = None for experiences in reversed(experiences_tuple): (s, a, r, non_terminal, s_next, rnn_states_dict, *_) = marshal_experiences(experiences) - rnn_states = rnn_states_dict['rnn_states'] if 'rnn_states' in rnn_states_dict else {} - batch = TrainingBatch(batch_size=self._config.batch_size, - s_current=s, - a_current=a, - gamma=self._config.gamma, - reward=r, - non_terminal=non_terminal, - s_next=s_next, - weight=info['weights'], - next_step_batch=batch, - rnn_states=rnn_states) + rnn_states = rnn_states_dict["rnn_states"] if "rnn_states" in rnn_states_dict else {} + batch = TrainingBatch( + batch_size=self._config.batch_size, + s_current=s, + a_current=a, + gamma=self._config.gamma, + reward=r, + non_terminal=non_terminal, + s_next=s_next, + weight=info["weights"], + next_step_batch=batch, + rnn_states=rnn_states, + ) self._q_function_trainer_state = self._q_function_trainer.train(batch) if self.iteration_num % self._config.target_update_frequency == 0: sync_model(self._q, self._target_q) - td_errors = self._q_function_trainer_state['td_errors'] + td_errors = self._q_function_trainer_state["td_errors"] replay_buffer.update_priorities(td_errors) def _models(self): @@ -333,8 +347,9 @@ def _solvers(self): @classmethod def is_supported_env(cls, env_or_env_info): - env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \ - else env_or_env_info + env_info = ( + EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info + ) return not env_info.is_continuous_action_env() and not env_info.is_tuple_action_env() @classmethod @@ -344,10 +359,11 @@ def is_rnn_supported(self): @property def latest_iteration_state(self): latest_iteration_state = super(DQN, self).latest_iteration_state - if hasattr(self, '_q_function_trainer_state'): - latest_iteration_state['scalar'].update({'q_loss': float(self._q_function_trainer_state['q_loss'])}) - latest_iteration_state['histogram'].update( - {'td_errors': self._q_function_trainer_state['td_errors'].flatten()}) + if hasattr(self, "_q_function_trainer_state"): + latest_iteration_state["scalar"].update({"q_loss": float(self._q_function_trainer_state["q_loss"])}) + latest_iteration_state["histogram"].update( + {"td_errors": self._q_function_trainer_state["td_errors"].flatten()} + ) return latest_iteration_state @property diff --git a/nnabla_rl/algorithms/drqn.py b/nnabla_rl/algorithms/drqn.py index a57feca9..e084adf0 100644 --- a/nnabla_rl/algorithms/drqn.py +++ b/nnabla_rl/algorithms/drqn.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -54,10 +54,9 @@ class DRQNConfig(DQNConfig): class DefaultSolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: DRQNConfig, - **kwargs) -> nn.solver.Solver: + def build_solver( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: DRQNConfig, **kwargs + ) -> nn.solver.Solver: decay: float = 0.95 solver = NS.Adadelta(lr=algorithm_config.learning_rate, decay=decay) solver = AutoClipGradByNorm(solver, algorithm_config.clip_grad_norm) @@ -65,11 +64,13 @@ def build_solver(self, # type: ignore[override] class DefaultQFunctionBuilder(ModelBuilder[QFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: DRQNConfig, - **kwargs) -> QFunction: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: DRQNConfig, + **kwargs, + ) -> QFunction: return DRQNQFunction(scope_name, env_info.action_dim) @@ -99,33 +100,40 @@ class DRQN(DQN): # See https://mypy.readthedocs.io/en/stable/class_basics.html for details _config: DRQNConfig - def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: DRQNConfig = DRQNConfig(), - q_func_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), - q_solver_builder: SolverBuilder = DefaultSolverBuilder(), - replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), - explorer_builder: ExplorerBuilder = DefaultExplorerBuilder()): - super(DRQN, self).__init__(env_or_env_info, - config=config, - q_func_builder=q_func_builder, - q_solver_builder=q_solver_builder, - replay_buffer_builder=replay_buffer_builder, - explorer_builder=explorer_builder) + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: DRQNConfig = DRQNConfig(), + q_func_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), + q_solver_builder: SolverBuilder = DefaultSolverBuilder(), + replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), + explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), + ): + super(DRQN, self).__init__( + env_or_env_info, + config=config, + q_func_builder=q_func_builder, + q_solver_builder=q_solver_builder, + replay_buffer_builder=replay_buffer_builder, + explorer_builder=explorer_builder, + ) def _setup_q_function_training(self, env_or_buffer): trainer_config = MT.q_value_trainers.DQNQTrainerConfig( num_steps=self._config.num_steps, - reduction_method='mean', # This parameter is different from DQN + reduction_method="mean", # This parameter is different from DQN grad_clip=self._config.grad_clip, unroll_steps=self._config.unroll_steps, burn_in_steps=self._config.burn_in_steps, - reset_on_terminal=self._config.reset_rnn_on_terminal) + reset_on_terminal=self._config.reset_rnn_on_terminal, + ) q_function_trainer = MT.q_value_trainers.DQNQTrainer( train_functions=self._q, solvers={self._q.scope_name: self._q_solver}, target_function=self._target_q, env_info=self._env_info, - config=trainer_config) + config=trainer_config, + ) sync_model(self._q, self._target_q) return q_function_trainer diff --git a/nnabla_rl/algorithms/dummy.py b/nnabla_rl/algorithms/dummy.py index 1e8484e1..bb7c206c 100644 --- a/nnabla_rl/algorithms/dummy.py +++ b/nnabla_rl/algorithms/dummy.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -43,12 +43,10 @@ def _before_training_start(self, env_or_buffer): logger.debug("Before training start!! Write your algorithm's initializations here.") def _run_online_training_iteration(self, env): - logger.debug("Running online training loop. Iteartion: {}".format( - self.iteration_num)) + logger.debug("Running online training loop. Iteartion: {}".format(self.iteration_num)) def _run_offline_training_iteration(self, buffer): - logger.debug("Running offline training loop. Iteartion: {}".format( - self.iteration_num)) + logger.debug("Running offline training loop. Iteartion: {}".format(self.iteration_num)) def _after_training_finish(self, env_or_buffer): logger.debug("Training finished. Do your algorithm's finalizations here.") diff --git a/nnabla_rl/algorithms/gail.py b/nnabla_rl/algorithms/gail.py index a8998c9b..841a52ff 100644 --- a/nnabla_rl/algorithms/gail.py +++ b/nnabla_rl/algorithms/gail.py @@ -1,5 +1,5 @@ # Copyright 2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,15 +27,26 @@ import nnabla_rl.model_trainers as MT import nnabla_rl.preprocessors as RP from nnabla_rl.algorithm import Algorithm, AlgorithmConfig, eval_api -from nnabla_rl.algorithms.common_utils import (_StatePreprocessedRewardFunction, _StatePreprocessedStochasticPolicy, - _StatePreprocessedVFunction, _StochasticPolicyActionSelector, - compute_v_target_and_advantage) +from nnabla_rl.algorithms.common_utils import ( + _StatePreprocessedRewardFunction, + _StatePreprocessedStochasticPolicy, + _StatePreprocessedVFunction, + _StochasticPolicyActionSelector, + compute_v_target_and_advantage, +) from nnabla_rl.builders import ExplorerBuilder, ModelBuilder, PreprocessorBuilder, SolverBuilder from nnabla_rl.environment_explorer import EnvironmentExplorer from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.model_trainers.model_trainer import ModelTrainer, TrainingBatch -from nnabla_rl.models import (GAILDiscriminator, GAILPolicy, GAILVFunction, Model, RewardFunction, StochasticPolicy, - VFunction) +from nnabla_rl.models import ( + GAILDiscriminator, + GAILPolicy, + GAILVFunction, + Model, + RewardFunction, + StochasticPolicy, + VFunction, +) from nnabla_rl.preprocessors import Preprocessor from nnabla_rl.replay_buffer import ReplayBuffer from nnabla_rl.replay_buffers.buffer_iterator import BufferIterator @@ -113,90 +124,97 @@ def __post_init__(self): Check the values are in valid range. """ - self._assert_between(self.pi_batch_size, 0, self.num_steps_per_iteration, 'pi_batch_size') + self._assert_between(self.pi_batch_size, 0, self.num_steps_per_iteration, "pi_batch_size") self._assert_positive(self.discriminator_learning_rate, "discriminator_learning_rate") self._assert_positive(self.discriminator_batch_size, "discriminator_batch_size") self._assert_positive(self.policy_update_frequency, "policy_update_frequency") self._assert_positive(self.discriminator_update_frequency, "discriminator_update_frequency") self._assert_positive(self.adversary_entropy_coef, "adversarial_entropy_coef") - self._assert_between(self.gamma, 0.0, 1.0, 'gamma') - self._assert_between(self.lmb, 0.0, 1.0, 'lmb') - self._assert_positive(self.num_steps_per_iteration, 'num_steps_per_iteration') - self._assert_positive(self.sigma_kl_divergence_constraint, 'sigma_kl_divergence_constraint') - self._assert_positive(self.maximum_backtrack_numbers, 'maximum_backtrack_numbers') - self._assert_positive(self.conjugate_gradient_damping, 'conjugate_gradient_damping') - self._assert_positive(self.conjugate_gradient_iterations, 'conjugate_gradient_iterations') - self._assert_positive(self.vf_epochs, 'vf_epochs') - self._assert_positive(self.vf_batch_size, 'vf_batch_size') - self._assert_positive(self.vf_learning_rate, 'vf_learning_rate') + self._assert_between(self.gamma, 0.0, 1.0, "gamma") + self._assert_between(self.lmb, 0.0, 1.0, "lmb") + self._assert_positive(self.num_steps_per_iteration, "num_steps_per_iteration") + self._assert_positive(self.sigma_kl_divergence_constraint, "sigma_kl_divergence_constraint") + self._assert_positive(self.maximum_backtrack_numbers, "maximum_backtrack_numbers") + self._assert_positive(self.conjugate_gradient_damping, "conjugate_gradient_damping") + self._assert_positive(self.conjugate_gradient_iterations, "conjugate_gradient_iterations") + self._assert_positive(self.vf_epochs, "vf_epochs") + self._assert_positive(self.vf_batch_size, "vf_batch_size") + self._assert_positive(self.vf_learning_rate, "vf_learning_rate") class DefaultPreprocessorBuilder(PreprocessorBuilder): - def build_preprocessor(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: GAILConfig, - **kwargs) -> Preprocessor: + def build_preprocessor( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: GAILConfig, + **kwargs, + ) -> Preprocessor: return RP.RunningMeanNormalizer(scope_name, env_info.state_shape, value_clip=(-5.0, 5.0)) class DefaultPolicyBuilder(ModelBuilder[StochasticPolicy]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: GAILConfig, - **kwargs) -> StochasticPolicy: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: GAILConfig, + **kwargs, + ) -> StochasticPolicy: return GAILPolicy(scope_name, env_info.action_dim) class DefaultVFunctionBuilder(ModelBuilder[VFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: GAILConfig, - **kwargs) -> VFunction: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: GAILConfig, + **kwargs, + ) -> VFunction: return GAILVFunction(scope_name) class DefaultRewardFunctionBuilder(ModelBuilder[RewardFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: GAILConfig, - **kwargs) -> RewardFunction: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: GAILConfig, + **kwargs, + ) -> RewardFunction: return GAILDiscriminator(scope_name) class DefaultVFunctionSolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: GAILConfig, - **kwargs) -> nn.solver.Solver: + def build_solver( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: GAILConfig, **kwargs + ) -> nn.solver.Solver: return NS.Adam(alpha=algorithm_config.vf_learning_rate) class DefaultRewardFunctionSolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: GAILConfig, - **kwargs) -> nn.solver.Solver: + def build_solver( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: GAILConfig, **kwargs + ) -> nn.solver.Solver: assert isinstance(algorithm_config, GAILConfig) return NS.Adam(alpha=algorithm_config.discriminator_learning_rate) class DefaultExplorerBuilder(ExplorerBuilder): - def build_explorer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: GAILConfig, - algorithm: "GAIL", - **kwargs) -> EnvironmentExplorer: + def build_explorer( # type: ignore[override] + self, + env_info: EnvironmentInfo, + algorithm_config: GAILConfig, + algorithm: "GAIL", + **kwargs, + ) -> EnvironmentExplorer: explorer_config = EE.RawPolicyExplorerConfig( - initial_step_num=algorithm.iteration_num, - timelimit_as_terminal=False + initial_step_num=algorithm.iteration_num, timelimit_as_terminal=False + ) + explorer = EE.RawPolicyExplorer( + policy_action_selector=algorithm._exploration_action_selector, env_info=env_info, config=explorer_config ) - explorer = EE.RawPolicyExplorer(policy_action_selector=algorithm._exploration_action_selector, - env_info=env_info, - config=explorer_config) return explorer @@ -255,16 +273,19 @@ class GAIL(Algorithm): _evaluation_actor: _StochasticPolicyActionSelector _exploration_actor: _StochasticPolicyActionSelector - def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], - expert_buffer: ReplayBuffer, - config: GAILConfig = GAILConfig(), - v_function_builder: ModelBuilder[VFunction] = DefaultVFunctionBuilder(), - v_solver_builder: SolverBuilder = DefaultVFunctionSolverBuilder(), - policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), - reward_function_builder: ModelBuilder[RewardFunction] = DefaultRewardFunctionBuilder(), - reward_solver_builder: SolverBuilder = DefaultRewardFunctionSolverBuilder(), - state_preprocessor_builder: Optional[PreprocessorBuilder] = DefaultPreprocessorBuilder(), - explorer_builder: ExplorerBuilder = DefaultExplorerBuilder()): + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + expert_buffer: ReplayBuffer, + config: GAILConfig = GAILConfig(), + v_function_builder: ModelBuilder[VFunction] = DefaultVFunctionBuilder(), + v_solver_builder: SolverBuilder = DefaultVFunctionSolverBuilder(), + policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), + reward_function_builder: ModelBuilder[RewardFunction] = DefaultRewardFunctionBuilder(), + reward_solver_builder: SolverBuilder = DefaultRewardFunctionSolverBuilder(), + state_preprocessor_builder: Optional[PreprocessorBuilder] = DefaultPreprocessorBuilder(), + explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), + ): super(GAIL, self).__init__(env_or_env_info, config=config) self._explorer_builder = explorer_builder @@ -276,13 +297,14 @@ def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], if self._config.preprocess_state: if state_preprocessor_builder is None: - raise ValueError('State preprocessing is enabled but no preprocessor builder is given') - pi_v_preprocessor = state_preprocessor_builder('pi_v_preprocessor', self._env_info, self._config) + raise ValueError("State preprocessing is enabled but no preprocessor builder is given") + pi_v_preprocessor = state_preprocessor_builder("pi_v_preprocessor", self._env_info, self._config) v_function = _StatePreprocessedVFunction(v_function=v_function, preprocessor=pi_v_preprocessor) policy = _StatePreprocessedStochasticPolicy(policy=policy, preprocessor=pi_v_preprocessor) - r_preprocessor = state_preprocessor_builder('r_preprocessor', self._env_info, self._config) + r_preprocessor = state_preprocessor_builder("r_preprocessor", self._env_info, self._config) discriminator = _StatePreprocessedRewardFunction( - reward_function=discriminator, preprocessor=r_preprocessor) + reward_function=discriminator, preprocessor=r_preprocessor + ) self._pi_v_state_preprocessor = pi_v_preprocessor self._r_state_preprocessor = r_preprocessor self._v_function = v_function @@ -295,9 +317,11 @@ def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], self._expert_buffer = expert_buffer self._evaluation_actor = _StochasticPolicyActionSelector( - self._env_info, self._policy.shallowcopy(), deterministic=self._config.act_deterministic_in_eval) + self._env_info, self._policy.shallowcopy(), deterministic=self._config.act_deterministic_in_eval + ) self._exploration_actor = _StochasticPolicyActionSelector( - self._env_info, self._policy.shallowcopy(), deterministic=False) + self._env_info, self._policy.shallowcopy(), deterministic=False + ) @eval_api def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}): @@ -318,14 +342,14 @@ def _setup_environment_explorer(self, env_or_buffer): def _setup_v_function_training(self, env_or_buffer): v_function_trainer_config = MT.v_value_trainers.MonteCarloVTrainerConfig( - reduction_method='mean', - v_loss_scalar=1.0 + reduction_method="mean", v_loss_scalar=1.0 ) v_function_trainer = MT.v_value_trainers.MonteCarloVTrainer( train_functions=self._v_function, solvers={self._v_function.scope_name: self._v_function_solver}, env_info=self._env_info, - config=v_function_trainer_config) + config=v_function_trainer_config, + ) return v_function_trainer def _setup_policy_training(self, env_or_buffer): @@ -333,24 +357,25 @@ def _setup_policy_training(self, env_or_buffer): sigma_kl_divergence_constraint=self._config.sigma_kl_divergence_constraint, maximum_backtrack_numbers=self._config.maximum_backtrack_numbers, conjugate_gradient_damping=self._config.conjugate_gradient_damping, - conjugate_gradient_iterations=self._config.conjugate_gradient_iterations) + conjugate_gradient_iterations=self._config.conjugate_gradient_iterations, + ) policy_trainer = MT.policy_trainers.TRPOPolicyTrainer( - model=self._policy, - env_info=self._env_info, - config=policy_trainer_config) + model=self._policy, env_info=self._env_info, config=policy_trainer_config + ) return policy_trainer def _setup_reward_function_training(self, env_or_buffer): reward_function_trainer_config = MT.reward_trainiers.GAILRewardFunctionTrainerConfig( batch_size=self._config.discriminator_batch_size, learning_rate=self._config.discriminator_learning_rate, - entropy_coef=self._config.adversary_entropy_coef + entropy_coef=self._config.adversary_entropy_coef, ) reward_function_trainer = MT.reward_trainiers.GAILRewardFunctionTrainer( models=self._discriminator, solvers={self._discriminator.scope_name: self._discriminator_solver}, env_info=self._env_info, - config=reward_function_trainer_config) + config=reward_function_trainer_config, + ) return reward_function_trainer @@ -371,13 +396,13 @@ def _run_online_training_iteration(self, env): def _label_experience(self, experience): labeled_experience = [] - if not hasattr(self, '_s_var_label'): + if not hasattr(self, "_s_var_label"): # build graph self._s_var_label = create_variable(1, self._env_info.state_shape) self._s_next_var_label = create_variable(1, self._env_info.state_shape) self._a_var_label = create_variable(1, self._env_info.action_shape) logits_fake = self._discriminator.r(self._s_var_label, self._a_var_label, self._s_next_var_label) - self._reward = -NF.log(1. - NF.sigmoid(logits_fake) + 1e-8) + self._reward = -NF.log(1.0 - NF.sigmoid(logits_fake) + 1e-8) for s, a, _, non_terminal, n_s, info in experience: # forward and get reward @@ -408,24 +433,28 @@ def _gail_training(self, buffer): # discriminator learning if self._iteration_num % self._config.discriminator_update_frequency == 0: - s_curr_expert, a_curr_expert, s_next_expert, s_curr_agent, a_curr_agent, s_next_agent = \ + s_curr_expert, a_curr_expert, s_next_expert, s_curr_agent, a_curr_agent, s_next_agent = ( self._align_discriminator_experiences(buffer_iterator) + ) if self._config.preprocess_state: self._r_state_preprocessor.update(np.concatenate([s_curr_agent, s_curr_expert], axis=0)) - self._discriminator_training(s_curr_expert, a_curr_expert, s_next_expert, - s_curr_agent, a_curr_agent, s_next_agent) + self._discriminator_training( + s_curr_expert, a_curr_expert, s_next_expert, s_curr_agent, a_curr_agent, s_next_agent + ) def _align_policy_experiences(self, buffer_iterator): v_target_batch, adv_batch = self._compute_v_target_and_advantage(buffer_iterator) s_batch, a_batch, _ = self._align_state_and_action(buffer_iterator) - return s_batch[:self._config.num_steps_per_iteration], \ - a_batch[:self._config.num_steps_per_iteration], \ - v_target_batch[:self._config.num_steps_per_iteration], \ - adv_batch[:self._config.num_steps_per_iteration] + return ( + s_batch[: self._config.num_steps_per_iteration], + a_batch[: self._config.num_steps_per_iteration], + v_target_batch[: self._config.num_steps_per_iteration], + adv_batch[: self._config.num_steps_per_iteration], + ) def _compute_v_target_and_advantage(self, buffer_iterator): v_target_batch = [] @@ -435,7 +464,8 @@ def _compute_v_target_and_advantage(self, buffer_iterator): for experiences, *_ in buffer_iterator: # length of experiences is 1 v_target, adv = compute_v_target_and_advantage( - self._v_function, experiences[0], gamma=self._config.gamma, lmb=self._config.lmb) + self._v_function, experiences[0], gamma=self._config.gamma, lmb=self._config.lmb + ) v_target_batch.append(v_target.reshape(-1, 1)) adv_batch.append(adv.reshape(-1, 1)) @@ -476,7 +506,8 @@ def _align_discriminator_experiences(self, buffer_iterator): s_expert_batch, a_expert_batch, _, _, s_next_expert_batch, *_ = marshal_experiences(expert_experience) # sample agent data s_batch, a_batch, s_next_batch = self._align_state_and_action( - buffer_iterator, batch_size=self._config.discriminator_batch_size) + buffer_iterator, batch_size=self._config.discriminator_batch_size + ) return s_expert_batch, a_expert_batch, s_next_expert_batch, s_batch, a_batch, s_next_batch @@ -485,34 +516,36 @@ def _v_function_training(self, s, v_target): for _ in range(self._config.vf_epochs * num_iterations_per_epoch): indices = np.random.randint(0, self._config.num_steps_per_iteration, size=self._config.vf_batch_size) - batch = TrainingBatch(batch_size=self._config.vf_batch_size, - s_current=s[indices], - extra={'v_target': v_target[indices]}) + batch = TrainingBatch( + batch_size=self._config.vf_batch_size, s_current=s[indices], extra={"v_target": v_target[indices]} + ) self._v_function_trainer_state = self._v_function_trainer.train(batch) def _policy_training(self, s, a, v_target, advantage): extra = {} - extra['v_target'] = v_target[:self._config.pi_batch_size] - extra['advantage'] = advantage[:self._config.pi_batch_size] - batch = TrainingBatch(batch_size=self._config.pi_batch_size, - s_current=s[:self._config.pi_batch_size], - a_current=a[:self._config.pi_batch_size], - extra=extra) + extra["v_target"] = v_target[: self._config.pi_batch_size] + extra["advantage"] = advantage[: self._config.pi_batch_size] + batch = TrainingBatch( + batch_size=self._config.pi_batch_size, + s_current=s[: self._config.pi_batch_size], + a_current=a[: self._config.pi_batch_size], + extra=extra, + ) self._policy_trainer_state = self._policy_trainer.train(batch) - def _discriminator_training(self, s_curr_expert, a_curr_expert, s_next_expert, - s_curr_agent, a_curr_agent, s_next_agent): + def _discriminator_training( + self, s_curr_expert, a_curr_expert, s_next_expert, s_curr_agent, a_curr_agent, s_next_agent + ): extra = {} - extra['s_current_agent'] = s_curr_agent[:self._config.discriminator_batch_size] - extra['a_current_agent'] = a_curr_agent[:self._config.discriminator_batch_size] - extra['s_next_agent'] = s_next_agent[:self._config.discriminator_batch_size] - extra['s_current_expert'] = s_curr_expert[:self._config.discriminator_batch_size] - extra['a_current_expert'] = a_curr_expert[:self._config.discriminator_batch_size] - extra['s_next_expert'] = s_next_expert[:self._config.discriminator_batch_size] + extra["s_current_agent"] = s_curr_agent[: self._config.discriminator_batch_size] + extra["a_current_agent"] = a_curr_agent[: self._config.discriminator_batch_size] + extra["s_next_agent"] = s_next_agent[: self._config.discriminator_batch_size] + extra["s_current_expert"] = s_curr_expert[: self._config.discriminator_batch_size] + extra["a_current_expert"] = a_curr_expert[: self._config.discriminator_batch_size] + extra["s_next_expert"] = s_next_expert[: self._config.discriminator_batch_size] - batch = TrainingBatch(batch_size=self._config.discriminator_batch_size, - extra=extra) + batch = TrainingBatch(batch_size=self._config.discriminator_batch_size, extra=extra) self._discriminator_trainer_state = self._discriminator_trainer.train(batch) @@ -541,18 +574,20 @@ def _solvers(self): @classmethod def is_supported_env(cls, env_or_env_info): - env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \ - else env_or_env_info + env_info = ( + EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info + ) return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env() @property def latest_iteration_state(self): latest_iteration_state = super(GAIL, self).latest_iteration_state - if hasattr(self, '_discriminator_trainer_state'): - latest_iteration_state['scalar'].update( - {'reward_loss': float(self._discriminator_trainer_state['reward_loss'])}) - if hasattr(self, '_v_function_trainer_state'): - latest_iteration_state['scalar'].update({'v_loss': float(self._v_function_trainer_state['v_loss'])}) + if hasattr(self, "_discriminator_trainer_state"): + latest_iteration_state["scalar"].update( + {"reward_loss": float(self._discriminator_trainer_state["reward_loss"])} + ) + if hasattr(self, "_v_function_trainer_state"): + latest_iteration_state["scalar"].update({"v_loss": float(self._v_function_trainer_state["v_loss"])}) return latest_iteration_state @property diff --git a/nnabla_rl/algorithms/her.py b/nnabla_rl/algorithms/her.py index 0ff0c0c6..38f86087 100644 --- a/nnabla_rl/algorithms/her.py +++ b/nnabla_rl/algorithms/her.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,8 +25,11 @@ import nnabla_rl.model_trainers as MT import nnabla_rl.preprocessors as RP from nnabla_rl.algorithms import DDPG, DDPGConfig -from nnabla_rl.algorithms.common_utils import (_DeterministicPolicyActionSelector, - _StatePreprocessedDeterministicPolicy, _StatePreprocessedQFunction) +from nnabla_rl.algorithms.common_utils import ( + _DeterministicPolicyActionSelector, + _StatePreprocessedDeterministicPolicy, + _StatePreprocessedQFunction, +) from nnabla_rl.builders import ModelBuilder, PreprocessorBuilder, ReplayBufferBuilder, SolverBuilder from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.model_trainers.model_trainer import TrainingBatch @@ -85,49 +88,59 @@ class HERConfig(DDPGConfig): class HERActorBuilder(ModelBuilder[DeterministicPolicy]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: HERConfig, - **kwargs) -> DeterministicPolicy: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: HERConfig, + **kwargs, + ) -> DeterministicPolicy: max_action_value = float(env_info.action_high[0]) return HERPolicy(scope_name, env_info.action_dim, max_action_value=max_action_value) class HERCriticBuilder(ModelBuilder[QFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: HERConfig, - **kwargs) -> QFunction: - target_policy = kwargs.get('target_policy') + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: HERConfig, + **kwargs, + ) -> QFunction: + target_policy = kwargs.get("target_policy") return HERQFunction(scope_name, optimal_policy=target_policy) class HERPreprocessorBuilder(PreprocessorBuilder): - def build_preprocessor(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: HERConfig, - **kwargs) -> Preprocessor: - return RP.HERPreprocessor('preprocessor', env_info.state_shape, - epsilon=algorithm_config.normalize_epsilon, - value_clip=algorithm_config.normalize_clip_range) + def build_preprocessor( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: HERConfig, + **kwargs, + ) -> Preprocessor: + return RP.HERPreprocessor( + "preprocessor", + env_info.state_shape, + epsilon=algorithm_config.normalize_epsilon, + value_clip=algorithm_config.normalize_clip_range, + ) class HERSolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: HERConfig, - **kwargs) -> nn.solver.Solver: + def build_solver( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: HERConfig, **kwargs + ) -> nn.solver.Solver: return NS.Adam(alpha=algorithm_config.learning_rate) class HindsightReplayBufferBuilder(ReplayBufferBuilder): def __call__(self, env_info, algorithm_config, **kwargs): - return HindsightReplayBuffer(reward_function=env_info.reward_function, - hindsight_prob=algorithm_config.hindsight_prob, - capacity=algorithm_config.replay_buffer_size) + return HindsightReplayBuffer( + reward_function=env_info.reward_function, + hindsight_prob=algorithm_config.hindsight_prob, + capacity=algorithm_config.replay_buffer_size, + ) class HER(DDPG): @@ -157,6 +170,7 @@ class HER(DDPG): replay_buffer_builder (:py:class:`ReplayBufferBuilder `): builder of replay_buffer """ + _config: HERConfig _q: QFunction _q_solver: nn.solver.Solver @@ -167,25 +181,30 @@ class HER(DDPG): _state_preprocessor: Optional[Preprocessor] _replay_buffer: HindsightReplayBuffer - def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: HERConfig = HERConfig(), - critic_builder: ModelBuilder[QFunction] = HERCriticBuilder(), - critic_solver_builder: SolverBuilder = HERSolverBuilder(), - actor_builder: ModelBuilder[DeterministicPolicy] = HERActorBuilder(), - actor_solver_builder: SolverBuilder = HERSolverBuilder(), - state_preprocessor_builder: Optional[PreprocessorBuilder] = HERPreprocessorBuilder(), - replay_buffer_builder: ReplayBufferBuilder = HindsightReplayBufferBuilder()): - - super(HER, self).__init__(env_or_env_info=env_or_env_info, - config=config, - critic_builder=critic_builder, - critic_solver_builder=critic_solver_builder, - actor_builder=actor_builder, - actor_solver_builder=actor_solver_builder, - replay_buffer_builder=replay_buffer_builder) + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: HERConfig = HERConfig(), + critic_builder: ModelBuilder[QFunction] = HERCriticBuilder(), + critic_solver_builder: SolverBuilder = HERSolverBuilder(), + actor_builder: ModelBuilder[DeterministicPolicy] = HERActorBuilder(), + actor_solver_builder: SolverBuilder = HERSolverBuilder(), + state_preprocessor_builder: Optional[PreprocessorBuilder] = HERPreprocessorBuilder(), + replay_buffer_builder: ReplayBufferBuilder = HindsightReplayBufferBuilder(), + ): + + super(HER, self).__init__( + env_or_env_info=env_or_env_info, + config=config, + critic_builder=critic_builder, + critic_solver_builder=critic_solver_builder, + actor_builder=actor_builder, + actor_solver_builder=actor_solver_builder, + replay_buffer_builder=replay_buffer_builder, + ) if self._config.preprocess_state and state_preprocessor_builder is not None: - preprocessor = state_preprocessor_builder('preprocessor', self._env_info, self._config) + preprocessor = state_preprocessor_builder("preprocessor", self._env_info, self._config) assert preprocessor is not None self._q = _StatePreprocessedQFunction(q_function=self._q, preprocessor=preprocessor) self._target_q = _StatePreprocessedQFunction(q_function=self._target_q, preprocessor=preprocessor) @@ -199,19 +218,22 @@ def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], def _setup_q_function_training(self, env_or_buffer): q_function_trainer_config = MT.q_value_trainers.HERQTrainerConfig( - reduction_method='mean', + reduction_method="mean", grad_clip=None, return_clip=self._config.return_clip, unroll_steps=self._config.critic_unroll_steps, burn_in_steps=self._config.critic_burn_in_steps, - reset_on_terminal=self._config.critic_reset_rnn_on_terminal) - - q_function_trainer = MT.q_value_trainers.HERQTrainer(train_functions=self._q, - solvers={self._q.scope_name: self._q_solver}, - target_functions=self._target_q, - target_policy=self._target_pi, - env_info=self._env_info, - config=q_function_trainer_config) + reset_on_terminal=self._config.critic_reset_rnn_on_terminal, + ) + + q_function_trainer = MT.q_value_trainers.HERQTrainer( + train_functions=self._q, + solvers={self._q.scope_name: self._q_solver}, + target_functions=self._target_q, + target_policy=self._target_pi, + env_info=self._env_info, + config=q_function_trainer_config, + ) sync_model(self._q, self._target_q) return q_function_trainer @@ -220,13 +242,16 @@ def _setup_policy_training(self, env_or_buffer): action_loss_coef=self._config.action_loss_coef, unroll_steps=self._config.actor_unroll_steps, burn_in_steps=self._config.actor_burn_in_steps, - reset_on_terminal=self._config.actor_reset_rnn_on_terminal) + reset_on_terminal=self._config.actor_reset_rnn_on_terminal, + ) - policy_trainer = MT.policy_trainers.HERPolicyTrainer(models=self._pi, - solvers={self._pi.scope_name: self._pi_solver}, - q_function=self._q, - env_info=self._env_info, - config=policy_trainer_config) + policy_trainer = MT.policy_trainers.HERPolicyTrainer( + models=self._pi, + solvers={self._pi.scope_name: self._pi_solver}, + q_function=self._q, + env_info=self._env_info, + config=policy_trainer_config, + ) sync_model(self._pi, self._target_pi) return policy_trainer @@ -290,28 +315,30 @@ def _her_training(self, replay_buffer): for i in range(self._config.n_update): experiences_tuple, info = replay_buffer.sample(self._config.batch_size, num_steps=num_steps) if num_steps == 1: - experiences_tuple = (experiences_tuple, ) + experiences_tuple = (experiences_tuple,) assert len(experiences_tuple) == num_steps batch = None for experiences in reversed(experiences_tuple): (s, a, r, non_terminal, s_next, rnn_states_dict, *_) = marshal_experiences(experiences) - rnn_states = rnn_states_dict['rnn_states'] if 'rnn_states' in rnn_states_dict else {} - batch = TrainingBatch(batch_size=self._config.batch_size, - s_current=s, - a_current=a, - gamma=self._config.gamma, - reward=r, - non_terminal=non_terminal, - s_next=s_next, - weight=info['weights'], - next_step_batch=batch, - rnn_states=rnn_states) + rnn_states = rnn_states_dict["rnn_states"] if "rnn_states" in rnn_states_dict else {} + batch = TrainingBatch( + batch_size=self._config.batch_size, + s_current=s, + a_current=a, + gamma=self._config.gamma, + reward=r, + non_terminal=non_terminal, + s_next=s_next, + weight=info["weights"], + next_step_batch=batch, + rnn_states=rnn_states, + ) self._q_function_trainer_state = self._q_function_trainer.train(batch) self._policy_trainer_state = self._policy_trainer.train(batch) - td_errors = np.abs(self._q_function_trainer_state['td_errors']) + td_errors = np.abs(self._q_function_trainer_state["td_errors"]) replay_buffer.update_priorities(td_errors) # target update @@ -346,8 +373,9 @@ def _models(self): @classmethod def is_supported_env(cls, env_or_env_info): - env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \ - else env_or_env_info + env_info = ( + EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info + ) # continuous action env is_continuous_action_env = env_info.is_continuous_action_env() diff --git a/nnabla_rl/algorithms/hyar.py b/nnabla_rl/algorithms/hyar.py index 52483279..059e4a2a 100644 --- a/nnabla_rl/algorithms/hyar.py +++ b/nnabla_rl/algorithms/hyar.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -118,127 +118,130 @@ class HyARConfig(TD3Config): vae_buffer_size: int = int(2e6) latent_select_batch_size: int = 5000 - latent_select_range: float = 96. + latent_select_range: float = 96.0 noise_decay_steps: int = 1000 initial_exploration_noise: float = 1.0 final_exploration_noise: float = 0.1 def __post_init__(self): - self._assert_positive(self.latent_dim, 'latent_dim') - self._assert_positive(self.embed_dim, 'embed_dim') - self._assert_positive(self.T, 'T') - self._assert_positive_or_zero(self.vae_pretrain_episodes, 'vae_pretrain_episodes') - self._assert_positive(self.vae_pretrain_batch_size, 'vae_pretrain_batch_size') - self._assert_positive_or_zero(self.vae_pretrain_times, 'vae_pretrain_times') - self._assert_positive(self.vae_training_batch_size, 'vae_training_batch_size') - self._assert_positive_or_zero(self.vae_training_times, 'vae_training_times') - self._assert_positive_or_zero(self.vae_learning_rate, 'vae_learning_rate') - self._assert_positive(self.vae_buffer_size, 'vae_buffer_size') - self._assert_positive(self.latent_select_batch_size, 'latent_select_batch_size') - self._assert_between(self.latent_select_range, 0, 100, 'latent_select_range') - self._assert_positive(self.noise_decay_steps, 'noise_decay_steps') - self._assert_positive(self.initial_exploration_noise, 'initial_exploration_noise') - self._assert_positive(self.final_exploration_noise, 'final_exploration_noise') + self._assert_positive(self.latent_dim, "latent_dim") + self._assert_positive(self.embed_dim, "embed_dim") + self._assert_positive(self.T, "T") + self._assert_positive_or_zero(self.vae_pretrain_episodes, "vae_pretrain_episodes") + self._assert_positive(self.vae_pretrain_batch_size, "vae_pretrain_batch_size") + self._assert_positive_or_zero(self.vae_pretrain_times, "vae_pretrain_times") + self._assert_positive(self.vae_training_batch_size, "vae_training_batch_size") + self._assert_positive_or_zero(self.vae_training_times, "vae_training_times") + self._assert_positive_or_zero(self.vae_learning_rate, "vae_learning_rate") + self._assert_positive(self.vae_buffer_size, "vae_buffer_size") + self._assert_positive(self.latent_select_batch_size, "latent_select_batch_size") + self._assert_between(self.latent_select_range, 0, 100, "latent_select_range") + self._assert_positive(self.noise_decay_steps, "noise_decay_steps") + self._assert_positive(self.initial_exploration_noise, "initial_exploration_noise") + self._assert_positive(self.final_exploration_noise, "final_exploration_noise") return super().__post_init__() class DefaultCriticBuilder(ModelBuilder[QFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: HyARConfig, - **kwargs) -> QFunction: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: HyARConfig, + **kwargs, + ) -> QFunction: return HyARQFunction(scope_name) class DefaultActorBuilder(ModelBuilder[DeterministicPolicy]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: HyARConfig, - **kwargs) -> DeterministicPolicy: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: HyARConfig, + **kwargs, + ) -> DeterministicPolicy: max_action_value = 1.0 action_dim = algorithm_config.latent_dim + algorithm_config.embed_dim return HyARPolicy(scope_name, action_dim, max_action_value=max_action_value) class DefaultVAEBuilder(ModelBuilder[HyARVAE]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: HyARConfig, - **kwargs) -> HyARVAE: - return HyARVAE(scope_name, - state_dim=env_info.state_dim, - action_dim=env_info.action_dim, - encode_dim=algorithm_config.latent_dim, - embed_dim=algorithm_config.embed_dim) + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: HyARConfig, + **kwargs, + ) -> HyARVAE: + return HyARVAE( + scope_name, + state_dim=env_info.state_dim, + action_dim=env_info.action_dim, + encode_dim=algorithm_config.latent_dim, + embed_dim=algorithm_config.embed_dim, + ) class DefaultActorSolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: HyARConfig, - **kwargs) -> nn.solver.Solver: + def build_solver( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: HyARConfig, **kwargs + ) -> nn.solver.Solver: solver = NS.Adam(alpha=algorithm_config.learning_rate) return AutoClipGradByNorm(solver, 10.0) class DefaultVAESolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: HyARConfig, - **kwargs) -> nn.solver.Solver: + def build_solver( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: HyARConfig, **kwargs + ) -> nn.solver.Solver: return NS.Adam(alpha=algorithm_config.vae_learning_rate) class DefaultExplorerBuilder(ExplorerBuilder): - def build_explorer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: HyARConfig, - algorithm: "HyAR", - **kwargs) -> EnvironmentExplorer: + def build_explorer( # type: ignore[override] + self, + env_info: EnvironmentInfo, + algorithm_config: HyARConfig, + algorithm: "HyAR", + **kwargs, + ) -> EnvironmentExplorer: explorer_config = HyARPolicyExplorerConfig( - warmup_random_steps=0, - initial_step_num=algorithm.iteration_num, - timelimit_as_terminal=False + warmup_random_steps=0, initial_step_num=algorithm.iteration_num, timelimit_as_terminal=False + ) + explorer = HyARPolicyExplorer( + policy_action_selector=algorithm._exploration_action_selector, env_info=env_info, config=explorer_config ) - explorer = HyARPolicyExplorer(policy_action_selector=algorithm._exploration_action_selector, - env_info=env_info, - config=explorer_config) return explorer class DefaultPretrainExplorerBuilder(ExplorerBuilder): - def build_explorer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: HyARConfig, - algorithm: "HyAR", - **kwargs) -> EnvironmentExplorer: + def build_explorer( # type: ignore[override] + self, + env_info: EnvironmentInfo, + algorithm_config: HyARConfig, + algorithm: "HyAR", + **kwargs, + ) -> EnvironmentExplorer: explorer_config = HyARPretrainExplorerConfig( - warmup_random_steps=0, - initial_step_num=algorithm.iteration_num, - timelimit_as_terminal=False + warmup_random_steps=0, initial_step_num=algorithm.iteration_num, timelimit_as_terminal=False ) - explorer = HyARPretrainExplorer(env_info=env_info, - config=explorer_config) + explorer = HyARPretrainExplorer(env_info=env_info, config=explorer_config) return explorer class DefaultReplayBufferBuilder(ReplayBufferBuilder): - def build_replay_buffer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: HyARConfig, - **kwargs) -> ReplayBuffer: + def build_replay_buffer( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: HyARConfig, **kwargs + ) -> ReplayBuffer: return ReplacementSamplingReplayBuffer(capacity=algorithm_config.replay_buffer_size) class DefaultVAEBufferBuilder(ReplayBufferBuilder): - def build_replay_buffer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: HyARConfig, - **kwargs) -> ReplayBuffer: + def build_replay_buffer( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: HyARConfig, **kwargs + ) -> ReplayBuffer: return ReplacementSamplingReplayBuffer(capacity=algorithm_config.vae_buffer_size) @@ -272,6 +275,7 @@ class HyAR(TD3): pretrain_explorer_builder (:py:class:`ExplorerBuilder `): builder of environment explorer for pretraining stage """ + # type declarations to type check with mypy # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar # See https://mypy.readthedocs.io/en/stable/class_basics.html for details @@ -279,30 +283,34 @@ class HyAR(TD3): _evaluation_actor: "_HyARPolicyActionSelector" # type: ignore _exploration_actor: "_HyARPolicyActionSelector" # type: ignore - def __init__(self, - env_or_env_info, - config: HyARConfig = HyARConfig(), - critic_builder=DefaultCriticBuilder(), - critic_solver_builder=DefaultSolverBuilder(), - actor_builder=DefaultActorBuilder(), - actor_solver_builder=DefaultActorSolverBuilder(), - vae_builder=DefaultVAEBuilder(), - vae_solver_buidler=DefaultVAESolverBuilder(), - replay_buffer_builder=DefaultReplayBufferBuilder(), - vae_buffer_builder=DefaultVAEBufferBuilder(), - explorer_builder=DefaultExplorerBuilder(), - pretrain_explorer_builder=DefaultPretrainExplorerBuilder()): - super().__init__(env_or_env_info, - config, - critic_builder, - critic_solver_builder, - actor_builder, - actor_solver_builder, - replay_buffer_builder, - explorer_builder) + def __init__( + self, + env_or_env_info, + config: HyARConfig = HyARConfig(), + critic_builder=DefaultCriticBuilder(), + critic_solver_builder=DefaultSolverBuilder(), + actor_builder=DefaultActorBuilder(), + actor_solver_builder=DefaultActorSolverBuilder(), + vae_builder=DefaultVAEBuilder(), + vae_solver_buidler=DefaultVAESolverBuilder(), + replay_buffer_builder=DefaultReplayBufferBuilder(), + vae_buffer_builder=DefaultVAEBufferBuilder(), + explorer_builder=DefaultExplorerBuilder(), + pretrain_explorer_builder=DefaultPretrainExplorerBuilder(), + ): + super().__init__( + env_or_env_info, + config, + critic_builder, + critic_solver_builder, + actor_builder, + actor_solver_builder, + replay_buffer_builder, + explorer_builder, + ) with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)): - self._vae = vae_builder('vae', self._env_info, self._config) + self._vae = vae_builder("vae", self._env_info, self._config) self._vae_solver = vae_solver_buidler(self._env_info, self._config) # We use different replay buffer for vae self._vae_replay_buffer = vae_buffer_builder(env_info=self._env_info, algorithm_config=self._config) @@ -313,7 +321,8 @@ def __init__(self, self._pi.shallowcopy(), self._vae.shallowcopy(), embed_dim=self._config.embed_dim, - latent_dim=self._config.latent_dim) + latent_dim=self._config.latent_dim, + ) self._exploration_actor = _HyARPolicyActionSelector( self._env_info, self._pi.shallowcopy(), @@ -323,7 +332,8 @@ def __init__(self, append_noise=True, sigma=self._config.exploration_noise_sigma, action_clip_low=-1.0, - action_clip_high=1.0) + action_clip_high=1.0, + ) self._episode_number = 1 self._experienced_episodes = 0 @@ -337,7 +347,7 @@ def _before_training_start(self, env_or_buffer): def _setup_q_function_training(self, env_or_buffer): # training input/loss variables q_function_trainer_config = MT.q_value_trainers.HyARQTrainerConfig( - reduction_method='mean', + reduction_method="mean", q_loss_scalar=1.0, grad_clip=None, train_action_noise_sigma=self._config.train_action_noise_sigma, @@ -349,7 +359,8 @@ def _setup_q_function_training(self, env_or_buffer): burn_in_steps=self._config.critic_burn_in_steps, reset_on_terminal=self._config.critic_reset_rnn_on_terminal, embed_dim=self._config.embed_dim, - latent_dim=self._config.latent_dim) + latent_dim=self._config.latent_dim, + ) q_function_trainer = MT.q_value_trainers.HyARQTrainer( train_functions=self._train_q_functions, solvers=self._train_q_solvers, @@ -357,26 +368,29 @@ def _setup_q_function_training(self, env_or_buffer): target_policy=self._target_pi, vae=self._vae, env_info=self._env_info, - config=q_function_trainer_config) + config=q_function_trainer_config, + ) for q, target_q in zip(self._train_q_functions, self._target_q_functions): sync_model(q, target_q) return q_function_trainer def _setup_policy_training(self, env_or_buffer): # return super()._setup_policy_training(env_or_buffer) - action_dim = self._config.latent_dim+self._config.embed_dim + action_dim = self._config.latent_dim + self._config.embed_dim policy_trainer_config = MT.policy_trainers.HyARPolicyTrainerConfig( unroll_steps=self._config.actor_unroll_steps, burn_in_steps=self._config.actor_burn_in_steps, reset_on_terminal=self._config.actor_reset_rnn_on_terminal, p_max=np.ones(shape=(1, action_dim)), - p_min=-np.ones(shape=(1, action_dim))) + p_min=-np.ones(shape=(1, action_dim)), + ) policy_trainer = MT.policy_trainers.HyARPolicyTrainer( models=self._pi, solvers={self._pi.scope_name: self._pi_solver}, q_function=self._q1, env_info=self._env_info, - config=policy_trainer_config) + config=policy_trainer_config, + ) sync_model(self._pi, self._target_pi, 1.0) return policy_trainer @@ -385,16 +399,18 @@ def _setup_vae_training(self, env_or_buffer): vae_trainer_config = MT.encoder_trainers.HyARVAETrainerConfig( unroll_steps=self._config.critic_unroll_steps, burn_in_steps=self._config.critic_burn_in_steps, - reset_on_terminal=self._config.critic_reset_rnn_on_terminal) - return MT.encoder_trainers.HyARVAETrainer(self._vae, - {self._vae.scope_name: self._vae_solver}, - self._env_info, - vae_trainer_config) + reset_on_terminal=self._config.critic_reset_rnn_on_terminal, + ) + return MT.encoder_trainers.HyARVAETrainer( + self._vae, {self._vae.scope_name: self._vae_solver}, self._env_info, vae_trainer_config + ) def _setup_pretrain_explorer(self, env_or_buffer): - return None if self._is_buffer(env_or_buffer) else self._pretrain_explorer_builder(self._env_info, - self._config, - self) + return ( + None + if self._is_buffer(env_or_buffer) + else self._pretrain_explorer_builder(self._env_info, self._config, self) + ) def _pretrain_vae(self, env: gym.Env): for _ in range(self._config.vae_pretrain_episodes): @@ -415,13 +431,15 @@ def _run_online_training_iteration(self, env): self._vae_replay_buffer.append_all(experiences) (_, _, _, non_terminal, *_) = experiences[-1] - end_of_episode = (non_terminal == 0.0) + end_of_episode = non_terminal == 0.0 if end_of_episode: self._experienced_episodes += 1 - if (self._experienced_episodes < self._config.noise_decay_steps): + if self._experienced_episodes < self._config.noise_decay_steps: ratio = self._experienced_episodes / self._config.noise_decay_steps - new_sigma = self._config.initial_exploration_noise * (1.0 - ratio) \ + new_sigma = ( + self._config.initial_exploration_noise * (1.0 - ratio) + self._config.final_exploration_noise * ratio + ) self._exploration_actor.update_sigma(sigma=new_sigma) else: self._exploration_actor.update_sigma(sigma=self._config.final_exploration_noise) @@ -448,28 +466,30 @@ def _rl_training(self, replay_buffer): num_steps = max(actor_steps, critic_steps) experiences_tuple, info = replay_buffer.sample(self._config.batch_size, num_steps=num_steps) if num_steps == 1: - experiences_tuple = (experiences_tuple, ) + experiences_tuple = (experiences_tuple,) assert len(experiences_tuple) == num_steps batch = None for experiences in reversed(experiences_tuple): (s, a, r, non_terminal, s_next, extra, *_) = marshal_experiences(experiences) - rnn_states = extra['rnn_states'] if 'rnn_states' in extra else {} - extra.update({'c_rate': self._c_rate, 'ds_rate': self._ds_rate}) - batch = TrainingBatch(batch_size=self._config.batch_size, - s_current=s, - a_current=a, - gamma=self._config.gamma, - reward=r, - non_terminal=non_terminal, - s_next=s_next, - extra=extra, - weight=info['weights'], - next_step_batch=batch, - rnn_states=rnn_states) + rnn_states = extra["rnn_states"] if "rnn_states" in extra else {} + extra.update({"c_rate": self._c_rate, "ds_rate": self._ds_rate}) + batch = TrainingBatch( + batch_size=self._config.batch_size, + s_current=s, + a_current=a, + gamma=self._config.gamma, + reward=r, + non_terminal=non_terminal, + s_next=s_next, + extra=extra, + weight=info["weights"], + next_step_batch=batch, + rnn_states=rnn_states, + ) self._q_function_trainer_state = self._q_function_trainer.train(batch) - td_errors = self._q_function_trainer_state['td_errors'] + td_errors = self._q_function_trainer_state["td_errors"] replay_buffer.update_priorities(td_errors) if self.iteration_num % self._config.d == 0: @@ -487,24 +507,26 @@ def _vae_training(self, replay_buffer, batch_size): num_steps = max(actor_steps, critic_steps) experiences_tuple, info = replay_buffer.sample(batch_size, num_steps=num_steps) if num_steps == 1: - experiences_tuple = (experiences_tuple, ) + experiences_tuple = (experiences_tuple,) assert len(experiences_tuple) == num_steps batch = None for experiences in reversed(experiences_tuple): (s, a, r, non_terminal, s_next, extra, *_) = marshal_experiences(experiences) - rnn_states = extra['rnn_states'] if 'rnn_states' in extra else {} - batch = TrainingBatch(batch_size=batch_size, - s_current=s, - a_current=a, - gamma=self._config.gamma, - reward=r, - non_terminal=non_terminal, - s_next=s_next, - extra=extra, - weight=info['weights'], - next_step_batch=batch, - rnn_states=rnn_states) + rnn_states = extra["rnn_states"] if "rnn_states" in extra else {} + batch = TrainingBatch( + batch_size=batch_size, + s_current=s, + a_current=a, + gamma=self._config.gamma, + reward=r, + non_terminal=non_terminal, + s_next=s_next, + extra=extra, + weight=info["weights"], + next_step_batch=batch, + rnn_states=rnn_states, + ) self._vae_trainer_state = self._vae_trainer.train(batch) @@ -525,8 +547,9 @@ def _compute_reconstruction_rate(self, replay_buffer): experiences, _ = replay_buffer.sample(num_samples=batch_size) (s, a, _, _, s_next, *_) = marshal_experiences(experiences) - if not hasattr(self, '_rate_state_var'): + if not hasattr(self, "_rate_state_var"): from nnabla_rl.utils.misc import create_variable + self._rate_state_var = create_variable(batch_size, self._env_info.state_shape) self._rate_action_var = create_variable(batch_size, self._env_info.action_shape) self._rate_next_state_var = create_variable(batch_size, self._env_info.state_shape) @@ -534,7 +557,8 @@ def _compute_reconstruction_rate(self, replay_buffer): action1, action2 = self._rate_action_var x = action1 if isinstance(self._env_info.action_space[0], gym.spaces.Box) else action2 latent_distribution, (_, predicted_ds) = self._vae.encode_and_decode( - x=x, state=self._rate_state_var, action=self._rate_action_var) + x=x, state=self._rate_state_var, action=self._rate_action_var + ) z = latent_distribution.sample() # NOTE: ascending order z_sorted = NF.sort(z, axis=0) @@ -560,8 +584,9 @@ def _compute_reconstruction_rate(self, replay_buffer): @classmethod def is_supported_env(cls, env_or_env_info): - env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \ - else env_or_env_info + env_info = ( + EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info + ) return env_info.is_tuple_action_env() and not env_info.is_tuple_state_env() @classmethod @@ -571,28 +596,33 @@ def is_rnn_supported(self): @property def latest_iteration_state(self): latest_iteration_state = super().latest_iteration_state - if hasattr(self, '_vae_trainer_state'): - latest_iteration_state['scalar'].update( - {'encoder_loss': float(self._vae_trainer_state['encoder_loss']), - 'kl_loss': float(self._vae_trainer_state['kl_loss']), - 'reconstruction_loss': float(self._vae_trainer_state['reconstruction_loss']), - 'dyn_loss': float(self._vae_trainer_state['dyn_loss'])}) + if hasattr(self, "_vae_trainer_state"): + latest_iteration_state["scalar"].update( + { + "encoder_loss": float(self._vae_trainer_state["encoder_loss"]), + "kl_loss": float(self._vae_trainer_state["kl_loss"]), + "reconstruction_loss": float(self._vae_trainer_state["reconstruction_loss"]), + "dyn_loss": float(self._vae_trainer_state["dyn_loss"]), + } + ) return latest_iteration_state class _HyARPolicyActionSelector(_ActionSelector[DeterministicPolicy]): _vae: HyARVAE - def __init__(self, - env_info: EnvironmentInfo, - model: DeterministicPolicy, - vae: HyARVAE, - embed_dim: int, - latent_dim: int, - append_noise: bool = False, - action_clip_low: float = np.finfo(np.float32).min, # type: ignore - action_clip_high: float = np.finfo(np.float32).max, # type: ignore - sigma: float = 1.0): + def __init__( + self, + env_info: EnvironmentInfo, + model: DeterministicPolicy, + vae: HyARVAE, + embed_dim: int, + latent_dim: int, + append_noise: bool = False, + action_clip_low: float = np.finfo(np.float32).min, # type: ignore + action_clip_high: float = np.finfo(np.float32).max, # type: ignore + sigma: float = 1.0, + ): super().__init__(env_info, model) self._vae = vae self._embed_dim = embed_dim @@ -617,7 +647,7 @@ def __call__(self, s, *, begin_of_episode=False, extra_info={}): # self._e.d[0] and self._z.d[0] e = self._e.d[0] z = self._z.d[0] - info.update({'e': e, 'z': z}) + info.update({"e": e, "z": z}) (d_action, c_action) = action return (d_action, c_action), info @@ -634,9 +664,9 @@ def _compute_action(self, state_var: nn.Variable) -> nn.Variable: noise = NF.randn(shape=latent_action.shape) latent_action = latent_action + noise * self._sigma latent_action = NF.clip_by_value(latent_action, min=self._action_clip_low, max=self._action_clip_high) - self._e = latent_action[:, :self._embed_dim] + self._e = latent_action[:, : self._embed_dim] self._e.persistent = True - self._z = latent_action[:, self._embed_dim:] + self._z = latent_action[:, self._embed_dim :] self._z.persistent = True assert latent_action.shape[-1] == self._embed_dim + self._latent_dim @@ -668,9 +698,7 @@ class HyARPretrainExplorerConfig(EnvironmentExplorerConfig): class HyARPretrainExplorer(EnvironmentExplorer): - def __init__(self, - env_info: EnvironmentInfo, - config: HyARPretrainExplorerConfig = HyARPretrainExplorerConfig()): + def __init__(self, env_info: EnvironmentInfo, config: HyARPretrainExplorerConfig = HyARPretrainExplorerConfig()): super().__init__(env_info, config) def action(self, step: int, state, *, begin_of_episode: bool = False): @@ -687,13 +715,13 @@ def _sample_action(self, env_info): action = [] for a, action_space in zip(env_info.action_space.sample(), env_info.action_space): if isinstance(action_space, gym.spaces.Discrete): - a = np.asarray(a).reshape((1, )) + a = np.asarray(a).reshape((1,)) action.append(a) action = tuple(action) else: if env_info.is_discrete_action_env(): action = env_info.action_space.sample() - action = np.asarray(action).reshape((1, )) + action = np.asarray(action).reshape((1,)) else: action = env_info.action_space.sample() return action, action_info diff --git a/nnabla_rl/algorithms/icml2015_trpo.py b/nnabla_rl/algorithms/icml2015_trpo.py index 1875dbb8..28d89710 100644 --- a/nnabla_rl/algorithms/icml2015_trpo.py +++ b/nnabla_rl/algorithms/icml2015_trpo.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -55,6 +55,7 @@ class ICML2015TRPOConfig(AlgorithmConfig): conjugate_gradient_damping (float): Damping size of conjugate gradient method. Defaults to 0.1. conjugate_gradient_iterations (int): Number of iterations of conjugate gradient method. Defaults to 20. """ + gamma: float = 0.99 num_steps_per_iteration: int = int(1e5) batch_size: int = int(1e5) @@ -69,51 +70,52 @@ def __post_init__(self): Check the values are in valid range. """ - self._assert_between(self.gamma, 0.0, 1.0, 'gamma') - self._assert_between(self.batch_size, 0, self.num_steps_per_iteration, 'batch_size') - self._assert_positive(self.num_steps_per_iteration, 'num_steps_per_iteration') - self._assert_positive(self.sigma_kl_divergence_constraint, 'sigma_kl_divergence_constraint') - self._assert_positive(self.maximum_backtrack_numbers, 'maximum_backtrack_numbers') - self._assert_positive(self.conjugate_gradient_damping, 'conjugate_gradient_damping') + self._assert_between(self.gamma, 0.0, 1.0, "gamma") + self._assert_between(self.batch_size, 0, self.num_steps_per_iteration, "batch_size") + self._assert_positive(self.num_steps_per_iteration, "num_steps_per_iteration") + self._assert_positive(self.sigma_kl_divergence_constraint, "sigma_kl_divergence_constraint") + self._assert_positive(self.maximum_backtrack_numbers, "maximum_backtrack_numbers") + self._assert_positive(self.conjugate_gradient_damping, "conjugate_gradient_damping") class DefaultPolicyBuilder(ModelBuilder[StochasticPolicy]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: ICML2015TRPOConfig, - **kwargs) -> StochasticPolicy: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: ICML2015TRPOConfig, + **kwargs, + ) -> StochasticPolicy: if env_info.is_discrete_action_env(): return self._build_default_discrete_policy(scope_name, env_info, algorithm_config) else: return self._build_default_continuous_policy(scope_name, env_info, algorithm_config) - def _build_default_continuous_policy(self, - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: ICML2015TRPOConfig, - **kwargs) -> StochasticPolicy: + def _build_default_continuous_policy( + self, scope_name: str, env_info: EnvironmentInfo, algorithm_config: ICML2015TRPOConfig, **kwargs + ) -> StochasticPolicy: return ICML2015TRPOMujocoPolicy(scope_name, env_info.action_dim) - def _build_default_discrete_policy(self, - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: ICML2015TRPOConfig, - **kwargs) -> StochasticPolicy: + def _build_default_discrete_policy( + self, scope_name: str, env_info: EnvironmentInfo, algorithm_config: ICML2015TRPOConfig, **kwargs + ) -> StochasticPolicy: return ICML2015TRPOAtariPolicy(scope_name, env_info.action_dim) class DefaultExplorerBuilder(ExplorerBuilder): - def build_explorer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: ICML2015TRPOConfig, - algorithm: "ICML2015TRPO", - **kwargs) -> EnvironmentExplorer: - explorer_config = EE.RawPolicyExplorerConfig(initial_step_num=algorithm.iteration_num, - timelimit_as_terminal=False) - explorer = EE.RawPolicyExplorer(policy_action_selector=algorithm._exploration_action_selector, - env_info=env_info, - config=explorer_config) + def build_explorer( # type: ignore[override] + self, + env_info: EnvironmentInfo, + algorithm_config: ICML2015TRPOConfig, + algorithm: "ICML2015TRPO", + **kwargs, + ) -> EnvironmentExplorer: + explorer_config = EE.RawPolicyExplorerConfig( + initial_step_num=algorithm.iteration_num, timelimit_as_terminal=False + ) + explorer = EE.RawPolicyExplorer( + policy_action_selector=algorithm._exploration_action_selector, env_info=env_info, config=explorer_config + ) return explorer @@ -150,10 +152,13 @@ class ICML2015TRPO(Algorithm): _evaluation_actor: _StochasticPolicyActionSelector _exploration_actor: _StochasticPolicyActionSelector - def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: ICML2015TRPOConfig = ICML2015TRPOConfig(), - policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), - explorer_builder: ExplorerBuilder = DefaultExplorerBuilder()): + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: ICML2015TRPOConfig = ICML2015TRPOConfig(), + policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), + explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), + ): super(ICML2015TRPO, self).__init__(env_or_env_info, config=config) self._explorer_builder = explorer_builder @@ -162,9 +167,11 @@ def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], self._policy = policy_builder("pi", self._env_info, self._config) self._evaluation_actor = _StochasticPolicyActionSelector( - self._env_info, self._policy.shallowcopy(), deterministic=False) + self._env_info, self._policy.shallowcopy(), deterministic=False + ) self._exploration_actor = _StochasticPolicyActionSelector( - self._env_info, self._policy.shallowcopy(), deterministic=False) + self._env_info, self._policy.shallowcopy(), deterministic=False + ) @eval_api def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}): @@ -187,11 +194,11 @@ def _setup_policy_training(self, env_or_buffer): sigma_kl_divergence_constraint=self._config.sigma_kl_divergence_constraint, maximum_backtrack_numbers=self._config.maximum_backtrack_numbers, conjugate_gradient_damping=self._config.conjugate_gradient_damping, - conjugate_gradient_iterations=self._config.conjugate_gradient_iterations) + conjugate_gradient_iterations=self._config.conjugate_gradient_iterations, + ) policy_trainer = MT.policy_trainers.TRPOPolicyTrainer( - model=self._policy, - env_info=self._env_info, - config=policy_trainer_config) + model=self._policy, env_info=self._env_info, config=policy_trainer_config + ) return policy_trainer @@ -217,11 +224,8 @@ def _trpo_training(self, buffer): s, a, accumulated_reward = self._align_experiences(buffer_iterator) extra = {} - extra['advantage'] = accumulated_reward # Use accumulated_reward as advantage - batch = TrainingBatch(batch_size=self._config.batch_size, - s_current=s, - a_current=a, - extra=extra) + extra["advantage"] = accumulated_reward # Use accumulated_reward as advantage + batch = TrainingBatch(batch_size=self._config.batch_size, s_current=s, a_current=a, extra=extra) self._policy_trainer_state = self._policy_trainer.train(batch) @@ -243,25 +247,25 @@ def _align_experiences(self, buffer_iterator): a_batch = np.concatenate(a_batch, axis=0) accumulated_reward_batch = np.concatenate(accumulated_reward_batch, axis=0) - return s_batch[:self._config.num_steps_per_iteration], \ - a_batch[:self._config.num_steps_per_iteration], \ - accumulated_reward_batch[:self._config.num_steps_per_iteration] + return ( + s_batch[: self._config.num_steps_per_iteration], + a_batch[: self._config.num_steps_per_iteration], + accumulated_reward_batch[: self._config.num_steps_per_iteration], + ) def _compute_accumulated_reward(self, reward_sequence, gamma): if not reward_sequence.ndim == 1: raise ValueError("Invalid reward_sequence dimension") episode_length = len(reward_sequence) - gamma_seq = np.array( - [gamma**i for i in range(episode_length)]) + gamma_seq = np.array([gamma**i for i in range(episode_length)]) - left_justified_gamma_seqs = np.tril( - np.tile(gamma_seq, (episode_length, 1)), k=0)[::-1] - mask = left_justified_gamma_seqs != 0. + left_justified_gamma_seqs = np.tril(np.tile(gamma_seq, (episode_length, 1)), k=0)[::-1] + mask = left_justified_gamma_seqs != 0.0 gamma_seqs = np.zeros((episode_length, episode_length)) gamma_seqs[np.triu_indices(episode_length)] = left_justified_gamma_seqs[mask] - return np.sum(reward_sequence*gamma_seqs, axis=1, keepdims=True) + return np.sum(reward_sequence * gamma_seqs, axis=1, keepdims=True) def _evaluation_action_selector(self, s, *, begin_of_episode=False): return self._evaluation_actor(s, begin_of_episode=begin_of_episode) @@ -279,8 +283,9 @@ def _solvers(self): @classmethod def is_supported_env(cls, env_or_env_info): - env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \ - else env_or_env_info + env_info = ( + EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info + ) return not env_info.is_tuple_action_env() @property diff --git a/nnabla_rl/algorithms/icml2018_sac.py b/nnabla_rl/algorithms/icml2018_sac.py index be37067c..58826c50 100644 --- a/nnabla_rl/algorithms/icml2018_sac.py +++ b/nnabla_rl/algorithms/icml2018_sac.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -86,7 +86,7 @@ class ICML2018SACConfig(AlgorithmConfig): """ gamma: float = 0.99 - learning_rate: float = 3.0*1e-4 + learning_rate: float = 3.0 * 1e-4 batch_size: int = 256 tau: float = 0.005 environment_steps: int = 1 @@ -115,82 +115,88 @@ def __post_init__(self): Check the values are in valid range. """ - self._assert_between(self.tau, 0.0, 1.0, 'tau') - self._assert_between(self.gamma, 0.0, 1.0, 'gamma') - self._assert_positive(self.gradient_steps, 'gradient_steps') - self._assert_positive(self.environment_steps, 'environment_steps') - self._assert_positive(self.start_timesteps, 'start_timesteps') - self._assert_positive(self.target_update_interval, 'target_update_interval') - self._assert_positive(self.num_steps, 'num_steps') - - self._assert_positive(self.pi_unroll_steps, 'pi_unroll_steps') - self._assert_positive_or_zero(self.pi_burn_in_steps, 'pi_burn_in_steps') - self._assert_positive(self.q_unroll_steps, 'q_unroll_steps') - self._assert_positive_or_zero(self.q_burn_in_steps, 'q_burn_in_steps') - self._assert_positive(self.v_unroll_steps, 'v_unroll_steps') - self._assert_positive_or_zero(self.v_burn_in_steps, 'v_burn_in_steps') + self._assert_between(self.tau, 0.0, 1.0, "tau") + self._assert_between(self.gamma, 0.0, 1.0, "gamma") + self._assert_positive(self.gradient_steps, "gradient_steps") + self._assert_positive(self.environment_steps, "environment_steps") + self._assert_positive(self.start_timesteps, "start_timesteps") + self._assert_positive(self.target_update_interval, "target_update_interval") + self._assert_positive(self.num_steps, "num_steps") + + self._assert_positive(self.pi_unroll_steps, "pi_unroll_steps") + self._assert_positive_or_zero(self.pi_burn_in_steps, "pi_burn_in_steps") + self._assert_positive(self.q_unroll_steps, "q_unroll_steps") + self._assert_positive_or_zero(self.q_burn_in_steps, "q_burn_in_steps") + self._assert_positive(self.v_unroll_steps, "v_unroll_steps") + self._assert_positive_or_zero(self.v_burn_in_steps, "v_burn_in_steps") class DefaultVFunctionBuilder(ModelBuilder[VFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: ICML2018SACConfig, - **kwargs) -> VFunction: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: ICML2018SACConfig, + **kwargs, + ) -> VFunction: return SACVFunction(scope_name) class DefaultQFunctionBuilder(ModelBuilder[QFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: ICML2018SACConfig, - **kwargs) -> QFunction: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: ICML2018SACConfig, + **kwargs, + ) -> QFunction: return SACQFunction(scope_name) class DefaultPolicyBuilder(ModelBuilder[StochasticPolicy]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: ICML2018SACConfig, - **kwargs) -> StochasticPolicy: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: ICML2018SACConfig, + **kwargs, + ) -> StochasticPolicy: return SACPolicy(scope_name, env_info.action_dim) class DefaultSolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: ICML2018SACConfig, - **kwargs) -> nn.solver.Solver: + def build_solver( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: ICML2018SACConfig, **kwargs + ) -> nn.solver.Solver: assert isinstance(algorithm_config, ICML2018SACConfig) return NS.Adam(alpha=algorithm_config.learning_rate) class DefaultReplayBufferBuilder(ReplayBufferBuilder): - def build_replay_buffer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: ICML2018SACConfig, - **kwargs) -> ReplayBuffer: + def build_replay_buffer( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: ICML2018SACConfig, **kwargs + ) -> ReplayBuffer: assert isinstance(algorithm_config, ICML2018SACConfig) return ReplayBuffer(capacity=algorithm_config.replay_buffer_size) class DefaultExplorerBuilder(ExplorerBuilder): - def build_explorer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: ICML2018SACConfig, - algorithm: "ICML2018SAC", - **kwargs) -> EnvironmentExplorer: + def build_explorer( # type: ignore[override] + self, + env_info: EnvironmentInfo, + algorithm_config: ICML2018SACConfig, + algorithm: "ICML2018SAC", + **kwargs, + ) -> EnvironmentExplorer: explorer_config = EE.RawPolicyExplorerConfig( warmup_random_steps=algorithm_config.start_timesteps, reward_scalar=algorithm_config.reward_scalar, initial_step_num=algorithm.iteration_num, - timelimit_as_terminal=False + timelimit_as_terminal=False, + ) + explorer = EE.RawPolicyExplorer( + policy_action_selector=algorithm._exploration_action_selector, env_info=env_info, config=explorer_config ) - explorer = EE.RawPolicyExplorer(policy_action_selector=algorithm._exploration_action_selector, - env_info=env_info, - config=explorer_config) return explorer @@ -256,16 +262,19 @@ class ICML2018SAC(Algorithm): _q_function_trainer_state: Dict[str, Any] _v_function_trainer_state: Dict[str, Any] - def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: ICML2018SACConfig = ICML2018SACConfig(), - v_function_builder: ModelBuilder[VFunction] = DefaultVFunctionBuilder(), - v_solver_builder: SolverBuilder = DefaultSolverBuilder(), - q_function_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), - q_solver_builder: SolverBuilder = DefaultSolverBuilder(), - policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), - policy_solver_builder: SolverBuilder = DefaultSolverBuilder(), - replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), - explorer_builder: ExplorerBuilder = DefaultExplorerBuilder()): + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: ICML2018SACConfig = ICML2018SACConfig(), + v_function_builder: ModelBuilder[VFunction] = DefaultVFunctionBuilder(), + v_solver_builder: SolverBuilder = DefaultSolverBuilder(), + q_function_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), + q_solver_builder: SolverBuilder = DefaultSolverBuilder(), + policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), + policy_solver_builder: SolverBuilder = DefaultSolverBuilder(), + replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), + explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), + ): super(ICML2018SAC, self).__init__(env_or_env_info, config=config) self._explorer_builder = explorer_builder @@ -273,7 +282,7 @@ def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)): self._v = v_function_builder(scope_name="v", env_info=self._env_info, algorithm_config=self._config) self._v_solver = v_solver_builder(env_info=self._env_info, algorithm_config=self._config) - self._target_v = self._v.deepcopy('target_' + self._v.scope_name) + self._target_v = self._v.deepcopy("target_" + self._v.scope_name) self._q1 = q_function_builder(scope_name="q1", env_info=self._env_info, algorithm_config=self._config) self._q2 = q_function_builder(scope_name="q2", env_info=self._env_info, algorithm_config=self._config) @@ -282,7 +291,8 @@ def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], self._train_q_solvers = {} for q in self._train_q_functions: self._train_q_solvers[q.scope_name] = q_solver_builder( - env_info=self._env_info, algorithm_config=self._config) + env_info=self._env_info, algorithm_config=self._config + ) self._pi = policy_builder(scope_name="pi", env_info=self._env_info, algorithm_config=self._config) self._pi_solver = policy_solver_builder(env_info=self._env_info, algorithm_config=self._config) @@ -290,9 +300,11 @@ def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], self._replay_buffer = replay_buffer_builder(env_info=self._env_info, algorithm_config=self._config) self._evaluation_actor = _StochasticPolicyActionSelector( - self._env_info, self._pi.shallowcopy(), deterministic=True) + self._env_info, self._pi.shallowcopy(), deterministic=True + ) self._exploration_actor = _StochasticPolicyActionSelector( - self._env_info, self._pi.shallowcopy(), deterministic=False) + self._env_info, self._pi.shallowcopy(), deterministic=False + ) @eval_api def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}): @@ -317,10 +329,11 @@ def _setup_policy_training(self, env_or_buffer): fixed_temperature=True, unroll_steps=self._config.pi_unroll_steps, burn_in_steps=self._config.pi_burn_in_steps, - reset_on_terminal=self._config.pi_reset_rnn_on_terminal) + reset_on_terminal=self._config.pi_reset_rnn_on_terminal, + ) temperature = MT.policy_trainers.soft_policy_trainer.AdjustableTemperature( - scope_name='temperature', - initial_value=1.0) + scope_name="temperature", initial_value=1.0 + ) policy_trainer = MT.policy_trainers.SoftPolicyTrainer( models=self._pi, solvers={self._pi.scope_name: self._pi_solver}, @@ -328,39 +341,44 @@ def _setup_policy_training(self, env_or_buffer): temperature=temperature, temperature_solver=None, env_info=self._env_info, - config=policy_trainer_config) + config=policy_trainer_config, + ) return policy_trainer def _setup_q_function_training(self, env_or_buffer): q_function_trainer_param = MT.q_value_trainers.VTargetedQTrainerConfig( - reduction_method='mean', + reduction_method="mean", q_loss_scalar=0.5, num_steps=self._config.num_steps, unroll_steps=self._config.q_unroll_steps, burn_in_steps=self._config.q_burn_in_steps, - reset_on_terminal=self._config.q_reset_rnn_on_terminal) + reset_on_terminal=self._config.q_reset_rnn_on_terminal, + ) q_function_trainer = MT.q_value_trainers.VTargetedQTrainer( train_functions=self._train_q_functions, solvers=self._train_q_solvers, target_functions=self._target_v, env_info=self._env_info, - config=q_function_trainer_param) + config=q_function_trainer_param, + ) return q_function_trainer def _setup_v_function_training(self, env_or_buffer): v_function_trainer_config = MT.v_value_trainers.SoftVTrainerConfig( - reduction_method='mean', + reduction_method="mean", v_loss_scalar=0.5, unroll_steps=self._config.v_unroll_steps, burn_in_steps=self._config.v_burn_in_steps, - reset_on_terminal=self._config.v_reset_rnn_on_terminal) + reset_on_terminal=self._config.v_reset_rnn_on_terminal, + ) v_function_trainer = MT.v_value_trainers.SoftVTrainer( train_functions=self._v, solvers={self._v.scope_name: self._v_solver}, target_functions=self._train_q_functions, # Set training q as target target_policy=self._pi, env_info=self._env_info, - config=v_function_trainer_config) + config=v_function_trainer_config, + ) sync_model(self._v, self._target_v, 1.0) return v_function_trainer @@ -389,23 +407,25 @@ def _sac_training(self, replay_buffer): num_steps = max(pi_steps, max(q_steps, v_steps)) experiences_tuple, info = replay_buffer.sample(self._config.batch_size, num_steps=num_steps) if num_steps == 1: - experiences_tuple = (experiences_tuple, ) + experiences_tuple = (experiences_tuple,) assert len(experiences_tuple) == num_steps batch = None for experiences in reversed(experiences_tuple): (s, a, r, non_terminal, s_next, rnn_states_dict, *_) = marshal_experiences(experiences) - rnn_states = rnn_states_dict['rnn_states'] if 'rnn_states' in rnn_states_dict else {} - batch = TrainingBatch(batch_size=self._config.batch_size, - s_current=s, - a_current=a, - gamma=self._config.gamma, - reward=r, - non_terminal=non_terminal, - s_next=s_next, - weight=info['weights'], - next_step_batch=batch, - rnn_states=rnn_states) + rnn_states = rnn_states_dict["rnn_states"] if "rnn_states" in rnn_states_dict else {} + batch = TrainingBatch( + batch_size=self._config.batch_size, + s_current=s, + a_current=a, + gamma=self._config.gamma, + reward=r, + non_terminal=non_terminal, + s_next=s_next, + weight=info["weights"], + next_step_batch=batch, + rnn_states=rnn_states, + ) # Train in the order of v -> q -> policy self._v_function_trainer_state = self._v_function_trainer.train(batch) @@ -414,7 +434,7 @@ def _sac_training(self, replay_buffer): sync_model(self._v, self._target_v, tau=self._config.tau) self._policy_trainer_state = self._policy_trainer.train(batch) - td_errors = self._q_function_trainer_state['td_errors'] + td_errors = self._q_function_trainer_state["td_errors"] replay_buffer.update_priorities(td_errors) def _evaluation_action_selector(self, s, *, begin_of_episode=False): @@ -440,21 +460,23 @@ def is_rnn_supported(cls): @classmethod def is_supported_env(cls, env_or_env_info): - env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \ - else env_or_env_info + env_info = ( + EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info + ) return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env() @property def latest_iteration_state(self): latest_iteration_state = super(ICML2018SAC, self).latest_iteration_state - if hasattr(self, '_policy_trainer_state'): - latest_iteration_state['scalar'].update({'pi_loss': float(self._policy_trainer_state['pi_loss'])}) - if hasattr(self, '_v_function_trainer_state'): - latest_iteration_state['scalar'].update({'v_loss': float(self._v_function_trainer_state['v_loss'])}) - if hasattr(self, '_q_function_trainer_state'): - latest_iteration_state['scalar'].update({'q_loss': float(self._q_function_trainer_state['q_loss'])}) - latest_iteration_state['histogram'].update( - {'td_errors': self._q_function_trainer_state['td_errors'].flatten()}) + if hasattr(self, "_policy_trainer_state"): + latest_iteration_state["scalar"].update({"pi_loss": float(self._policy_trainer_state["pi_loss"])}) + if hasattr(self, "_v_function_trainer_state"): + latest_iteration_state["scalar"].update({"v_loss": float(self._v_function_trainer_state["v_loss"])}) + if hasattr(self, "_q_function_trainer_state"): + latest_iteration_state["scalar"].update({"q_loss": float(self._q_function_trainer_state["q_loss"])}) + latest_iteration_state["histogram"].update( + {"td_errors": self._q_function_trainer_state["td_errors"].flatten()} + ) return latest_iteration_state @property diff --git a/nnabla_rl/algorithms/icra2018_qtopt.py b/nnabla_rl/algorithms/icra2018_qtopt.py index 63e18a28..4648c1b0 100644 --- a/nnabla_rl/algorithms/icra2018_qtopt.py +++ b/nnabla_rl/algorithms/icra2018_qtopt.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -84,6 +84,7 @@ class ICRA2018QtOptConfig(DDQNConfig): random_sample_size (int): number of candidates at the sampling step of random shooting method. Defaults to 16. """ + gamma: float = 0.9 learning_rate: float = 0.001 batch_size: int = 64 @@ -107,82 +108,87 @@ def __post_init__(self): Check set values are in valid range. """ - self._assert_between(self.gamma, 0.0, 1.0, 'gamma') - self._assert_positive(self.learning_rate, 'learning_rate') - self._assert_positive(self.batch_size, 'batch_size') - self._assert_positive(self.num_steps, 'num_steps') - self._assert_positive(self.q_loss_scalar, 'q_loss_scalar') - self._assert_positive(self.learner_update_frequency, 'learner_update_frequency') - self._assert_positive(self.target_update_frequency, 'target_update_frequency') - self._assert_positive(self.start_timesteps, 'start_timesteps') - self._assert_positive(self.replay_buffer_size, 'replay_buffer_size') - self._assert_smaller_than(self.start_timesteps, self.replay_buffer_size, 'start_timesteps') - self._assert_between(self.initial_epsilon, 0.0, 1.0, 'initial_epsilon') - self._assert_between(self.final_epsilon, 0.0, 1.0, 'final_epsilon') - self._assert_between(self.test_epsilon, 0.0, 1.0, 'test_epsilon') - self._assert_positive(self.max_explore_steps, 'max_explore_steps') - self._assert_positive(self.unroll_steps, 'unroll_steps') - self._assert_positive_or_zero(self.burn_in_steps, 'burn_in_steps') - self._assert_positive(self.cem_sample_size, 'cem_sample_size') - self._assert_positive(self.cem_num_elites, 'cem_num_elites') + self._assert_between(self.gamma, 0.0, 1.0, "gamma") + self._assert_positive(self.learning_rate, "learning_rate") + self._assert_positive(self.batch_size, "batch_size") + self._assert_positive(self.num_steps, "num_steps") + self._assert_positive(self.q_loss_scalar, "q_loss_scalar") + self._assert_positive(self.learner_update_frequency, "learner_update_frequency") + self._assert_positive(self.target_update_frequency, "target_update_frequency") + self._assert_positive(self.start_timesteps, "start_timesteps") + self._assert_positive(self.replay_buffer_size, "replay_buffer_size") + self._assert_smaller_than(self.start_timesteps, self.replay_buffer_size, "start_timesteps") + self._assert_between(self.initial_epsilon, 0.0, 1.0, "initial_epsilon") + self._assert_between(self.final_epsilon, 0.0, 1.0, "final_epsilon") + self._assert_between(self.test_epsilon, 0.0, 1.0, "test_epsilon") + self._assert_positive(self.max_explore_steps, "max_explore_steps") + self._assert_positive(self.unroll_steps, "unroll_steps") + self._assert_positive_or_zero(self.burn_in_steps, "burn_in_steps") + self._assert_positive(self.cem_sample_size, "cem_sample_size") + self._assert_positive(self.cem_num_elites, "cem_num_elites") self._assert_positive_or_zero(self.cem_alpha, "cem_alpha") - self._assert_positive(self.cem_num_iterations, 'cem_num_iterations') - self._assert_positive(self.random_sample_size, 'random_sample_size') + self._assert_positive(self.cem_num_iterations, "cem_num_iterations") + self._assert_positive(self.random_sample_size, "random_sample_size") class DefaultQFunctionBuilder(ModelBuilder[QFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: ICRA2018QtOptConfig, - **kwargs) -> QFunction: - return ICRA2018QtOptQFunction(scope_name, - env_info.action_dim, - action_high=env_info.action_high, - action_low=env_info.action_low, - cem_initial_mean=algorithm_config.cem_initial_mean, - cem_initial_variance=algorithm_config.cem_initial_variance, - cem_sample_size=algorithm_config.cem_sample_size, - cem_num_elites=algorithm_config.cem_num_elites, - cem_num_iterations=algorithm_config.cem_num_iterations, - cem_alpha=algorithm_config.cem_alpha, - random_sample_size=algorithm_config.random_sample_size) + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: ICRA2018QtOptConfig, + **kwargs, + ) -> QFunction: + return ICRA2018QtOptQFunction( + scope_name, + env_info.action_dim, + action_high=env_info.action_high, + action_low=env_info.action_low, + cem_initial_mean=algorithm_config.cem_initial_mean, + cem_initial_variance=algorithm_config.cem_initial_variance, + cem_sample_size=algorithm_config.cem_sample_size, + cem_num_elites=algorithm_config.cem_num_elites, + cem_num_iterations=algorithm_config.cem_num_iterations, + cem_alpha=algorithm_config.cem_alpha, + random_sample_size=algorithm_config.random_sample_size, + ) class DefaultSolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: ICRA2018QtOptConfig, - **kwargs) -> nn.solver.Solver: + def build_solver( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: ICRA2018QtOptConfig, **kwargs + ) -> nn.solver.Solver: return NS.Adam(alpha=algorithm_config.learning_rate) class DefaultReplayBufferBuilder(ReplayBufferBuilder): - def build_replay_buffer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: ICRA2018QtOptConfig, - **kwargs) -> ReplayBuffer: + def build_replay_buffer( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: ICRA2018QtOptConfig, **kwargs + ) -> ReplayBuffer: return ReplayBuffer(capacity=algorithm_config.replay_buffer_size) class DefaultExplorerBuilder(ExplorerBuilder): - def build_explorer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: ICRA2018QtOptConfig, - algorithm: "ICRA2018QtOpt", - **kwargs) -> EnvironmentExplorer: + def build_explorer( # type: ignore[override] + self, + env_info: EnvironmentInfo, + algorithm_config: ICRA2018QtOptConfig, + algorithm: "ICRA2018QtOpt", + **kwargs, + ) -> EnvironmentExplorer: explorer_config = EE.LinearDecayEpsilonGreedyExplorerConfig( warmup_random_steps=algorithm_config.start_timesteps, initial_step_num=algorithm.iteration_num, initial_epsilon=algorithm_config.initial_epsilon, final_epsilon=algorithm_config.final_epsilon, - max_explore_steps=algorithm_config.max_explore_steps + max_explore_steps=algorithm_config.max_explore_steps, ) explorer = EE.LinearDecayEpsilonGreedyExplorer( greedy_action_selector=algorithm._exploration_action_selector, random_action_selector=algorithm._random_action_selector, env_info=env_info, - config=explorer_config) + config=explorer_config, + ) return explorer @@ -227,33 +233,42 @@ class ICRA2018QtOpt(DDQN): _evaluation_actor: _GreedyActionSelector _exploration_actor: _GreedyActionSelector - def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: ICRA2018QtOptConfig = ICRA2018QtOptConfig(), - q_func_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), - q_solver_builder: SolverBuilder = DefaultSolverBuilder(), - replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), - explorer_builder: ExplorerBuilder = DefaultExplorerBuilder()): - super(ICRA2018QtOpt, self).__init__(env_or_env_info=env_or_env_info, - config=config, - q_func_builder=q_func_builder, - q_solver_builder=q_solver_builder, - replay_buffer_builder=replay_buffer_builder, - explorer_builder=explorer_builder) + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: ICRA2018QtOptConfig = ICRA2018QtOptConfig(), + q_func_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), + q_solver_builder: SolverBuilder = DefaultSolverBuilder(), + replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), + explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), + ): + super(ICRA2018QtOpt, self).__init__( + env_or_env_info=env_or_env_info, + config=config, + q_func_builder=q_func_builder, + q_solver_builder=q_solver_builder, + replay_buffer_builder=replay_buffer_builder, + explorer_builder=explorer_builder, + ) def _setup_q_function_training(self, env_or_buffer): - trainer_config = MT.q_value_trainers.DDQNQTrainerConfig(num_steps=self._config.num_steps, - q_loss_scalar=self._config.q_loss_scalar, - reduction_method='sum', - grad_clip=self._config.grad_clip, - unroll_steps=self._config.unroll_steps, - burn_in_steps=self._config.burn_in_steps, - reset_on_terminal=self._config.reset_rnn_on_terminal) + trainer_config = MT.q_value_trainers.DDQNQTrainerConfig( + num_steps=self._config.num_steps, + q_loss_scalar=self._config.q_loss_scalar, + reduction_method="sum", + grad_clip=self._config.grad_clip, + unroll_steps=self._config.unroll_steps, + burn_in_steps=self._config.burn_in_steps, + reset_on_terminal=self._config.reset_rnn_on_terminal, + ) - q_function_trainer = MT.q_value_trainers.DDQNQTrainer(train_function=self._q, - solvers={self._q.scope_name: self._q_solver}, - target_function=self._target_q, - env_info=self._env_info, - config=trainer_config) + q_function_trainer = MT.q_value_trainers.DDQNQTrainer( + train_function=self._q, + solvers={self._q.scope_name: self._q_solver}, + target_function=self._target_q, + env_info=self._env_info, + config=trainer_config, + ) sync_model(self._q, self._target_q) return q_function_trainer @@ -262,6 +277,7 @@ def _random_action_selector(self, s, *, begin_of_episode=False): @classmethod def is_supported_env(cls, env_or_env_info): - env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \ - else env_or_env_info + env_info = ( + EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info + ) return env_info.is_continuous_action_env() and not env_info.is_tuple_action_env() diff --git a/nnabla_rl/algorithms/ilqr.py b/nnabla_rl/algorithms/ilqr.py index 4e332cc3..845a89f5 100644 --- a/nnabla_rl/algorithms/ilqr.py +++ b/nnabla_rl/algorithms/ilqr.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -46,13 +46,10 @@ class iLQR(DDP): config (:py:class:`iLQRConfig `): the parameter for iLQR controller """ + _config: iLQRConfig - def __init__(self, - env_or_env_info, - dynamics: Dynamics, - cost_function: CostFunction, - config=iLQRConfig()): + def __init__(self, env_or_env_info, dynamics: Dynamics, cost_function: CostFunction, config=iLQRConfig()): super(iLQR, self).__init__(env_or_env_info, dynamics, cost_function, config=config) def _backward_pass(self, trajectory, dynamics, cost_function, mu): diff --git a/nnabla_rl/algorithms/iqn.py b/nnabla_rl/algorithms/iqn.py index 904a3e00..76fddd1b 100644 --- a/nnabla_rl/algorithms/iqn.py +++ b/nnabla_rl/algorithms/iqn.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -78,6 +78,7 @@ class IQNConfig(AlgorithmConfig): This flag does not take effect if given model is not an RNN model. Defaults to True. """ + gamma: float = 0.99 learning_rate: float = 0.00005 batch_size: int = 32 @@ -106,24 +107,24 @@ def __post_init__(self): Check that set values are in valid range. """ - self._assert_between(self.gamma, 0.0, 1.0, 'gamma') - self._assert_positive(self.batch_size, 'batch_size') - self._assert_positive(self.replay_buffer_size, 'replay_buffer_size') - self._assert_positive(self.num_steps, 'num_steps') - self._assert_positive(self.learner_update_frequency, 'learner_update_frequency') - self._assert_positive(self.target_update_frequency, 'target_update_frequency') - self._assert_positive(self.max_explore_steps, 'max_explore_steps') - self._assert_positive(self.learning_rate, 'learning_rate') - self._assert_positive(self.initial_epsilon, 'initial_epsilon') - self._assert_positive(self.final_epsilon, 'final_epsilon') - self._assert_positive(self.test_epsilon, 'test_epsilon') - self._assert_positive(self.N, 'N') - self._assert_positive(self.N_prime, 'N_prime') - self._assert_positive(self.K, 'K') - self._assert_positive(self.kappa, 'kappa') - self._assert_positive(self.embedding_dim, 'embedding_dim') - self._assert_positive(self.unroll_steps, 'unroll_steps') - self._assert_positive_or_zero(self.burn_in_steps, 'burn_in_steps') + self._assert_between(self.gamma, 0.0, 1.0, "gamma") + self._assert_positive(self.batch_size, "batch_size") + self._assert_positive(self.replay_buffer_size, "replay_buffer_size") + self._assert_positive(self.num_steps, "num_steps") + self._assert_positive(self.learner_update_frequency, "learner_update_frequency") + self._assert_positive(self.target_update_frequency, "target_update_frequency") + self._assert_positive(self.max_explore_steps, "max_explore_steps") + self._assert_positive(self.learning_rate, "learning_rate") + self._assert_positive(self.initial_epsilon, "initial_epsilon") + self._assert_positive(self.final_epsilon, "final_epsilon") + self._assert_positive(self.test_epsilon, "test_epsilon") + self._assert_positive(self.N, "N") + self._assert_positive(self.N_prime, "N_prime") + self._assert_positive(self.K, "K") + self._assert_positive(self.kappa, "kappa") + self._assert_positive(self.embedding_dim, "embedding_dim") + self._assert_positive(self.unroll_steps, "unroll_steps") + self._assert_positive_or_zero(self.burn_in_steps, "burn_in_steps") def risk_neutral_measure(tau): @@ -131,56 +132,61 @@ def risk_neutral_measure(tau): class DefaultQuantileFunctionBuilder(ModelBuilder[StateActionQuantileFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: IQNConfig, - **kwargs) -> StateActionQuantileFunction: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: IQNConfig, + **kwargs, + ) -> StateActionQuantileFunction: assert isinstance(algorithm_config, IQNConfig) - risk_measure_function = kwargs['risk_measure_function'] - return IQNQuantileFunction(scope_name, - env_info.action_dim, - algorithm_config.embedding_dim, - K=algorithm_config.K, - risk_measure_function=risk_measure_function) + risk_measure_function = kwargs["risk_measure_function"] + return IQNQuantileFunction( + scope_name, + env_info.action_dim, + algorithm_config.embedding_dim, + K=algorithm_config.K, + risk_measure_function=risk_measure_function, + ) class DefaultSolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: IQNConfig, - **kwargs) -> nn.solver.Solver: + def build_solver( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: IQNConfig, **kwargs + ) -> nn.solver.Solver: assert isinstance(algorithm_config, IQNConfig) return NS.Adam(alpha=algorithm_config.learning_rate, eps=1e-2 / algorithm_config.batch_size) class DefaultReplayBufferBuilder(ReplayBufferBuilder): - def build_replay_buffer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: IQNConfig, - **kwargs) -> ReplayBuffer: + def build_replay_buffer( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: IQNConfig, **kwargs + ) -> ReplayBuffer: assert isinstance(algorithm_config, IQNConfig) return ReplayBuffer(capacity=algorithm_config.replay_buffer_size) class DefaultExplorerBuilder(ExplorerBuilder): - def build_explorer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: IQNConfig, - algorithm: "IQN", - **kwargs) -> EnvironmentExplorer: + def build_explorer( # type: ignore[override] + self, + env_info: EnvironmentInfo, + algorithm_config: IQNConfig, + algorithm: "IQN", + **kwargs, + ) -> EnvironmentExplorer: explorer_config = EE.LinearDecayEpsilonGreedyExplorerConfig( warmup_random_steps=algorithm_config.start_timesteps, initial_step_num=algorithm.iteration_num, initial_epsilon=algorithm_config.initial_epsilon, final_epsilon=algorithm_config.final_epsilon, - max_explore_steps=algorithm_config.max_explore_steps + max_explore_steps=algorithm_config.max_explore_steps, ) explorer = EE.LinearDecayEpsilonGreedyExplorer( greedy_action_selector=algorithm._exploration_action_selector, random_action_selector=algorithm._random_action_selector, env_info=env_info, - config=explorer_config) + config=explorer_config, + ) return explorer @@ -221,42 +227,49 @@ class IQN(Algorithm): _quantile_function_trainer_state: Dict[str, Any] - def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: IQNConfig = IQNConfig(), - quantile_function_builder: ModelBuilder[StateActionQuantileFunction] - = DefaultQuantileFunctionBuilder(), - quantile_solver_builder: SolverBuilder = DefaultSolverBuilder(), - replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), - explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), - risk_measure_function=risk_neutral_measure): + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: IQNConfig = IQNConfig(), + quantile_function_builder: ModelBuilder[StateActionQuantileFunction] = DefaultQuantileFunctionBuilder(), + quantile_solver_builder: SolverBuilder = DefaultSolverBuilder(), + replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), + explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), + risk_measure_function=risk_neutral_measure, + ): super(IQN, self).__init__(env_or_env_info, config=config) self._explorer_builder = explorer_builder with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)): kwargs = {} - kwargs['risk_measure_function'] = risk_measure_function + kwargs["risk_measure_function"] = risk_measure_function self._quantile_function = quantile_function_builder( - 'quantile_function', self._env_info, self._config, **kwargs) - self._target_quantile_function = self._quantile_function.deepcopy('target_quantile_function') + "quantile_function", self._env_info, self._config, **kwargs + ) + self._target_quantile_function = self._quantile_function.deepcopy("target_quantile_function") self._quantile_function_solver = quantile_solver_builder(self._env_info, self._config) self._replay_buffer = replay_buffer_builder(self._env_info, self._config) self._evaluation_actor = _GreedyActionSelector( - self._env_info, self._quantile_function.shallowcopy().as_q_function()) + self._env_info, self._quantile_function.shallowcopy().as_q_function() + ) self._exploration_actor = _GreedyActionSelector( - self._env_info, self._quantile_function.shallowcopy().as_q_function()) + self._env_info, self._quantile_function.shallowcopy().as_q_function() + ) @eval_api def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}): with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)): - (action, _), _ = epsilon_greedy_action_selection(state, - self._evaluation_action_selector, - self._random_action_selector, - epsilon=self._config.test_epsilon, - begin_of_episode=begin_of_episode) + (action, _), _ = epsilon_greedy_action_selection( + state, + self._evaluation_action_selector, + self._random_action_selector, + epsilon=self._config.test_epsilon, + begin_of_episode=begin_of_episode, + ) return action def _before_training_start(self, env_or_buffer): @@ -277,14 +290,16 @@ def _setup_quantile_function_training(self, env_or_buffer): kappa=self._config.kappa, unroll_steps=self._config.unroll_steps, burn_in_steps=self._config.burn_in_steps, - reset_on_terminal=self._config.reset_rnn_on_terminal) + reset_on_terminal=self._config.reset_rnn_on_terminal, + ) quantile_function_trainer = MT.q_value_trainers.IQNQTrainer( train_functions=self._quantile_function, solvers={self._quantile_function.scope_name: self._quantile_function_solver}, target_function=self._target_quantile_function, env_info=self._env_info, - config=trainer_config) + config=trainer_config, + ) # NOTE: Copy initial parameters after setting up the training # Because the parameter is created after training graph construction @@ -306,23 +321,25 @@ def _iqn_training(self, replay_buffer): num_steps = self._config.num_steps + self._config.burn_in_steps + self._config.unroll_steps - 1 experiences_tuple, info = replay_buffer.sample(self._config.batch_size, num_steps=num_steps) if num_steps == 1: - experiences_tuple = (experiences_tuple, ) + experiences_tuple = (experiences_tuple,) assert len(experiences_tuple) == num_steps batch = None for experiences in reversed(experiences_tuple): (s, a, r, non_terminal, s_next, rnn_states_dict, *_) = marshal_experiences(experiences) - rnn_states = rnn_states_dict['rnn_states'] if 'rnn_states' in rnn_states_dict else {} - batch = TrainingBatch(batch_size=self._config.batch_size, - s_current=s, - a_current=a, - gamma=self._config.gamma, - reward=r, - non_terminal=non_terminal, - s_next=s_next, - weight=info['weights'], - next_step_batch=batch, - rnn_states=rnn_states) + rnn_states = rnn_states_dict["rnn_states"] if "rnn_states" in rnn_states_dict else {} + batch = TrainingBatch( + batch_size=self._config.batch_size, + s_current=s, + a_current=a, + gamma=self._config.gamma, + reward=r, + non_terminal=non_terminal, + s_next=s_next, + weight=info["weights"], + next_step_batch=batch, + rnn_states=rnn_states, + ) self._quantile_function_trainer_state = self._quantile_function_trainer.train(batch) if self.iteration_num % self._config.target_update_frequency: @@ -336,7 +353,7 @@ def _exploration_action_selector(self, s, *, begin_of_episode=False): def _random_action_selector(self, s, *, begin_of_episode=False): action = self._env_info.action_space.sample() - return np.asarray(action).reshape((1, )), {} + return np.asarray(action).reshape((1,)), {} def _models(self): models = {} @@ -350,8 +367,9 @@ def _solvers(self): @classmethod def is_supported_env(cls, env_or_env_info): - env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \ - else env_or_env_info + env_info = ( + EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info + ) return not env_info.is_continuous_action_env() and not env_info.is_tuple_action_env() @classmethod @@ -361,8 +379,8 @@ def is_rnn_supported(self): @property def latest_iteration_state(self): latest_iteration_state = super(IQN, self).latest_iteration_state - if hasattr(self, '_quantile_function_trainer_state'): - latest_iteration_state['scalar'].update({'q_loss': float(self._quantile_function_trainer_state['q_loss'])}) + if hasattr(self, "_quantile_function_trainer_state"): + latest_iteration_state["scalar"].update({"q_loss": float(self._quantile_function_trainer_state["q_loss"])}) return latest_iteration_state @property diff --git a/nnabla_rl/algorithms/lqr.py b/nnabla_rl/algorithms/lqr.py index 493df540..eab013ef 100644 --- a/nnabla_rl/algorithms/lqr.py +++ b/nnabla_rl/algorithms/lqr.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,12 +31,13 @@ class LQRConfig(AlgorithmConfig): Args: T_max (int): Planning time step length. Defaults to 50. """ + T_max: int = 50 def __post_init__(self): super().__post_init__() - self._assert_positive(self.T_max, 'T_max') + self._assert_positive(self.T_max, "T_max") class LQR(Algorithm): @@ -53,13 +54,10 @@ class LQR(Algorithm): config (:py:class:`LQRConfig `): the parameter for LQR controller """ + _config: LQRConfig - def __init__(self, - env_or_env_info, - dynamics: Dynamics, - cost_function: CostFunction, - config=LQRConfig()): + def __init__(self, env_or_env_info, dynamics: Dynamics, cost_function: CostFunction, config=LQRConfig()): super(LQR, self).__init__(env_or_env_info, config=config) self._dynamics = dynamics self._cost_function = cost_function @@ -76,9 +74,9 @@ def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}): return improved_trajectory[0][1] @eval_api - def compute_trajectory(self, - initial_trajectory: Sequence[Tuple[np.ndarray, Optional[np.ndarray]]]) \ - -> Tuple[Sequence[Tuple[np.ndarray, Optional[np.ndarray]]], Sequence[Dict[str, Any]]]: + def compute_trajectory( + self, initial_trajectory: Sequence[Tuple[np.ndarray, Optional[np.ndarray]]] + ) -> Tuple[Sequence[Tuple[np.ndarray, Optional[np.ndarray]]], Sequence[Dict[str, Any]]]: assert len(initial_trajectory) == self._config.T_max dynamics = self._dynamics cost_function = self._cost_function @@ -93,12 +91,13 @@ def _compute_initial_trajectory(self, x0, dynamics, T, u): trajectory.append((x, None)) return trajectory - def _optimize(self, - initial_state: Union[np.ndarray, Sequence[Tuple[np.ndarray, Optional[np.ndarray]]]], - dynamics: Dynamics, - cost_function: CostFunction, - **kwargs) \ - -> Tuple[Sequence[Tuple[np.ndarray, Optional[np.ndarray]]], Sequence[Dict[str, Any]]]: + def _optimize( + self, + initial_state: Union[np.ndarray, Sequence[Tuple[np.ndarray, Optional[np.ndarray]]]], + dynamics: Dynamics, + cost_function: CostFunction, + **kwargs, + ) -> Tuple[Sequence[Tuple[np.ndarray, Optional[np.ndarray]]], Sequence[Dict[str, Any]]]: assert len(initial_state) == self._config.T_max initial_state = cast(Sequence[Tuple[np.ndarray, Optional[np.ndarray]]], initial_state) x_last, u_last = initial_state[-1] @@ -114,7 +113,7 @@ def _optimize(self, assert F is not None assert R is not None C = np.linalg.inv(R + (B.T.dot(Sk).dot(B))) - D = (F.T + B.T.dot(Sk).dot(A)) + D = F.T + B.T.dot(Sk).dot(A) Sk = Q + A.T.dot(Sk).dot(A) - D.T.dot(C).dot(D) matrices.append((Sk, A, B, R, F)) @@ -125,7 +124,7 @@ def _optimize(self, u = self._compute_optimal_input(x, S, A, B, R, F) trajectory.append((x, u)) # Save quadratic cost coefficient R as Quu and R^-1 as Quu_inv - trajectory_info.append({'Quu': R, 'Quu_inv': np.linalg.inv(R)}) + trajectory_info.append({"Quu": R, "Quu_inv": np.linalg.inv(R)}) x, _ = dynamics.next_state(x, u, t) trajectory.append((x, None)) # final timestep input is None @@ -134,20 +133,20 @@ def _optimize(self, def _compute_optimal_input(self, x, S, A, B, R, F) -> np.ndarray: C = np.linalg.inv(R + (B.T.dot(S).dot(B))) - D = (F.T + B.T.dot(S).dot(A)) + D = F.T + B.T.dot(S).dot(A) return cast(np.ndarray, -C.dot(D).dot(x)) def _before_training_start(self, env_or_buffer): - raise NotImplementedError('You do not need training to use this algorithm.') + raise NotImplementedError("You do not need training to use this algorithm.") def _run_online_training_iteration(self, env): - raise NotImplementedError('You do not need training to use this algorithm.') + raise NotImplementedError("You do not need training to use this algorithm.") def _run_offline_training_iteration(self, buffer): - raise NotImplementedError('You do not need training to use this algorithm.') + raise NotImplementedError("You do not need training to use this algorithm.") def _after_training_finish(self, env_or_buffer): - raise NotImplementedError('You do not need training to use this algorithm.') + raise NotImplementedError("You do not need training to use this algorithm.") def _models(self): return {} @@ -157,8 +156,9 @@ def _solvers(self): @classmethod def is_supported_env(cls, env_or_env_info): - env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \ - else env_or_env_info + env_info = ( + EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info + ) return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env() @property diff --git a/nnabla_rl/algorithms/mme_sac.py b/nnabla_rl/algorithms/mme_sac.py index 3e3c780a..db352e8d 100644 --- a/nnabla_rl/algorithms/mme_sac.py +++ b/nnabla_rl/algorithms/mme_sac.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,9 +19,14 @@ import nnabla_rl.model_trainers as MT from nnabla_rl.algorithms import ICML2018SAC, ICML2018SACConfig -from nnabla_rl.algorithms.icml2018_sac import (DefaultExplorerBuilder, DefaultPolicyBuilder, DefaultQFunctionBuilder, - DefaultReplayBufferBuilder, DefaultSolverBuilder, - DefaultVFunctionBuilder) +from nnabla_rl.algorithms.icml2018_sac import ( + DefaultExplorerBuilder, + DefaultPolicyBuilder, + DefaultQFunctionBuilder, + DefaultReplayBufferBuilder, + DefaultSolverBuilder, + DefaultVFunctionBuilder, +) from nnabla_rl.builders import ExplorerBuilder, ModelBuilder, ReplayBufferBuilder, SolverBuilder from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.models import QFunction, StochasticPolicy, VFunction @@ -37,6 +42,7 @@ class MMESACConfig(ICML2018SACConfig): Otherwise 1/alpha_pi will be used to scale the reward. Defaults to None. alpha_q (float): Temperature value for negative entropy term. Defaults to 1.0. """ + # override configurations reward_scalar: float = 5.0 alpha_pi: Optional[float] = None @@ -89,37 +95,43 @@ class MMESAC(ICML2018SAC): # See https://mypy.readthedocs.io/en/stable/class_basics.html for details _config: MMESACConfig - def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: MMESACConfig = MMESACConfig(), - v_function_builder: ModelBuilder[VFunction] = DefaultVFunctionBuilder(), - v_solver_builder: SolverBuilder = DefaultSolverBuilder(), - q_function_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), - q_solver_builder: SolverBuilder = DefaultSolverBuilder(), - policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), - policy_solver_builder: SolverBuilder = DefaultSolverBuilder(), - replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), - explorer_builder: ExplorerBuilder = DefaultExplorerBuilder()): - super(MMESAC, self).__init__(env_or_env_info, - config=config, - v_function_builder=v_function_builder, - v_solver_builder=v_solver_builder, - q_function_builder=q_function_builder, - q_solver_builder=q_solver_builder, - policy_builder=policy_builder, - policy_solver_builder=policy_solver_builder, - replay_buffer_builder=replay_buffer_builder, - explorer_builder=explorer_builder) + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: MMESACConfig = MMESACConfig(), + v_function_builder: ModelBuilder[VFunction] = DefaultVFunctionBuilder(), + v_solver_builder: SolverBuilder = DefaultSolverBuilder(), + q_function_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), + q_solver_builder: SolverBuilder = DefaultSolverBuilder(), + policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), + policy_solver_builder: SolverBuilder = DefaultSolverBuilder(), + replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), + explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), + ): + super(MMESAC, self).__init__( + env_or_env_info, + config=config, + v_function_builder=v_function_builder, + v_solver_builder=v_solver_builder, + q_function_builder=q_function_builder, + q_solver_builder=q_solver_builder, + policy_builder=policy_builder, + policy_solver_builder=policy_solver_builder, + replay_buffer_builder=replay_buffer_builder, + explorer_builder=explorer_builder, + ) def _setup_v_function_training(self, env_or_buffer): alpha_q = MT.policy_trainers.soft_policy_trainer.AdjustableTemperature( - scope_name='alpha_q', - initial_value=self._config.alpha_q) + scope_name="alpha_q", initial_value=self._config.alpha_q + ) v_function_trainer_config = MT.v_value_trainers.MMEVTrainerConfig( - reduction_method='mean', + reduction_method="mean", v_loss_scalar=0.5, unroll_steps=self._config.v_unroll_steps, burn_in_steps=self._config.v_burn_in_steps, - reset_on_terminal=self._config.v_reset_rnn_on_terminal) + reset_on_terminal=self._config.v_reset_rnn_on_terminal, + ) v_function_trainer = MT.v_value_trainers.MMEVTrainer( train_functions=self._v, temperature=alpha_q, @@ -127,7 +139,8 @@ def _setup_v_function_training(self, env_or_buffer): target_functions=self._train_q_functions, # Set training q as target target_policy=self._pi, env_info=self._env_info, - config=v_function_trainer_config) + config=v_function_trainer_config, + ) sync_model(self._v, self._target_v, 1.0) return v_function_trainer diff --git a/nnabla_rl/algorithms/mppi.py b/nnabla_rl/algorithms/mppi.py index 398cdff0..01c6c73d 100644 --- a/nnabla_rl/algorithms/mppi.py +++ b/nnabla_rl/algorithms/mppi.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -69,7 +69,8 @@ class MPPIConfig(AlgorithmConfig): dt (float): Time interval between states. Defaults to 0.05 [s]. We strongly recommended to adjust this interval considering the sensor frequency. """ - learning_rate: float = 1.0*1e-3 + + learning_rate: float = 1.0 * 1e-3 batch_size: int = 100 replay_buffer_size: int = 1000000 training_iterations: int = 500 @@ -87,36 +88,36 @@ class MPPIConfig(AlgorithmConfig): def __post_init__(self): super().__post_init__() - self._assert_positive(self.lmb, 'lmb') - self._assert_positive(self.M, 'M') - self._assert_positive(self.K, 'K') - self._assert_positive(self.T, 'T') - self._assert_positive(self.dt, 'dt') + self._assert_positive(self.lmb, "lmb") + self._assert_positive(self.M, "M") + self._assert_positive(self.K, "K") + self._assert_positive(self.T, "T") + self._assert_positive(self.dt, "dt") class DefaultDynamicsBuilder(ModelBuilder[DeterministicDynamics]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: MPPIConfig, - **kwargs) -> DeterministicDynamics: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: MPPIConfig, + **kwargs, + ) -> DeterministicDynamics: return MPPIDeterministicDynamics(scope_name, dt=algorithm_config.dt) class DefaultSolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: MPPIConfig, - **kwargs) -> nn.solver.Solver: + def build_solver( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: MPPIConfig, **kwargs + ) -> nn.solver.Solver: # return NS.RMSprop(lr=algorithm_config.learning_rate) return NS.Adam(alpha=algorithm_config.learning_rate) class DefaultReplayBufferBuilder(ReplayBufferBuilder): - def build_replay_buffer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: MPPIConfig, - **kwargs) -> ReplayBuffer: + def build_replay_buffer( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: MPPIConfig, **kwargs + ) -> ReplayBuffer: return ReplayBuffer(capacity=algorithm_config.replay_buffer_size) @@ -161,22 +162,25 @@ class MPPI(Algorithm): builder of replay_buffer. If you have bootstrap data, override the default builder and return a replay buffer with bootstrap data. """ + _config: MPPIConfig _evaluation_dynamics: _DeterministicStatePredictor - def __init__(self, - env_or_env_info, - cost_function: CostFunction, - known_dynamics: Optional[Dynamics] = None, - state_normalizer: Optional[Callable[[np.ndarray], np.ndarray]] = None, - config: MPPIConfig = MPPIConfig(), - dynamics_builder: ModelBuilder[DeterministicDynamics] = DefaultDynamicsBuilder(), - dynamics_solver_builder: SolverBuilder = DefaultSolverBuilder(), - replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder()): + def __init__( + self, + env_or_env_info, + cost_function: CostFunction, + known_dynamics: Optional[Dynamics] = None, + state_normalizer: Optional[Callable[[np.ndarray], np.ndarray]] = None, + config: MPPIConfig = MPPIConfig(), + dynamics_builder: ModelBuilder[DeterministicDynamics] = DefaultDynamicsBuilder(), + dynamics_solver_builder: SolverBuilder = DefaultSolverBuilder(), + replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), + ): super(MPPI, self).__init__(env_or_env_info, config=config) with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)): self._known_dynamics = known_dynamics - self._dynamics = dynamics_builder('dynamics', env_info=self._env_info, algorithm_config=self._config) + self._dynamics = dynamics_builder("dynamics", env_info=self._env_info, algorithm_config=self._config) self._dynamics_solver = dynamics_solver_builder(env_info=self._env_info, algorithm_config=self._config) self._replay_buffer = replay_buffer_builder(env_info=self._env_info, algorithm_config=self._config) @@ -196,9 +200,9 @@ def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}): return control_inputs[0] @eval_api - def compute_trajectory(self, - initial_trajectory: Sequence[Tuple[np.ndarray, Optional[np.ndarray]]]) \ - -> Tuple[Sequence[Tuple[np.ndarray, Optional[np.ndarray]]], Sequence[Dict[str, Any]]]: + def compute_trajectory( + self, initial_trajectory: Sequence[Tuple[np.ndarray, Optional[np.ndarray]]] + ) -> Tuple[Sequence[Tuple[np.ndarray, Optional[np.ndarray]]], Sequence[Dict[str, Any]]]: assert len(initial_trajectory) == self._config.T x, u = unzip(initial_trajectory) @@ -225,13 +229,15 @@ def _setup_dynamics_training(self, env_or_buffer): unroll_steps=self._config.unroll_steps, burn_in_steps=self._config.burn_in_steps, reset_on_terminal=self._config.reset_rnn_on_terminal, - dt=self._config.dt) + dt=self._config.dt, + ) dynamics_trainer = MT.dynamics_trainers.MPPIDynamicsTrainer( models=self._dynamics, solvers={self._dynamics.scope_name: self._dynamics_solver}, env_info=self._env_info, - config=dynamics_trainer_config) + config=dynamics_trainer_config, + ) return dynamics_trainer def _run_online_training_iteration(self, env): @@ -243,26 +249,28 @@ def _run_online_training_iteration(self, env): self._replay_buffer.append_all(experiences) # D U Dj def _run_offline_training_iteration(self, buffer): - raise NotImplementedError('You can not train MPPI only with buffer. Try online training.') + raise NotImplementedError("You can not train MPPI only with buffer. Try online training.") def _mppi_training(self, replay_buffer): # train the dynamics model num_steps = self._config.burn_in_steps + self._config.unroll_steps experiences_tuple, info = replay_buffer.sample(self._config.batch_size, num_steps=num_steps) if num_steps == 1: - experiences_tuple = (experiences_tuple, ) + experiences_tuple = (experiences_tuple,) assert len(experiences_tuple) == num_steps batch = None for experiences in reversed(experiences_tuple): (s, a, _, non_terminal, s_next, *_) = marshal_experiences(experiences) - batch = TrainingBatch(batch_size=self._config.batch_size, - s_current=s, - a_current=a, - s_next=s_next, - non_terminal=non_terminal, - weight=info['weights'], - next_step_batch=batch, - rnn_states={}) + batch = TrainingBatch( + batch_size=self._config.batch_size, + s_current=s, + a_current=a, + s_next=s_next, + non_terminal=non_terminal, + weight=info["weights"], + next_step_batch=batch, + rnn_states={}, + ) self._dynamics_trainer_state = self._dynamics_trainer.train(batch) @@ -291,7 +299,7 @@ def _compute_control_inputs(self, x, control_inputs): dummy_states = [] improved_inputs = control_inputs.copy() control_inputs = np.broadcast_to(control_inputs, shape=(self._config.K, *control_inputs.shape)) - mean = np.zeros(shape=(self._env_info.action_dim, )) + mean = np.zeros(shape=(self._env_info.action_dim,)) cov = np.eye(N=self._env_info.action_dim) if self._config.covariance is not None: assert cov.shape == self._config.covariance.shape @@ -323,8 +331,8 @@ def _compute_control_inputs(self, x, control_inputs): for k in range(self._config.K): S[k] += self._cost_function.evaluate(x[k], zero_control, self._config.T, final_state=True) beta = np.min(S) - eta = np.sum(np.exp(-(S - beta)/self._config.lmb)) - weights = np.exp(-(S - beta)/self._config.lmb) / eta + eta = np.sum(np.exp(-(S - beta) / self._config.lmb)) + weights = np.exp(-(S - beta) / self._config.lmb) / eta du = np.sum(weights[:, np.newaxis, :] * input_noise, axis=0) improved_inputs += du @@ -363,17 +371,19 @@ def _solvers(self): @classmethod def is_supported_env(cls, env_or_env_info): - env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \ - else env_or_env_info + env_info = ( + EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info + ) return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env() @property def latest_iteration_state(self): latest_iteration_state = super(MPPI, self).latest_iteration_state - if hasattr(self, '_dynamics_trainer_state'): - print('latest iteration state') - latest_iteration_state['scalar'].update( - {'dynamics_loss': float(self._dynamics_trainer_state['dynamics_loss'])}) + if hasattr(self, "_dynamics_trainer_state"): + print("latest iteration state") + latest_iteration_state["scalar"].update( + {"dynamics_loss": float(self._dynamics_trainer_state["dynamics_loss"])} + ) return latest_iteration_state @classmethod diff --git a/nnabla_rl/algorithms/munchausen_dqn.py b/nnabla_rl/algorithms/munchausen_dqn.py index 5d3bc9f1..4c6fab1d 100644 --- a/nnabla_rl/algorithms/munchausen_dqn.py +++ b/nnabla_rl/algorithms/munchausen_dqn.py @@ -1,5 +1,5 @@ # Copyright 2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,8 +21,13 @@ import nnabla as nn import nnabla.solvers as NS import nnabla_rl.model_trainers as MT -from nnabla_rl.algorithms.dqn import (DQN, DefaultExplorerBuilder, DefaultQFunctionBuilder, DefaultReplayBufferBuilder, - DQNConfig) +from nnabla_rl.algorithms.dqn import ( + DQN, + DefaultExplorerBuilder, + DefaultQFunctionBuilder, + DefaultReplayBufferBuilder, + DQNConfig, +) from nnabla_rl.builders import ExplorerBuilder, ModelBuilder, ReplayBufferBuilder, SolverBuilder from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.models import QFunction @@ -60,15 +65,14 @@ def __post_init__(self): Check set values are in valid range. """ super().__post_init__() - self._assert_positive(self.max_explore_steps, 'max_explore_steps') - self._assert_negative(self.clipping_value, 'clipping_value') + self._assert_positive(self.max_explore_steps, "max_explore_steps") + self._assert_negative(self.clipping_value, "clipping_value") class DefaultQSolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: MunchausenDQNConfig, - **kwargs) -> nn.solvers.Solver: + def build_solver( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: MunchausenDQNConfig, **kwargs + ) -> nn.solvers.Solver: assert isinstance(algorithm_config, MunchausenDQNConfig) return NS.Adam(algorithm_config.learning_rate, eps=1e-2 / algorithm_config.batch_size) @@ -101,23 +105,28 @@ class MunchausenDQN(DQN): # See https://mypy.readthedocs.io/en/stable/class_basics.html for details _config: MunchausenDQNConfig - def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: MunchausenDQNConfig = MunchausenDQNConfig(), - q_func_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), - q_solver_builder: SolverBuilder = DefaultQSolverBuilder(), - replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), - explorer_builder: ExplorerBuilder = DefaultExplorerBuilder()): - super(MunchausenDQN, self).__init__(env_or_env_info=env_or_env_info, - config=config, - q_func_builder=q_func_builder, - q_solver_builder=q_solver_builder, - replay_buffer_builder=replay_buffer_builder, - explorer_builder=explorer_builder) + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: MunchausenDQNConfig = MunchausenDQNConfig(), + q_func_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), + q_solver_builder: SolverBuilder = DefaultQSolverBuilder(), + replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), + explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), + ): + super(MunchausenDQN, self).__init__( + env_or_env_info=env_or_env_info, + config=config, + q_func_builder=q_func_builder, + q_solver_builder=q_solver_builder, + replay_buffer_builder=replay_buffer_builder, + explorer_builder=explorer_builder, + ) def _setup_q_function_training(self, env_or_buffer): trainer_config = MT.q_value_trainers.MunchausenDQNQTrainerConfig( num_steps=self._config.num_steps, - reduction_method='mean', + reduction_method="mean", q_loss_scalar=0.5, grad_clip=(-1.0, 1.0), tau=self._config.entropy_temperature, @@ -126,13 +135,15 @@ def _setup_q_function_training(self, env_or_buffer): clip_max=0.0, unroll_steps=self._config.unroll_steps, burn_in_steps=self._config.burn_in_steps, - reset_on_terminal=self._config.reset_rnn_on_terminal) + reset_on_terminal=self._config.reset_rnn_on_terminal, + ) q_function_trainer = MT.q_value_trainers.MunchausenDQNQTrainer( train_functions=self._q, solvers={self._q.scope_name: self._q_solver}, target_function=self._target_q, env_info=self._env_info, - config=trainer_config) + config=trainer_config, + ) sync_model(self._q, self._target_q) return q_function_trainer diff --git a/nnabla_rl/algorithms/munchausen_iqn.py b/nnabla_rl/algorithms/munchausen_iqn.py index caed501b..5bdb38a8 100644 --- a/nnabla_rl/algorithms/munchausen_iqn.py +++ b/nnabla_rl/algorithms/munchausen_iqn.py @@ -1,5 +1,5 @@ # Copyright 2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,8 +19,15 @@ import gym import nnabla_rl.model_trainers as MT -from nnabla_rl.algorithms.iqn import (IQN, DefaultExplorerBuilder, DefaultQuantileFunctionBuilder, - DefaultReplayBufferBuilder, DefaultSolverBuilder, IQNConfig, risk_neutral_measure) +from nnabla_rl.algorithms.iqn import ( + IQN, + DefaultExplorerBuilder, + DefaultQuantileFunctionBuilder, + DefaultReplayBufferBuilder, + DefaultSolverBuilder, + IQNConfig, + risk_neutral_measure, +) from nnabla_rl.builders import ExplorerBuilder, ModelBuilder, ReplayBufferBuilder, SolverBuilder from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.models import StateActionQuantileFunction @@ -48,8 +55,8 @@ def __post_init__(self): Check that set values are in valid range. """ super().__post_init__() - self._assert_positive(self.embedding_dim, 'embedding_dim') - self._assert_negative(self.clipping_value, 'clipping_value') + self._assert_positive(self.embedding_dim, "embedding_dim") + self._assert_negative(self.clipping_value, "clipping_value") class MunchausenIQN(IQN): @@ -81,21 +88,25 @@ class MunchausenIQN(IQN): # See https://mypy.readthedocs.io/en/stable/class_basics.html for details _config: MunchausenIQNConfig - def __init__(self, - env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: MunchausenIQNConfig = MunchausenIQNConfig(), - risk_measure_function=risk_neutral_measure, - quantile_function_builder: ModelBuilder[StateActionQuantileFunction] - = DefaultQuantileFunctionBuilder(), - quantile_solver_builder: SolverBuilder = DefaultSolverBuilder(), - replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), - explorer_builder: ExplorerBuilder = DefaultExplorerBuilder()): - super(MunchausenIQN, self).__init__(env_or_env_info, config=config, - risk_measure_function=risk_measure_function, - quantile_function_builder=quantile_function_builder, - quantile_solver_builder=quantile_solver_builder, - replay_buffer_builder=replay_buffer_builder, - explorer_builder=explorer_builder) + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: MunchausenIQNConfig = MunchausenIQNConfig(), + risk_measure_function=risk_neutral_measure, + quantile_function_builder: ModelBuilder[StateActionQuantileFunction] = DefaultQuantileFunctionBuilder(), + quantile_solver_builder: SolverBuilder = DefaultSolverBuilder(), + replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), + explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), + ): + super(MunchausenIQN, self).__init__( + env_or_env_info, + config=config, + risk_measure_function=risk_measure_function, + quantile_function_builder=quantile_function_builder, + quantile_solver_builder=quantile_solver_builder, + replay_buffer_builder=replay_buffer_builder, + explorer_builder=explorer_builder, + ) def _setup_quantile_function_training(self, env_or_buffer): trainer_config = MT.q_value_trainers.MunchausenIQNQTrainerConfig( @@ -110,14 +121,16 @@ def _setup_quantile_function_training(self, env_or_buffer): clip_max=0.0, unroll_steps=self._config.unroll_steps, burn_in_steps=self._config.burn_in_steps, - reset_on_terminal=self._config.reset_rnn_on_terminal) + reset_on_terminal=self._config.reset_rnn_on_terminal, + ) quantile_function_trainer = MT.q_value_trainers.MunchausenIQNQTrainer( train_functions=self._quantile_function, solvers={self._quantile_function.scope_name: self._quantile_function_solver}, target_function=self._target_quantile_function, env_info=self._env_info, - config=trainer_config) + config=trainer_config, + ) # NOTE: Copy initial parameters after setting up the training # Because the parameter is created after training graph construction diff --git a/nnabla_rl/algorithms/ppo.py b/nnabla_rl/algorithms/ppo.py index 4dfc49f8..73e8cc8c 100644 --- a/nnabla_rl/algorithms/ppo.py +++ b/nnabla_rl/algorithms/ppo.py @@ -30,21 +30,39 @@ import nnabla_rl.preprocessors as RP import nnabla_rl.utils.context as context from nnabla_rl.algorithm import Algorithm, AlgorithmConfig, eval_api -from nnabla_rl.algorithms.common_utils import (_StatePreprocessedStochasticPolicy, _StatePreprocessedVFunction, - _StochasticPolicyActionSelector, compute_v_target_and_advantage) +from nnabla_rl.algorithms.common_utils import ( + _StatePreprocessedStochasticPolicy, + _StatePreprocessedVFunction, + _StochasticPolicyActionSelector, + compute_v_target_and_advantage, +) from nnabla_rl.builders import ModelBuilder, PreprocessorBuilder, SolverBuilder from nnabla_rl.environment_explorer import EnvironmentExplorer from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.model_trainers.model_trainer import ModelTrainer, TrainingBatch -from nnabla_rl.models import (Model, PPOAtariPolicy, PPOAtariVFunction, PPOMujocoPolicy, PPOMujocoVFunction, - PPOSharedFunctionHead, StochasticPolicy, VFunction) +from nnabla_rl.models import ( + Model, + PPOAtariPolicy, + PPOAtariVFunction, + PPOMujocoPolicy, + PPOMujocoVFunction, + PPOSharedFunctionHead, + StochasticPolicy, + VFunction, +) from nnabla_rl.preprocessors.preprocessor import Preprocessor from nnabla_rl.replay_buffer import ReplayBuffer from nnabla_rl.replay_buffers import BufferIterator from nnabla_rl.utils.data import add_batch_dimension, marshal_experiences, set_data_to_variable, unzip from nnabla_rl.utils.misc import create_variable -from nnabla_rl.utils.multiprocess import (copy_mp_arrays_to_params, copy_params_to_mp_arrays, mp_array_from_np_array, - mp_to_np_array, new_mp_arrays_from_params, np_to_mp_array) +from nnabla_rl.utils.multiprocess import ( + copy_mp_arrays_to_params, + copy_params_to_mp_arrays, + mp_array_from_np_array, + mp_to_np_array, + new_mp_arrays_from_params, + np_to_mp_array, +) from nnabla_rl.utils.reproductions import set_global_seed @@ -79,7 +97,7 @@ class PPOConfig(AlgorithmConfig): epsilon: float = 0.1 gamma: float = 0.99 - learning_rate: float = 2.5*1e-4 + learning_rate: float = 2.5 * 1e-4 lmb: float = 0.95 entropy_coefficient: float = 0.01 value_coefficient: float = 1.0 @@ -98,90 +116,81 @@ def __post_init__(self): Check the set values are in valid range. """ - self._assert_between(self.gamma, 0.0, 1.0, 'gamma') - self._assert_positive(self.actor_num, 'actor num') - self._assert_positive(self.epochs, 'epochs') - self._assert_positive(self.batch_size, 'batch_size') - self._assert_positive(self.actor_timesteps, 'actor_timesteps') - self._assert_positive(self.total_timesteps, 'total_timesteps') + self._assert_between(self.gamma, 0.0, 1.0, "gamma") + self._assert_positive(self.actor_num, "actor num") + self._assert_positive(self.epochs, "epochs") + self._assert_positive(self.batch_size, "batch_size") + self._assert_positive(self.actor_timesteps, "actor_timesteps") + self._assert_positive(self.total_timesteps, "total_timesteps") class DefaultPolicyBuilder(ModelBuilder[StochasticPolicy]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: PPOConfig, - **kwargs) -> StochasticPolicy: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: PPOConfig, + **kwargs, + ) -> StochasticPolicy: if env_info.is_discrete_action_env(): # scope name is same as that of v-function -> parameter is shared across models automatically return self._build_shared_policy("shared", env_info, algorithm_config) else: return self._build_mujoco_policy(scope_name, env_info, algorithm_config) - def _build_shared_policy(self, - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: PPOConfig, - **kwargs) -> StochasticPolicy: - _shared_function_head = PPOSharedFunctionHead(scope_name=scope_name, - state_shape=env_info.state_shape, - action_dim=env_info.action_dim) + def _build_shared_policy( + self, scope_name: str, env_info: EnvironmentInfo, algorithm_config: PPOConfig, **kwargs + ) -> StochasticPolicy: + _shared_function_head = PPOSharedFunctionHead( + scope_name=scope_name, state_shape=env_info.state_shape, action_dim=env_info.action_dim + ) return PPOAtariPolicy(scope_name=scope_name, action_dim=env_info.action_dim, head=_shared_function_head) - def _build_mujoco_policy(self, - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: PPOConfig, - **kwargs) -> StochasticPolicy: + def _build_mujoco_policy( + self, scope_name: str, env_info: EnvironmentInfo, algorithm_config: PPOConfig, **kwargs + ) -> StochasticPolicy: return PPOMujocoPolicy(scope_name=scope_name, action_dim=env_info.action_dim) class DefaultVFunctionBuilder(ModelBuilder[VFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: PPOConfig, - **kwargs) -> VFunction: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: PPOConfig, + **kwargs, + ) -> VFunction: if env_info.is_discrete_action_env(): # scope name is same as that of policy -> parameter is shared across models automatically return self._build_shared_v_function("shared", env_info, algorithm_config) else: return self._build_mujoco_v_function(scope_name, env_info, algorithm_config) - def _build_shared_v_function(self, - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: PPOConfig, - **kwargs) -> VFunction: - _shared_function_head = PPOSharedFunctionHead(scope_name=scope_name, - state_shape=env_info.state_shape, - action_dim=env_info.action_dim) + def _build_shared_v_function( + self, scope_name: str, env_info: EnvironmentInfo, algorithm_config: PPOConfig, **kwargs + ) -> VFunction: + _shared_function_head = PPOSharedFunctionHead( + scope_name=scope_name, state_shape=env_info.state_shape, action_dim=env_info.action_dim + ) return PPOAtariVFunction(scope_name=scope_name, head=_shared_function_head) - def _build_mujoco_v_function(self, - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: PPOConfig, - **kwargs) -> VFunction: + def _build_mujoco_v_function( + self, scope_name: str, env_info: EnvironmentInfo, algorithm_config: PPOConfig, **kwargs + ) -> VFunction: return PPOMujocoVFunction(scope_name=scope_name) class DefaultSolverBuilder(SolverBuilder): - def build_solver(self, - env_info: EnvironmentInfo, - algorithm_config: AlgorithmConfig, - **kwargs) -> nn.solver.Solver: + def build_solver(self, env_info: EnvironmentInfo, algorithm_config: AlgorithmConfig, **kwargs) -> nn.solver.Solver: assert isinstance(algorithm_config, PPOConfig) return NS.Adam(alpha=algorithm_config.learning_rate, eps=1e-5) class DefaultPreprocessorBuilder(PreprocessorBuilder): - def build_preprocessor(self, - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: AlgorithmConfig, - **kwargs) -> Preprocessor: - return RP.RunningMeanNormalizer('preprocessor', env_info.state_shape, value_clip=(-5.0, 5.0)) + def build_preprocessor( + self, scope_name: str, env_info: EnvironmentInfo, algorithm_config: AlgorithmConfig, **kwargs + ) -> Preprocessor: + return RP.RunningMeanNormalizer("preprocessor", env_info.state_shape, value_clip=(-5.0, 5.0)) class PPO(Algorithm): @@ -225,29 +234,32 @@ class PPO(Algorithm): _policy_solver_builder: SolverBuilder _v_solver_builder: SolverBuilder - _actors: List['_PPOActor'] + _actors: List["_PPOActor"] _actor_processes: List[Union[mp.Process, th.Thread]] _policy_trainer_state: Dict[str, Any] _v_function_trainer_state: Dict[str, Any] - def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: PPOConfig = PPOConfig(), - v_function_builder: ModelBuilder[VFunction] = DefaultVFunctionBuilder(), - v_solver_builder: SolverBuilder = DefaultSolverBuilder(), - policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), - policy_solver_builder: SolverBuilder = DefaultSolverBuilder(), - state_preprocessor_builder: Optional[PreprocessorBuilder] = DefaultPreprocessorBuilder()): + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: PPOConfig = PPOConfig(), + v_function_builder: ModelBuilder[VFunction] = DefaultVFunctionBuilder(), + v_solver_builder: SolverBuilder = DefaultSolverBuilder(), + policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), + policy_solver_builder: SolverBuilder = DefaultSolverBuilder(), + state_preprocessor_builder: Optional[PreprocessorBuilder] = DefaultPreprocessorBuilder(), + ): super(PPO, self).__init__(env_or_env_info, config=config) # Initialize on cpu and change the context later with nn.context_scope(context.get_nnabla_context(-1)): - self._v_function = v_function_builder('v', self._env_info, self._config) - self._policy = policy_builder('pi', self._env_info, self._config) + self._v_function = v_function_builder("v", self._env_info, self._config) + self._policy = policy_builder("pi", self._env_info, self._config) self._state_preprocessor = None if self._config.preprocess_state and state_preprocessor_builder is not None: - preprocessor = state_preprocessor_builder('preprocessor', self._env_info, self._config) + preprocessor = state_preprocessor_builder("preprocessor", self._env_info, self._config) assert preprocessor is not None self._v_function = _StatePreprocessedVFunction(v_function=self._v_function, preprocessor=preprocessor) self._policy = _StatePreprocessedStochasticPolicy(policy=self._policy, preprocessor=preprocessor) @@ -259,7 +271,8 @@ def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], self._v_solver_builder = v_solver_builder # keep for later use self._evaluation_actor = _StochasticPolicyActionSelector( - self._env_info, self._policy.shallowcopy(), deterministic=False) + self._env_info, self._policy.shallowcopy(), deterministic=False + ) @eval_api def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}): @@ -269,7 +282,7 @@ def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}): def _before_training_start(self, env_or_buffer): if not self._is_env(env_or_buffer): - raise ValueError('PPO only supports online training') + raise ValueError("PPO only supports online training") env = env_or_buffer # FIXME: This setup is a workaround for creating underlying model parameters @@ -295,27 +308,27 @@ def _before_training_start(self, env_or_buffer): def _setup_policy_training(self, env_or_buffer): policy_trainer_config = MT.policy_trainers.PPOPolicyTrainerConfig( - epsilon=self._config.epsilon, - entropy_coefficient=self._config.entropy_coefficient + epsilon=self._config.epsilon, entropy_coefficient=self._config.entropy_coefficient ) policy_trainer = MT.policy_trainers.PPOPolicyTrainer( models=self._policy, solvers={self._policy.scope_name: self._policy_solver}, env_info=self._env_info, - config=policy_trainer_config) + config=policy_trainer_config, + ) return policy_trainer def _setup_v_function_training(self, env_or_buffer): # training input/loss variables v_function_trainer_config = MT.v_value.MonteCarloVTrainerConfig( - reduction_method='mean', - v_loss_scalar=self._config.value_coefficient + reduction_method="mean", v_loss_scalar=self._config.value_coefficient ) v_function_trainer = MT.v_value.MonteCarloVTrainer( train_functions=self._v_function, solvers={self._v_function.scope_name: self._v_function_solver}, env_info=self._env_info, - config=v_function_trainer_config) + config=v_function_trainer_config, + ) return v_function_trainer def _after_training_finish(self, env_or_buffer): @@ -332,8 +345,7 @@ def normalize(values): if self.iteration_num % update_interval != 0: return - s, a, r, non_terminal, s_next, log_prob, v_targets, advantages = \ - self._collect_experiences(self._actors) + s, a, r, non_terminal, s_next, log_prob, v_targets, advantages = self._collect_experiences(self._actors) if self._config.preprocess_state: self._state_preprocessor.update(s) @@ -350,10 +362,9 @@ def normalize(values): buffer_iterator.reset() def _launch_actor_processes(self, env): - actors = self._build_ppo_actors(env, - v_function=self._v_function, - policy=self._policy, - state_preprocessor=self._state_preprocessor) + actors = self._build_ppo_actors( + env, v_function=self._v_function, policy=self._policy, state_preprocessor=self._state_preprocessor + ) processes = [] for actor in actors: if self._config.actor_num == 1: @@ -408,21 +419,18 @@ def concat_result(result): def _ppo_training(self, experiences): if self._config.decrease_alpha: - alpha = (1.0 - self.iteration_num / self._config.total_timesteps) + alpha = 1.0 - self.iteration_num / self._config.total_timesteps alpha = np.maximum(alpha, 0.0) else: alpha = 1.0 (s, a, _, _, _, log_prob, v_target, advantage) = marshal_experiences(experiences) extra = {} - extra['log_prob'] = log_prob - extra['advantage'] = advantage - extra['alpha'] = alpha - extra['v_target'] = v_target - batch = TrainingBatch(batch_size=len(experiences), - s_current=s, - a_current=a, - extra=extra) + extra["log_prob"] = log_prob + extra["advantage"] = advantage + extra["alpha"] = alpha + extra["v_target"] = v_target + batch = TrainingBatch(batch_size=len(experiences), s_current=s, a_current=a, extra=extra) self._policy_trainer.set_learning_rate(self._config.learning_rate * alpha) self._policy_trainer_state = self._policy_trainer.train(batch) @@ -448,8 +456,9 @@ def _solvers(self): @classmethod def is_supported_env(cls, env_or_env_info): - env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \ - else env_or_env_info + env_info = ( + EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info + ) return not env_info.is_tuple_action_env() def _build_ppo_actors(self, env, v_function, policy, state_preprocessor): @@ -462,17 +471,18 @@ def _build_ppo_actors(self, env, v_function, policy, state_preprocessor): v_function=v_function, policy=policy, state_preprocessor=state_preprocessor, - config=self._config) + config=self._config, + ) actors.append(actor) return actors @property def latest_iteration_state(self): latest_iteration_state = super(PPO, self).latest_iteration_state - if hasattr(self, '_policy_trainer_state'): - latest_iteration_state['scalar'].update({'pi_loss': float(self._policy_trainer_state['pi_loss'])}) - if hasattr(self, '_v_function_trainer_state'): - latest_iteration_state['scalar'].update({'v_loss': float(self._v_function_trainer_state['v_loss'])}) + if hasattr(self, "_policy_trainer_state"): + latest_iteration_state["scalar"].update({"pi_loss": float(self._policy_trainer_state["pi_loss"])}) + if hasattr(self, "_v_function_trainer_state"): + latest_iteration_state["scalar"].update({"v_loss": float(self._v_function_trainer_state["v_loss"])}) return latest_iteration_state @property @@ -511,7 +521,7 @@ def __init__(self, actor_num, env, env_info, v_function, policy, state_preproces self._config = config # IPC communication variables - self._disposed = mp.Value('i', False) + self._disposed = mp.Value("i", False) self._task_start_event = mp.Event() self._task_finish_event = mp.Event() @@ -521,35 +531,50 @@ def __init__(self, actor_num, env, env_info, v_function, policy, state_preproces self._state_preprocessor_mp_arrays = new_mp_arrays_from_params(state_preprocessor.get_parameters()) explorer_config = EE.RawPolicyExplorerConfig( - initial_step_num=0, - timelimit_as_terminal=self._config.timelimit_as_terminal + initial_step_num=0, timelimit_as_terminal=self._config.timelimit_as_terminal + ) + self._environment_explorer = EE.RawPolicyExplorer( + policy_action_selector=self._compute_action, env_info=self._env_info, config=explorer_config ) - self._environment_explorer = EE.RawPolicyExplorer(policy_action_selector=self._compute_action, - env_info=self._env_info, - config=explorer_config) obs_space = self._env.observation_space action_space = self._env.action_space - MultiProcessingArrays = namedtuple('MultiProcessingArrays', - ['state', 'action', 'reward', 'non_terminal', - 'next_state', 'log_prob', 'v_target', 'advantage']) + MultiProcessingArrays = namedtuple( + "MultiProcessingArrays", + ["state", "action", "reward", "non_terminal", "next_state", "log_prob", "v_target", "advantage"], + ) state_mp_array = self._prepare_state_mp_array(obs_space, env_info) action_mp_array = self._prepare_action_mp_array(action_space, env_info) scalar_mp_array_shape = (self._timesteps, 1) - reward_mp_array = (mp_array_from_np_array( - np.empty(shape=scalar_mp_array_shape, dtype=np.float32)), scalar_mp_array_shape, np.float32) - non_terminal_mp_array = (mp_array_from_np_array( - np.empty(shape=scalar_mp_array_shape, dtype=np.float32)), scalar_mp_array_shape, np.float32) + reward_mp_array = ( + mp_array_from_np_array(np.empty(shape=scalar_mp_array_shape, dtype=np.float32)), + scalar_mp_array_shape, + np.float32, + ) + non_terminal_mp_array = ( + mp_array_from_np_array(np.empty(shape=scalar_mp_array_shape, dtype=np.float32)), + scalar_mp_array_shape, + np.float32, + ) next_state_mp_array = self._prepare_state_mp_array(obs_space, env_info) - log_prob_mp_array = (mp_array_from_np_array( - np.empty(shape=scalar_mp_array_shape, dtype=np.float32)), scalar_mp_array_shape, np.float32) - v_target_mp_array = (mp_array_from_np_array( - np.empty(shape=scalar_mp_array_shape, dtype=np.float32)), scalar_mp_array_shape, np.float32) - advantage_mp_array = (mp_array_from_np_array( - np.empty(shape=scalar_mp_array_shape, dtype=np.float32)), scalar_mp_array_shape, np.float32) + log_prob_mp_array = ( + mp_array_from_np_array(np.empty(shape=scalar_mp_array_shape, dtype=np.float32)), + scalar_mp_array_shape, + np.float32, + ) + v_target_mp_array = ( + mp_array_from_np_array(np.empty(shape=scalar_mp_array_shape, dtype=np.float32)), + scalar_mp_array_shape, + np.float32, + ) + advantage_mp_array = ( + mp_array_from_np_array(np.empty(shape=scalar_mp_array_shape, dtype=np.float32)), + scalar_mp_array_shape, + np.float32, + ) self._mp_arrays = MultiProcessingArrays( state_mp_array, @@ -559,7 +584,7 @@ def __init__(self, actor_num, env, env_info, v_function, policy, state_preproces next_state_mp_array, log_prob_mp_array, v_target_mp_array, - advantage_mp_array + advantage_mp_array, ) def __call__(self): @@ -580,6 +605,7 @@ def _mp_to_np_array(mp_array): return tuple(mp_to_np_array(*array) for array in mp_array) else: return mp_to_np_array(*mp_array) + self._task_finish_event.wait() return tuple(_mp_to_np_array(mp_array) for mp_array in self._mp_arrays) @@ -601,7 +627,7 @@ def _run_actor_loop(self): set_global_seed(seed) self._env.seed(seed) - while (True): + while True: self._task_start_event.wait() if self._disposed.get_obj(): break @@ -620,16 +646,18 @@ def _run_actor_loop(self): def _run_data_collection(self): experiences = self._environment_explorer.step(self._env, n=self._timesteps) - experiences = [(s, a, r, non_terminal, s_next, info['log_prob']) - for (s, a, r, non_terminal, s_next, info) in experiences] + experiences = [ + (s, a, r, non_terminal, s_next, info["log_prob"]) for (s, a, r, non_terminal, s_next, info) in experiences + ] v_targets, advantages = compute_v_target_and_advantage( - self._v_function, experiences, gamma=self._gamma, lmb=self._lambda) + self._v_function, experiences, gamma=self._gamma, lmb=self._lambda + ) return experiences, v_targets, advantages @eval_api def _compute_action(self, s, *, begin_of_episode=False): s = add_batch_dimension(s) - if not hasattr(self, '_eval_state_var'): + if not hasattr(self, "_eval_state_var"): self._eval_state_var = create_variable(1, self._env_info.state_shape) distribution = self._policy.pi(self._eval_state_var) self._eval_action, self._eval_log_prob = distribution.sample_and_compute_log_prob() @@ -638,7 +666,7 @@ def _compute_action(self, s, *, begin_of_episode=False): action = np.squeeze(self._eval_action.d, axis=0) log_prob = np.squeeze(self._eval_log_prob.d, axis=0) info = {} - info['log_prob'] = log_prob + info["log_prob"] = log_prob if self._env_info.is_discrete_action_env(): return np.int32(action), info else: @@ -677,27 +705,23 @@ def _prepare_state_mp_array(self, obs_space, env_info): state_mp_array_dtypes = [] for space in obs_space: state_mp_array_shape = (self._timesteps, *space.shape) - state_mp_array = mp_array_from_np_array( - np.empty(shape=state_mp_array_shape, dtype=space.dtype)) + state_mp_array = mp_array_from_np_array(np.empty(shape=state_mp_array_shape, dtype=space.dtype)) state_mp_array_shapes.append(state_mp_array_shape) state_mp_array_dtypes.append(space.dtype) state_mp_arrays.append(state_mp_array) return tuple(x for x in zip(state_mp_arrays, state_mp_array_shapes, state_mp_array_dtypes)) else: state_mp_array_shape = (self._timesteps, *obs_space.shape) - state_mp_array = mp_array_from_np_array( - np.empty(shape=state_mp_array_shape, dtype=obs_space.dtype)) + state_mp_array = mp_array_from_np_array(np.empty(shape=state_mp_array_shape, dtype=obs_space.dtype)) return (state_mp_array, state_mp_array_shape, obs_space.dtype) def _prepare_action_mp_array(self, action_space, env_info): if env_info.is_discrete_action_env(): action_mp_array_shape = (self._timesteps, 1) - action_mp_array = mp_array_from_np_array( - np.empty(shape=action_mp_array_shape, dtype=action_space.dtype)) + action_mp_array = mp_array_from_np_array(np.empty(shape=action_mp_array_shape, dtype=action_space.dtype)) else: action_mp_array_shape = (self._timesteps, action_space.shape[0]) - action_mp_array = mp_array_from_np_array( - np.empty(shape=action_mp_array_shape, dtype=action_space.dtype)) + action_mp_array = mp_array_from_np_array(np.empty(shape=action_mp_array_shape, dtype=action_space.dtype)) return (action_mp_array, action_mp_array_shape, action_space.dtype) diff --git a/nnabla_rl/algorithms/qrdqn.py b/nnabla_rl/algorithms/qrdqn.py index 73bf3290..b01868f5 100644 --- a/nnabla_rl/algorithms/qrdqn.py +++ b/nnabla_rl/algorithms/qrdqn.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -101,66 +101,69 @@ def __post_init__(self): Check that set values are in valid range. """ - self._assert_between(self.gamma, 0.0, 1.0, 'gamma') - self._assert_positive(self.batch_size, 'batch_size') - self._assert_positive(self.num_steps, 'num_steps') - self._assert_positive(self.replay_buffer_size, 'replay_buffer_size') - self._assert_positive(self.learner_update_frequency, 'learner_update_frequency') - self._assert_positive(self.target_update_frequency, 'target_update_frequency') - self._assert_positive(self.max_explore_steps, 'max_explore_steps') - self._assert_positive(self.learning_rate, 'learning_rate') - self._assert_positive(self.initial_epsilon, 'initial_epsilon') - self._assert_positive(self.final_epsilon, 'final_epsilon') - self._assert_positive(self.test_epsilon, 'test_epsilon') - self._assert_positive(self.num_quantiles, 'num_quantiles') - self._assert_positive(self.kappa, 'kappa') - self._assert_positive(self.unroll_steps, 'unroll_steps') - self._assert_positive_or_zero(self.burn_in_steps, 'burn_in_steps') + self._assert_between(self.gamma, 0.0, 1.0, "gamma") + self._assert_positive(self.batch_size, "batch_size") + self._assert_positive(self.num_steps, "num_steps") + self._assert_positive(self.replay_buffer_size, "replay_buffer_size") + self._assert_positive(self.learner_update_frequency, "learner_update_frequency") + self._assert_positive(self.target_update_frequency, "target_update_frequency") + self._assert_positive(self.max_explore_steps, "max_explore_steps") + self._assert_positive(self.learning_rate, "learning_rate") + self._assert_positive(self.initial_epsilon, "initial_epsilon") + self._assert_positive(self.final_epsilon, "final_epsilon") + self._assert_positive(self.test_epsilon, "test_epsilon") + self._assert_positive(self.num_quantiles, "num_quantiles") + self._assert_positive(self.kappa, "kappa") + self._assert_positive(self.unroll_steps, "unroll_steps") + self._assert_positive_or_zero(self.burn_in_steps, "burn_in_steps") class DefaultQuantileBuilder(ModelBuilder[QuantileDistributionFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: QRDQNConfig, - **kwargs) -> QuantileDistributionFunction: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: QRDQNConfig, + **kwargs, + ) -> QuantileDistributionFunction: return QRDQNQuantileDistributionFunction(scope_name, env_info.action_dim, algorithm_config.num_quantiles) class DefaultSolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: QRDQNConfig, - **kwargs) -> nn.solver.Solver: + def build_solver( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: QRDQNConfig, **kwargs + ) -> nn.solver.Solver: return NS.Adam(alpha=algorithm_config.learning_rate, eps=1e-2 / algorithm_config.batch_size) class DefaultReplayBufferBuilder(ReplayBufferBuilder): - def build_replay_buffer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: QRDQNConfig, - **kwargs) -> ReplayBuffer: + def build_replay_buffer( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: QRDQNConfig, **kwargs + ) -> ReplayBuffer: return ReplayBuffer(capacity=algorithm_config.replay_buffer_size) class DefaultExplorerBuilder(ExplorerBuilder): - def build_explorer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: QRDQNConfig, - algorithm: "QRDQN", - **kwargs) -> EnvironmentExplorer: + def build_explorer( # type: ignore[override] + self, + env_info: EnvironmentInfo, + algorithm_config: QRDQNConfig, + algorithm: "QRDQN", + **kwargs, + ) -> EnvironmentExplorer: explorer_config = EE.LinearDecayEpsilonGreedyExplorerConfig( warmup_random_steps=algorithm_config.start_timesteps, initial_step_num=algorithm.iteration_num, initial_epsilon=algorithm_config.initial_epsilon, final_epsilon=algorithm_config.final_epsilon, - max_explore_steps=algorithm_config.max_explore_steps + max_explore_steps=algorithm_config.max_explore_steps, ) explorer = EE.LinearDecayEpsilonGreedyExplorer( greedy_action_selector=algorithm._exploration_action_selector, random_action_selector=algorithm._random_action_selector, env_info=env_info, - config=explorer_config) + config=explorer_config, + ) return explorer @@ -203,36 +206,43 @@ class QRDQN(Algorithm): _quantile_dist_trainer_state: Dict[str, Any] - def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: QRDQNConfig = QRDQNConfig(), - quantile_dist_function_builder: ModelBuilder[QuantileDistributionFunction] = DefaultQuantileBuilder(), - quantile_solver_builder: SolverBuilder = DefaultSolverBuilder(), - replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), - explorer_builder: ExplorerBuilder = DefaultExplorerBuilder()): + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: QRDQNConfig = QRDQNConfig(), + quantile_dist_function_builder: ModelBuilder[QuantileDistributionFunction] = DefaultQuantileBuilder(), + quantile_solver_builder: SolverBuilder = DefaultSolverBuilder(), + replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), + explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), + ): super(QRDQN, self).__init__(env_or_env_info, config=config) self._explorer_builder = explorer_builder with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)): - self._quantile_dist = quantile_dist_function_builder('quantile_dist_train', self._env_info, self._config) + self._quantile_dist = quantile_dist_function_builder("quantile_dist_train", self._env_info, self._config) self._quantile_dist_solver = quantile_solver_builder(self._env_info, self._config) - self._target_quantile_dist = self._quantile_dist.deepcopy('quantile_dist_target') + self._target_quantile_dist = self._quantile_dist.deepcopy("quantile_dist_target") self._replay_buffer = replay_buffer_builder(self._env_info, self._config) self._evaluation_actor = _GreedyActionSelector( - self._env_info, self._quantile_dist.shallowcopy().as_q_function()) + self._env_info, self._quantile_dist.shallowcopy().as_q_function() + ) self._exploration_actor = _GreedyActionSelector( - self._env_info, self._quantile_dist.shallowcopy().as_q_function()) + self._env_info, self._quantile_dist.shallowcopy().as_q_function() + ) @eval_api def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}): with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)): - (action, _), _ = epsilon_greedy_action_selection(state, - self._evaluation_action_selector, - self._random_action_selector, - epsilon=self._config.test_epsilon, - begin_of_episode=begin_of_episode) + (action, _), _ = epsilon_greedy_action_selection( + state, + self._evaluation_action_selector, + self._random_action_selector, + epsilon=self._config.test_epsilon, + begin_of_episode=begin_of_episode, + ) return action def _before_training_start(self, env_or_buffer): @@ -251,14 +261,16 @@ def _setup_quantile_function_training(self, env_or_buffer): kappa=self._config.kappa, unroll_steps=self._config.unroll_steps, burn_in_steps=self._config.burn_in_steps, - reset_on_terminal=self._config.reset_rnn_on_terminal) + reset_on_terminal=self._config.reset_rnn_on_terminal, + ) quantile_dist_trainer = MT.q_value_trainers.QRDQNQTrainer( train_functions=self._quantile_dist, solvers={self._quantile_dist.scope_name: self._quantile_dist_solver}, target_function=self._target_quantile_dist, env_info=self._env_info, - config=trainer_config) + config=trainer_config, + ) # NOTE: Copy initial parameters after setting up the training # Because the parameter is created after training graph construction @@ -279,23 +291,25 @@ def _qrdqn_training(self, replay_buffer): num_steps = self._config.num_steps + self._config.burn_in_steps + self._config.unroll_steps - 1 experiences_tuple, info = replay_buffer.sample(self._config.batch_size, num_steps=num_steps) if num_steps == 1: - experiences_tuple = (experiences_tuple, ) + experiences_tuple = (experiences_tuple,) assert len(experiences_tuple) == num_steps batch = None for experiences in reversed(experiences_tuple): (s, a, r, non_terminal, s_next, rnn_states_dict, *_) = marshal_experiences(experiences) - rnn_states = rnn_states_dict['rnn_states'] if 'rnn_states' in rnn_states_dict else {} - batch = TrainingBatch(batch_size=self._config.batch_size, - s_current=s, - a_current=a, - gamma=self._config.gamma, - reward=r, - non_terminal=non_terminal, - s_next=s_next, - weight=info['weights'], - next_step_batch=batch, - rnn_states=rnn_states) + rnn_states = rnn_states_dict["rnn_states"] if "rnn_states" in rnn_states_dict else {} + batch = TrainingBatch( + batch_size=self._config.batch_size, + s_current=s, + a_current=a, + gamma=self._config.gamma, + reward=r, + non_terminal=non_terminal, + s_next=s_next, + weight=info["weights"], + next_step_batch=batch, + rnn_states=rnn_states, + ) self._quantile_dist_trainer_state = self._quantile_dist_trainer.train(batch) if self.iteration_num % self._config.target_update_frequency: @@ -307,9 +321,9 @@ def _evaluation_action_selector(self, s, *, begin_of_episode=False): def _exploration_action_selector(self, s, *, begin_of_episode=False): return self._exploration_actor(s, begin_of_episode=begin_of_episode) - def _random_action_selector(self, s, *, begin_of_episode=False): + def _random_action_selector(self, s, *, begin_of_episode=False): action = self._env_info.action_space.sample() - return np.asarray(action).reshape((1, )), {} + return np.asarray(action).reshape((1,)), {} def _models(self): models = {} @@ -323,8 +337,9 @@ def _solvers(self): @classmethod def is_supported_env(cls, env_or_env_info): - env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \ - else env_or_env_info + env_info = ( + EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info + ) return not env_info.is_continuous_action_env() and not env_info.is_tuple_action_env() @classmethod @@ -334,8 +349,8 @@ def is_rnn_supported(self): @property def latest_iteration_state(self): latest_iteration_state = super(QRDQN, self).latest_iteration_state - if hasattr(self, '_quantile_dist_trainer_state'): - latest_iteration_state['scalar'].update({'q_loss': float(self._quantile_dist_trainer_state['q_loss'])}) + if hasattr(self, "_quantile_dist_trainer_state"): + latest_iteration_state["scalar"].update({"q_loss": float(self._quantile_dist_trainer_state["q_loss"])}) return latest_iteration_state @property diff --git a/nnabla_rl/algorithms/qrsac.py b/nnabla_rl/algorithms/qrsac.py index b6db1dcc..0ca40269 100644 --- a/nnabla_rl/algorithms/qrsac.py +++ b/nnabla_rl/algorithms/qrsac.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,8 +27,12 @@ from nnabla_rl.environment_explorer import EnvironmentExplorer from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.model_trainers.model_trainer import ModelTrainer, TrainingBatch -from nnabla_rl.models import (QRSACQuantileDistributionFunction, QuantileDistributionFunction, SACPolicy, - StochasticPolicy) +from nnabla_rl.models import ( + QRSACQuantileDistributionFunction, + QuantileDistributionFunction, + SACPolicy, + StochasticPolicy, +) from nnabla_rl.replay_buffer import ReplayBuffer from nnabla_rl.utils import context from nnabla_rl.utils.data import marshal_experiences @@ -80,7 +84,7 @@ class QRSACConfig(AlgorithmConfig): """ gamma: float = 0.99 - learning_rate: float = 3.0*1e-4 + learning_rate: float = 3.0 * 1e-4 batch_size: int = 256 tau: float = 0.005 environment_steps: int = 1 @@ -107,70 +111,74 @@ class QRSACConfig(AlgorithmConfig): def __post_init__(self): """__post_init__ Check set values are in valid range.""" - self._assert_between(self.tau, 0.0, 1.0, 'tau') - self._assert_between(self.gamma, 0.0, 1.0, 'gamma') - self._assert_positive(self.gradient_steps, 'gradient_steps') - self._assert_positive(self.environment_steps, 'environment_steps') + self._assert_between(self.tau, 0.0, 1.0, "tau") + self._assert_between(self.gamma, 0.0, 1.0, "gamma") + self._assert_positive(self.gradient_steps, "gradient_steps") + self._assert_positive(self.environment_steps, "environment_steps") if self.initial_temperature is not None: - self._assert_positive(self.initial_temperature, 'initial_temperature') - self._assert_positive(self.start_timesteps, 'start_timesteps') + self._assert_positive(self.initial_temperature, "initial_temperature") + self._assert_positive(self.start_timesteps, "start_timesteps") - self._assert_positive(self.critic_unroll_steps, 'critic_unroll_steps') - self._assert_positive_or_zero(self.critic_burn_in_steps, 'critic_burn_in_steps') - self._assert_positive(self.actor_unroll_steps, 'actor_unroll_steps') - self._assert_positive_or_zero(self.actor_burn_in_steps, 'actor_burn_in_steps') - self._assert_positive(self.num_quantiles, 'num_quantiles') - self._assert_positive(self.kappa, 'kappa') + self._assert_positive(self.critic_unroll_steps, "critic_unroll_steps") + self._assert_positive_or_zero(self.critic_burn_in_steps, "critic_burn_in_steps") + self._assert_positive(self.actor_unroll_steps, "actor_unroll_steps") + self._assert_positive_or_zero(self.actor_burn_in_steps, "actor_burn_in_steps") + self._assert_positive(self.num_quantiles, "num_quantiles") + self._assert_positive(self.kappa, "kappa") class DefaultQuantileFunctionBuilder(ModelBuilder[QuantileDistributionFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: QRSACConfig, - **kwargs) -> QuantileDistributionFunction: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: QRSACConfig, + **kwargs, + ) -> QuantileDistributionFunction: return QRSACQuantileDistributionFunction(scope_name, n_quantile=algorithm_config.num_quantiles) class DefaultPolicyBuilder(ModelBuilder[StochasticPolicy]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: QRSACConfig, - **kwargs) -> StochasticPolicy: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: QRSACConfig, + **kwargs, + ) -> StochasticPolicy: return SACPolicy(scope_name, env_info.action_dim) class DefaultSolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: QRSACConfig, - **kwargs) -> nn.solver.Solver: + def build_solver( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: QRSACConfig, **kwargs + ) -> nn.solver.Solver: return NS.Adam(alpha=algorithm_config.learning_rate) class DefaultReplayBufferBuilder(ReplayBufferBuilder): - def build_replay_buffer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: QRSACConfig, - **kwargs) -> ReplayBuffer: + def build_replay_buffer( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: QRSACConfig, **kwargs + ) -> ReplayBuffer: return ReplayBuffer(capacity=algorithm_config.replay_buffer_size) class DefaultExplorerBuilder(ExplorerBuilder): - def build_explorer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: QRSACConfig, - algorithm: "QRSAC", - **kwargs) -> EnvironmentExplorer: + def build_explorer( # type: ignore[override] + self, + env_info: EnvironmentInfo, + algorithm_config: QRSACConfig, + algorithm: "QRSAC", + **kwargs, + ) -> EnvironmentExplorer: explorer_config = EE.RawPolicyExplorerConfig( warmup_random_steps=algorithm_config.start_timesteps, initial_step_num=algorithm.iteration_num, - timelimit_as_terminal=False + timelimit_as_terminal=False, + ) + explorer = EE.RawPolicyExplorer( + policy_action_selector=algorithm._exploration_action_selector, env_info=env_info, config=explorer_config ) - explorer = EE.RawPolicyExplorer(policy_action_selector=algorithm._exploration_action_selector, - env_info=env_info, - config=explorer_config) return explorer @@ -226,36 +234,41 @@ class QRSAC(Algorithm): _policy_trainer_state: Dict[str, Any] _quantile_function_trainer_state: Dict[str, Any] - def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: QRSACConfig = QRSACConfig(), - quantile_function_builder: ModelBuilder[QuantileDistributionFunction] - = DefaultQuantileFunctionBuilder(), - quantile_solver_builder: SolverBuilder = DefaultSolverBuilder(), - policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), - policy_solver_builder: SolverBuilder = DefaultSolverBuilder(), - temperature_solver_builder: SolverBuilder = DefaultSolverBuilder(), - replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), - explorer_builder: ExplorerBuilder = DefaultExplorerBuilder()): + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: QRSACConfig = QRSACConfig(), + quantile_function_builder: ModelBuilder[QuantileDistributionFunction] = DefaultQuantileFunctionBuilder(), + quantile_solver_builder: SolverBuilder = DefaultSolverBuilder(), + policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), + policy_solver_builder: SolverBuilder = DefaultSolverBuilder(), + temperature_solver_builder: SolverBuilder = DefaultSolverBuilder(), + replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), + explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), + ): super(QRSAC, self).__init__(env_or_env_info, config=config) self._explorer_builder = explorer_builder with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)): self._q1 = quantile_function_builder( - scope_name="q1", env_info=self._env_info, algorithm_config=self._config) + scope_name="q1", env_info=self._env_info, algorithm_config=self._config + ) self._q2 = quantile_function_builder( - scope_name="q2", env_info=self._env_info, algorithm_config=self._config) + scope_name="q2", env_info=self._env_info, algorithm_config=self._config + ) self._train_q_functions = [self._q1, self._q2] - self._train_q_solvers = {q.scope_name: quantile_solver_builder(self._env_info, self._config) - for q in self._train_q_functions} - self._target_q_functions = [q.deepcopy('target_' + q.scope_name) for q in self._train_q_functions] + self._train_q_solvers = { + q.scope_name: quantile_solver_builder(self._env_info, self._config) for q in self._train_q_functions + } + self._target_q_functions = [q.deepcopy("target_" + q.scope_name) for q in self._train_q_functions] self._pi = policy_builder(scope_name="pi", env_info=self._env_info, algorithm_config=self._config) self._pi_solver = policy_solver_builder(self._env_info, self._config) self._temperature = MT.policy_trainers.soft_policy_trainer.AdjustableTemperature( - scope_name='temperature', - initial_value=self._config.initial_temperature) + scope_name="temperature", initial_value=self._config.initial_temperature + ) if not self._config.fix_temperature: self._temperature_solver = temperature_solver_builder(self._env_info, self._config) else: @@ -264,9 +277,11 @@ def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], self._replay_buffer = replay_buffer_builder(self._env_info, self._config) self._evaluation_actor = _StochasticPolicyActionSelector( - self._env_info, self._pi.shallowcopy(), deterministic=True) + self._env_info, self._pi.shallowcopy(), deterministic=True + ) self._exploration_actor = _StochasticPolicyActionSelector( - self._env_info, self._pi.shallowcopy(), deterministic=False) + self._env_info, self._pi.shallowcopy(), deterministic=False + ) @eval_api def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}): @@ -290,7 +305,8 @@ def _setup_policy_training(self, env_or_buffer): target_entropy=self._config.target_entropy, unroll_steps=self._config.actor_unroll_steps, burn_in_steps=self._config.actor_burn_in_steps, - reset_on_terminal=self._config.actor_reset_rnn_on_terminal) + reset_on_terminal=self._config.actor_reset_rnn_on_terminal, + ) policy_trainer = MT.policy_trainers.SoftPolicyTrainer( models=self._pi, solvers={self._pi.scope_name: self._pi_solver}, @@ -298,7 +314,8 @@ def _setup_policy_training(self, env_or_buffer): temperature_solver=self._temperature_solver, q_functions=[self._q1.as_q_function(), self._q2.as_q_function()], env_info=self._env_info, - config=policy_trainer_config) + config=policy_trainer_config, + ) return policy_trainer def _setup_quantile_function_training(self, env_or_buffer): @@ -309,7 +326,8 @@ def _setup_quantile_function_training(self, env_or_buffer): num_steps=self._config.num_steps, unroll_steps=self._config.critic_unroll_steps, burn_in_steps=self._config.critic_burn_in_steps, - reset_on_terminal=self._config.critic_reset_rnn_on_terminal) + reset_on_terminal=self._config.critic_reset_rnn_on_terminal, + ) quantile_function_trainer = MT.q_value_trainers.QRSACQTrainer( train_functions=self._train_q_functions, @@ -318,7 +336,8 @@ def _setup_quantile_function_training(self, env_or_buffer): target_policy=self._pi, temperature=self._policy_trainer.get_temperature(), env_info=self._env_info, - config=quantile_function_trainer_config) + config=quantile_function_trainer_config, + ) for q, target_q in zip(self._train_q_functions, self._target_q_functions): sync_model(q, target_q) return quantile_function_trainer @@ -346,23 +365,25 @@ def _qrsac_training(self, replay_buffer): num_steps = max(actor_steps, critic_steps) experiences_tuple, info = replay_buffer.sample(self._config.batch_size, num_steps=num_steps) if num_steps == 1: - experiences_tuple = (experiences_tuple, ) + experiences_tuple = (experiences_tuple,) assert len(experiences_tuple) == num_steps batch = None for experiences in reversed(experiences_tuple): (s, a, r, non_terminal, s_next, rnn_states_dict, *_) = marshal_experiences(experiences) - rnn_states = rnn_states_dict['rnn_states'] if 'rnn_states' in rnn_states_dict else {} - batch = TrainingBatch(batch_size=self._config.batch_size, - s_current=s, - a_current=a, - gamma=self._config.gamma, - reward=r, - non_terminal=non_terminal, - s_next=s_next, - weight=info['weights'], - next_step_batch=batch, - rnn_states=rnn_states) + rnn_states = rnn_states_dict["rnn_states"] if "rnn_states" in rnn_states_dict else {} + batch = TrainingBatch( + batch_size=self._config.batch_size, + s_current=s, + a_current=a, + gamma=self._config.gamma, + reward=r, + non_terminal=non_terminal, + s_next=s_next, + weight=info["weights"], + next_step_batch=batch, + rnn_states=rnn_states, + ) self._quantile_function_trainer_state = self._quantile_function_trainer.train(batch) for q, target_q in zip(self._train_q_functions, self._target_q_functions): @@ -393,17 +414,18 @@ def is_rnn_supported(self): @classmethod def is_supported_env(cls, env_or_env_info): - env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \ - else env_or_env_info + env_info = ( + EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info + ) return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env() @property def latest_iteration_state(self): latest_iteration_state = super().latest_iteration_state - if hasattr(self, '_policy_trainer_state'): - latest_iteration_state['scalar'].update({'pi_loss': float(self._policy_trainer_state['pi_loss'])}) - if hasattr(self, '_quantile_function_trainer_state'): - latest_iteration_state['scalar'].update({'q_loss': float(self._quantile_function_trainer_state['q_loss'])}) + if hasattr(self, "_policy_trainer_state"): + latest_iteration_state["scalar"].update({"pi_loss": float(self._policy_trainer_state["pi_loss"])}) + if hasattr(self, "_quantile_function_trainer_state"): + latest_iteration_state["scalar"].update({"q_loss": float(self._quantile_function_trainer_state["q_loss"])}) return latest_iteration_state @property diff --git a/nnabla_rl/algorithms/rainbow.py b/nnabla_rl/algorithms/rainbow.py index d002b84d..c0857a3b 100644 --- a/nnabla_rl/algorithms/rainbow.py +++ b/nnabla_rl/algorithms/rainbow.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -61,6 +61,7 @@ class RainbowConfig(CategoricalDDQNConfig): categorical q value update. :math:`r + \\gamma\\max_{a}{Q_{\\text{target}}(s_{t+1}, a)}`. Defaults to False. """ + learning_rate: float = 0.00025 / 4 start_timesteps: int = 20000 # 20k steps = 80k frames in Atari game target_update_frequency: int = 8000 # 8k steps = 32k frames in Atari game @@ -76,57 +77,58 @@ class RainbowConfig(CategoricalDDQNConfig): class DefaultValueDistFunctionBuilder(ModelBuilder[ValueDistributionFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: RainbowConfig, - **kwargs) -> ValueDistributionFunction: - return RainbowValueDistributionFunction(scope_name, - env_info.action_dim, - algorithm_config.num_atoms, - algorithm_config.v_min, - algorithm_config.v_max) + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: RainbowConfig, + **kwargs, + ) -> ValueDistributionFunction: + return RainbowValueDistributionFunction( + scope_name, env_info.action_dim, algorithm_config.num_atoms, algorithm_config.v_min, algorithm_config.v_max + ) class DefaultReplayBufferBuilder(ReplayBufferBuilder): - def build_replay_buffer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: RainbowConfig, - **kwargs) -> ReplayBuffer: - return ProportionalPrioritizedReplayBuffer(capacity=algorithm_config.replay_buffer_size, - alpha=algorithm_config.alpha, - beta=algorithm_config.beta, - betasteps=algorithm_config.betasteps, - error_clip=(-100, 100), - normalization_method="batch_max") + def build_replay_buffer( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: RainbowConfig, **kwargs + ) -> ReplayBuffer: + return ProportionalPrioritizedReplayBuffer( + capacity=algorithm_config.replay_buffer_size, + alpha=algorithm_config.alpha, + beta=algorithm_config.beta, + betasteps=algorithm_config.betasteps, + error_clip=(-100, 100), + normalization_method="batch_max", + ) class DefaultSolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: RainbowConfig, - **kwargs) -> nn.solver.Solver: + def build_solver( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: RainbowConfig, **kwargs + ) -> nn.solver.Solver: return NS.Adam(alpha=algorithm_config.learning_rate, eps=1.5e-4) class DefaultExplorerBuilder(ExplorerBuilder): - def build_explorer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: RainbowConfig, - algorithm: "Rainbow", - **kwargs) -> EnvironmentExplorer: + def build_explorer( # type: ignore[override] + self, + env_info: EnvironmentInfo, + algorithm_config: RainbowConfig, + algorithm: "Rainbow", + **kwargs, + ) -> EnvironmentExplorer: explorer_config = EE.RawPolicyExplorerConfig( - warmup_random_steps=algorithm_config.warmup_random_steps, - initial_step_num=algorithm.iteration_num + warmup_random_steps=algorithm_config.warmup_random_steps, initial_step_num=algorithm.iteration_num + ) + explorer = EE.RawPolicyExplorer( + policy_action_selector=algorithm._exploration_action_selector, env_info=env_info, config=explorer_config ) - explorer = EE.RawPolicyExplorer(policy_action_selector=algorithm._exploration_action_selector, - env_info=env_info, - config=explorer_config) return explorer class Rainbow(CategoricalDDQN): - '''Rainbow algorithm. + """Rainbow algorithm. This class implements the Rainbow algorithm proposed by M. Bellemare, et al. in the paper: "Rainbow: Combining Improvements in Deep Reinforcement Learning" For details see: https://arxiv.org/abs/1710.02298 @@ -145,21 +147,25 @@ class Rainbow(CategoricalDDQN): builder of replay_buffer explorer_builder (:py:class:`ExplorerBuilder `): builder of environment explorer - ''' - - def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: RainbowConfig = RainbowConfig(), - value_distribution_builder: ModelBuilder[ValueDistributionFunction] - = DefaultValueDistFunctionBuilder(), - value_distribution_solver_builder: SolverBuilder = DefaultSolverBuilder(), - replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), - explorer_builder: ExplorerBuilder = DefaultExplorerBuilder()): - super(Rainbow, self).__init__(env_or_env_info, - config=config, - value_distribution_builder=value_distribution_builder, - value_distribution_solver_builder=value_distribution_solver_builder, - replay_buffer_builder=replay_buffer_builder, - explorer_builder=explorer_builder) + """ + + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: RainbowConfig = RainbowConfig(), + value_distribution_builder: ModelBuilder[ValueDistributionFunction] = DefaultValueDistFunctionBuilder(), + value_distribution_solver_builder: SolverBuilder = DefaultSolverBuilder(), + replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), + explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), + ): + super(Rainbow, self).__init__( + env_or_env_info, + config=config, + value_distribution_builder=value_distribution_builder, + value_distribution_solver_builder=value_distribution_solver_builder, + replay_buffer_builder=replay_buffer_builder, + explorer_builder=explorer_builder, + ) def _setup_value_distribution_function_training(self, env_or_buffer): if self._config.no_double: diff --git a/nnabla_rl/algorithms/redq.py b/nnabla_rl/algorithms/redq.py index 989b23ca..c767635d 100644 --- a/nnabla_rl/algorithms/redq.py +++ b/nnabla_rl/algorithms/redq.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,8 +18,15 @@ import gym import nnabla_rl.model_trainers as MT -from nnabla_rl.algorithms.sac import (SAC, DefaultExplorerBuilder, DefaultPolicyBuilder, DefaultQFunctionBuilder, - DefaultReplayBufferBuilder, DefaultSolverBuilder, SACConfig) +from nnabla_rl.algorithms.sac import ( + SAC, + DefaultExplorerBuilder, + DefaultPolicyBuilder, + DefaultQFunctionBuilder, + DefaultReplayBufferBuilder, + DefaultSolverBuilder, + SACConfig, +) from nnabla_rl.builders import ExplorerBuilder, ModelBuilder, ReplayBufferBuilder, SolverBuilder from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.model_trainers.model_trainer import TrainingBatch @@ -37,6 +44,7 @@ class REDQConfig(SACConfig): M (int): Size of subset M. Defaults to 2. N (int): Number of q functions of an ensemble. Defaults to 10. """ + # override timesteps start_timesteps: int = 5000 @@ -48,9 +56,9 @@ class REDQConfig(SACConfig): def __post_init__(self): """__post_init__ Check set values are in valid range.""" super().__post_init__() - self._assert_positive(self.G, 'G') - self._assert_positive(self.N, 'N') - self._assert_positive(self.M, 'M') + self._assert_positive(self.G, "G") + self._assert_positive(self.N, "N") + self._assert_positive(self.M, "M") class REDQ(SAC): @@ -87,35 +95,40 @@ class REDQ(SAC): # See https://mypy.readthedocs.io/en/stable/class_basics.html for details _config: REDQConfig - def __init__(self, - env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: REDQConfig = REDQConfig(), - q_function_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), - q_solver_builder: SolverBuilder = DefaultSolverBuilder(), - policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), - policy_solver_builder: SolverBuilder = DefaultSolverBuilder(), - temperature_solver_builder: SolverBuilder = DefaultSolverBuilder(), - replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), - explorer_builder: ExplorerBuilder = DefaultExplorerBuilder()): - super().__init__(env_or_env_info, - config=config, - q_function_builder=q_function_builder, - q_solver_builder=q_solver_builder, - policy_builder=policy_builder, - policy_solver_builder=policy_solver_builder, - temperature_solver_builder=temperature_solver_builder, - replay_buffer_builder=replay_buffer_builder, - explorer_builder=explorer_builder) + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: REDQConfig = REDQConfig(), + q_function_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), + q_solver_builder: SolverBuilder = DefaultSolverBuilder(), + policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), + policy_solver_builder: SolverBuilder = DefaultSolverBuilder(), + temperature_solver_builder: SolverBuilder = DefaultSolverBuilder(), + replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), + explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), + ): + super().__init__( + env_or_env_info, + config=config, + q_function_builder=q_function_builder, + q_solver_builder=q_solver_builder, + policy_builder=policy_builder, + policy_solver_builder=policy_solver_builder, + temperature_solver_builder=temperature_solver_builder, + replay_buffer_builder=replay_buffer_builder, + explorer_builder=explorer_builder, + ) def _setup_q_function_training(self, env_or_buffer): q_function_trainer_config = MT.q_value_trainers.REDQQTrainerConfig( - reduction_method='mean', + reduction_method="mean", grad_clip=None, M=self._config.M, num_steps=self._config.num_steps, unroll_steps=self._config.critic_unroll_steps, burn_in_steps=self._config.critic_burn_in_steps, - reset_on_terminal=self._config.critic_reset_rnn_on_terminal) + reset_on_terminal=self._config.critic_reset_rnn_on_terminal, + ) q_function_trainer = MT.q_value_trainers.REDQQTrainer( train_functions=self._train_q_functions, @@ -124,7 +137,8 @@ def _setup_q_function_training(self, env_or_buffer): target_policy=self._pi, temperature=self._policy_trainer.get_temperature(), env_info=self._env_info, - config=q_function_trainer_config) + config=q_function_trainer_config, + ) for q, target_q in zip(self._train_q_functions, self._target_q_functions): sync_model(q, target_q) return q_function_trainer @@ -154,28 +168,30 @@ def _redq_training(self, replay_buffer): for _ in range(self._config.G): experiences_tuple, info = replay_buffer.sample(self._config.batch_size, num_steps=num_steps) if num_steps == 1: - experiences_tuple = (experiences_tuple, ) + experiences_tuple = (experiences_tuple,) assert len(experiences_tuple) == num_steps batch = None for experiences in reversed(experiences_tuple): (s, a, r, non_terminal, s_next, rnn_states_dict, *_) = marshal_experiences(experiences) - rnn_states = rnn_states_dict['rnn_states'] if 'rnn_states' in rnn_states_dict else {} - batch = TrainingBatch(batch_size=self._config.batch_size, - s_current=s, - a_current=a, - gamma=self._config.gamma, - reward=r, - non_terminal=non_terminal, - s_next=s_next, - weight=info['weights'], - next_step_batch=batch, - rnn_states=rnn_states) + rnn_states = rnn_states_dict["rnn_states"] if "rnn_states" in rnn_states_dict else {} + batch = TrainingBatch( + batch_size=self._config.batch_size, + s_current=s, + a_current=a, + gamma=self._config.gamma, + reward=r, + non_terminal=non_terminal, + s_next=s_next, + weight=info["weights"], + next_step_batch=batch, + rnn_states=rnn_states, + ) self._q_function_trainer_state = self._q_function_trainer.train(batch) for q, target_q in zip(self._train_q_functions, self._target_q_functions): sync_model(q, target_q, tau=self._config.tau) - td_errors = self._q_function_trainer_state['td_errors'] + td_errors = self._q_function_trainer_state["td_errors"] replay_buffer.update_priorities(td_errors) self._policy_trainer_state = self._policy_trainer.train(batch) diff --git a/nnabla_rl/algorithms/reinforce.py b/nnabla_rl/algorithms/reinforce.py index ddde0d75..17dfc484 100644 --- a/nnabla_rl/algorithms/reinforce.py +++ b/nnabla_rl/algorithms/reinforce.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -51,6 +51,7 @@ class REINFORCEConfig(AlgorithmConfig): fixed_ln_var (float): Fixed log variance of the policy.\ This configuration is only valid when the enviroment is continuous. Defaults to 1.0. """ + reward_scale: float = 0.01 num_rollouts_per_train_iteration: int = 10 learning_rate: float = 1e-3 @@ -63,60 +64,59 @@ def __post_init__(self): Check the set values are in valid range. """ - self._assert_positive(self.reward_scale, 'reward_scale') - self._assert_positive(self.num_rollouts_per_train_iteration, 'num_rollouts_per_train_iteration') - self._assert_positive(self.learning_rate, 'learning_rate') - self._assert_positive(self.clip_grad_norm, 'clip_grad_norm') + self._assert_positive(self.reward_scale, "reward_scale") + self._assert_positive(self.num_rollouts_per_train_iteration, "num_rollouts_per_train_iteration") + self._assert_positive(self.learning_rate, "learning_rate") + self._assert_positive(self.clip_grad_norm, "clip_grad_norm") class DefaultPolicyBuilder(ModelBuilder[StochasticPolicy]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: REINFORCEConfig, - **kwargs) -> StochasticPolicy: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: REINFORCEConfig, + **kwargs, + ) -> StochasticPolicy: if env_info.is_discrete_action_env(): return self._build_discrete_policy(scope_name, env_info, algorithm_config) else: return self._build_continuous_policy(scope_name, env_info, algorithm_config) - def _build_continuous_policy(self, - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: REINFORCEConfig, - **kwargs) -> StochasticPolicy: + def _build_continuous_policy( + self, scope_name: str, env_info: EnvironmentInfo, algorithm_config: REINFORCEConfig, **kwargs + ) -> StochasticPolicy: return REINFORCEContinousPolicy(scope_name, env_info.action_dim, algorithm_config.fixed_ln_var) - def _build_discrete_policy(self, - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: REINFORCEConfig, - **kwargs) -> StochasticPolicy: + def _build_discrete_policy( + self, scope_name: str, env_info: EnvironmentInfo, algorithm_config: REINFORCEConfig, **kwargs + ) -> StochasticPolicy: return REINFORCEDiscretePolicy(scope_name, env_info.action_dim) class DefaultSolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: REINFORCEConfig, - **kwargs) -> nn.solver.Solver: + def build_solver( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: REINFORCEConfig, **kwargs + ) -> nn.solver.Solver: return NS.Adam(alpha=algorithm_config.learning_rate) class DefaultExplorerBuilder(ExplorerBuilder): - def build_explorer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: REINFORCEConfig, - algorithm: "REINFORCE", - **kwargs) -> EnvironmentExplorer: + def build_explorer( # type: ignore[override] + self, + env_info: EnvironmentInfo, + algorithm_config: REINFORCEConfig, + algorithm: "REINFORCE", + **kwargs, + ) -> EnvironmentExplorer: explorer_config = EE.RawPolicyExplorerConfig( reward_scalar=algorithm_config.reward_scale, initial_step_num=algorithm.iteration_num, - timelimit_as_terminal=False + timelimit_as_terminal=False, + ) + explorer = EE.RawPolicyExplorer( + policy_action_selector=algorithm._exploration_action_selector, env_info=env_info, config=explorer_config ) - explorer = EE.RawPolicyExplorer(policy_action_selector=algorithm._exploration_action_selector, - env_info=env_info, - config=explorer_config) return explorer @@ -142,6 +142,7 @@ class REINFORCE(Algorithm): explorer_builder (:py:class:`ExplorerBuilder `): builder of environment explorer """ + _config: REINFORCEConfig _policy: StochasticPolicy _policy_solver: nn.solver.Solver @@ -155,12 +156,14 @@ class REINFORCE(Algorithm): _policy_trainer_state: Dict[str, Any] - def __init__(self, - env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: REINFORCEConfig = REINFORCEConfig(), - policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), - policy_solver_builder: SolverBuilder = DefaultSolverBuilder(), - explorer_builder: ExplorerBuilder = DefaultExplorerBuilder()): + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: REINFORCEConfig = REINFORCEConfig(), + policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), + policy_solver_builder: SolverBuilder = DefaultSolverBuilder(), + explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), + ): super(REINFORCE, self).__init__(env_or_env_info, config=config) self._explorer_builder = explorer_builder @@ -170,9 +173,11 @@ def __init__(self, self._policy_solver = policy_solver_builder(self._env_info, self._config) self._evaluation_actor = _StochasticPolicyActionSelector( - self._env_info, self._policy.shallowcopy(), deterministic=False) + self._env_info, self._policy.shallowcopy(), deterministic=False + ) self._exploration_actor = _StochasticPolicyActionSelector( - self._env_info, self._policy.shallowcopy(), deterministic=False) + self._env_info, self._policy.shallowcopy(), deterministic=False + ) @eval_api def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}): @@ -192,12 +197,14 @@ def _setup_environment_explorer(self, env_or_buffer): def _setup_policy_training(self, env_or_buffer): policy_trainer_config = MT.policy_trainers.REINFORCEPolicyTrainerConfig( pi_loss_scalar=1.0 / self._config.num_rollouts_per_train_iteration, - grad_clip_norm=self._config.clip_grad_norm) + grad_clip_norm=self._config.clip_grad_norm, + ) policy_trainer = MT.policy_trainers.REINFORCEPolicyTrainer( models=self._policy, solvers={self._policy.scope_name: self._policy_solver}, env_info=self._env_info, - config=policy_trainer_config) + config=policy_trainer_config, + ) return policy_trainer def _run_online_training_iteration(self, env): @@ -218,11 +225,8 @@ def _reinforce_training(self, buffer): s_batch, a_batch, target_return = self._align_experiences_and_compute_accumulated_reward(experiences) batch_size = len(s_batch) extra = {} - extra['target_return'] = np.reshape(target_return, newshape=(batch_size, 1)) - batch = TrainingBatch(batch_size, - s_current=s_batch, - a_current=a_batch, - extra=extra) + extra["target_return"] = np.reshape(target_return, newshape=(batch_size, 1)) + batch = TrainingBatch(batch_size, s_current=s_batch, a_current=a_batch, extra=extra) self._policy_trainer_state = self._policy_trainer.train(batch) @@ -243,8 +247,7 @@ def _align_experiences_and_compute_accumulated_reward(self, experiences): s_batch = np.concatenate((s_batch, s_seq), axis=0) a_batch = np.concatenate((a_batch, a_seq), axis=0) - accumulated_reward_batch = np.concatenate( - (accumulated_reward_batch, accumulated_reward)) + accumulated_reward_batch = np.concatenate((accumulated_reward_batch, accumulated_reward)) return s_batch, a_batch, accumulated_reward_batch @@ -266,15 +269,16 @@ def _solvers(self): @classmethod def is_supported_env(cls, env_or_env_info): - env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \ - else env_or_env_info + env_info = ( + EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info + ) return not env_info.is_tuple_action_env() @property def latest_iteration_state(self): latest_iteration_state = super(REINFORCE, self).latest_iteration_state - if hasattr(self, '_policy_trainer_state'): - latest_iteration_state['scalar'].update({'pi_loss': float(self._policy_trainer_state['pi_loss'])}) + if hasattr(self, "_policy_trainer_state"): + latest_iteration_state["scalar"].update({"pi_loss": float(self._policy_trainer_state["pi_loss"])}) return latest_iteration_state @property diff --git a/nnabla_rl/algorithms/sac.py b/nnabla_rl/algorithms/sac.py index fc83549a..b3f20ecb 100644 --- a/nnabla_rl/algorithms/sac.py +++ b/nnabla_rl/algorithms/sac.py @@ -78,7 +78,7 @@ class SACConfig(AlgorithmConfig): """ gamma: float = 0.99 - learning_rate: float = 3.0*1e-4 + learning_rate: float = 3.0 * 1e-4 batch_size: int = 256 tau: float = 0.005 environment_steps: int = 1 @@ -101,69 +101,72 @@ class SACConfig(AlgorithmConfig): def __post_init__(self): """__post_init__ Check set values are in valid range.""" - self._assert_between(self.tau, 0.0, 1.0, 'tau') - self._assert_between(self.gamma, 0.0, 1.0, 'gamma') - self._assert_positive(self.gradient_steps, 'gradient_steps') - self._assert_positive(self.environment_steps, 'environment_steps') + self._assert_between(self.tau, 0.0, 1.0, "tau") + self._assert_between(self.gamma, 0.0, 1.0, "gamma") + self._assert_positive(self.gradient_steps, "gradient_steps") + self._assert_positive(self.environment_steps, "environment_steps") if self.initial_temperature is not None: - self._assert_positive( - self.initial_temperature, 'initial_temperature') - self._assert_positive(self.start_timesteps, 'start_timesteps') + self._assert_positive(self.initial_temperature, "initial_temperature") + self._assert_positive(self.start_timesteps, "start_timesteps") - self._assert_positive(self.critic_unroll_steps, 'critic_unroll_steps') - self._assert_positive_or_zero(self.critic_burn_in_steps, 'critic_burn_in_steps') - self._assert_positive(self.actor_unroll_steps, 'actor_unroll_steps') - self._assert_positive_or_zero(self.actor_burn_in_steps, 'actor_burn_in_steps') + self._assert_positive(self.critic_unroll_steps, "critic_unroll_steps") + self._assert_positive_or_zero(self.critic_burn_in_steps, "critic_burn_in_steps") + self._assert_positive(self.actor_unroll_steps, "actor_unroll_steps") + self._assert_positive_or_zero(self.actor_burn_in_steps, "actor_burn_in_steps") class DefaultQFunctionBuilder(ModelBuilder[QFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: SACConfig, - **kwargs) -> QFunction: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: SACConfig, + **kwargs, + ) -> QFunction: return SACQFunction(scope_name) class DefaultPolicyBuilder(ModelBuilder[StochasticPolicy]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: SACConfig, - **kwargs) -> StochasticPolicy: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: SACConfig, + **kwargs, + ) -> StochasticPolicy: return SACPolicy(scope_name, env_info.action_dim) class DefaultSolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: SACConfig, - **kwargs) -> nn.solver.Solver: + def build_solver( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: SACConfig, **kwargs + ) -> nn.solver.Solver: return NS.Adam(alpha=algorithm_config.learning_rate) class DefaultReplayBufferBuilder(ReplayBufferBuilder): - def build_replay_buffer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: SACConfig, - **kwargs) -> ReplayBuffer: + def build_replay_buffer( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: SACConfig, **kwargs + ) -> ReplayBuffer: return ReplayBuffer(capacity=algorithm_config.replay_buffer_size) class DefaultExplorerBuilder(ExplorerBuilder): - def build_explorer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: SACConfig, - algorithm: "SAC", - **kwargs) -> EnvironmentExplorer: + def build_explorer( # type: ignore[override] + self, + env_info: EnvironmentInfo, + algorithm_config: SACConfig, + algorithm: "SAC", + **kwargs, + ) -> EnvironmentExplorer: explorer_config = EE.RawPolicyExplorerConfig( warmup_random_steps=algorithm_config.start_timesteps, initial_step_num=algorithm.iteration_num, - timelimit_as_terminal=False + timelimit_as_terminal=False, + ) + explorer = EE.RawPolicyExplorer( + policy_action_selector=algorithm._exploration_action_selector, env_info=env_info, config=explorer_config ) - explorer = EE.RawPolicyExplorer(policy_action_selector=algorithm._exploration_action_selector, - env_info=env_info, - config=explorer_config) return explorer @@ -222,24 +225,28 @@ class SAC(Algorithm): _policy_trainer_state: Dict[str, Any] _q_function_trainer_state: Dict[str, Any] - def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: SACConfig = SACConfig(), - q_function_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), - q_solver_builder: SolverBuilder = DefaultSolverBuilder(), - policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), - policy_solver_builder: SolverBuilder = DefaultSolverBuilder(), - temperature_solver_builder: SolverBuilder = DefaultSolverBuilder(), - replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), - explorer_builder: ExplorerBuilder = DefaultExplorerBuilder()): + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: SACConfig = SACConfig(), + q_function_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), + q_solver_builder: SolverBuilder = DefaultSolverBuilder(), + policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), + policy_solver_builder: SolverBuilder = DefaultSolverBuilder(), + temperature_solver_builder: SolverBuilder = DefaultSolverBuilder(), + replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), + explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), + ): super(SAC, self).__init__(env_or_env_info, config=config) self._explorer_builder = explorer_builder with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)): self._train_q_functions = self._build_q_functions(q_function_builder) - self._train_q_solvers = {q.scope_name: q_solver_builder(self._env_info, self._config) - for q in self._train_q_functions} - self._target_q_functions = [q.deepcopy('target_' + q.scope_name) for q in self._train_q_functions] + self._train_q_solvers = { + q.scope_name: q_solver_builder(self._env_info, self._config) for q in self._train_q_functions + } + self._target_q_functions = [q.deepcopy("target_" + q.scope_name) for q in self._train_q_functions] self._pi = policy_builder(scope_name="pi", env_info=self._env_info, algorithm_config=self._config) self._pi_solver = policy_solver_builder(self._env_info, self._config) @@ -279,8 +286,8 @@ def _setup_environment_explorer(self, env_or_buffer): def _setup_temperature_model(self): return MT.policy_trainers.soft_policy_trainer.AdjustableTemperature( - scope_name='temperature', - initial_value=self._config.initial_temperature) + scope_name="temperature", initial_value=self._config.initial_temperature + ) def _setup_policy_training(self, env_or_buffer): policy_trainer_config = MT.policy_trainers.SoftPolicyTrainerConfig( @@ -288,7 +295,8 @@ def _setup_policy_training(self, env_or_buffer): target_entropy=self._config.target_entropy, unroll_steps=self._config.actor_unroll_steps, burn_in_steps=self._config.actor_burn_in_steps, - reset_on_terminal=self._config.actor_reset_rnn_on_terminal) + reset_on_terminal=self._config.actor_reset_rnn_on_terminal, + ) policy_trainer = MT.policy_trainers.SoftPolicyTrainer( models=self._pi, solvers={self._pi.scope_name: self._pi_solver}, @@ -296,18 +304,20 @@ def _setup_policy_training(self, env_or_buffer): temperature_solver=self._temperature_solver, q_functions=self._train_q_functions, env_info=self._env_info, - config=policy_trainer_config) + config=policy_trainer_config, + ) return policy_trainer def _setup_q_function_training(self, env_or_buffer): # training input/loss variables q_function_trainer_config = MT.q_value_trainers.SoftQTrainerConfig( - reduction_method='mean', + reduction_method="mean", grad_clip=None, num_steps=self._config.num_steps, unroll_steps=self._config.critic_unroll_steps, burn_in_steps=self._config.critic_burn_in_steps, - reset_on_terminal=self._config.critic_reset_rnn_on_terminal) + reset_on_terminal=self._config.critic_reset_rnn_on_terminal, + ) q_function_trainer = MT.q_value_trainers.SoftQTrainer( train_functions=self._train_q_functions, @@ -316,7 +326,8 @@ def _setup_q_function_training(self, env_or_buffer): target_policy=self._pi, temperature=self._policy_trainer.get_temperature(), env_info=self._env_info, - config=q_function_trainer_config) + config=q_function_trainer_config, + ) for q, target_q in zip(self._train_q_functions, self._target_q_functions): sync_model(q, target_q) return q_function_trainer @@ -344,30 +355,32 @@ def _sac_training(self, replay_buffer): num_steps = max(actor_steps, critic_steps) experiences_tuple, info = replay_buffer.sample(self._config.batch_size, num_steps=num_steps) if num_steps == 1: - experiences_tuple = (experiences_tuple, ) + experiences_tuple = (experiences_tuple,) assert len(experiences_tuple) == num_steps batch = None for experiences in reversed(experiences_tuple): (s, a, r, non_terminal, s_next, rnn_states_dict, *_) = marshal_experiences(experiences) - rnn_states = rnn_states_dict['rnn_states'] if 'rnn_states' in rnn_states_dict else {} - batch = TrainingBatch(batch_size=self._config.batch_size, - s_current=s, - a_current=a, - gamma=self._config.gamma, - reward=r, - non_terminal=non_terminal, - s_next=s_next, - weight=info['weights'], - next_step_batch=batch, - rnn_states=rnn_states) + rnn_states = rnn_states_dict["rnn_states"] if "rnn_states" in rnn_states_dict else {} + batch = TrainingBatch( + batch_size=self._config.batch_size, + s_current=s, + a_current=a, + gamma=self._config.gamma, + reward=r, + non_terminal=non_terminal, + s_next=s_next, + weight=info["weights"], + next_step_batch=batch, + rnn_states=rnn_states, + ) self._q_function_trainer_state = self._q_function_trainer.train(batch) for q, target_q in zip(self._train_q_functions, self._target_q_functions): sync_model(q, target_q, tau=self._config.tau) self._policy_trainer_state = self._policy_trainer.train(batch) - td_errors = self._q_function_trainer_state['td_errors'] + td_errors = self._q_function_trainer_state["td_errors"] replay_buffer.update_priorities(td_errors) def _evaluation_action_selector(self, s, *, begin_of_episode=False): @@ -401,19 +414,21 @@ def is_rnn_supported(self): @classmethod def is_supported_env(cls, env_or_env_info): - env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \ - else env_or_env_info + env_info = ( + EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info + ) return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env() @property def latest_iteration_state(self): latest_iteration_state = super(SAC, self).latest_iteration_state - if hasattr(self, '_policy_trainer_state'): - latest_iteration_state['scalar'].update({'pi_loss': float(self._policy_trainer_state['pi_loss'])}) - if hasattr(self, '_q_function_trainer_state'): - latest_iteration_state['scalar'].update({'q_loss': float(self._q_function_trainer_state['q_loss'])}) - latest_iteration_state['histogram'].update( - {'td_errors': self._q_function_trainer_state['td_errors'].flatten()}) + if hasattr(self, "_policy_trainer_state"): + latest_iteration_state["scalar"].update({"pi_loss": float(self._policy_trainer_state["pi_loss"])}) + if hasattr(self, "_q_function_trainer_state"): + latest_iteration_state["scalar"].update({"q_loss": float(self._q_function_trainer_state["q_loss"])}) + latest_iteration_state["histogram"].update( + {"td_errors": self._q_function_trainer_state["td_errors"].flatten()} + ) return latest_iteration_state @property diff --git a/nnabla_rl/algorithms/sacd.py b/nnabla_rl/algorithms/sacd.py index fb426517..6829b1d3 100644 --- a/nnabla_rl/algorithms/sacd.py +++ b/nnabla_rl/algorithms/sacd.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,8 +22,14 @@ import nnabla_rl.model_trainers as MT from nnabla_rl.algorithm import eval_api from nnabla_rl.algorithms.common_utils import _InfluenceMetricsEvaluator -from nnabla_rl.algorithms.sac import (SAC, DefaultExplorerBuilder, DefaultPolicyBuilder, DefaultReplayBufferBuilder, - DefaultSolverBuilder, SACConfig) +from nnabla_rl.algorithms.sac import ( + SAC, + DefaultExplorerBuilder, + DefaultPolicyBuilder, + DefaultReplayBufferBuilder, + DefaultSolverBuilder, + SACConfig, +) from nnabla_rl.builders import ExplorerBuilder, ModelBuilder, ReplayBufferBuilder, SolverBuilder from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.models import QFunction, SACDQFunction, StochasticPolicy @@ -79,15 +85,17 @@ class SACDConfig(SACConfig): def __post_init__(self): super().__post_init__() - self._assert_positive(self.reward_dimension, 'reward_dimension') + self._assert_positive(self.reward_dimension, "reward_dimension") class DefaultQFunctionBuilder(ModelBuilder[QFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: SACDConfig, - **kwargs) -> QFunction: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: SACDConfig, + **kwargs, + ) -> QFunction: # increment reward dimension to accomodate entropy bonus return SACDQFunction(scope_name, algorithm_config.reward_dimension + 1) @@ -129,16 +137,18 @@ class SACD(SAC): _config: SACDConfig _influence_metrics_evaluator: _InfluenceMetricsEvaluator - def __init__(self, - env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: SACDConfig = SACDConfig(), - q_function_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), - q_solver_builder: SolverBuilder = DefaultSolverBuilder(), - policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), - policy_solver_builder: SolverBuilder = DefaultSolverBuilder(), - temperature_solver_builder: SolverBuilder = DefaultSolverBuilder(), - replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), - explorer_builder: ExplorerBuilder = DefaultExplorerBuilder()): + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: SACDConfig = SACDConfig(), + q_function_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), + q_solver_builder: SolverBuilder = DefaultSolverBuilder(), + policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), + policy_solver_builder: SolverBuilder = DefaultSolverBuilder(), + temperature_solver_builder: SolverBuilder = DefaultSolverBuilder(), + replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), + explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), + ): super(SACD, self).__init__( env_or_env_info=env_or_env_info, config=config, @@ -156,13 +166,14 @@ def __init__(self, def _setup_q_function_training(self, env_or_buffer): # training input/loss variables q_function_trainer_config = MT.q_value_trainers.SoftQDTrainerConfig( - reduction_method='mean', + reduction_method="mean", grad_clip=None, num_steps=self._config.num_steps, unroll_steps=self._config.critic_unroll_steps, burn_in_steps=self._config.critic_burn_in_steps, reset_on_terminal=self._config.critic_reset_rnn_on_terminal, - reward_dimension=self._config.reward_dimension) + reward_dimension=self._config.reward_dimension, + ) q_function_trainer = MT.q_value_trainers.SoftQDTrainer( train_functions=self._train_q_functions, @@ -171,17 +182,16 @@ def _setup_q_function_training(self, env_or_buffer): target_policy=self._pi, temperature=self._policy_trainer.get_temperature(), env_info=self._env_info, - config=q_function_trainer_config) + config=q_function_trainer_config, + ) for q, target_q in zip(self._train_q_functions, self._target_q_functions): sync_model(q, target_q) return q_function_trainer @eval_api - def compute_influence_metrics(self, - state: np.ndarray, - action: np.ndarray, - *, - begin_of_episode: bool = False) -> np.ndarray: + def compute_influence_metrics( + self, state: np.ndarray, action: np.ndarray, *, begin_of_episode: bool = False + ) -> np.ndarray: """Compute relative influence metrics. The influence metrics represent how much each reward component contributes to an agent's decisions. diff --git a/nnabla_rl/algorithms/srsac.py b/nnabla_rl/algorithms/srsac.py index 037d0c7e..db80b8a9 100644 --- a/nnabla_rl/algorithms/srsac.py +++ b/nnabla_rl/algorithms/srsac.py @@ -20,8 +20,15 @@ import nnabla as nn import nnabla_rl.model_trainers as MT -from nnabla_rl.algorithms.sac import (SAC, DefaultExplorerBuilder, DefaultPolicyBuilder, DefaultQFunctionBuilder, - DefaultReplayBufferBuilder, DefaultSolverBuilder, SACConfig) +from nnabla_rl.algorithms.sac import ( + SAC, + DefaultExplorerBuilder, + DefaultPolicyBuilder, + DefaultQFunctionBuilder, + DefaultReplayBufferBuilder, + DefaultSolverBuilder, + SACConfig, +) from nnabla_rl.builders import ExplorerBuilder, ModelBuilder, ReplayBufferBuilder, SolverBuilder from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.model_trainers.model_trainer import TrainingBatch @@ -80,8 +87,8 @@ class SRSACConfig(SACConfig): def __post_init__(self): super().__post_init__() - self._assert_positive(self.replay_ratio, 'replay_ratio') - self._assert_positive(self.reset_interval, 'reset_interval') + self._assert_positive(self.replay_ratio, "replay_ratio") + self._assert_positive(self.reset_interval, "reset_interval") class SRSAC(SAC): @@ -119,16 +126,18 @@ class SRSAC(SAC): # See https://mypy.readthedocs.io/en/stable/class_basics.html for details _config: SRSACConfig - def __init__(self, - env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: SRSACConfig = SRSACConfig(), - q_function_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), - q_solver_builder: SolverBuilder = DefaultSolverBuilder(), - policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), - policy_solver_builder: SolverBuilder = DefaultSolverBuilder(), - temperature_solver_builder: SolverBuilder = DefaultSolverBuilder(), - replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), - explorer_builder: ExplorerBuilder = DefaultExplorerBuilder()): + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: SRSACConfig = SRSACConfig(), + q_function_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), + q_solver_builder: SolverBuilder = DefaultSolverBuilder(), + policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), + policy_solver_builder: SolverBuilder = DefaultSolverBuilder(), + temperature_solver_builder: SolverBuilder = DefaultSolverBuilder(), + replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), + explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), + ): super(SRSAC, self).__init__( env_or_env_info=env_or_env_info, config=config, @@ -212,6 +221,7 @@ class EfficientSRSACConfig(SRSACConfig): replay_ratio (int): Number of updates per environment step. reset_interval (int): Paramerters will be reset every this number of updates. """ + actor_reset_rnn_on_terminal: bool = False critic_reset_rnn_on_terminal: bool = False @@ -219,23 +229,25 @@ def __post_init__(self): super().__post_init__() def fill_warning_message(config_name, config_value, expected_value): - return f'''{config_name} is set to {config_value}(!={expected_value}) - but this value does not take any effect on EfficentSRSAC.''' + return f"""{config_name} is set to {config_value}(!={expected_value}) + but this value does not take any effect on EfficentSRSAC.""" + if 1 != self.num_steps: - warnings.warn(fill_warning_message('num_steps', self.num_steps, 1)) + warnings.warn(fill_warning_message("num_steps", self.num_steps, 1)) if 0 != self.actor_burn_in_steps: - warnings.warn(fill_warning_message('actor_burn_in_steps', self.actor_burn_in_steps, 0)) + warnings.warn(fill_warning_message("actor_burn_in_steps", self.actor_burn_in_steps, 0)) if 1 != self.actor_unroll_steps: - warnings.warn(fill_warning_message('actor_unroll_steps', self.actor_unroll_steps, 1)) + warnings.warn(fill_warning_message("actor_unroll_steps", self.actor_unroll_steps, 1)) if self.actor_reset_rnn_on_terminal: - warnings.warn(fill_warning_message('actor_reset_rnn_on_terminal', self.actor_reset_rnn_on_terminal, False)) + warnings.warn(fill_warning_message("actor_reset_rnn_on_terminal", self.actor_reset_rnn_on_terminal, False)) if 0 != self.critic_burn_in_steps: - warnings.warn(fill_warning_message('critic_burn_in_steps', self.critic_burn_in_steps, 0)) + warnings.warn(fill_warning_message("critic_burn_in_steps", self.critic_burn_in_steps, 0)) if 1 != self.critic_unroll_steps: - warnings.warn(fill_warning_message('critic_unroll_steps', self.critic_unroll_steps, 1)) + warnings.warn(fill_warning_message("critic_unroll_steps", self.critic_unroll_steps, 1)) if self.critic_reset_rnn_on_terminal: - warnings.warn(fill_warning_message('critic_reset_rnn_on_terminal', - self.critic_reset_rnn_on_terminal, False)) + warnings.warn( + fill_warning_message("critic_reset_rnn_on_terminal", self.critic_reset_rnn_on_terminal, False) + ) class EfficientSRSAC(SRSAC): @@ -276,25 +288,29 @@ class EfficientSRSAC(SRSAC): # See https://mypy.readthedocs.io/en/stable/class_basics.html for details _config: EfficientSRSACConfig - def __init__(self, - env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: EfficientSRSACConfig = EfficientSRSACConfig(), - q_function_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), - q_solver_builder: SolverBuilder = DefaultSolverBuilder(), - policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), - policy_solver_builder: SolverBuilder = DefaultSolverBuilder(), - temperature_solver_builder: SolverBuilder = DefaultSolverBuilder(), - replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), - explorer_builder: ExplorerBuilder = DefaultExplorerBuilder()): - super().__init__(env_or_env_info=env_or_env_info, - config=config, - q_function_builder=q_function_builder, - q_solver_builder=q_solver_builder, - policy_builder=policy_builder, - policy_solver_builder=policy_solver_builder, - temperature_solver_builder=temperature_solver_builder, - replay_buffer_builder=replay_buffer_builder, - explorer_builder=explorer_builder) + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: EfficientSRSACConfig = EfficientSRSACConfig(), + q_function_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), + q_solver_builder: SolverBuilder = DefaultSolverBuilder(), + policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), + policy_solver_builder: SolverBuilder = DefaultSolverBuilder(), + temperature_solver_builder: SolverBuilder = DefaultSolverBuilder(), + replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), + explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), + ): + super().__init__( + env_or_env_info=env_or_env_info, + config=config, + q_function_builder=q_function_builder, + q_solver_builder=q_solver_builder, + policy_builder=policy_builder, + policy_solver_builder=policy_solver_builder, + temperature_solver_builder=temperature_solver_builder, + replay_buffer_builder=replay_buffer_builder, + explorer_builder=explorer_builder, + ) @classmethod def is_rnn_supported(cls): @@ -314,7 +330,8 @@ def _setup_actor_critic_training(self, env_or_buffer): fixed_temperature=self._config.fix_temperature, target_entropy=self._config.target_entropy, replay_ratio=self._config.replay_ratio, - tau=self._config.tau) + tau=self._config.tau, + ) actor_critic_trainer = MT.hybrid_trainers.SRSACActorCriticTrainer( pi=self._pi, pi_solvers={self._pi.scope_name: self._pi_solver}, @@ -324,7 +341,8 @@ def _setup_actor_critic_training(self, env_or_buffer): temperature=self._temperature, temperature_solver=self._temperature_solver, env_info=self._env_info, - config=actor_critic_trainer_config) + config=actor_critic_trainer_config, + ) return actor_critic_trainer def _run_gradient_step(self, replay_buffer): @@ -350,21 +368,23 @@ def _efficient_srsac_training(self, replay_buffer): batch = None for experiences, info in zip(experiences_tuple, info_tuple): (s, a, r, non_terminal, s_next, rnn_states_dict, *_) = marshal_experiences(experiences) - rnn_states = rnn_states_dict['rnn_states'] if 'rnn_states' in rnn_states_dict else {} - batch = TrainingBatch(batch_size=self._config.batch_size, - s_current=s, - a_current=a, - gamma=self._config.gamma, - reward=r, - non_terminal=non_terminal, - s_next=s_next, - weight=info['weights'], - next_step_batch=batch, - rnn_states=rnn_states) + rnn_states = rnn_states_dict["rnn_states"] if "rnn_states" in rnn_states_dict else {} + batch = TrainingBatch( + batch_size=self._config.batch_size, + s_current=s, + a_current=a, + gamma=self._config.gamma, + reward=r, + non_terminal=non_terminal, + s_next=s_next, + weight=info["weights"], + next_step_batch=batch, + rnn_states=rnn_states, + ) self._actor_critic_trainer_state = self._actor_critic_trainer.train(batch) - td_errors = self._actor_critic_trainer_state['td_errors'] + td_errors = self._actor_critic_trainer_state["td_errors"] replay_buffer.update_priorities(td_errors) def _reconstruct_training_graphs(self): @@ -374,10 +394,11 @@ def _reconstruct_training_graphs(self): @property def latest_iteration_state(self): latest_iteration_state = super(SAC, self).latest_iteration_state - if hasattr(self, '_actor_critic_trainer_state'): - latest_iteration_state['scalar'].update({'pi_loss': float(self._actor_critic_trainer_state['pi_loss'])}) - if hasattr(self, '_actor_critic_trainer_state'): - latest_iteration_state['scalar'].update({'q_loss': float(self._actor_critic_trainer_state['q_loss'])}) - latest_iteration_state['histogram'].update( - {'td_errors': self._actor_critic_trainer_state['td_errors'].flatten()}) + if hasattr(self, "_actor_critic_trainer_state"): + latest_iteration_state["scalar"].update({"pi_loss": float(self._actor_critic_trainer_state["pi_loss"])}) + if hasattr(self, "_actor_critic_trainer_state"): + latest_iteration_state["scalar"].update({"q_loss": float(self._actor_critic_trainer_state["q_loss"])}) + latest_iteration_state["histogram"].update( + {"td_errors": self._actor_critic_trainer_state["td_errors"].flatten()} + ) return latest_iteration_state diff --git a/nnabla_rl/algorithms/td3.py b/nnabla_rl/algorithms/td3.py index 6c8ce11a..0f09adf2 100644 --- a/nnabla_rl/algorithms/td3.py +++ b/nnabla_rl/algorithms/td3.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -78,7 +78,7 @@ class TD3Config(AlgorithmConfig): """ gamma: float = 0.99 - learning_rate: float = 1.0*1e-3 + learning_rate: float = 1.0 * 1e-3 batch_size: int = 100 tau: float = 0.005 start_timesteps: int = 10000 @@ -103,75 +103,79 @@ def __post_init__(self): Check the set values are in valid range. """ - self._assert_between(self.gamma, 0.0, 1.0, 'gamma') - self._assert_positive(self.batch_size, 'batch_size') - self._assert_between(self.tau, 0.0, 1.0, 'tau') - self._assert_positive(self.start_timesteps, 'start_timesteps') - self._assert_positive(self.replay_buffer_size, 'replay_buffer_size') - self._assert_positive(self.d, 'd') - self._assert_positive(self.exploration_noise_sigma, 'exploration_noise_sigma') - self._assert_positive(self.train_action_noise_sigma, 'train_action_noise_sigma') - self._assert_positive(self.train_action_noise_abs, 'train_action_noise_abs') - - self._assert_positive(self.critic_unroll_steps, 'critic_unroll_steps') - self._assert_positive_or_zero(self.critic_burn_in_steps, 'critic_burn_in_steps') - self._assert_positive(self.actor_unroll_steps, 'actor_unroll_steps') - self._assert_positive_or_zero(self.actor_burn_in_steps, 'actor_burn_in_steps') + self._assert_between(self.gamma, 0.0, 1.0, "gamma") + self._assert_positive(self.batch_size, "batch_size") + self._assert_between(self.tau, 0.0, 1.0, "tau") + self._assert_positive(self.start_timesteps, "start_timesteps") + self._assert_positive(self.replay_buffer_size, "replay_buffer_size") + self._assert_positive(self.d, "d") + self._assert_positive(self.exploration_noise_sigma, "exploration_noise_sigma") + self._assert_positive(self.train_action_noise_sigma, "train_action_noise_sigma") + self._assert_positive(self.train_action_noise_abs, "train_action_noise_abs") + + self._assert_positive(self.critic_unroll_steps, "critic_unroll_steps") + self._assert_positive_or_zero(self.critic_burn_in_steps, "critic_burn_in_steps") + self._assert_positive(self.actor_unroll_steps, "actor_unroll_steps") + self._assert_positive_or_zero(self.actor_burn_in_steps, "actor_burn_in_steps") class DefaultCriticBuilder(ModelBuilder[QFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: TD3Config, - **kwargs) -> QFunction: - target_policy = kwargs.get('target_policy') + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: TD3Config, + **kwargs, + ) -> QFunction: + target_policy = kwargs.get("target_policy") return TD3QFunction(scope_name, optimal_policy=target_policy) class DefaultActorBuilder(ModelBuilder[DeterministicPolicy]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: TD3Config, - **kwargs) -> DeterministicPolicy: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: TD3Config, + **kwargs, + ) -> DeterministicPolicy: max_action_value = float(env_info.action_high[0]) return TD3Policy(scope_name, env_info.action_dim, max_action_value=max_action_value) class DefaultSolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: TD3Config, - **kwargs) -> nn.solver.Solver: + def build_solver( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: TD3Config, **kwargs + ) -> nn.solver.Solver: return NS.Adam(alpha=algorithm_config.learning_rate) class DefaultReplayBufferBuilder(ReplayBufferBuilder): - def build_replay_buffer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: TD3Config, - **kwargs) -> ReplayBuffer: + def build_replay_buffer( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: TD3Config, **kwargs + ) -> ReplayBuffer: return ReplayBuffer(capacity=algorithm_config.replay_buffer_size) class DefaultExplorerBuilder(ExplorerBuilder): - def build_explorer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: TD3Config, - algorithm: "TD3", - **kwargs) -> EnvironmentExplorer: + def build_explorer( # type: ignore[override] + self, + env_info: EnvironmentInfo, + algorithm_config: TD3Config, + algorithm: "TD3", + **kwargs, + ) -> EnvironmentExplorer: explorer_config = EE.GaussianExplorerConfig( warmup_random_steps=algorithm_config.start_timesteps, initial_step_num=algorithm.iteration_num, timelimit_as_terminal=False, action_clip_low=env_info.action_low, action_clip_high=env_info.action_high, - sigma=algorithm_config.exploration_noise_sigma + sigma=algorithm_config.exploration_noise_sigma, + ) + explorer = EE.GaussianExplorer( + policy_action_selector=algorithm._exploration_action_selector, env_info=env_info, config=explorer_config ) - explorer = EE.GaussianExplorer(policy_action_selector=algorithm._exploration_action_selector, - env_info=env_info, - config=explorer_config) return explorer @@ -227,14 +231,17 @@ class TD3(Algorithm): _policy_trainer_state: Dict[str, Any] _q_function_trainer_state: Dict[str, Any] - def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: TD3Config = TD3Config(), - critic_builder: ModelBuilder[QFunction] = DefaultCriticBuilder(), - critic_solver_builder: SolverBuilder = DefaultSolverBuilder(), - actor_builder: ModelBuilder[DeterministicPolicy] = DefaultActorBuilder(), - actor_solver_builder: SolverBuilder = DefaultSolverBuilder(), - replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), - explorer_builder: ExplorerBuilder = DefaultExplorerBuilder()): + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: TD3Config = TD3Config(), + critic_builder: ModelBuilder[QFunction] = DefaultCriticBuilder(), + critic_solver_builder: SolverBuilder = DefaultSolverBuilder(), + actor_builder: ModelBuilder[DeterministicPolicy] = DefaultActorBuilder(), + actor_solver_builder: SolverBuilder = DefaultSolverBuilder(), + replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), + explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), + ): super(TD3, self).__init__(env_or_env_info, config=config) self._explorer_builder = explorer_builder @@ -243,13 +250,15 @@ def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], self._q1 = critic_builder(scope_name="q1", env_info=self._env_info, algorithm_config=self._config) self._q2 = critic_builder(scope_name="q2", env_info=self._env_info, algorithm_config=self._config) self._train_q_functions = [self._q1, self._q2] - self._train_q_solvers = {q.scope_name: critic_solver_builder( - env_info=self._env_info, algorithm_config=self._config) for q in self._train_q_functions} - self._target_q_functions = [q.deepcopy('target_' + q.scope_name) for q in self._train_q_functions] + self._train_q_solvers = { + q.scope_name: critic_solver_builder(env_info=self._env_info, algorithm_config=self._config) + for q in self._train_q_functions + } + self._target_q_functions = [q.deepcopy("target_" + q.scope_name) for q in self._train_q_functions] self._pi = actor_builder(scope_name="pi", env_info=self._env_info, algorithm_config=self._config) self._pi_solver = actor_solver_builder(env_info=self._env_info, algorithm_config=self._config) - self._target_pi = self._pi.deepcopy('target_' + self._pi.scope_name) + self._target_pi = self._pi.deepcopy("target_" + self._pi.scope_name) self._replay_buffer = replay_buffer_builder(env_info=self._env_info, algorithm_config=self._config) @@ -275,21 +284,23 @@ def _setup_environment_explorer(self, env_or_buffer): def _setup_q_function_training(self, env_or_buffer): # training input/loss variables q_function_trainer_config = MT.q_value_trainers.TD3QTrainerConfig( - reduction_method='mean', + reduction_method="mean", grad_clip=None, train_action_noise_sigma=self._config.train_action_noise_sigma, train_action_noise_abs=self._config.train_action_noise_abs, num_steps=self._config.num_steps, unroll_steps=self._config.critic_unroll_steps, burn_in_steps=self._config.critic_burn_in_steps, - reset_on_terminal=self._config.critic_reset_rnn_on_terminal) + reset_on_terminal=self._config.critic_reset_rnn_on_terminal, + ) q_function_trainer = MT.q_value_trainers.TD3QTrainer( train_functions=self._train_q_functions, solvers=self._train_q_solvers, target_functions=self._target_q_functions, target_policy=self._target_pi, env_info=self._env_info, - config=q_function_trainer_config) + config=q_function_trainer_config, + ) for q, target_q in zip(self._train_q_functions, self._target_q_functions): sync_model(q, target_q) return q_function_trainer @@ -298,13 +309,15 @@ def _setup_policy_training(self, env_or_buffer): policy_trainer_config = MT.policy_trainers.DPGPolicyTrainerConfig( unroll_steps=self._config.actor_unroll_steps, burn_in_steps=self._config.actor_burn_in_steps, - reset_on_terminal=self._config.actor_reset_rnn_on_terminal) + reset_on_terminal=self._config.actor_reset_rnn_on_terminal, + ) policy_trainer = MT.policy_trainers.DPGPolicyTrainer( models=self._pi, solvers={self._pi.scope_name: self._pi_solver}, q_function=self._q1, env_info=self._env_info, - config=policy_trainer_config) + config=policy_trainer_config, + ) sync_model(self._pi, self._target_pi, 1.0) return policy_trainer @@ -325,26 +338,28 @@ def _td3_training(self, replay_buffer): num_steps = max(actor_steps, critic_steps) experiences_tuple, info = replay_buffer.sample(self._config.batch_size, num_steps=num_steps) if num_steps == 1: - experiences_tuple = (experiences_tuple, ) + experiences_tuple = (experiences_tuple,) assert len(experiences_tuple) == num_steps batch = None for experiences in reversed(experiences_tuple): (s, a, r, non_terminal, s_next, rnn_states_dict, *_) = marshal_experiences(experiences) - rnn_states = rnn_states_dict['rnn_states'] if 'rnn_states' in rnn_states_dict else {} - batch = TrainingBatch(batch_size=self._config.batch_size, - s_current=s, - a_current=a, - gamma=self._config.gamma, - reward=r, - non_terminal=non_terminal, - s_next=s_next, - weight=info['weights'], - next_step_batch=batch, - rnn_states=rnn_states) + rnn_states = rnn_states_dict["rnn_states"] if "rnn_states" in rnn_states_dict else {} + batch = TrainingBatch( + batch_size=self._config.batch_size, + s_current=s, + a_current=a, + gamma=self._config.gamma, + reward=r, + non_terminal=non_terminal, + s_next=s_next, + weight=info["weights"], + next_step_batch=batch, + rnn_states=rnn_states, + ) self._q_function_trainer_state = self._q_function_trainer.train(batch) - td_errors = self._q_function_trainer_state['td_errors'] + td_errors = self._q_function_trainer_state["td_errors"] replay_buffer.update_priorities(td_errors) if self.iteration_num % self._config.d == 0: @@ -382,19 +397,21 @@ def is_rnn_supported(self): @classmethod def is_supported_env(cls, env_or_env_info): - env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \ - else env_or_env_info + env_info = ( + EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info + ) return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env() @property def latest_iteration_state(self): latest_iteration_state = super(TD3, self).latest_iteration_state - if hasattr(self, '_policy_trainer_state'): - latest_iteration_state['scalar'].update({'pi_loss': float(self._policy_trainer_state['pi_loss'])}) - if hasattr(self, '_q_function_trainer_state'): - latest_iteration_state['scalar'].update({'q_loss': float(self._q_function_trainer_state['q_loss'])}) - latest_iteration_state['histogram'].update( - {'td_errors': self._q_function_trainer_state['td_errors'].flatten()}) + if hasattr(self, "_policy_trainer_state"): + latest_iteration_state["scalar"].update({"pi_loss": float(self._policy_trainer_state["pi_loss"])}) + if hasattr(self, "_q_function_trainer_state"): + latest_iteration_state["scalar"].update({"q_loss": float(self._q_function_trainer_state["q_loss"])}) + latest_iteration_state["histogram"].update( + {"td_errors": self._q_function_trainer_state["td_errors"].flatten()} + ) return latest_iteration_state def trainers(self): diff --git a/nnabla_rl/algorithms/trpo.py b/nnabla_rl/algorithms/trpo.py index 7483cfe2..89f88bb5 100644 --- a/nnabla_rl/algorithms/trpo.py +++ b/nnabla_rl/algorithms/trpo.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,8 +25,12 @@ import nnabla_rl.model_trainers as MT import nnabla_rl.preprocessors as RP from nnabla_rl.algorithm import Algorithm, AlgorithmConfig, eval_api -from nnabla_rl.algorithms.common_utils import (_StatePreprocessedStochasticPolicy, _StatePreprocessedVFunction, - _StochasticPolicyActionSelector, compute_v_target_and_advantage) +from nnabla_rl.algorithms.common_utils import ( + _StatePreprocessedStochasticPolicy, + _StatePreprocessedVFunction, + _StochasticPolicyActionSelector, + compute_v_target_and_advantage, +) from nnabla_rl.builders import ExplorerBuilder, ModelBuilder, PreprocessorBuilder, SolverBuilder from nnabla_rl.environment_explorer import EnvironmentExplorer from nnabla_rl.environments.environment_info import EnvironmentInfo @@ -71,6 +75,7 @@ class TRPOConfig(AlgorithmConfig): As long as gpu memory size is enough, this configuration should not be specified. If not specified, \ gpu_batch_size is the same as pi_batch_size. Defaults to None. """ + gamma: float = 0.995 lmb: float = 0.97 num_steps_per_iteration: int = 5000 @@ -90,68 +95,74 @@ def __post_init__(self): Check the values are in valid range. """ - self._assert_between(self.pi_batch_size, 0, self.num_steps_per_iteration, 'pi_batch_size') - self._assert_between(self.gamma, 0.0, 1.0, 'gamma') - self._assert_between(self.lmb, 0.0, 1.0, 'lmb') - self._assert_positive(self.num_steps_per_iteration, 'num_steps_per_iteration') - self._assert_between(self.pi_batch_size, 0, self.num_steps_per_iteration, 'pi_batch_size') - self._assert_positive(self.sigma_kl_divergence_constraint, 'sigma_kl_divergence_constraint') - self._assert_positive(self.maximum_backtrack_numbers, 'maximum_backtrack_numbers') - self._assert_positive(self.conjugate_gradient_damping, 'conjugate_gradient_damping') - self._assert_positive(self.conjugate_gradient_iterations, 'conjugate_gradient_iterations') - self._assert_positive(self.vf_epochs, 'vf_epochs') - self._assert_positive(self.vf_batch_size, 'vf_batch_size') - self._assert_positive(self.vf_learning_rate, 'vf_learning_rate') + self._assert_between(self.pi_batch_size, 0, self.num_steps_per_iteration, "pi_batch_size") + self._assert_between(self.gamma, 0.0, 1.0, "gamma") + self._assert_between(self.lmb, 0.0, 1.0, "lmb") + self._assert_positive(self.num_steps_per_iteration, "num_steps_per_iteration") + self._assert_between(self.pi_batch_size, 0, self.num_steps_per_iteration, "pi_batch_size") + self._assert_positive(self.sigma_kl_divergence_constraint, "sigma_kl_divergence_constraint") + self._assert_positive(self.maximum_backtrack_numbers, "maximum_backtrack_numbers") + self._assert_positive(self.conjugate_gradient_damping, "conjugate_gradient_damping") + self._assert_positive(self.conjugate_gradient_iterations, "conjugate_gradient_iterations") + self._assert_positive(self.vf_epochs, "vf_epochs") + self._assert_positive(self.vf_batch_size, "vf_batch_size") + self._assert_positive(self.vf_learning_rate, "vf_learning_rate") class DefaultPolicyBuilder(ModelBuilder[StochasticPolicy]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: TRPOConfig, - **kwargs) -> StochasticPolicy: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: TRPOConfig, + **kwargs, + ) -> StochasticPolicy: return TRPOPolicy(scope_name, env_info.action_dim) class DefaultVFunctionBuilder(ModelBuilder[VFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: TRPOConfig, - **kwargs) -> VFunction: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: TRPOConfig, + **kwargs, + ) -> VFunction: return TRPOVFunction(scope_name) class DefaultSolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: TRPOConfig, - **kwargs) -> nn.solver.Solver: + def build_solver( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: TRPOConfig, **kwargs + ) -> nn.solver.Solver: return NS.Adam(alpha=algorithm_config.vf_learning_rate) class DefaultPreprocessorBuilder(PreprocessorBuilder): - def build_preprocessor(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: TRPOConfig, - **kwargs) -> Preprocessor: + def build_preprocessor( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: TRPOConfig, + **kwargs, + ) -> Preprocessor: return RP.RunningMeanNormalizer(scope_name, env_info.state_shape, value_clip=(-5.0, 5.0)) class DefaultExplorerBuilder(ExplorerBuilder): - def build_explorer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: TRPOConfig, - algorithm: "TRPO", - **kwargs) -> EnvironmentExplorer: + def build_explorer( # type: ignore[override] + self, + env_info: EnvironmentInfo, + algorithm_config: TRPOConfig, + algorithm: "TRPO", + **kwargs, + ) -> EnvironmentExplorer: explorer_config = EE.RawPolicyExplorerConfig( - initial_step_num=algorithm.iteration_num, - timelimit_as_terminal=False + initial_step_num=algorithm.iteration_num, timelimit_as_terminal=False + ) + explorer = EE.RawPolicyExplorer( + policy_action_selector=algorithm._exploration_action_selector, env_info=env_info, config=explorer_config ) - explorer = EE.RawPolicyExplorer(policy_action_selector=algorithm._exploration_action_selector, - env_info=env_info, - config=explorer_config) return explorer @@ -201,25 +212,27 @@ class TRPO(Algorithm): _policy_trainer_state: Dict[str, Any] _v_function_trainer_state: Dict[str, Any] - def __init__(self, - env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: TRPOConfig = TRPOConfig(), - v_function_builder: ModelBuilder[VFunction] = DefaultVFunctionBuilder(), - v_solver_builder: SolverBuilder = DefaultSolverBuilder(), - policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), - state_preprocessor_builder: Optional[PreprocessorBuilder] = DefaultPreprocessorBuilder(), - explorer_builder: ExplorerBuilder = DefaultExplorerBuilder()): + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: TRPOConfig = TRPOConfig(), + v_function_builder: ModelBuilder[VFunction] = DefaultVFunctionBuilder(), + v_solver_builder: SolverBuilder = DefaultSolverBuilder(), + policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), + state_preprocessor_builder: Optional[PreprocessorBuilder] = DefaultPreprocessorBuilder(), + explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), + ): super(TRPO, self).__init__(env_or_env_info, config=config) self._explorer_builder = explorer_builder with nn.context_scope(context.get_nnabla_context(self._config.gpu_id)): - self._v_function = v_function_builder('v', self._env_info, self._config) - self._policy = policy_builder('pi', self._env_info, self._config) + self._v_function = v_function_builder("v", self._env_info, self._config) + self._policy = policy_builder("pi", self._env_info, self._config) self._preprocessor: Optional[Preprocessor] = None if self._config.preprocess_state and state_preprocessor_builder is not None: - preprocessor = state_preprocessor_builder('preprocessor', self._env_info, self._config) + preprocessor = state_preprocessor_builder("preprocessor", self._env_info, self._config) assert preprocessor is not None self._v_function = _StatePreprocessedVFunction(v_function=self._v_function, preprocessor=preprocessor) self._policy = _StatePreprocessedStochasticPolicy(policy=self._policy, preprocessor=preprocessor) @@ -227,9 +240,11 @@ def __init__(self, self._v_function_solver = v_solver_builder(self._env_info, self._config) self._evaluation_actor = _StochasticPolicyActionSelector( - self._env_info, self._policy.shallowcopy(), deterministic=False) + self._env_info, self._policy.shallowcopy(), deterministic=False + ) self._exploration_actor = _StochasticPolicyActionSelector( - self._env_info, self._policy.shallowcopy(), deterministic=False) + self._env_info, self._policy.shallowcopy(), deterministic=False + ) @eval_api def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}): @@ -248,15 +263,13 @@ def _setup_environment_explorer(self, env_or_buffer): return None if self._is_buffer(env_or_buffer) else self._explorer_builder(self._env_info, self._config, self) def _setup_v_function_training(self, env_or_buffer): - v_function_trainer_config = MT.v_value.MonteCarloVTrainerConfig( - reduction_method='mean', - v_loss_scalar=1.0 - ) + v_function_trainer_config = MT.v_value.MonteCarloVTrainerConfig(reduction_method="mean", v_loss_scalar=1.0) v_function_trainer = MT.v_value.MonteCarloVTrainer( train_functions=self._v_function, solvers={self._v_function.scope_name: self._v_function_solver}, env_info=self._env_info, - config=v_function_trainer_config) + config=v_function_trainer_config, + ) return v_function_trainer def _setup_policy_training(self, env_or_buffer): @@ -265,11 +278,11 @@ def _setup_policy_training(self, env_or_buffer): sigma_kl_divergence_constraint=self._config.sigma_kl_divergence_constraint, maximum_backtrack_numbers=self._config.maximum_backtrack_numbers, conjugate_gradient_damping=self._config.conjugate_gradient_damping, - conjugate_gradient_iterations=self._config.conjugate_gradient_iterations) + conjugate_gradient_iterations=self._config.conjugate_gradient_iterations, + ) policy_trainer = MT.policy_trainers.TRPOPolicyTrainer( - model=self._policy, - env_info=self._env_info, - config=policy_trainer_config) + model=self._policy, env_info=self._env_info, config=policy_trainer_config + ) return policy_trainer def _run_online_training_iteration(self, env): @@ -307,10 +320,12 @@ def _align_experiences(self, buffer_iterator): s_batch, a_batch = self._align_state_and_action(buffer_iterator) - return s_batch[:self._config.num_steps_per_iteration], \ - a_batch[:self._config.num_steps_per_iteration], \ - v_target_batch[:self._config.num_steps_per_iteration], \ - adv_batch[:self._config.num_steps_per_iteration] + return ( + s_batch[: self._config.num_steps_per_iteration], + a_batch[: self._config.num_steps_per_iteration], + v_target_batch[: self._config.num_steps_per_iteration], + adv_batch[: self._config.num_steps_per_iteration], + ) def _compute_v_target_and_advantage(self, buffer_iterator): v_target_batch = [] @@ -319,7 +334,8 @@ def _compute_v_target_and_advantage(self, buffer_iterator): for experiences, _ in buffer_iterator: # length of experiences is 1 v_target, adv = compute_v_target_and_advantage( - self._v_function, experiences[0], gamma=self._config.gamma, lmb=self._config.lmb) + self._v_function, experiences[0], gamma=self._config.gamma, lmb=self._config.lmb + ) v_target_batch.append(v_target.reshape(-1, 1)) adv_batch.append(adv.reshape(-1, 1)) @@ -351,19 +367,21 @@ def _v_function_training(self, s, v_target): for _ in range(self._config.vf_epochs * num_iterations_per_epoch): indices = np.random.randint(0, self._config.num_steps_per_iteration, size=self._config.vf_batch_size) - batch = TrainingBatch(batch_size=self._config.vf_batch_size, - s_current=s[indices], - extra={'v_target': v_target[indices]}) + batch = TrainingBatch( + batch_size=self._config.vf_batch_size, s_current=s[indices], extra={"v_target": v_target[indices]} + ) self._v_function_trainer_state = self._v_function_trainer.train(batch) def _policy_training(self, s, a, v_target, advantage): extra = {} - extra['v_target'] = v_target[:self._config.pi_batch_size] - extra['advantage'] = advantage[:self._config.pi_batch_size] - batch = TrainingBatch(batch_size=self._config.pi_batch_size, - s_current=s[:self._config.pi_batch_size], - a_current=a[:self._config.pi_batch_size], - extra=extra) + extra["v_target"] = v_target[: self._config.pi_batch_size] + extra["advantage"] = advantage[: self._config.pi_batch_size] + batch = TrainingBatch( + batch_size=self._config.pi_batch_size, + s_current=s[: self._config.pi_batch_size], + a_current=a[: self._config.pi_batch_size], + extra=extra, + ) self._policy_trainer_state = self._policy_trainer.train(batch) @@ -388,15 +406,16 @@ def _solvers(self): @classmethod def is_supported_env(cls, env_or_env_info): - env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \ - else env_or_env_info + env_info = ( + EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info + ) return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env() @property def latest_iteration_state(self): latest_iteration_state = super(TRPO, self).latest_iteration_state - if hasattr(self, '_v_function_trainer_state'): - latest_iteration_state['scalar'].update({'v_loss': float(self._v_function_trainer_state['v_loss'])}) + if hasattr(self, "_v_function_trainer_state"): + latest_iteration_state["scalar"].update({"v_loss": float(self._v_function_trainer_state["v_loss"])}) return latest_iteration_state def trainers(self): diff --git a/nnabla_rl/algorithms/xql.py b/nnabla_rl/algorithms/xql.py index c3a4f577..3ae76355 100644 --- a/nnabla_rl/algorithms/xql.py +++ b/nnabla_rl/algorithms/xql.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -89,7 +89,7 @@ class XQLConfig(AlgorithmConfig): """ gamma: float = 0.99 - learning_rate: float = 3.0*1e-4 + learning_rate: float = 3.0 * 1e-4 batch_size: int = 256 tau: float = 0.005 value_temperature: float = 2.0 @@ -113,79 +113,85 @@ class XQLConfig(AlgorithmConfig): def __post_init__(self): """__post_init__ Check set values are in valid range.""" - self._assert_between(self.gamma, 0.0, 1.0, 'gamma') - self._assert_positive(self.learning_rate, 'learning_rate') - self._assert_positive(self.batch_size, 'batch_size') - self._assert_between(self.tau, 0.0, 1.0, 'tau') - self._assert_positive(self.start_timesteps, 'start_timesteps') - self._assert_positive(self.replay_buffer_size, 'replay_buffer_size') - self._assert_positive(self.num_steps, 'num_steps') - - self._assert_positive(self.pi_unroll_steps, 'pi_unroll_steps') - self._assert_positive_or_zero(self.pi_burn_in_steps, 'pi_burn_in_steps') - self._assert_positive(self.q_unroll_steps, 'q_unroll_steps') - self._assert_positive_or_zero(self.q_burn_in_steps, 'q_burn_in_steps') - self._assert_positive(self.v_unroll_steps, 'v_unroll_steps') - self._assert_positive_or_zero(self.v_burn_in_steps, 'v_burn_in_steps') + self._assert_between(self.gamma, 0.0, 1.0, "gamma") + self._assert_positive(self.learning_rate, "learning_rate") + self._assert_positive(self.batch_size, "batch_size") + self._assert_between(self.tau, 0.0, 1.0, "tau") + self._assert_positive(self.start_timesteps, "start_timesteps") + self._assert_positive(self.replay_buffer_size, "replay_buffer_size") + self._assert_positive(self.num_steps, "num_steps") + + self._assert_positive(self.pi_unroll_steps, "pi_unroll_steps") + self._assert_positive_or_zero(self.pi_burn_in_steps, "pi_burn_in_steps") + self._assert_positive(self.q_unroll_steps, "q_unroll_steps") + self._assert_positive_or_zero(self.q_burn_in_steps, "q_burn_in_steps") + self._assert_positive(self.v_unroll_steps, "v_unroll_steps") + self._assert_positive_or_zero(self.v_burn_in_steps, "v_burn_in_steps") class DefaultQFunctionBuilder(ModelBuilder[QFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: XQLConfig, - **kwargs) -> QFunction: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: XQLConfig, + **kwargs, + ) -> QFunction: return XQLQFunction(scope_name) class DefaultVFunctionBuilder(ModelBuilder[VFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: XQLConfig, - **kwargs) -> VFunction: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: XQLConfig, + **kwargs, + ) -> VFunction: return XQLVFunction(scope_name) class DefaultPolicyBuilder(ModelBuilder[StochasticPolicy]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: XQLConfig, - **kwargs) -> StochasticPolicy: + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: XQLConfig, + **kwargs, + ) -> StochasticPolicy: return XQLPolicy(scope_name, env_info.action_dim) class DefaultSolverBuilder(SolverBuilder): - def build_solver(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: XQLConfig, - **kwargs) -> nn.solver.Solver: + def build_solver( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: XQLConfig, **kwargs + ) -> nn.solver.Solver: return NS.Adam(alpha=algorithm_config.learning_rate) class DefaultReplayBufferBuilder(ReplayBufferBuilder): - def build_replay_buffer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: XQLConfig, - **kwargs) -> ReplayBuffer: + def build_replay_buffer( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: XQLConfig, **kwargs + ) -> ReplayBuffer: return ReplayBuffer(capacity=algorithm_config.replay_buffer_size) class DefaultExplorerBuilder(ExplorerBuilder): - def build_explorer(self, # type: ignore[override] - env_info: EnvironmentInfo, - algorithm_config: XQLConfig, - algorithm: "XQL", - **kwargs) -> EnvironmentExplorer: + def build_explorer( # type: ignore[override] + self, + env_info: EnvironmentInfo, + algorithm_config: XQLConfig, + algorithm: "XQL", + **kwargs, + ) -> EnvironmentExplorer: explorer_config = EE.RawPolicyExplorerConfig( warmup_random_steps=algorithm_config.start_timesteps, initial_step_num=algorithm.iteration_num, - timelimit_as_terminal=False + timelimit_as_terminal=False, + ) + explorer = EE.RawPolicyExplorer( + policy_action_selector=algorithm._exploration_action_selector, env_info=env_info, config=explorer_config ) - explorer = EE.RawPolicyExplorer(policy_action_selector=algorithm._exploration_action_selector, - env_info=env_info, - config=explorer_config) return explorer @@ -243,16 +249,19 @@ class XQL(Algorithm): _v_function_trainer_state: Dict[str, Any] _q_function_trainer_state: Dict[str, Any] - def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], - config: XQLConfig = XQLConfig(), - q_function_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), - q_solver_builder: SolverBuilder = DefaultSolverBuilder(), - v_function_builder: ModelBuilder[VFunction] = DefaultVFunctionBuilder(), - v_solver_builder: SolverBuilder = DefaultSolverBuilder(), - policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), - policy_solver_builder: SolverBuilder = DefaultSolverBuilder(), - replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), - explorer_builder: ExplorerBuilder = DefaultExplorerBuilder()): + def __init__( + self, + env_or_env_info: Union[gym.Env, EnvironmentInfo], + config: XQLConfig = XQLConfig(), + q_function_builder: ModelBuilder[QFunction] = DefaultQFunctionBuilder(), + q_solver_builder: SolverBuilder = DefaultSolverBuilder(), + v_function_builder: ModelBuilder[VFunction] = DefaultVFunctionBuilder(), + v_solver_builder: SolverBuilder = DefaultSolverBuilder(), + policy_builder: ModelBuilder[StochasticPolicy] = DefaultPolicyBuilder(), + policy_solver_builder: SolverBuilder = DefaultSolverBuilder(), + replay_buffer_builder: ReplayBufferBuilder = DefaultReplayBufferBuilder(), + explorer_builder: ExplorerBuilder = DefaultExplorerBuilder(), + ): super(XQL, self).__init__(env_or_env_info, config=config) self._explorer_builder = explorer_builder @@ -265,8 +274,9 @@ def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], self._train_q_solvers = {} for q in self._train_q_functions: self._train_q_solvers[q.scope_name] = q_solver_builder( - env_info=self._env_info, algorithm_config=self._config) - self._target_q_functions = [q.deepcopy('target_' + q.scope_name) for q in self._train_q_functions] + env_info=self._env_info, algorithm_config=self._config + ) + self._target_q_functions = [q.deepcopy("target_" + q.scope_name) for q in self._train_q_functions] self._v_function = v_function_builder("v", env_info=self._env_info, algorithm_config=self._config) self._v_solver = v_solver_builder(self._env_info, self._config) @@ -277,9 +287,11 @@ def __init__(self, env_or_env_info: Union[gym.Env, EnvironmentInfo], self._replay_buffer = replay_buffer_builder(self._env_info, self._config) self._evaluation_actor = _StochasticPolicyActionSelector( - self._env_info, self._pi.shallowcopy(), deterministic=True) + self._env_info, self._pi.shallowcopy(), deterministic=True + ) self._exploration_actor = _StochasticPolicyActionSelector( - self._env_info, self._pi.shallowcopy(), deterministic=False) + self._env_info, self._pi.shallowcopy(), deterministic=False + ) @eval_api def compute_eval_action(self, state, *, begin_of_episode=False, extra_info={}): @@ -305,46 +317,51 @@ def _setup_policy_training(self, env_or_buffer): beta=self._config.policy_temperature, unroll_steps=self._config.pi_unroll_steps, burn_in_steps=self._config.pi_burn_in_steps, - reset_on_terminal=self._config.pi_reset_rnn_on_terminal) + reset_on_terminal=self._config.pi_reset_rnn_on_terminal, + ) policy_trainer = MT.policy_trainers.XQLForwardPolicyTrainer( models=self._pi, solvers={self._pi.scope_name: self._pi_solver}, q_functions=self._target_q_functions, v_function=self._v_function, env_info=self._env_info, - config=policy_trainer_config) + config=policy_trainer_config, + ) return policy_trainer else: raise NotImplementedError def _setup_q_function_training(self, env_or_buffer): q_function_trainer_config = MT.q_value_trainers.VTargetedQTrainerConfig( - loss_type='huber', - reduction_method='mean', + loss_type="huber", + reduction_method="mean", num_steps=self._config.num_steps, huber_delta=20.0, q_loss_scalar=0.5, unroll_steps=self._config.q_unroll_steps, burn_in_steps=self._config.q_burn_in_steps, - reset_on_terminal=self._config.q_reset_rnn_on_terminal) + reset_on_terminal=self._config.q_reset_rnn_on_terminal, + ) q_function_trainer = MT.q_value_trainers.VTargetedQTrainer( train_functions=self._train_q_functions, solvers=self._train_q_solvers, target_functions=self._v_function, env_info=self._env_info, - config=q_function_trainer_config) + config=q_function_trainer_config, + ) return q_function_trainer def _setup_v_function_training(self, env_or_buffer): is_offline = isinstance(env_or_buffer, ReplayBuffer) v_function_trainer_config = MT.v_value_trainers.XQLVTrainerConfig( - reduction_method='mean', + reduction_method="mean", beta=self._config.value_temperature, unroll_steps=self._config.v_unroll_steps, burn_in_steps=self._config.v_burn_in_steps, - reset_on_terminal=self._config.v_reset_rnn_on_terminal) + reset_on_terminal=self._config.v_reset_rnn_on_terminal, + ) v_function_trainer = MT.v_value_trainers.XQLVTrainer( train_functions=self._v_function, @@ -352,7 +369,8 @@ def _setup_v_function_training(self, env_or_buffer): target_functions=self._target_q_functions, target_policy=None if is_offline else self._pi, env_info=self._env_info, - config=v_function_trainer_config) + config=v_function_trainer_config, + ) return v_function_trainer def _run_online_training_iteration(self, env): @@ -368,23 +386,25 @@ def _xql_training(self, replay_buffer): num_steps = max(pi_steps, max(q_steps, v_steps)) experiences_tuple, info = replay_buffer.sample(self._config.batch_size, num_steps=num_steps) if num_steps == 1: - experiences_tuple = (experiences_tuple, ) + experiences_tuple = (experiences_tuple,) assert len(experiences_tuple) == num_steps batch = None for experiences in reversed(experiences_tuple): (s, a, r, non_terminal, s_next, rnn_states_dict, *_) = marshal_experiences(experiences) - rnn_states = rnn_states_dict['rnn_states'] if 'rnn_states' in rnn_states_dict else {} - batch = TrainingBatch(batch_size=self._config.batch_size, - s_current=s, - a_current=a, - gamma=self._config.gamma, - reward=r, - non_terminal=non_terminal, - s_next=s_next, - weight=info['weights'], - next_step_batch=batch, - rnn_states=rnn_states) + rnn_states = rnn_states_dict["rnn_states"] if "rnn_states" in rnn_states_dict else {} + batch = TrainingBatch( + batch_size=self._config.batch_size, + s_current=s, + a_current=a, + gamma=self._config.gamma, + reward=r, + non_terminal=non_terminal, + s_next=s_next, + weight=info["weights"], + next_step_batch=batch, + rnn_states=rnn_states, + ) self._v_function_trainer_state = self._v_function_trainer.train(batch) self._policy_trainer_state = self._policy_trainer.train(batch) @@ -392,7 +412,7 @@ def _xql_training(self, replay_buffer): for q, target_q in zip(self._train_q_functions, self._target_q_functions): sync_model(q, target_q, tau=self._config.tau) - td_errors = self._q_function_trainer_state['td_errors'] + td_errors = self._q_function_trainer_state["td_errors"] replay_buffer.update_priorities(td_errors) def _evaluation_action_selector(self, s, *, begin_of_episode=False): @@ -418,25 +438,29 @@ def is_rnn_supported(self): @classmethod def is_supported_env(cls, env_or_env_info): - env_info = EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) \ - else env_or_env_info + env_info = ( + EnvironmentInfo.from_env(env_or_env_info) if isinstance(env_or_env_info, gym.Env) else env_or_env_info + ) return not env_info.is_discrete_action_env() and not env_info.is_tuple_action_env() @property def latest_iteration_state(self): latest_iteration_state = super(XQL, self).latest_iteration_state - if hasattr(self, '_policy_trainer_state'): - latest_iteration_state['scalar'].update({'pi_loss': float(self._policy_trainer_state['pi_loss'])}) - if hasattr(self, '_v_function_trainer_state'): - latest_iteration_state['scalar'].update({'v_loss': float(self._v_function_trainer_state['v_loss'])}) - if hasattr(self, '_q_function_trainer_state'): - latest_iteration_state['scalar'].update({'q_loss': float(self._q_function_trainer_state['q_loss'])}) - latest_iteration_state['histogram'].update( - {'td_errors': self._q_function_trainer_state['td_errors'].flatten()}) + if hasattr(self, "_policy_trainer_state"): + latest_iteration_state["scalar"].update({"pi_loss": float(self._policy_trainer_state["pi_loss"])}) + if hasattr(self, "_v_function_trainer_state"): + latest_iteration_state["scalar"].update({"v_loss": float(self._v_function_trainer_state["v_loss"])}) + if hasattr(self, "_q_function_trainer_state"): + latest_iteration_state["scalar"].update({"q_loss": float(self._q_function_trainer_state["q_loss"])}) + latest_iteration_state["histogram"].update( + {"td_errors": self._q_function_trainer_state["td_errors"].flatten()} + ) return latest_iteration_state @property def trainers(self): - return {"q_function": self._q_function_trainer, - "v_function": self._v_function_trainer, - "policy": self._policy_trainer} + return { + "q_function": self._q_function_trainer, + "v_function": self._v_function_trainer, + "policy": self._policy_trainer, + } diff --git a/nnabla_rl/builders/explorer_builder.py b/nnabla_rl/builders/explorer_builder.py index 3387370c..7d81b704 100644 --- a/nnabla_rl/builders/explorer_builder.py +++ b/nnabla_rl/builders/explorer_builder.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,21 +19,16 @@ class ExplorerBuilder(object): - """Explorer builder interface class - """ - - def __call__(self, - env_info: EnvironmentInfo, - algorithm_config: AlgorithmConfig, - algorithm: Algorithm, - **kwargs) -> EnvironmentExplorer: + """Explorer builder interface class""" + + def __call__( + self, env_info: EnvironmentInfo, algorithm_config: AlgorithmConfig, algorithm: Algorithm, **kwargs + ) -> EnvironmentExplorer: return self.build_explorer(env_info, algorithm_config, algorithm, **kwargs) - def build_explorer(self, - env_info: EnvironmentInfo, - algorithm_config: AlgorithmConfig, - algorithm: Algorithm, - **kwargs) -> EnvironmentExplorer: + def build_explorer( + self, env_info: EnvironmentInfo, algorithm_config: AlgorithmConfig, algorithm: Algorithm, **kwargs + ) -> EnvironmentExplorer: """Build explorer. Args: diff --git a/nnabla_rl/builders/lr_scheduler_builder.py b/nnabla_rl/builders/lr_scheduler_builder.py index 9db23162..c4658166 100644 --- a/nnabla_rl/builders/lr_scheduler_builder.py +++ b/nnabla_rl/builders/lr_scheduler_builder.py @@ -1,5 +1,4 @@ - -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,20 +18,15 @@ from nnabla_rl.environments.environment_info import EnvironmentInfo -class LearningRateSchedulerBuilder(): - """Learning rate scheduler builder interface class - """ +class LearningRateSchedulerBuilder: + """Learning rate scheduler builder interface class""" - def __call__(self, - env_info: EnvironmentInfo, - algorithm_config: AlgorithmConfig, - **kwargs) -> nn.solver.Solver: + def __call__(self, env_info: EnvironmentInfo, algorithm_config: AlgorithmConfig, **kwargs) -> nn.solver.Solver: return self.build_scheduler(env_info, algorithm_config, **kwargs) - def build_scheduler(self, - env_info: EnvironmentInfo, - algorithm_config: AlgorithmConfig, - **kwargs) -> BaseLearningRateScheduler: + def build_scheduler( + self, env_info: EnvironmentInfo, algorithm_config: AlgorithmConfig, **kwargs + ) -> BaseLearningRateScheduler: """Build learning rate scheduler function Args: diff --git a/nnabla_rl/builders/model_builder.py b/nnabla_rl/builders/model_builder.py index 713dc649..823db2d8 100644 --- a/nnabla_rl/builders/model_builder.py +++ b/nnabla_rl/builders/model_builder.py @@ -1,5 +1,5 @@ # Copyright 2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,25 +18,16 @@ from nnabla_rl.algorithm import AlgorithmConfig from nnabla_rl.environments.environment_info import EnvironmentInfo -T = TypeVar('T') +T = TypeVar("T") class ModelBuilder(Generic[T]): - """Model builder interface class - """ - - def __call__(self, - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: AlgorithmConfig, - **kwargs) -> T: + """Model builder interface class""" + + def __call__(self, scope_name: str, env_info: EnvironmentInfo, algorithm_config: AlgorithmConfig, **kwargs) -> T: return self.build_model(scope_name, env_info, algorithm_config, **kwargs) - def build_model(self, - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: AlgorithmConfig, - **kwargs) -> T: + def build_model(self, scope_name: str, env_info: EnvironmentInfo, algorithm_config: AlgorithmConfig, **kwargs) -> T: """Build model. Args: diff --git a/nnabla_rl/builders/preprocessor_builder.py b/nnabla_rl/builders/preprocessor_builder.py index 9fad3a47..9f0b857b 100644 --- a/nnabla_rl/builders/preprocessor_builder.py +++ b/nnabla_rl/builders/preprocessor_builder.py @@ -1,5 +1,5 @@ # Copyright 2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,22 +18,17 @@ from nnabla_rl.preprocessors.preprocessor import Preprocessor -class PreprocessorBuilder(): - """Preprocessor builder interface class - """ +class PreprocessorBuilder: + """Preprocessor builder interface class""" - def __call__(self, - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: AlgorithmConfig, - **kwargs) -> Preprocessor: + def __call__( + self, scope_name: str, env_info: EnvironmentInfo, algorithm_config: AlgorithmConfig, **kwargs + ) -> Preprocessor: return self.build_preprocessor(scope_name, env_info, algorithm_config, **kwargs) - def build_preprocessor(self, - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: AlgorithmConfig, - **kwargs) -> Preprocessor: + def build_preprocessor( + self, scope_name: str, env_info: EnvironmentInfo, algorithm_config: AlgorithmConfig, **kwargs + ) -> Preprocessor: """Build preprocessor Args: diff --git a/nnabla_rl/builders/replay_buffer_builder.py b/nnabla_rl/builders/replay_buffer_builder.py index 85844f6e..01aef045 100644 --- a/nnabla_rl/builders/replay_buffer_builder.py +++ b/nnabla_rl/builders/replay_buffer_builder.py @@ -1,5 +1,5 @@ # Copyright 2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,20 +18,15 @@ from nnabla_rl.replay_buffer import ReplayBuffer -class ReplayBufferBuilder(): - """ReplayBuffer builder interface class - """ +class ReplayBufferBuilder: + """ReplayBuffer builder interface class""" - def __call__(self, - env_info: EnvironmentInfo, - algorithm_config: AlgorithmConfig, - **kwargs) -> ReplayBuffer: + def __call__(self, env_info: EnvironmentInfo, algorithm_config: AlgorithmConfig, **kwargs) -> ReplayBuffer: return self.build_replay_buffer(env_info, algorithm_config, **kwargs) - def build_replay_buffer(self, - env_info: EnvironmentInfo, - algorithm_config: AlgorithmConfig, - **kwargs) -> ReplayBuffer: + def build_replay_buffer( + self, env_info: EnvironmentInfo, algorithm_config: AlgorithmConfig, **kwargs + ) -> ReplayBuffer: """Build replay buffer Args: diff --git a/nnabla_rl/builders/solver_builder.py b/nnabla_rl/builders/solver_builder.py index 103551b9..81e54563 100644 --- a/nnabla_rl/builders/solver_builder.py +++ b/nnabla_rl/builders/solver_builder.py @@ -1,5 +1,5 @@ # Copyright 2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,20 +18,13 @@ from nnabla_rl.environments.environment_info import EnvironmentInfo -class SolverBuilder(): - """Solver builder interface class - """ +class SolverBuilder: + """Solver builder interface class""" - def __call__(self, - env_info: EnvironmentInfo, - algorithm_config: AlgorithmConfig, - **kwargs) -> nn.solver.Solver: + def __call__(self, env_info: EnvironmentInfo, algorithm_config: AlgorithmConfig, **kwargs) -> nn.solver.Solver: return self.build_solver(env_info, algorithm_config, **kwargs) - def build_solver(self, - env_info: EnvironmentInfo, - algorithm_config: AlgorithmConfig, - **kwargs) -> nn.solver.Solver: + def build_solver(self, env_info: EnvironmentInfo, algorithm_config: AlgorithmConfig, **kwargs) -> nn.solver.Solver: """Build solver function Args: diff --git a/nnabla_rl/configuration.py b/nnabla_rl/configuration.py index 763c7d85..f22af521 100644 --- a/nnabla_rl/configuration.py +++ b/nnabla_rl/configuration.py @@ -1,5 +1,5 @@ # Copyright 2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ @dataclass -class Configuration(): +class Configuration: def __post_init__(self): pass @@ -26,47 +26,46 @@ def to_dict(self): def _assert_positive(self, config, config_name): if config <= 0: - raise ValueError('{} must be positive'.format(config_name)) + raise ValueError("{} must be positive".format(config_name)) def _assert_positive_or_zero(self, config, config_name): if config < 0: - raise ValueError('{} must be positive'.format(config_name)) + raise ValueError("{} must be positive".format(config_name)) def _assert_negative(self, config, config_name): if 0 <= config: - raise ValueError('{} must be negative'.format(config_name)) + raise ValueError("{} must be negative".format(config_name)) def _assert_negative_or_zero(self, config, config_name): if 0 < config: - raise ValueError('{} must be positive'.format(config_name)) + raise ValueError("{} must be positive".format(config_name)) def _assert_between(self, config, low, high, config_name): if not (low <= config and config <= high): - raise ValueError( - '{} must lie between [{}, {}]'.format(config_name, low, high)) + raise ValueError("{} must lie between [{}, {}]".format(config_name, low, high)) def _assert_one_of(self, config, choices, config_name): if config not in choices: - raise ValueError(f'{config_name} is not available. Available choices: {choices}') + raise ValueError(f"{config_name} is not available. Available choices: {choices}") def _assert_ascending_order(self, config, config_name): - ascending = all(config[i] <= config[i+1] for i in range(len(config)-1)) + ascending = all(config[i] <= config[i + 1] for i in range(len(config) - 1)) if not ascending: - raise ValueError(f'{config_name} is not in ascending order!: {config}') + raise ValueError(f"{config_name} is not in ascending order!: {config}") def _assert_descending_order(self, config, config_name): - descending = all(config[i] >= config[i+1] for i in range(len(config)-1)) + descending = all(config[i] >= config[i + 1] for i in range(len(config) - 1)) if not descending: - raise ValueError(f'{config_name} is not in descending order!: {config}') + raise ValueError(f"{config_name} is not in descending order!: {config}") def _assert_smaller_than(self, config, ref_value, config_name): if config > ref_value: - raise ValueError(f'{config_name} is not in smaller than reference value!: {config} > {ref_value}') + raise ValueError(f"{config_name} is not in smaller than reference value!: {config} > {ref_value}") def _assert_greater_than(self, config, ref_value, config_name): if config < ref_value: - raise ValueError(f'{config_name} is not greater than reference value!: {config} < {ref_value}') + raise ValueError(f"{config_name} is not greater than reference value!: {config} < {ref_value}") def _assert_length(self, config, expected_length, config_name): if len(config) != expected_length: - raise ValueError(f'{config_name} length is not {expected_length}') + raise ValueError(f"{config_name} length is not {expected_length}") diff --git a/nnabla_rl/distributions/bernoulli.py b/nnabla_rl/distributions/bernoulli.py index 55818845..b6124926 100644 --- a/nnabla_rl/distributions/bernoulli.py +++ b/nnabla_rl/distributions/bernoulli.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -90,6 +90,8 @@ def entropy(self): def kl_divergence(self, q): assert isinstance(q, Bernoulli) - return NF.sum(self._distribution * (self._log_distribution - q._log_distribution), - axis=len(self._distribution.shape) - 1, - keepdims=True) + return NF.sum( + self._distribution * (self._log_distribution - q._log_distribution), + axis=len(self._distribution.shape) - 1, + keepdims=True, + ) diff --git a/nnabla_rl/distributions/common_utils.py b/nnabla_rl/distributions/common_utils.py index 64d7fac2..0fc69be9 100644 --- a/nnabla_rl/distributions/common_utils.py +++ b/nnabla_rl/distributions/common_utils.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,4 +22,4 @@ def gaussian_log_prob(x, mean, var, ln_var): # log N(x|mu, var) # = -0.5*log2*pi - 0.5 * ln_var - 0.5 * (x-mu)**2 / var axis = len(x.shape) - 1 - return NF.sum(-0.5 * np.log(2.0 * np.pi) - 0.5 * ln_var - 0.5 * (x-mean)**2 / var, axis=axis, keepdims=True) + return NF.sum(-0.5 * np.log(2.0 * np.pi) - 0.5 * ln_var - 0.5 * (x - mean) ** 2 / var, axis=axis, keepdims=True) diff --git a/nnabla_rl/distributions/distribution.py b/nnabla_rl/distributions/distribution.py index 23b809fc..043ceb60 100644 --- a/nnabla_rl/distributions/distribution.py +++ b/nnabla_rl/distributions/distribution.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -42,8 +42,9 @@ def ndim(self) -> int: """The number of dimensions of the distribution.""" raise NotImplementedError - def sample_multiple(self, num_samples: int, noise_clip: Optional[Tuple[float, float]] = None - ) -> Union[nn.Variable, np.ndarray]: + def sample_multiple( + self, num_samples: int, noise_clip: Optional[Tuple[float, float]] = None + ) -> Union[nn.Variable, np.ndarray]: """Sample mutiple value from the distribution New axis will be added between the first and second axis. Thefore, the returned value shape for mean and variance with shape (batch_size, data_shape) will be @@ -92,8 +93,9 @@ def log_prob(self, x: Union[nn.Variable, np.ndarray]) -> Union[nn.Variable, np.n """ raise NotImplementedError - def sample_and_compute_log_prob(self, noise_clip: Optional[Tuple[float, float]] = None) \ - -> Union[Tuple[nn.Variable, nn.Variable], Tuple[np.ndarray, np.ndarray]]: + def sample_and_compute_log_prob( + self, noise_clip: Optional[Tuple[float, float]] = None + ) -> Union[Tuple[nn.Variable, nn.Variable], Tuple[np.ndarray, np.ndarray]]: """Sample a value from the distribution and compute its log probability. @@ -114,7 +116,7 @@ def entropy(self) -> Union[nn.Variable, np.ndarray]: """ raise NotImplementedError - def kl_divergence(self, q: 'Distribution') -> Union[nn.Variable, np.ndarray]: + def kl_divergence(self, q: "Distribution") -> Union[nn.Variable, np.ndarray]: """Compute the kullback leibler divergence between given distribution. This function will compute KL(self||q) diff --git a/nnabla_rl/distributions/gaussian.py b/nnabla_rl/distributions/gaussian.py index 31c287e4..f7e46ac3 100644 --- a/nnabla_rl/distributions/gaussian.py +++ b/nnabla_rl/distributions/gaussian.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -43,13 +43,15 @@ def __init__(self, mean: Union[nn.Variable, np.ndarray], ln_var: Union[nn.Variab warnings.warn( "Numpy ndarrays are given as mean and ln_var.\n" "From v0.12.0, if numpy.ndarray is given, " - "all Gaussian class methods return numpy.ndarray not nnabla.Variable") + "all Gaussian class methods return numpy.ndarray not nnabla.Variable" + ) self._delegate = NumpyGaussian(mean, ln_var) elif isinstance(mean, nn.Variable) and isinstance(ln_var, nn.Variable): self._delegate = NnablaGaussian(mean, ln_var) else: raise ValueError( - f"Invalid type or a pair of types, mean type is {type(mean)} and ln type is {type(ln_var)}") + f"Invalid type or a pair of types, mean type is {type(mean)} and ln type is {type(ln_var)}" + ) @property def ndim(self): @@ -106,32 +108,29 @@ def ndim(self): return self._ndim def sample(self, noise_clip=None): - return RF.sample_gaussian(self._mean, - self._ln_var, - noise_clip=noise_clip) + return RF.sample_gaussian(self._mean, self._ln_var, noise_clip=noise_clip) def sample_multiple(self, num_samples, noise_clip=None): - return RF.sample_gaussian_multiple(self._mean, - self._ln_var, - num_samples, - noise_clip=noise_clip) + return RF.sample_gaussian_multiple(self._mean, self._ln_var, num_samples, noise_clip=noise_clip) def sample_and_compute_log_prob(self, noise_clip=None): - x = RF.sample_gaussian(mean=self._mean, - ln_var=self._ln_var, - noise_clip=noise_clip) + x = RF.sample_gaussian(mean=self._mean, ln_var=self._ln_var, noise_clip=noise_clip) return x, self.log_prob(x) def sample_multiple_and_compute_log_prob(self, num_samples, noise_clip=None): - x = RF.sample_gaussian_multiple(self._mean, - self._ln_var, - num_samples, - noise_clip=noise_clip) + x = RF.sample_gaussian_multiple(self._mean, self._ln_var, num_samples, noise_clip=noise_clip) mean = RF.expand_dims(self._mean, axis=1) var = RF.expand_dims(self._var, axis=1) ln_var = RF.expand_dims(self._ln_var, axis=1) - assert mean.shape == (self._batch_size, 1, ) + self._data_dim + assert ( + mean.shape + == ( + self._batch_size, + 1, + ) + + self._data_dim + ) assert var.shape == mean.shape assert ln_var.shape == mean.shape @@ -155,9 +154,9 @@ def entropy(self): def kl_divergence(self, q): assert isinstance(q, NnablaGaussian) p = self - return 0.5 * NF.sum(q._ln_var - p._ln_var + (p._var + (p._mean - q._mean) ** 2.0) / q._var - 1, - axis=1, - keepdims=True) + return 0.5 * NF.sum( + q._ln_var - p._ln_var + (p._var + (p._mean - q._mean) ** 2.0) / q._var - 1, axis=1, keepdims=True + ) class NumpyGaussian(Distribution): @@ -167,7 +166,7 @@ class NumpyGaussian(Distribution): def __init__(self, mean: np.ndarray, ln_var: np.ndarray) -> None: super(Distribution, self).__init__() self._dim = mean.shape[0] - assert (self._dim, ) == mean.shape + assert (self._dim,) == mean.shape assert (self._dim, self._dim) == ln_var.shape self._mean = mean self._var = np.exp(ln_var) @@ -199,7 +198,7 @@ def sample_multiple(self, num_samples, noise_clip=None): def sample_multiple_and_compute_log_prob(self, num_samples, noise_clip=None): raise NotImplementedError - def kl_divergence(self, q: 'Distribution'): + def kl_divergence(self, q: "Distribution"): if not isinstance(q, NumpyGaussian): raise NotImplementedError diff --git a/nnabla_rl/distributions/gmm.py b/nnabla_rl/distributions/gmm.py index 4bdea3e9..05f82cfd 100644 --- a/nnabla_rl/distributions/gmm.py +++ b/nnabla_rl/distributions/gmm.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,21 +36,23 @@ class GMM(ContinuosDistribution): mixing coefficients of each gaussian distribution. :math:`\\pi_k`. """ - def __init__(self, - means: Union[nn.Variable, np.ndarray], - covariances: Union[nn.Variable, np.ndarray], - mixing_coefficients: Union[nn.Variable, np.ndarray]): + def __init__( + self, + means: Union[nn.Variable, np.ndarray], + covariances: Union[nn.Variable, np.ndarray], + mixing_coefficients: Union[nn.Variable, np.ndarray], + ): super(GMM, self).__init__() if ( - isinstance(means, np.ndarray) and - isinstance(covariances, np.ndarray) and - isinstance(mixing_coefficients, np.ndarray) + isinstance(means, np.ndarray) + and isinstance(covariances, np.ndarray) + and isinstance(mixing_coefficients, np.ndarray) ): self._delegate = NumpyGMM(means, covariances, mixing_coefficients) elif ( - isinstance(means, nn.Variable) and - isinstance(covariances, nn.Variable) and - isinstance(mixing_coefficients, nn.Variable) + isinstance(means, nn.Variable) + and isinstance(covariances, nn.Variable) + and isinstance(mixing_coefficients, nn.Variable) ): raise NotImplementedError else: @@ -58,7 +60,8 @@ def __init__(self, "Invalid type or a set of types.\n" f"means type is {type(means)}, " f"covariances type is {type(covariances)} and" - f"mixing_coefficients type is {type(mixing_coefficients)}") + f"mixing_coefficients type is {type(mixing_coefficients)}" + ) def sample(self, noise_clip=None): raise self._delegate.sample(noise_clip) @@ -86,13 +89,13 @@ def __init__(self, means: np.ndarray, covariances: np.ndarray, mixing_coefficien super(NumpyGMM, self).__init__() self._num_classes, self._dim = means.shape assert (self._num_classes, self._dim, self._dim) == covariances.shape - assert (self._num_classes, ) == mixing_coefficients.shape + assert (self._num_classes,) == mixing_coefficients.shape self._means = means # shape (num_classes, dim) self._covariances = covariances # shape (num_classes, dim, dim) self._mixing_coefficients = mixing_coefficients # shape(num_classes, ) @staticmethod - def from_gmm_parameter(parameter: 'GMMParameter') -> 'NumpyGMM': + def from_gmm_parameter(parameter: "GMMParameter") -> "NumpyGMM": return NumpyGMM(parameter._means, parameter._covariances, parameter._mixing_coefficients) def sample(self, noise_clip=None): @@ -130,7 +133,7 @@ def log_prob(self, x): # compute quadratic form value without solving inverse matrix diff = (x - mean).T # shape (dim, num_data) diff_inv_cholesky_decomposed_cov = linalg.solve_triangular(cholesky_decomposed_cov, diff, lower=True) - log_probs[:, i] -= 0.5 * np.sum(diff_inv_cholesky_decomposed_cov ** 2, axis=0) + log_probs[:, i] -= 0.5 * np.sum(diff_inv_cholesky_decomposed_cov**2, axis=0) log_probs += np.log(self._mixing_coefficients) return log_probs @@ -160,8 +163,9 @@ def compute_responsibility(data: np.ndarray, distribution: NumpyGMM) -> Tuple[np return np.exp(log_probs), np.exp(log_responsibility) -def compute_mean_and_covariance(distribution: NumpyGMM, - mixing_coefficients: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray]: +def compute_mean_and_covariance( + distribution: NumpyGMM, mixing_coefficients: Optional[np.ndarray] = None +) -> Tuple[np.ndarray, np.ndarray]: """Compute mean and covariance of gmm. Args: @@ -173,8 +177,7 @@ def compute_mean_and_covariance(distribution: NumpyGMM, """ if mixing_coefficients is None: # sum of all class's mean, shape (dim, ) - mean = np.sum(distribution._means * - distribution._mixing_coefficients[:, np.newaxis], axis=0) + mean = np.sum(distribution._means * distribution._mixing_coefficients[:, np.newaxis], axis=0) else: # sum of all class's mean, shape (dim, ) mean = np.sum(distribution._means * mixing_coefficients[:, np.newaxis], axis=0) @@ -236,9 +239,10 @@ def logsumexp(x: np.ndarray, axis: int = 0, keepdims: bool = False) -> np.ndarra np.ndarray: log sum exp value of the input """ max_x = np.max(x, axis=axis, keepdims=True) - max_x[max_x == -float('inf')] = 0. + max_x[max_x == -float("inf")] = 0.0 if keepdims: - return cast(np.ndarray, np.log(np.sum(np.exp(x-max_x), axis=axis, keepdims=keepdims)) + max_x) + return cast(np.ndarray, np.log(np.sum(np.exp(x - max_x), axis=axis, keepdims=keepdims)) + max_x) else: - return cast(np.ndarray, np.log(np.sum(np.exp(x-max_x), - axis=axis, keepdims=keepdims)) + np.squeeze(max_x, axis=axis)) + return cast( + np.ndarray, np.log(np.sum(np.exp(x - max_x), axis=axis, keepdims=keepdims)) + np.squeeze(max_x, axis=axis) + ) diff --git a/nnabla_rl/distributions/one_hot_softmax.py b/nnabla_rl/distributions/one_hot_softmax.py index 89caeb17..43e8ce7b 100644 --- a/nnabla_rl/distributions/one_hot_softmax.py +++ b/nnabla_rl/distributions/one_hot_softmax.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ def ndim(self): def sample(self, noise_clip=None): sample = NF.random_choice(self._actions, w=self._distribution) - one_hot = NF.one_hot(sample, shape=(self._num_class, )) + one_hot = NF.one_hot(sample, shape=(self._num_class,)) one_hot.need_grad = False # straight through biased gradient estimator assert one_hot.shape == self._distribution.shape @@ -47,7 +47,7 @@ def sample(self, noise_clip=None): def sample_and_compute_log_prob(self, noise_clip=None): sample = NF.random_choice(self._actions, w=self._distribution) log_prob = self.log_prob(sample) - one_hot = NF.one_hot(sample, shape=(self._num_class, )) + one_hot = NF.one_hot(sample, shape=(self._num_class,)) one_hot.need_grad = False # straight through biased gradient estimator assert one_hot.shape == self._distribution.shape @@ -56,7 +56,7 @@ def sample_and_compute_log_prob(self, noise_clip=None): def choose_probable(self): class_index = RF.argmax(self._distribution, axis=len(self._distribution.shape) - 1, keepdims=True) - one_hot = NF.one_hot(class_index, shape=(self._num_class, )) + one_hot = NF.one_hot(class_index, shape=(self._num_class,)) one_hot.need_grad = False # straight through biased gradient estimator assert one_hot.shape == self._distribution.shape diff --git a/nnabla_rl/distributions/softmax.py b/nnabla_rl/distributions/softmax.py index f58f3316..fb23ba97 100644 --- a/nnabla_rl/distributions/softmax.py +++ b/nnabla_rl/distributions/softmax.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -41,8 +41,7 @@ def __init__(self, z): self._batch_size = z.shape[0] self._num_class = z.shape[-1] - labels = np.array( - [label for label in range(self._num_class)], dtype=np.int32) + labels = np.array([label for label in range(self._num_class)], dtype=np.int32) self._labels = nn.Variable.from_numpy_array(labels) self._actions = self._labels for size in reversed(z.shape[0:-1]): @@ -73,7 +72,7 @@ def mean(self): raise NotImplementedError def log_prob(self, x): - one_hot_action = NF.one_hot(x, shape=(self._num_class, )) + one_hot_action = NF.one_hot(x, shape=(self._num_class,)) return NF.sum(self._log_distribution * one_hot_action, axis=len(self._distribution.shape) - 1, keepdims=True) def entropy(self): @@ -83,6 +82,8 @@ def entropy(self): def kl_divergence(self, q): if not isinstance(q, Softmax): raise ValueError("Invalid q to compute kl divergence") - return NF.sum(self._distribution * (self._log_distribution - q._log_distribution), - axis=len(self._distribution.shape) - 1, - keepdims=True) + return NF.sum( + self._distribution * (self._log_distribution - q._log_distribution), + axis=len(self._distribution.shape) - 1, + keepdims=True, + ) diff --git a/nnabla_rl/distributions/squashed_gaussian.py b/nnabla_rl/distributions/squashed_gaussian.py index e5fb1a9e..72ae1118 100644 --- a/nnabla_rl/distributions/squashed_gaussian.py +++ b/nnabla_rl/distributions/squashed_gaussian.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -53,16 +53,11 @@ def ndim(self): return self._ndim def sample(self, noise_clip=None): - x = RF.sample_gaussian(mean=self._mean, - ln_var=self._ln_var, - noise_clip=noise_clip) + x = RF.sample_gaussian(mean=self._mean, ln_var=self._ln_var, noise_clip=noise_clip) return NF.tanh(x) def sample_multiple(self, num_samples, noise_clip=None): - x = RF.sample_gaussian_multiple(self._mean, - self._ln_var, - num_samples, - noise_clip=noise_clip) + x = RF.sample_gaussian_multiple(self._mean, self._ln_var, num_samples, noise_clip=noise_clip) return NF.tanh(x) def sample_and_compute_log_prob(self, noise_clip=None): @@ -71,10 +66,8 @@ def sample_and_compute_log_prob(self, noise_clip=None): If you forward the two variables independently, you'll get a log_prob for different sample, since different random variables are sampled internally.""" - x = RF.sample_gaussian( - mean=self._mean, ln_var=self._ln_var, noise_clip=noise_clip) - log_prob = self._log_prob_internal( - x, self._mean, self._var, self._ln_var) + x = RF.sample_gaussian(mean=self._mean, ln_var=self._ln_var, noise_clip=noise_clip) + log_prob = self._log_prob_internal(x, self._mean, self._var, self._ln_var) return NF.tanh(x), log_prob def sample_multiple_and_compute_log_prob(self, num_samples, noise_clip=None): @@ -83,10 +76,7 @@ def sample_multiple_and_compute_log_prob(self, num_samples, noise_clip=None): If you forward the two variables independently, you'll get a log_prob for different sample, since different random variables are sampled internally.""" - x = RF.sample_gaussian_multiple(self._mean, - self._ln_var, - num_samples=num_samples, - noise_clip=noise_clip) + x = RF.sample_gaussian_multiple(self._mean, self._ln_var, num_samples=num_samples, noise_clip=noise_clip) mean = RF.expand_dims(self._mean, axis=1) var = RF.expand_dims(self._var, axis=1) ln_var = RF.expand_dims(self._ln_var, axis=1) diff --git a/nnabla_rl/environment_explorer.py b/nnabla_rl/environment_explorer.py index 742023aa..8980e27d 100644 --- a/nnabla_rl/environment_explorer.py +++ b/nnabla_rl/environment_explorer.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -46,9 +46,7 @@ class EnvironmentExplorer(metaclass=ABCMeta): _next_state: Union[State, None] _steps: int - def __init__(self, - env_info: EnvironmentInfo, - config: EnvironmentExplorerConfig = EnvironmentExplorerConfig()): + def __init__(self, env_info: EnvironmentInfo, config: EnvironmentExplorerConfig = EnvironmentExplorerConfig()): self._env_info = env_info self._config = config @@ -124,12 +122,12 @@ def _step_once(self, env, *, begin_of_episode=False) -> Tuple[Experience, bool]: if self._steps < self._config.warmup_random_steps: self._action, action_info = self._warmup_action(env, begin_of_episode=begin_of_episode) else: - self._action, action_info = self.action(self._steps, - cast(np.ndarray, self._state), - begin_of_episode=begin_of_episode) + self._action, action_info = self.action( + self._steps, cast(np.ndarray, self._state), begin_of_episode=begin_of_episode + ) self._next_state, r, done, step_info = env.step(self._action) - timelimit = step_info.get('TimeLimit.truncated', False) + timelimit = step_info.get("TimeLimit.truncated", False) if _is_end_of_episode(done, timelimit, self._config.timelimit_as_terminal): non_terminal = 0.0 else: @@ -138,12 +136,14 @@ def _step_once(self, env, *, begin_of_episode=False) -> Tuple[Experience, bool]: extra_info: Dict[str, Any] = {} extra_info.update(action_info) extra_info.update(step_info) - experience = (cast(np.ndarray, self._state), - cast(np.ndarray, self._action), - r * self._config.reward_scalar, - non_terminal, - cast(np.ndarray, self._next_state), - extra_info) + experience = ( + cast(np.ndarray, self._state), + cast(np.ndarray, self._action), + r * self._config.reward_scalar, + non_terminal, + cast(np.ndarray, self._next_state), + extra_info, + ) if done: self._state = env.reset() @@ -168,13 +168,13 @@ def _sample_action(env, env_info): action = [] for a, action_space in zip(env.action_space.sample(), env_info.action_space): if isinstance(action_space, gym.spaces.Discrete): - a = np.asarray(a).reshape((1, )) + a = np.asarray(a).reshape((1,)) action.append(a) action = tuple(action) else: if env_info.is_discrete_action_env(): action = env.action_space.sample() - action = np.asarray(action).reshape((1, )) + action = np.asarray(action).reshape((1,)) else: action = env.action_space.sample() return action, action_info diff --git a/nnabla_rl/environment_explorers/__init__.py b/nnabla_rl/environment_explorers/__init__.py index a22005c0..374f9447 100644 --- a/nnabla_rl/environment_explorers/__init__.py +++ b/nnabla_rl/environment_explorers/__init__.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,10 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nnabla_rl.environment_explorers.epsilon_greedy_explorer import (NoDecayEpsilonGreedyExplorer, # noqa - NoDecayEpsilonGreedyExplorerConfig, - LinearDecayEpsilonGreedyExplorer, - LinearDecayEpsilonGreedyExplorerConfig) +from nnabla_rl.environment_explorers.epsilon_greedy_explorer import ( # noqa + NoDecayEpsilonGreedyExplorer, + NoDecayEpsilonGreedyExplorerConfig, + LinearDecayEpsilonGreedyExplorer, + LinearDecayEpsilonGreedyExplorerConfig, +) from nnabla_rl.environment_explorers.gaussian_explorer import GaussianExplorer, GaussianExplorerConfig # noqa from nnabla_rl.environment_explorers.raw_policy_explorer import RawPolicyExplorer, RawPolicyExplorerConfig # noqa diff --git a/nnabla_rl/environment_explorers/epsilon_greedy_explorer.py b/nnabla_rl/environment_explorers/epsilon_greedy_explorer.py index 43b16a29..7a0b4a99 100644 --- a/nnabla_rl/environment_explorers/epsilon_greedy_explorer.py +++ b/nnabla_rl/environment_explorers/epsilon_greedy_explorer.py @@ -23,12 +23,14 @@ from nnabla_rl.typing import ActionSelector -def epsilon_greedy_action_selection(state: np.ndarray, - greedy_action_selector: ActionSelector, - random_action_selector: ActionSelector, - epsilon: float, - *, - begin_of_episode: bool = False): +def epsilon_greedy_action_selection( + state: np.ndarray, + greedy_action_selector: ActionSelector, + random_action_selector: ActionSelector, + epsilon: float, + *, + begin_of_episode: bool = False, +): if np.random.rand() > epsilon: # optimal action return greedy_action_selector(state, begin_of_episode=begin_of_episode), True @@ -42,7 +44,7 @@ class NoDecayEpsilonGreedyExplorerConfig(EnvironmentExplorerConfig): epsilon: float = 1.0 def __post_init__(self): - self._assert_between(self.epsilon, 0.0, 1.0, 'epsilon') + self._assert_between(self.epsilon, 0.0, 1.0, "epsilon") class NoDecayEpsilonGreedyExplorer(EnvironmentExplorer): @@ -51,22 +53,26 @@ class NoDecayEpsilonGreedyExplorer(EnvironmentExplorer): # See https://mypy.readthedocs.io/en/stable/class_basics.html for details _config: NoDecayEpsilonGreedyExplorerConfig - def __init__(self, - greedy_action_selector: ActionSelector, - random_action_selector: ActionSelector, - env_info: EnvironmentInfo, - config: NoDecayEpsilonGreedyExplorerConfig = NoDecayEpsilonGreedyExplorerConfig()): + def __init__( + self, + greedy_action_selector: ActionSelector, + random_action_selector: ActionSelector, + env_info: EnvironmentInfo, + config: NoDecayEpsilonGreedyExplorerConfig = NoDecayEpsilonGreedyExplorerConfig(), + ): super().__init__(env_info, config) self._greedy_action_selector = greedy_action_selector self._random_action_selector = random_action_selector def action(self, step: int, state: np.ndarray, *, begin_of_episode: bool = False) -> Tuple[np.ndarray, Dict]: epsilon = self._config.epsilon - (action, info), _ = epsilon_greedy_action_selection(state, - self._greedy_action_selector, - self._random_action_selector, - epsilon, - begin_of_episode=begin_of_episode) + (action, info), _ = epsilon_greedy_action_selection( + state, + self._greedy_action_selector, + self._random_action_selector, + epsilon, + begin_of_episode=begin_of_episode, + ) return action, info @@ -90,10 +96,10 @@ class LinearDecayEpsilonGreedyExplorerConfig(EnvironmentExplorerConfig): append_explorer_info: bool = False def __post_init__(self): - self._assert_between(self.initial_epsilon, 0.0, 1.0, 'initial_epsilon') - self._assert_between(self.final_epsilon, 0.0, 1.0, 'final_epsilon') - self._assert_descending_order([self.initial_epsilon, self.final_epsilon], 'initial/final epsilon') - self._assert_positive(self.max_explore_steps, 'max_explore_steps') + self._assert_between(self.initial_epsilon, 0.0, 1.0, "initial_epsilon") + self._assert_between(self.final_epsilon, 0.0, 1.0, "final_epsilon") + self._assert_descending_order([self.initial_epsilon, self.final_epsilon], "initial/final epsilon") + self._assert_positive(self.max_explore_steps, "max_explore_steps") class LinearDecayEpsilonGreedyExplorer(EnvironmentExplorer): @@ -117,29 +123,34 @@ class LinearDecayEpsilonGreedyExplorer(EnvironmentExplorer): # See https://mypy.readthedocs.io/en/stable/class_basics.html for details _config: LinearDecayEpsilonGreedyExplorerConfig - def __init__(self, - greedy_action_selector: ActionSelector, - random_action_selector: ActionSelector, - env_info: EnvironmentInfo, - config: LinearDecayEpsilonGreedyExplorerConfig = LinearDecayEpsilonGreedyExplorerConfig()): + def __init__( + self, + greedy_action_selector: ActionSelector, + random_action_selector: ActionSelector, + env_info: EnvironmentInfo, + config: LinearDecayEpsilonGreedyExplorerConfig = LinearDecayEpsilonGreedyExplorerConfig(), + ): super().__init__(env_info, config) self._greedy_action_selector = greedy_action_selector self._random_action_selector = random_action_selector def action(self, step: int, state: np.ndarray, *, begin_of_episode: bool = False) -> Tuple[np.ndarray, Dict]: epsilon = self._compute_epsilon(step) - (action, info), is_greedy_action = epsilon_greedy_action_selection(state, - self._greedy_action_selector, - self._random_action_selector, - epsilon, - begin_of_episode=begin_of_episode) + (action, info), is_greedy_action = epsilon_greedy_action_selection( + state, + self._greedy_action_selector, + self._random_action_selector, + epsilon, + begin_of_episode=begin_of_episode, + ) if self._config.append_explorer_info: info.update({"greedy_action": is_greedy_action, "explore_rate": epsilon}) return action, info def _compute_epsilon(self, step): assert 0 <= step - delta_epsilon = step / self._config.max_explore_steps \ - * (self._config.initial_epsilon - self._config.final_epsilon) + delta_epsilon = ( + step / self._config.max_explore_steps * (self._config.initial_epsilon - self._config.final_epsilon) + ) epsilon = self._config.initial_epsilon - delta_epsilon return max(epsilon, self._config.final_epsilon) diff --git a/nnabla_rl/environment_explorers/gaussian_explorer.py b/nnabla_rl/environment_explorers/gaussian_explorer.py index 341d991a..79ac95c1 100644 --- a/nnabla_rl/environment_explorers/gaussian_explorer.py +++ b/nnabla_rl/environment_explorers/gaussian_explorer.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -40,7 +40,7 @@ class GaussianExplorerConfig(EnvironmentExplorerConfig): sigma: float = 1.0 def __post_init__(self): - self._assert_positive(self.sigma, 'sigma') + self._assert_positive(self.sigma, "sigma") class GaussianExplorer(EnvironmentExplorer): @@ -62,10 +62,12 @@ class GaussianExplorer(EnvironmentExplorer): # See https://mypy.readthedocs.io/en/stable/class_basics.html for details _config: GaussianExplorerConfig - def __init__(self, - policy_action_selector: ActionSelector, - env_info: EnvironmentInfo, - config: GaussianExplorerConfig = GaussianExplorerConfig()): + def __init__( + self, + policy_action_selector: ActionSelector, + env_info: EnvironmentInfo, + config: GaussianExplorerConfig = GaussianExplorerConfig(), + ): super().__init__(env_info, config) self._policy_action_selector = policy_action_selector diff --git a/nnabla_rl/environment_explorers/raw_policy_explorer.py b/nnabla_rl/environment_explorers/raw_policy_explorer.py index 92030189..3b05eac2 100644 --- a/nnabla_rl/environment_explorers/raw_policy_explorer.py +++ b/nnabla_rl/environment_explorers/raw_policy_explorer.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -42,10 +42,12 @@ class RawPolicyExplorer(EnvironmentExplorer): `): the config of this class. """ - def __init__(self, - policy_action_selector: ActionSelector, - env_info: EnvironmentInfo, - config: RawPolicyExplorerConfig = RawPolicyExplorerConfig()): + def __init__( + self, + policy_action_selector: ActionSelector, + env_info: EnvironmentInfo, + config: RawPolicyExplorerConfig = RawPolicyExplorerConfig(), + ): super().__init__(env_info, config) self._policy_action_selector = policy_action_selector diff --git a/nnabla_rl/environments/__init__.py b/nnabla_rl/environments/__init__.py index 3ac5092f..6a20ad95 100644 --- a/nnabla_rl/environments/__init__.py +++ b/nnabla_rl/environments/__init__.py @@ -16,106 +16,105 @@ from gym.envs.registration import register from gymnasium.envs.registration import register as gymnasium_register -from nnabla_rl.environments.dummy import (DummyAtariEnv, DummyContinuous, DummyContinuousActionGoalEnv, DummyDiscrete, # noqa - DummyDiscreteActionGoalEnv, DummyDiscreteImg, DummyContinuousImg, - DummyFactoredContinuous, DummyMujocoEnv, - DummyTupleContinuous, DummyTupleDiscrete, DummyTupleMixed, - DummyTupleStateContinuous, DummyTupleStateDiscrete, - DummyTupleActionContinuous, DummyTupleActionDiscrete, - DummyHybridEnv, DummyAMPEnv, DummyAMPGoalEnv, - DummyGymnasiumAtariEnv, DummyGymnasiumMujocoEnv) - -register( - id='FakeMujocoNNablaRL-v1', - entry_point='nnabla_rl.environments.dummy:DummyMujocoEnv', - max_episode_steps=10 +from nnabla_rl.environments.dummy import ( # noqa + DummyAtariEnv, + DummyContinuous, + DummyContinuousActionGoalEnv, + DummyDiscrete, + DummyDiscreteActionGoalEnv, + DummyDiscreteImg, + DummyContinuousImg, + DummyFactoredContinuous, + DummyMujocoEnv, + DummyTupleContinuous, + DummyTupleDiscrete, + DummyTupleMixed, + DummyTupleStateContinuous, + DummyTupleStateDiscrete, + DummyTupleActionContinuous, + DummyTupleActionDiscrete, + DummyHybridEnv, + DummyAMPEnv, + DummyAMPGoalEnv, + DummyGymnasiumAtariEnv, + DummyGymnasiumMujocoEnv, ) +register(id="FakeMujocoNNablaRL-v1", entry_point="nnabla_rl.environments.dummy:DummyMujocoEnv", max_episode_steps=10) + register( - id='FakeDMControlNNablaRL-v1', - entry_point='nnabla_rl.environments.dummy:DummyDMControlEnv', - max_episode_steps=10 + id="FakeDMControlNNablaRL-v1", entry_point="nnabla_rl.environments.dummy:DummyDMControlEnv", max_episode_steps=10 ) register( - id='FakeAtariNNablaRLNoFrameskip-v1', - entry_point='nnabla_rl.environments.dummy:DummyAtariEnv', - max_episode_steps=10 + id="FakeAtariNNablaRLNoFrameskip-v1", entry_point="nnabla_rl.environments.dummy:DummyAtariEnv", max_episode_steps=10 ) register( - id='FakeGoalConditionedNNablaRL-v1', - entry_point='nnabla_rl.environments.dummy:DummyContinuousActionGoalEnv', - max_episode_steps=10 + id="FakeGoalConditionedNNablaRL-v1", + entry_point="nnabla_rl.environments.dummy:DummyContinuousActionGoalEnv", + max_episode_steps=10, ) register( - id='FactoredLunarLanderContinuousV2NNablaRL-v1', - entry_point='nnabla_rl.environments.factored_envs:FactoredLunarLanderV2', + id="FactoredLunarLanderContinuousV2NNablaRL-v1", + entry_point="nnabla_rl.environments.factored_envs:FactoredLunarLanderV2", kwargs={"continuous": True}, max_episode_steps=1000, reward_threshold=200.0, ) register( - id='FactoredAntV4NNablaRL-v1', - entry_point='nnabla_rl.environments.factored_envs:FactoredAntV4', + id="FactoredAntV4NNablaRL-v1", + entry_point="nnabla_rl.environments.factored_envs:FactoredAntV4", max_episode_steps=1000, reward_threshold=6000.0, ) register( - id='FactoredHopperV4NNablaRL-v1', - entry_point='nnabla_rl.environments.factored_envs:FactoredHopperV4', + id="FactoredHopperV4NNablaRL-v1", + entry_point="nnabla_rl.environments.factored_envs:FactoredHopperV4", max_episode_steps=1000, reward_threshold=3800.0, ) register( - id='FactoredHalfCheetahV4NNablaRL-v1', - entry_point='nnabla_rl.environments.factored_envs:FactoredHalfCheetahV4', + id="FactoredHalfCheetahV4NNablaRL-v1", + entry_point="nnabla_rl.environments.factored_envs:FactoredHalfCheetahV4", max_episode_steps=1000, reward_threshold=4800.0, ) register( - id='FactoredWalker2dV4NNablaRL-v1', - entry_point='nnabla_rl.environments.factored_envs:FactoredWalker2dV4', + id="FactoredWalker2dV4NNablaRL-v1", + entry_point="nnabla_rl.environments.factored_envs:FactoredWalker2dV4", max_episode_steps=1000, ) register( - id='FactoredHumanoidV4NNablaRL-v1', - entry_point='nnabla_rl.environments.factored_envs:FactoredHumanoidV4', + id="FactoredHumanoidV4NNablaRL-v1", + entry_point="nnabla_rl.environments.factored_envs:FactoredHumanoidV4", max_episode_steps=1000, ) -register( - id='FakeHybridNNablaRL-v1', - entry_point='nnabla_rl.environments.dummy:DummyHybridEnv', - max_episode_steps=10 -) +register(id="FakeHybridNNablaRL-v1", entry_point="nnabla_rl.environments.dummy:DummyHybridEnv", max_episode_steps=10) -register( - id='FakeAMPNNablaRL-v1', - entry_point='nnabla_rl.environments.dummy:DummyAMPEnv', - max_episode_steps=10 -) +register(id="FakeAMPNNablaRL-v1", entry_point="nnabla_rl.environments.dummy:DummyAMPEnv", max_episode_steps=10) register( - id='FakeAMPGoalConditionedNNablaRL-v1', - entry_point='nnabla_rl.environments.dummy:DummyAMPGoalEnv', - max_episode_steps=10 + id="FakeAMPGoalConditionedNNablaRL-v1", + entry_point="nnabla_rl.environments.dummy:DummyAMPGoalEnv", + max_episode_steps=10, ) gymnasium_register( - id='FakeGymnasiumMujocoNNablaRL-v1', - entry_point='nnabla_rl.environments.dummy:DummyGymnasiumMujocoEnv', - max_episode_steps=10 + id="FakeGymnasiumMujocoNNablaRL-v1", + entry_point="nnabla_rl.environments.dummy:DummyGymnasiumMujocoEnv", + max_episode_steps=10, ) gymnasium_register( - id='FakeGymnasiumAtariNNablaRLNoFrameskip-v1', - entry_point='nnabla_rl.environments.dummy:DummyGymnasiumAtariEnv', - max_episode_steps=10 + id="FakeGymnasiumAtariNNablaRLNoFrameskip-v1", + entry_point="nnabla_rl.environments.dummy:DummyGymnasiumAtariEnv", + max_episode_steps=10, ) diff --git a/nnabla_rl/environments/dummy.py b/nnabla_rl/environments/dummy.py index bde7c0fb..f6247b80 100644 --- a/nnabla_rl/environments/dummy.py +++ b/nnabla_rl/environments/dummy.py @@ -32,7 +32,7 @@ class AbstractDummyEnv(gym.Env): def __init__(self, max_episode_steps): - self.spec = EnvSpec('dummy-v0', max_episode_steps=max_episode_steps) + self.spec = EnvSpec("dummy-v0", max_episode_steps=max_episode_steps) self._episode_steps = 0 def reset(self): @@ -43,25 +43,23 @@ def step(self, a): next_state = self.observation_space.sample() reward = np.random.randn() done = False if self.spec.max_episode_steps is None else bool(self._episode_steps < self.spec.max_episode_steps) - info = {'rnn_states': {'dummy_scope': {'dummy_state1': 1, 'dummy_state2': 2}}} + info = {"rnn_states": {"dummy_scope": {"dummy_state1": 1, "dummy_state2": 2}}} self._episode_steps += 1 return next_state, reward, done, info class DummyContinuous(AbstractDummyEnv): - def __init__(self, max_episode_steps=None, observation_shape=(5, ), action_shape=(5, )): - super(DummyContinuous, self).__init__( - max_episode_steps=max_episode_steps) + def __init__(self, max_episode_steps=None, observation_shape=(5,), action_shape=(5,)): + super(DummyContinuous, self).__init__(max_episode_steps=max_episode_steps) self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=observation_shape) self.action_space = gym.spaces.Box(low=-1.0, high=5.0, shape=action_shape) class DummyFactoredContinuous(DummyContinuous): - def __init__(self, max_episode_steps=None, observation_shape=(5, ), action_shape=(5, ), reward_dimension=1): + def __init__(self, max_episode_steps=None, observation_shape=(5,), action_shape=(5,), reward_dimension=1): super(DummyFactoredContinuous, self).__init__( - max_episode_steps=max_episode_steps, - observation_shape=observation_shape, - action_shape=action_shape) + max_episode_steps=max_episode_steps, observation_shape=observation_shape, action_shape=action_shape + ) self.reward_dimension = reward_dimension def step(self, a): @@ -71,20 +69,20 @@ def step(self, a): class DummyDiscrete(AbstractDummyEnv): def __init__(self, max_episode_steps=None): - super(DummyDiscrete, self).__init__( - max_episode_steps=max_episode_steps) + super(DummyDiscrete, self).__init__(max_episode_steps=max_episode_steps) self.action_space = gym.spaces.Discrete(4) self.observation_space = gym.spaces.Discrete(5) class DummyTupleContinuous(AbstractDummyEnv): def __init__(self, max_episode_steps=None): - super(DummyTupleContinuous, self).__init__( - max_episode_steps=max_episode_steps) - self.action_space = gym.spaces.Tuple((gym.spaces.Box(low=0.0, high=1.0, shape=(2, )), - gym.spaces.Box(low=0.0, high=1.0, shape=(3, )))) - self.observation_space = gym.spaces.Tuple((gym.spaces.Box(low=0.0, high=1.0, shape=(4, )), - gym.spaces.Box(low=0.0, high=1.0, shape=(5, )))) + super(DummyTupleContinuous, self).__init__(max_episode_steps=max_episode_steps) + self.action_space = gym.spaces.Tuple( + (gym.spaces.Box(low=0.0, high=1.0, shape=(2,)), gym.spaces.Box(low=0.0, high=1.0, shape=(3,))) + ) + self.observation_space = gym.spaces.Tuple( + (gym.spaces.Box(low=0.0, high=1.0, shape=(4,)), gym.spaces.Box(low=0.0, high=1.0, shape=(5,))) + ) def step(self, a): for a, action_space in zip(a, self.action_space): @@ -94,8 +92,7 @@ def step(self, a): class DummyTupleDiscrete(AbstractDummyEnv): def __init__(self, max_episode_steps=None): - super(DummyTupleDiscrete, self).__init__( - max_episode_steps=max_episode_steps) + super(DummyTupleDiscrete, self).__init__(max_episode_steps=max_episode_steps) self.action_space = gym.spaces.Tuple((gym.spaces.Discrete(2), gym.spaces.Discrete(3))) self.observation_space = gym.spaces.Tuple((gym.spaces.Discrete(4), gym.spaces.Discrete(5))) @@ -107,12 +104,11 @@ def step(self, a): class DummyTupleMixed(AbstractDummyEnv): def __init__(self, max_episode_steps=None): - super(DummyTupleMixed, self).__init__( - max_episode_steps=max_episode_steps) - self.action_space = gym.spaces.Tuple((gym.spaces.Discrete(2), - gym.spaces.Box(low=0.0, high=1.0, shape=(3, )))) - self.observation_space = gym.spaces.Tuple((gym.spaces.Discrete(4), - gym.spaces.Box(low=0.0, high=1.0, shape=(5, )))) + super(DummyTupleMixed, self).__init__(max_episode_steps=max_episode_steps) + self.action_space = gym.spaces.Tuple((gym.spaces.Discrete(2), gym.spaces.Box(low=0.0, high=1.0, shape=(3,)))) + self.observation_space = gym.spaces.Tuple( + (gym.spaces.Discrete(4), gym.spaces.Box(low=0.0, high=1.0, shape=(5,))) + ) def step(self, a): for a, action_space in zip(a, self.action_space): @@ -122,54 +118,48 @@ def step(self, a): class DummyTupleStateContinuous(AbstractDummyEnv): def __init__(self, max_episode_steps=None): - super(DummyTupleStateContinuous, self).__init__( - max_episode_steps=max_episode_steps) - self.action_space = gym.spaces.Box(low=0.0, high=1.0, shape=(2, )) - self.observation_space = gym.spaces.Tuple((gym.spaces.Box(low=0.0, high=1.0, shape=(4, )), - gym.spaces.Box(low=0.0, high=1.0, shape=(5, )))) + super(DummyTupleStateContinuous, self).__init__(max_episode_steps=max_episode_steps) + self.action_space = gym.spaces.Box(low=0.0, high=1.0, shape=(2,)) + self.observation_space = gym.spaces.Tuple( + (gym.spaces.Box(low=0.0, high=1.0, shape=(4,)), gym.spaces.Box(low=0.0, high=1.0, shape=(5,))) + ) class DummyTupleStateDiscrete(AbstractDummyEnv): def __init__(self, max_episode_steps=None): - super(DummyTupleStateDiscrete, self).__init__( - max_episode_steps=max_episode_steps) + super(DummyTupleStateDiscrete, self).__init__(max_episode_steps=max_episode_steps) self.action_space = gym.spaces.Discrete(2) self.observation_space = gym.spaces.Tuple((gym.spaces.Discrete(4), gym.spaces.Discrete(5))) class DummyTupleActionContinuous(AbstractDummyEnv): def __init__(self, max_episode_steps=None): - super(DummyTupleActionContinuous, self).__init__( - max_episode_steps=max_episode_steps) - self.action_space = gym.spaces.Tuple((gym.spaces.Box(low=0.0, high=1.0, shape=(2, )), - gym.spaces.Box(low=0.0, high=1.0, shape=(3, )))) - self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(4, )) + super(DummyTupleActionContinuous, self).__init__(max_episode_steps=max_episode_steps) + self.action_space = gym.spaces.Tuple( + (gym.spaces.Box(low=0.0, high=1.0, shape=(2,)), gym.spaces.Box(low=0.0, high=1.0, shape=(3,))) + ) + self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(4,)) class DummyTupleActionDiscrete(AbstractDummyEnv): def __init__(self, max_episode_steps=None): - super(DummyTupleActionDiscrete, self).__init__( - max_episode_steps=max_episode_steps) + super(DummyTupleActionDiscrete, self).__init__(max_episode_steps=max_episode_steps) self.action_space = gym.spaces.Tuple((gym.spaces.Discrete(2), gym.spaces.Discrete(3))) self.observation_space = gym.spaces.Discrete(4) class DummyDiscreteImg(AbstractDummyEnv): def __init__(self, max_episode_steps=None): - super(DummyDiscreteImg, self).__init__( - max_episode_steps=max_episode_steps) - self.observation_space = gym.spaces.Box( - low=0.0, high=1.0, shape=(4, 84, 84)) + super(DummyDiscreteImg, self).__init__(max_episode_steps=max_episode_steps) + self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(4, 84, 84)) self.action_space = gym.spaces.Discrete(4) class DummyContinuousImg(AbstractDummyEnv): def __init__(self, image_shape=(3, 64, 64), max_episode_steps=None): - super(DummyContinuousImg, self).__init__( - max_episode_steps=max_episode_steps) - self.observation_space = gym.spaces.Box( - low=0.0, high=1.0, shape=image_shape) - self.action_space = gym.spaces.Box(low=0.0, high=1.0, shape=(2, )) + super(DummyContinuousImg, self).__init__(max_episode_steps=max_episode_steps) + self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=image_shape) + self.action_space = gym.spaces.Box(low=0.0, high=1.0, shape=(2,)) class DummyAtariEnv(AbstractDummyEnv): @@ -187,11 +177,9 @@ def lives(self): np_random = cast("RandomNumberGenerator", nnabla_rl.random.drng) def __init__(self, done_at_random=True, max_episode_length=None): - super(DummyAtariEnv, self).__init__( - max_episode_steps=max_episode_length) + super(DummyAtariEnv, self).__init__(max_episode_steps=max_episode_length) self.action_space = gym.spaces.Discrete(4) - self.observation_space = gym.spaces.Box( - low=0, high=255, shape=(84, 84, 3), dtype=np.uint8) + self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 3), dtype=np.uint8) self.ale = DummyAtariEnv.DummyALE() self._done_at_random = done_at_random self._max_episode_length = max_episode_length @@ -207,30 +195,30 @@ def step(self, action): done = False if self._max_episode_length is not None: done = (self._max_episode_length <= self._episode_length) or done - return observation, 1.0, done, {'needs_reset': False} + return observation, 1.0, done, {"needs_reset": False} def reset(self): self._episode_length = 0 return self.observation_space.sample() def get_action_meanings(self): - return ['NOOP', 'FIRE', 'LEFT', 'RIGHT'] + return ["NOOP", "FIRE", "LEFT", "RIGHT"] class DummyMujocoEnv(AbstractDummyEnv): def __init__(self, max_episode_steps=None): super(DummyMujocoEnv, self).__init__(max_episode_steps=max_episode_steps) - self.action_space = gym.spaces.Box(low=0.0, high=1.0, shape=(5, )) - self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(5, )) + self.action_space = gym.spaces.Box(low=0.0, high=1.0, shape=(5,)) + self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(5,)) def get_dataset(self): dataset = {} datasize = 2000 - dataset['observations'] = np.stack([self.observation_space.sample() for _ in range(datasize)], axis=0) - dataset['actions'] = np.stack([self.action_space.sample() for _ in range(datasize)], axis=0) - dataset['rewards'] = np.random.randn(datasize, 1) - dataset['terminals'] = np.random.randint(2, size=(datasize, 1)) - dataset['timeouts'] = np.zeros((datasize, 1)) + dataset["observations"] = np.stack([self.observation_space.sample() for _ in range(datasize)], axis=0) + dataset["actions"] = np.stack([self.action_space.sample() for _ in range(datasize)], axis=0) + dataset["rewards"] = np.random.randn(datasize, 1) + dataset["terminals"] = np.random.randint(2, size=(datasize, 1)) + dataset["timeouts"] = np.zeros((datasize, 1)) return dataset @@ -240,11 +228,15 @@ class DummyDMControlEnv(DummyMujocoEnv): class DummyContinuousActionGoalEnv(GoalEnv): def __init__(self, max_episode_steps=10): - self.spec = EnvSpec('dummy-continuou-action-goal-v0', max_episode_steps=max_episode_steps) - self.observation_space = gym.spaces.Dict({'observation': gym.spaces.Box(low=0.0, high=1.0, shape=(5, )), - 'achieved_goal': gym.spaces.Box(low=0.0, high=1.0, shape=(2, )), - 'desired_goal': gym.spaces.Box(low=0.0, high=1.0, shape=(2, ))}) - self.action_space = gym.spaces.Box(low=0.0, high=1.0, shape=(2, )) + self.spec = EnvSpec("dummy-continuou-action-goal-v0", max_episode_steps=max_episode_steps) + self.observation_space = gym.spaces.Dict( + { + "observation": gym.spaces.Box(low=0.0, high=1.0, shape=(5,)), + "achieved_goal": gym.spaces.Box(low=0.0, high=1.0, shape=(2,)), + "desired_goal": gym.spaces.Box(low=0.0, high=1.0, shape=(2,)), + } + ) + self.action_space = gym.spaces.Box(low=0.0, high=1.0, shape=(2,)) self._max_episode_length = max_episode_steps self._episode_length = 0 self._desired_goal = None @@ -253,15 +245,15 @@ def reset(self): super(DummyContinuousActionGoalEnv, self).reset() self._episode_length = 0 state = self.observation_space.sample() - self._desired_goal = state['desired_goal'] + self._desired_goal = state["desired_goal"] return state def step(self, a): next_state = self.observation_space.sample() - next_state['desired_goal'] = self._desired_goal - reward = self.compute_reward(next_state['achieved_goal'], next_state['desired_goal'], {}) + next_state["desired_goal"] = self._desired_goal + reward = self.compute_reward(next_state["achieved_goal"], next_state["desired_goal"], {}) self._episode_length += 1 - info = {'is_success': reward} + info = {"is_success": reward} if self._episode_length >= self._max_episode_length: done = True else: @@ -277,10 +269,14 @@ def compute_reward(self, achieved_goal, desired_goal, info): class DummyDiscreteActionGoalEnv(GoalEnv): def __init__(self, max_episode_steps=10): - self.spec = EnvSpec('dummy-discrete-action-goal-v0', max_episode_steps=max_episode_steps) - self.observation_space = gym.spaces.Dict({'observation': gym.spaces.Box(low=0.0, high=1.0, shape=(5, )), - 'achieved_goal': gym.spaces.Box(low=0.0, high=1.0, shape=(2, )), - 'desired_goal': gym.spaces.Box(low=0.0, high=1.0, shape=(2, ))}) + self.spec = EnvSpec("dummy-discrete-action-goal-v0", max_episode_steps=max_episode_steps) + self.observation_space = gym.spaces.Dict( + { + "observation": gym.spaces.Box(low=0.0, high=1.0, shape=(5,)), + "achieved_goal": gym.spaces.Box(low=0.0, high=1.0, shape=(2,)), + "desired_goal": gym.spaces.Box(low=0.0, high=1.0, shape=(2,)), + } + ) self.action_space = gym.spaces.Discrete(n=3) self._max_episode_length = max_episode_steps self._episode_length = 0 @@ -290,15 +286,15 @@ def reset(self): super(DummyDiscreteActionGoalEnv, self).reset() self._episode_length = 0 state = self.observation_space.sample() - self._desired_goal = state['desired_goal'] + self._desired_goal = state["desired_goal"] return state def step(self, a): next_state = self.observation_space.sample() - next_state['desired_goal'] = self._desired_goal - reward = self.compute_reward(next_state['achieved_goal'], next_state['desired_goal'], {}) + next_state["desired_goal"] = self._desired_goal + reward = self.compute_reward(next_state["achieved_goal"], next_state["desired_goal"], {}) self._episode_length += 1 - info = {'is_success': reward} + info = {"is_success": reward} if self._episode_length >= self._max_episode_length: done = True else: @@ -315,23 +311,28 @@ def compute_reward(self, achieved_goal, desired_goal, info): class DummyHybridEnv(AbstractDummyEnv): def __init__(self, max_episode_steps=None): super(DummyHybridEnv, self).__init__(max_episode_steps=max_episode_steps) - self.action_space = gym.spaces.Tuple((gym.spaces.Discrete(5), gym.spaces.Box(low=0.0, high=1.0, shape=(5, )))) - self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(5, )) + self.action_space = gym.spaces.Tuple((gym.spaces.Discrete(5), gym.spaces.Box(low=0.0, high=1.0, shape=(5,)))) + self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(5,)) class DummyAMPEnv(AMPEnv): def __init__(self, max_episode_steps=10): - self.spec = EnvSpec('dummy-amp-v0', max_episode_steps=max_episode_steps) - self.action_space = gym.spaces.Box(low=0.0, high=1.0, shape=(4, )) + self.spec = EnvSpec("dummy-amp-v0", max_episode_steps=max_episode_steps) + self.action_space = gym.spaces.Box(low=0.0, high=1.0, shape=(4,)) self.observation_space = gym.spaces.Tuple( - [gym.spaces.Box(low=0.0, high=1.0, shape=(2, )), - gym.spaces.Box(low=0.0, high=1.0, shape=(5, )), - gym.spaces.Box(low=0.0, high=1.0, shape=(1, ))]) + [ + gym.spaces.Box(low=0.0, high=1.0, shape=(2,)), + gym.spaces.Box(low=0.0, high=1.0, shape=(5,)), + gym.spaces.Box(low=0.0, high=1.0, shape=(1,)), + ] + ) self.reward_range = (0.0, 1.0) - self.observation_mean = tuple([np.zeros(2, dtype=np.float32), np.zeros( - 5, dtype=np.float32), np.zeros(1, dtype=np.float32)]) - self.observation_var = tuple([np.ones(2, dtype=np.float32), np.ones( - 5, dtype=np.float32), np.ones(1, dtype=np.float32)]) + self.observation_mean = tuple( + [np.zeros(2, dtype=np.float32), np.zeros(5, dtype=np.float32), np.zeros(1, dtype=np.float32)] + ) + self.observation_var = tuple( + [np.ones(2, dtype=np.float32), np.ones(5, dtype=np.float32), np.ones(1, dtype=np.float32)] + ) self.action_mean = np.zeros((4,), dtype=np.float32) self.action_var = np.ones((4,), dtype=np.float32) self.reward_at_task_fail = 0.0 @@ -360,35 +361,38 @@ def _step(self, a): next_state = list(self.observation_space.sample()) reward = np.random.randn() done = self._episode_steps >= self.spec.max_episode_steps - info = {'rnn_states': {'dummy_scope': {'dummy_state1': 1, 'dummy_state2': 2}}} + info = {"rnn_states": {"dummy_scope": {"dummy_state1": 1, "dummy_state2": 2}}} return tuple(next_state), reward, done, info class DummyAMPGoalEnv(AMPGoalEnv): def __init__(self, max_episode_steps=10): - self.spec = EnvSpec('dummy-amp-goal-v0', max_episode_steps=max_episode_steps) - self.action_space = gym.spaces.Box(low=0.0, high=1.0, shape=(4, )) + self.spec = EnvSpec("dummy-amp-goal-v0", max_episode_steps=max_episode_steps) + self.action_space = gym.spaces.Box(low=0.0, high=1.0, shape=(4,)) observation_space = gym.spaces.Tuple( - [gym.spaces.Box(low=0.0, high=1.0, shape=(2, )), - gym.spaces.Box(low=0.0, high=1.0, shape=(5, )), - gym.spaces.Box(low=0.0, high=1.0, shape=(1, ))]) - goal_state_space = gym.spaces.Tuple([gym.spaces.Box(low=-np.inf, - high=np.inf, - shape=(3,), - dtype=np.float32), - gym.spaces.Box(low=0.0, - high=1.0, - shape=(1,), - dtype=np.float32)]) - self.observation_space = gym.spaces.Dict({"observation": observation_space, - "desired_goal": goal_state_space, - "achieved_goal": goal_state_space}) + [ + gym.spaces.Box(low=0.0, high=1.0, shape=(2,)), + gym.spaces.Box(low=0.0, high=1.0, shape=(5,)), + gym.spaces.Box(low=0.0, high=1.0, shape=(1,)), + ] + ) + goal_state_space = gym.spaces.Tuple( + [ + gym.spaces.Box(low=-np.inf, high=np.inf, shape=(3,), dtype=np.float32), + gym.spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32), + ] + ) + self.observation_space = gym.spaces.Dict( + {"observation": observation_space, "desired_goal": goal_state_space, "achieved_goal": goal_state_space} + ) self.reward_range = (0.0, 1.0) - self.observation_mean = tuple([np.zeros(2, dtype=np.float32), np.zeros( - 5, dtype=np.float32), np.zeros(1, dtype=np.float32)]) - self.observation_var = tuple([np.ones(2, dtype=np.float32), np.ones( - 5, dtype=np.float32), np.ones(1, dtype=np.float32)]) + self.observation_mean = tuple( + [np.zeros(2, dtype=np.float32), np.zeros(5, dtype=np.float32), np.zeros(1, dtype=np.float32)] + ) + self.observation_var = tuple( + [np.ones(2, dtype=np.float32), np.ones(5, dtype=np.float32), np.ones(1, dtype=np.float32)] + ) self.action_mean = np.zeros((4,), dtype=np.float32) self.action_var = np.ones((4,), dtype=np.float32) self.reward_at_task_fail = 0.0 @@ -408,8 +412,14 @@ def is_valid_episode(self, state, reward, done, info) -> bool: def expert_experience(self, state, reward, done, info): action = self.action_space.sample() - return (self._generate_dummy_goal_env_flatten_state(), action, 0.0, - False, self._generate_dummy_goal_env_flatten_state(), {}) + return ( + self._generate_dummy_goal_env_flatten_state(), + action, + 0.0, + False, + self._generate_dummy_goal_env_flatten_state(), + {}, + ) def _generate_dummy_goal_env_flatten_state(self): state: List[np.ndarray] = [] @@ -428,14 +438,14 @@ def _step(self, a): next_state = self.observation_space.sample() reward = np.random.randn() done = self._episode_steps >= self.spec.max_episode_steps - info = {'rnn_states': {'dummy_scope': {'dummy_state1': 1, 'dummy_state2': 2}}} + info = {"rnn_states": {"dummy_scope": {"dummy_state1": 1, "dummy_state2": 2}}} return next_state, reward, done, info # =========== gymnasium ========== class AbstractDummyGymnasiumEnv(gymnasium.Env): def __init__(self, max_episode_steps): - self.spec = GymnasiumEnvSpec('dummy-v0', max_episode_steps=max_episode_steps) + self.spec = GymnasiumEnvSpec("dummy-v0", max_episode_steps=max_episode_steps) self._episode_steps = 0 def reset(self): @@ -450,7 +460,7 @@ def step(self, a): truncated = False else: truncated = bool(self._episode_steps < self.spec.max_episode_steps) - info = {'rnn_states': {'dummy_scope': {'dummy_state1': 1, 'dummy_state2': 2}}} + info = {"rnn_states": {"dummy_scope": {"dummy_state1": 1, "dummy_state2": 2}}} self._episode_steps += 1 return next_state, reward, terminated, truncated, info @@ -470,11 +480,9 @@ def lives(self): np_random = cast("RandomNumberGenerator", nnabla_rl.random.drng) def __init__(self, done_at_random=True, max_episode_length=None): - super(DummyGymnasiumAtariEnv, self).__init__( - max_episode_steps=max_episode_length) + super(DummyGymnasiumAtariEnv, self).__init__(max_episode_steps=max_episode_length) self.action_space = gymnasium.spaces.Discrete(4) - self.observation_space = gymnasium.spaces.Box( - low=0, high=255, shape=(84, 84, 3), dtype=np.uint8) + self.observation_space = gymnasium.spaces.Box(low=0, high=255, shape=(84, 84, 3), dtype=np.uint8) self.ale = DummyGymnasiumAtariEnv.DummyALE() self._done_at_random = done_at_random self._max_episode_length = max_episode_length @@ -490,28 +498,28 @@ def step(self, action): done = False if self._max_episode_length is not None: done = (self._max_episode_length <= self._episode_length) or done - return observation, 1.0, done, {'needs_reset': False} + return observation, 1.0, done, {"needs_reset": False} def reset(self): self._episode_length = 0 return self.observation_space.sample() def get_action_meanings(self): - return ['NOOP', 'FIRE', 'LEFT', 'RIGHT'] + return ["NOOP", "FIRE", "LEFT", "RIGHT"] class DummyGymnasiumMujocoEnv(AbstractDummyGymnasiumEnv): def __init__(self, max_episode_steps=None): super(DummyGymnasiumMujocoEnv, self).__init__(max_episode_steps=max_episode_steps) - self.action_space = gymnasium.spaces.Box(low=0.0, high=1.0, shape=(5, )) - self.observation_space = gymnasium.spaces.Box(low=0.0, high=1.0, shape=(5, )) + self.action_space = gymnasium.spaces.Box(low=0.0, high=1.0, shape=(5,)) + self.observation_space = gymnasium.spaces.Box(low=0.0, high=1.0, shape=(5,)) def get_dataset(self): dataset = {} datasize = 2000 - dataset['observations'] = np.stack([self.observation_space.sample() for _ in range(datasize)], axis=0) - dataset['actions'] = np.stack([self.action_space.sample() for _ in range(datasize)], axis=0) - dataset['rewards'] = np.random.randn(datasize, 1) - dataset['terminals'] = np.random.randint(2, size=(datasize, 1)) - dataset['timeouts'] = np.zeros((datasize, 1)) + dataset["observations"] = np.stack([self.observation_space.sample() for _ in range(datasize)], axis=0) + dataset["actions"] = np.stack([self.action_space.sample() for _ in range(datasize)], axis=0) + dataset["rewards"] = np.random.randn(datasize, 1) + dataset["terminals"] = np.random.randint(2, size=(datasize, 1)) + dataset["timeouts"] = np.zeros((datasize, 1)) return dataset diff --git a/nnabla_rl/environments/environment_info.py b/nnabla_rl/environments/environment_info.py index 4ee40e8e..cbf1111c 100644 --- a/nnabla_rl/environments/environment_info.py +++ b/nnabla_rl/environments/environment_info.py @@ -18,8 +18,14 @@ import gym -from nnabla_rl.environments.gym_utils import (extract_max_episode_steps, get_space_dim, get_space_high, get_space_low, - get_space_shape, is_same_space_type) +from nnabla_rl.environments.gym_utils import ( + extract_max_episode_steps, + get_space_dim, + get_space_high, + get_space_low, + get_space_shape, + is_same_space_type, +) from nnabla_rl.external.goal_env import GoalEnv @@ -30,16 +36,19 @@ class EnvironmentInfo(object): This class contains the basic information of the target training environment. """ + observation_space: gym.spaces.Space action_space: gym.spaces.Space max_episode_steps: int - def __init__(self, - observation_space, - action_space, - max_episode_steps, - unwrapped_env, - reward_function: Optional[Callable[[Any, Any, Dict], int]] = None): + def __init__( + self, + observation_space, + action_space, + max_episode_steps, + unwrapped_env, + reward_function: Optional[Callable[[Any, Any, Dict], int]] = None, + ): self.observation_space = observation_space self.action_space = action_space self.max_episode_steps = max_episode_steps @@ -71,13 +80,15 @@ def from_env(env): >>> env_info.state_shape (4,) """ - reward_function = env.compute_reward if hasattr(env, 'compute_reward') else None + reward_function = env.compute_reward if hasattr(env, "compute_reward") else None unwrapped_env = env.unwrapped - return EnvironmentInfo(observation_space=env.observation_space, - action_space=env.action_space, - max_episode_steps=extract_max_episode_steps(env), - unwrapped_env=unwrapped_env, - reward_function=reward_function) + return EnvironmentInfo( + observation_space=env.observation_space, + action_space=env.action_space, + max_episode_steps=extract_max_episode_steps(env), + unwrapped_env=unwrapped_env, + reward_function=reward_function, + ) def is_discrete_action_env(self): """Check whether the action to execute in the environment is discrete diff --git a/nnabla_rl/environments/factored_envs.py b/nnabla_rl/environments/factored_envs.py index c897c458..affae2c4 100644 --- a/nnabla_rl/environments/factored_envs.py +++ b/nnabla_rl/environments/factored_envs.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,8 +16,18 @@ from typing import Optional, Tuple import numpy as np -from gym.envs.box2d.lunar_lander import (FPS, LEG_DOWN, MAIN_ENGINE_POWER, SCALE, SIDE_ENGINE_AWAY, SIDE_ENGINE_HEIGHT, - SIDE_ENGINE_POWER, VIEWPORT_H, VIEWPORT_W, LunarLander) +from gym.envs.box2d.lunar_lander import ( + FPS, + LEG_DOWN, + MAIN_ENGINE_POWER, + SCALE, + SIDE_ENGINE_AWAY, + SIDE_ENGINE_HEIGHT, + SIDE_ENGINE_POWER, + VIEWPORT_H, + VIEWPORT_W, + LunarLander, +) from gym.envs.mujoco.ant_v4 import AntEnv from gym.envs.mujoco.half_cheetah_v4 import HalfCheetahEnv from gym.envs.mujoco.hopper_v4 import HopperEnv @@ -66,17 +76,11 @@ def step(self, action: Action) -> Tuple[State, Reward, bool, Info]: # type: ign # Update wind assert self.lander is not None, "You forgot to call reset()" - if self.enable_wind and not ( - self.legs[0].ground_contact or self.legs[1].ground_contact - ): + if self.enable_wind and not (self.legs[0].ground_contact or self.legs[1].ground_contact): # the function used for wind is tanh(sin(2 k x) + sin(pi k x)), # which is proven to never be periodic, k = 0.01 wind_mag = ( - math.tanh( - math.sin(0.02 * self.wind_idx) - + (math.sin(math.pi * 0.01 * self.wind_idx)) - ) - * self.wind_power + math.tanh(math.sin(0.02 * self.wind_idx) + (math.sin(math.pi * 0.01 * self.wind_idx))) * self.wind_power ) self.wind_idx += 1 self.lander.ApplyForceToCenter( @@ -86,10 +90,9 @@ def step(self, action: Action) -> Tuple[State, Reward, bool, Info]: # type: ign # the function used for torque is tanh(sin(2 k x) + sin(pi k x)), # which is proven to never be periodic, k = 0.01 - torque_mag = math.tanh( - math.sin(0.02 * self.torque_idx) - + (math.sin(math.pi * 0.01 * self.torque_idx)) - ) * (self.turbulence_power) + torque_mag = math.tanh(math.sin(0.02 * self.torque_idx) + (math.sin(math.pi * 0.01 * self.torque_idx))) * ( + self.turbulence_power + ) self.torque_idx += 1 self.lander.ApplyTorque( (torque_mag), @@ -99,9 +102,7 @@ def step(self, action: Action) -> Tuple[State, Reward, bool, Info]: # type: ign if self.continuous: action = np.clip(action, -1, +1).astype(np.float32) else: - assert self.action_space.contains( - action - ), f"{action!r} ({type(action)}) invalid " + assert self.action_space.contains(action), f"{action!r} ({type(action)}) invalid " # Engines tip = (math.sin(self.lander.angle), math.cos(self.lander.angle)) @@ -109,9 +110,7 @@ def step(self, action: Action) -> Tuple[State, Reward, bool, Info]: # type: ign dispersion = [self.np_random.uniform(-1.0, +1.0) / SCALE for _ in range(2)] m_power = 0.0 - if (self.continuous and action[0] > 0.0) or ( - not self.continuous and action == 2 - ): + if (self.continuous and action[0] > 0.0) or (not self.continuous and action == 2): # Main engine if self.continuous: m_power = (np.clip(action[0], 0.0, 1.0) + 1.0) * 0.5 # 0.5..1.0 @@ -140,9 +139,7 @@ def step(self, action: Action) -> Tuple[State, Reward, bool, Info]: # type: ign ) s_power = 0.0 - if (self.continuous and np.abs(action[1]) > 0.5) or ( - not self.continuous and action in [1, 3] - ): + if (self.continuous and np.abs(action[1]) > 0.5) or (not self.continuous and action in [1, 3]): # Orientation engines if self.continuous: direction = np.sign(action[1]) @@ -151,12 +148,8 @@ def step(self, action: Action) -> Tuple[State, Reward, bool, Info]: # type: ign else: direction = action - 2 s_power = 1.0 - ox = tip[0] * dispersion[0] + side[0] * ( - 3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE - ) - oy = -tip[1] * dispersion[0] - side[1] * ( - 3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE - ) + ox = tip[0] * dispersion[0] + side[0] * (3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE) + oy = -tip[1] * dispersion[0] - side[1] * (3 * dispersion[1] + direction * SIDE_ENGINE_AWAY / SCALE) impulse_pos = ( self.lander.position[0] + ox - tip[0] * 17 / SCALE, self.lander.position[1] + oy + tip[1] * SIDE_ENGINE_HEIGHT / SCALE, @@ -191,10 +184,12 @@ def step(self, action: Action) -> Tuple[State, Reward, bool, Info]: # type: ign # shaping rewards prev_state = np.zeros_like(state) if self.prev_state is None else self.prev_state - reward_position = -100 * (np.sqrt(state[0] ** 2 + state[1] ** 2) - - np.sqrt(prev_state[0] ** 2 + prev_state[1] ** 2)) - reward_velocity = -100 * (np.sqrt(state[2] ** 2 + state[3] ** 2) - - np.sqrt(prev_state[2] ** 2 + prev_state[3] ** 2)) + reward_position = -100 * ( + np.sqrt(state[0] ** 2 + state[1] ** 2) - np.sqrt(prev_state[0] ** 2 + prev_state[1] ** 2) + ) + reward_velocity = -100 * ( + np.sqrt(state[2] ** 2 + state[3] ** 2) - np.sqrt(prev_state[2] ** 2 + prev_state[3] ** 2) + ) reward_angle = -100 * (abs(state[4]) - abs(prev_state[4])) reward_left_leg = 10 * (state[6] - prev_state[6]) reward_right_leg = 10 * (state[7] - prev_state[7]) @@ -220,8 +215,17 @@ def step(self, action: Action) -> Tuple[State, Reward, bool, Info]: # type: ign if self.render_mode == "human": self.render() - reward = [reward_position, reward_velocity, reward_angle, reward_left_leg, reward_right_leg, - reward_main_engine, reward_side_engine, reward_failure, reward_success] + reward = [ + reward_position, + reward_velocity, + reward_angle, + reward_left_leg, + reward_right_leg, + reward_main_engine, + reward_side_engine, + reward_failure, + reward_success, + ] return np.array(state, dtype=np.float32), np.array(reward), terminated, {} diff --git a/nnabla_rl/environments/gym_utils.py b/nnabla_rl/environments/gym_utils.py index d9db6daf..4af3fecc 100644 --- a/nnabla_rl/environments/gym_utils.py +++ b/nnabla_rl/environments/gym_utils.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ def get_space_shape(space: gym.spaces.Space) -> Tuple[int, ...]: if isinstance(space, gym.spaces.Box): return tuple(space.shape) elif isinstance(space, gym.spaces.Discrete): - return (1, ) + return (1,) else: raise ValueError @@ -41,7 +41,7 @@ def get_space_high(space: gym.spaces.Space) -> Union[np.ndarray, str]: if isinstance(space, gym.spaces.Box): return np.asarray(space.high) elif isinstance(space, gym.spaces.Discrete): - return 'N/A' + return "N/A" else: raise ValueError @@ -50,7 +50,7 @@ def get_space_low(space: gym.spaces.Space) -> Union[np.ndarray, str]: if isinstance(space, gym.spaces.Box): return np.asarray(space.low) elif isinstance(space, gym.spaces.Discrete): - return 'N/A' + return "N/A" else: raise ValueError @@ -65,8 +65,7 @@ def extract_max_episode_steps(env_or_env_info): return env_or_env_info.max_episode_steps -def is_same_space_type(query_space: gym.spaces.Space, - key_space: Union[gym.spaces.Discrete, gym.spaces.Box]) -> bool: +def is_same_space_type(query_space: gym.spaces.Space, key_space: Union[gym.spaces.Discrete, gym.spaces.Box]) -> bool: """Check whether the query_space has the same type of key_space or not. Note that if the query_space is gym.spaces.Tuple, this method checks whether all of the element of the query_space are the key_space or not. @@ -77,6 +76,7 @@ def is_same_space_type(query_space: gym.spaces.Space, Returns: bool: True if the query_space is the same as key_space. Otherwise False. """ + def _check_each_space_type(space, key_space) -> bool: if key_space == gym.spaces.Discrete: return isinstance(space, gym.spaces.Discrete) diff --git a/nnabla_rl/environments/wrappers/__init__.py b/nnabla_rl/environments/wrappers/__init__.py index 83ee504f..b069f66e 100644 --- a/nnabla_rl/environments/wrappers/__init__.py +++ b/nnabla_rl/environments/wrappers/__init__.py @@ -13,11 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nnabla_rl.environments.wrappers.common import (Float32RewardEnv, HWCToCHWEnv, NumpyFloat32Env, # noqa - ScreenRenderEnv, TimestepAsStateEnv, FlattenNestedTupleStateWrapper) +from nnabla_rl.environments.wrappers.common import ( # noqa + Float32RewardEnv, + HWCToCHWEnv, + NumpyFloat32Env, + ScreenRenderEnv, + TimestepAsStateEnv, + FlattenNestedTupleStateWrapper, +) from nnabla_rl.environments.wrappers.mujoco import EndlessEnv # noqa from nnabla_rl.environments.wrappers.atari import make_atari, wrap_deepmind # noqa -from nnabla_rl.environments.wrappers.hybrid_env import (EmbedActionWrapper, FlattenActionWrapper, # noqa - RemoveStepWrapper, ScaleActionWrapper, ScaleStateWrapper) +from nnabla_rl.environments.wrappers.hybrid_env import ( # noqa + EmbedActionWrapper, + FlattenActionWrapper, + RemoveStepWrapper, + ScaleActionWrapper, + ScaleStateWrapper, +) from nnabla_rl.environments.wrappers.gymnasium import Gymnasium2GymWrapper # noqa diff --git a/nnabla_rl/environments/wrappers/atari.py b/nnabla_rl/environments/wrappers/atari.py index f47ac09f..3466ad99 100644 --- a/nnabla_rl/environments/wrappers/atari.py +++ b/nnabla_rl/environments/wrappers/atari.py @@ -22,8 +22,14 @@ import nnabla_rl as rl from nnabla_rl.environments.wrappers.gymnasium import Gymnasium2GymWrapper -from nnabla_rl.external.atari_wrappers import (ClipRewardEnv, EpisodicLifeEnv, FireResetEnv, MaxAndSkipEnv, - NoopResetEnv, ScaledFloatFrame) +from nnabla_rl.external.atari_wrappers import ( + ClipRewardEnv, + EpisodicLifeEnv, + FireResetEnv, + MaxAndSkipEnv, + NoopResetEnv, + ScaledFloatFrame, +) cv2.ocl.setUseOpenCL(False) @@ -49,14 +55,11 @@ def __init__(self, env): self.width = 84 self.height = 84 obs_shape = (1, self.height, self.width) # 'chw' order - self.observation_space = spaces.Box( - low=0, high=255, - shape=obs_shape, dtype=np.uint8) + self.observation_space = spaces.Box(low=0, high=255, shape=obs_shape, dtype=np.uint8) def observation(self, frame): frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) - frame = cv2.resize(frame, (self.width, self.height), - interpolation=cv2.INTER_AREA) + frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA) return frame.reshape(self.observation_space.low.shape) @@ -69,8 +72,7 @@ def __init__(self, env, k): orig_obs_space = env.observation_space low = np.repeat(orig_obs_space.low, k, axis=stack_axis) high = np.repeat(orig_obs_space.high, k, axis=stack_axis) - self.observation_space = spaces.Box( - low=low, high=high, dtype=orig_obs_space.dtype) + self.observation_space = spaces.Box(low=low, high=high, dtype=orig_obs_space.dtype) def reset(self): ob = self.env.reset() @@ -110,24 +112,26 @@ def make_atari(env_id, max_frames_per_episode=None, use_gymnasium=False): if max_frames_per_episode is not None: env = env.unwrapped env = gym.wrappers.TimeLimit(env, max_episode_steps=max_frames_per_episode) - assert 'NoFrameskip' in env.spec.id + assert "NoFrameskip" in env.spec.id assert isinstance(env, gym.wrappers.TimeLimit) env = NoopResetEnv(env, noop_max=30) env = MaxAndSkipEnv(env, skip=4) return env -def wrap_deepmind(env, - episode_life=True, - clip_rewards=True, - normalize=True, - frame_stack=True, - fire_reset=False, - flicker_probability=0.0): +def wrap_deepmind( + env, + episode_life=True, + clip_rewards=True, + normalize=True, + frame_stack=True, + fire_reset=False, + flicker_probability=0.0, +): """Configure environment for DeepMind-style Atari.""" if episode_life: env = EpisodicLifeEnv(env) - if fire_reset and 'FIRE' in env.unwrapped.get_action_meanings(): + if fire_reset and "FIRE" in env.unwrapped.get_action_meanings(): env = FireResetEnv(env) env = CHWWarpFrame(env) if 0.0 < flicker_probability: diff --git a/nnabla_rl/environments/wrappers/common.py b/nnabla_rl/environments/wrappers/common.py index 0bc8e25a..4ed1aade 100644 --- a/nnabla_rl/environments/wrappers/common.py +++ b/nnabla_rl/environments/wrappers/common.py @@ -28,8 +28,7 @@ def __init__(self, env): self.dtype = np.float32 if isinstance(env.observation_space, spaces.Tuple): self.observation_space = spaces.Tuple( - [self._create_observation_space(observation_space) - for observation_space in env.observation_space] + [self._create_observation_space(observation_space) for observation_space in env.observation_space] ) else: self.observation_space = self._create_observation_space(env.observation_space) @@ -37,10 +36,7 @@ def __init__(self, env): def _create_observation_space(self, observation_space): if isinstance(observation_space, spaces.Box): return spaces.Box( - low=observation_space.low, - high=observation_space.high, - shape=observation_space.shape, - dtype=self.dtype + low=observation_space.low, high=observation_space.high, shape=observation_space.shape, dtype=self.dtype ) elif isinstance(observation_space, spaces.Discrete): return spaces.Discrete(n=observation_space.n) @@ -98,10 +94,7 @@ def reverse_action(self, action): def _create_action_space(self, action_space): if isinstance(action_space, spaces.Box): return spaces.Box( - low=action_space.low, - high=action_space.high, - shape=action_space.shape, - dtype=self.continuous_dtype + low=action_space.low, high=action_space.high, shape=action_space.shape, dtype=self.continuous_dtype ) elif isinstance(action_space, spaces.Discrete): return spaces.Discrete(n=action_space.n) @@ -139,7 +132,7 @@ class ScreenRenderEnv(gym.Wrapper): def __init__(self, env): super(ScreenRenderEnv, self).__init__(env) self._installed_gym_version = parse(gym.__version__) - self._gym_version25 = parse('0.25.0') + self._gym_version25 = parse("0.25.0") self._env_name = "Unknown" if env.unwrapped.spec is None else env.unwrapped.spec.id def step(self, action): @@ -148,7 +141,7 @@ def step(self, action): return results def reset(self): - if 'Bullet' in self._env_name: + if "Bullet" in self._env_name: self._render_env() state = self.env.reset() else: @@ -159,8 +152,8 @@ def reset(self): def _render_env(self): if self._gym_version25 <= self._installed_gym_version: # 0.25.0 <= gym version - rgb_array = self.env.render(mode='rgb_array') - cv2.imshow(f'{self._env_name}', cv2.cvtColor(rgb_array, cv2.COLOR_RGB2BGR)) + rgb_array = self.env.render(mode="rgb_array") + cv2.imshow(f"{self._env_name}", cv2.cvtColor(rgb_array, cv2.COLOR_RGB2BGR)) cv2.waitKey(1) else: # old gym version @@ -180,8 +173,8 @@ def step(self, action): self._episode_num += 1 episode_steps = len(self._episode_rewards) episode_return = np.sum(self._episode_rewards) - logger.info(f'Episode #{self._episode_num} finished.') - logger.info(f'Episode steps: {episode_steps}. Total return: {episode_return}.') + logger.info(f"Episode #{self._episode_num} finished.") + logger.info(f"Episode steps: {episode_steps}. Total return: {episode_return}.") self._episode_rewards.clear() return s_next, reward, done, info @@ -201,7 +194,7 @@ def __init__(self, env): super(TimestepAsStateEnv, self).__init__(env) self._timestep = 0 obs_space = self.observation_space - timestep_obs_space = spaces.Box(low=0., high=np.inf, shape=(1, ), dtype=np.float32) + timestep_obs_space = spaces.Box(low=0.0, high=np.inf, shape=(1,), dtype=np.float32) self.observation_space = spaces.Tuple([obs_space, timestep_obs_space]) def reset(self): diff --git a/nnabla_rl/environments/wrappers/goal_conditioned.py b/nnabla_rl/environments/wrappers/goal_conditioned.py index e1633c72..20e34b25 100644 --- a/nnabla_rl/environments/wrappers/goal_conditioned.py +++ b/nnabla_rl/environments/wrappers/goal_conditioned.py @@ -109,9 +109,7 @@ def np_random(self, value): @property def _np_random(self): - raise AttributeError( - "Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`." - ) + raise AttributeError("Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`.") def step(self, action): """Steps through the environment with action.""" @@ -148,7 +146,7 @@ def compute_reward(self, achieved_goal, desired_goal, info): class GoalConditionedTupleObservationEnv(gym.ObservationWrapper): def __init__(self, env: GoalEnv): super(GoalConditionedTupleObservationEnv, self).__init__(env) - self._observation_keys = ['observation', 'desired_goal', 'achieved_goal'] + self._observation_keys = ["observation", "desired_goal", "achieved_goal"] self._check_env(env) self._observation_space = self._build_observation_space(env) @@ -159,7 +157,7 @@ def _check_env(self, env: GoalEnv): for key in env.observation_space.spaces: if key not in self._observation_keys: - error_msg = f'{key} should be included in observation_space!!' + error_msg = f"{key} should be included in observation_space!!" raise ValueError(error_msg) def _build_observation_space(self, env: GoalEnv): @@ -177,7 +175,7 @@ def observation(self, observation): def _check_observation(self, observation): for key in observation.keys(): if key not in self._observation_keys: - error_msg = f'{key} should be included in observations!!' + error_msg = f"{key} should be included in observations!!" raise ValueError(error_msg) def compute_reward(self, achieved_goal, desired_goal, info): diff --git a/nnabla_rl/environments/wrappers/gymnasium.py b/nnabla_rl/environments/wrappers/gymnasium.py index ce4c4d52..ce6ae32f 100644 --- a/nnabla_rl/environments/wrappers/gymnasium.py +++ b/nnabla_rl/environments/wrappers/gymnasium.py @@ -30,13 +30,14 @@ def __init__(self, env): # observation space if isinstance(env.observation_space, gymnasium_spaces.Tuple): self.observation_space = gym_spaces.Tuple( - [self._translate_space(observation_space) - for observation_space in env.observation_space] + [self._translate_space(observation_space) for observation_space in env.observation_space] ) elif isinstance(env.observation_space, gymnasium_spaces.Dict): self.observation_space = gym_spaces.Dict( - {key: self._translate_space(observation_space) - for key, observation_space in env.observation_space.items()} + { + key: self._translate_space(observation_space) + for key, observation_space in env.observation_space.items() + } ) else: self.observation_space = self._translate_space(env.observation_space) @@ -44,13 +45,11 @@ def __init__(self, env): # action space if isinstance(env.action_space, gymnasium_spaces.Tuple): self.action_space = gym_spaces.Tuple( - [self._translate_space(action_space) - for action_space in env.action_space] + [self._translate_space(action_space) for action_space in env.action_space] ) elif isinstance(env.action_space, gymnasium_spaces.Dict): self.action_space = gym_spaces.Dict( - {key: self._translate_space(action_space) - for key, action_space in env.action_space.items()} + {key: self._translate_space(action_space) for key, action_space in env.action_space.items()} ) else: self.action_space = self._translate_space(env.action_space) @@ -61,7 +60,7 @@ def reset(self): def step(self, action): obs, reward, terminated, truncated, info = self.env.step(action) - done = (terminated or truncated) + done = terminated or truncated info.update({"TimeLimit.truncated": truncated}) return obs, reward, done, info @@ -76,12 +75,7 @@ def unwrapped(self): def _translate_space(self, space): if isinstance(space, gymnasium_spaces.Box): - return gym_spaces.Box( - low=space.low, - high=space.high, - shape=space.shape, - dtype=space.dtype - ) + return gym_spaces.Box(low=space.low, high=space.high, shape=space.shape, dtype=space.dtype) elif isinstance(space, gymnasium_spaces.Discrete): return gym_spaces.Discrete(n=int(space.n)) else: diff --git a/nnabla_rl/environments/wrappers/hybrid_env.py b/nnabla_rl/environments/wrappers/hybrid_env.py index 6c78a9ff..0b29a181 100644 --- a/nnabla_rl/environments/wrappers/hybrid_env.py +++ b/nnabla_rl/environments/wrappers/hybrid_env.py @@ -28,9 +28,14 @@ def __init__(self, env): original_action_space = env.action_space num_actions = original_action_space.spaces[0].n discrete_space = original_action_space.spaces[0] - continuous_space = [gym.spaces.Box(original_action_space.spaces[1].spaces[i].low, - original_action_space.spaces[1].spaces[i].high, - dtype=np.float32) for i in range(0, num_actions)] + continuous_space = [ + gym.spaces.Box( + original_action_space.spaces[1].spaces[i].low, + original_action_space.spaces[1].spaces[i].high, + dtype=np.float32, + ) + for i in range(0, num_actions) + ] self.action_space = gym.spaces.Tuple((discrete_space, *continuous_space)) def step(self, action): @@ -47,14 +52,18 @@ def __init__(self, env: Env): if self._is_box(env.observation_space): observation_shape = cast(SupportsIndex, env.observation_space.shape) - self.observation_space = gym.spaces.Box(low=-np.ones(shape=observation_shape), - high=np.ones(shape=observation_shape), - dtype=np.float32) + self.observation_space = gym.spaces.Box( + low=-np.ones(shape=observation_shape), high=np.ones(shape=observation_shape), dtype=np.float32 + ) elif self._is_tuple(env.observation_space): - spaces = [gym.spaces.Box(low=-np.ones(shape=space.shape), - high=np.ones(shape=space.shape), - dtype=np.float32) - if self._is_box(space) else space for space in cast(gym.spaces.Tuple, env.observation_space)] + spaces = [ + ( + gym.spaces.Box(low=-np.ones(shape=space.shape), high=np.ones(shape=space.shape), dtype=np.float32) + if self._is_box(space) + else space + ) + for space in cast(gym.spaces.Tuple, env.observation_space) + ] self.observation_space = gym.spaces.Tuple(spaces) # type: ignore else: raise NotImplementedError @@ -87,14 +96,18 @@ def __init__(self, env: Env): if self._is_box(env.action_space): action_shape = cast(SupportsIndex, env.action_space.shape) - self.action_space = gym.spaces.Box(low=-np.ones(shape=action_shape), - high=np.ones(shape=action_shape), - dtype=np.float32) + self.action_space = gym.spaces.Box( + low=-np.ones(shape=action_shape), high=np.ones(shape=action_shape), dtype=np.float32 + ) elif self._is_tuple(env.action_space): - spaces = [gym.spaces.Box(low=-np.ones(shape=space.shape), - high=np.ones(shape=space.shape), - dtype=np.float32) - if self._is_box(space) else space for space in cast(gym.spaces.Tuple, env.action_space)] + spaces = [ + ( + gym.spaces.Box(low=-np.ones(shape=space.shape), high=np.ones(shape=space.shape), dtype=np.float32) + if self._is_box(space) + else space + ) + for space in cast(gym.spaces.Tuple, env.action_space) + ] self.action_space = gym.spaces.Tuple(spaces) # type: ignore else: raise NotImplementedError @@ -133,9 +146,9 @@ def __init__(self, env: Env): self._original_action_range = [space.high - space.low for space in original_action_space[1:]] action_size = max(self._original_action_size) d_action_space = original_action_space[0] - c_action_space = gym.spaces.Box(low=-np.ones(shape=(action_size, )), - high=np.ones(shape=(action_size, )), - dtype=np.float32) + c_action_space = gym.spaces.Box( + low=-np.ones(shape=(action_size,)), high=np.ones(shape=(action_size,)), dtype=np.float32 + ) self.action_space = gym.spaces.Tuple((d_action_space, c_action_space)) # type: ignore def action(self, action): @@ -163,8 +176,8 @@ def __init__(self, env: Env): self.embed_map = np.random.normal(size=(self.d_action_dim, self.d_action_dim)) def action(self, action): - d_action = self._decode(action[:self.d_action_dim]) - c_action = action[self.d_action_dim:] + d_action = self._decode(action[: self.d_action_dim]) + c_action = action[self.d_action_dim :] return (d_action, c_action) def reverse_action(self, action): @@ -182,7 +195,7 @@ class RemoveStepWrapper(gym.ObservationWrapper): def __init__(self, env: Env): super().__init__(env) if not isinstance(env.observation_space, gym.spaces.Tuple): # type: ignore - raise ValueError('observation space is not a tuple!') + raise ValueError("observation space is not a tuple!") self.observation_space = cast(gym.spaces.Tuple, env.observation_space)[0] def observation(self, observation): diff --git a/nnabla_rl/environments/wrappers/mujoco.py b/nnabla_rl/environments/wrappers/mujoco.py index 70404e73..6f3bf0b0 100644 --- a/nnabla_rl/environments/wrappers/mujoco.py +++ b/nnabla_rl/environments/wrappers/mujoco.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,12 +21,12 @@ class EndlessEnv(gym.Wrapper): - ''' Endless Env + """ Endless Env This environment wrapper makes an environment be endless. \ Any done flags will be False except for the timelimit, and reset_reward (Usually, this value is negative) will be given. The done flag is True only if the number of steps reaches the timelimit of the environment. - ''' + """ def __init__(self, env: gym.Env, reset_reward: float): super(EndlessEnv, self).__init__(env) @@ -35,12 +35,13 @@ def __init__(self, env: gym.Env, reset_reward: float): self._num_steps = 0 self._reset_reward = reset_reward - def step(self, action: Action) -> Union[Tuple[State, Reward, bool, Info], # type: ignore - Tuple[State, Reward, bool, bool, Info]]: + def step( # type: ignore + self, action: Action + ) -> Union[Tuple[State, Reward, bool, Info], Tuple[State, Reward, bool, bool, Info]]: self._num_steps += 1 next_state, reward, done, info = cast(Tuple[State, float, bool, Dict], super().step(action)) - timelimit = info.pop('TimeLimit.truncated', False) or (self._num_steps == self._max_episode_steps) + timelimit = info.pop("TimeLimit.truncated", False) or (self._num_steps == self._max_episode_steps) if timelimit: self._num_steps = 0 diff --git a/nnabla_rl/exceptions.py b/nnabla_rl/exceptions.py index 447cf74e..8a050dac 100644 --- a/nnabla_rl/exceptions.py +++ b/nnabla_rl/exceptions.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,18 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. + class NNablaRLError(Exception): """Base class of all specific exceptions defined for nnabla_rl.""" + pass class UnsupportedTrainingException(NNablaRLError): """Raised when the algorithm does not support requested training procedure.""" + pass class UnsupportedEnvironmentException(NNablaRLError): """Raised when the algorithm does not support given environment to train the policy.""" + pass diff --git a/nnabla_rl/functions.py b/nnabla_rl/functions.py index f1ffee2c..08c8054e 100644 --- a/nnabla_rl/functions.py +++ b/nnabla_rl/functions.py @@ -21,9 +21,9 @@ import nnabla.functions as NF -def sample_gaussian(mean: nn.Variable, - ln_var: nn.Variable, - noise_clip: Optional[Tuple[float, float]] = None) -> nn.Variable: +def sample_gaussian( + mean: nn.Variable, ln_var: nn.Variable, noise_clip: Optional[Tuple[float, float]] = None +) -> nn.Variable: """Sample value from a gaussian distribution of given mean and variance. Args: @@ -35,7 +35,7 @@ def sample_gaussian(mean: nn.Variable, nn.Variable: Sampled value from gaussian distribution of given mean and variance """ if not (mean.shape == ln_var.shape): - raise ValueError('mean and ln_var has different shape') + raise ValueError("mean and ln_var has different shape") noise = NF.randn(shape=mean.shape) stddev = NF.exp(ln_var * 0.5) @@ -45,10 +45,9 @@ def sample_gaussian(mean: nn.Variable, return mean + stddev * noise -def sample_gaussian_multiple(mean: nn.Variable, - ln_var: nn.Variable, - num_samples: int, - noise_clip: Optional[Tuple[float, float]] = None) -> nn.Variable: +def sample_gaussian_multiple( + mean: nn.Variable, ln_var: nn.Variable, num_samples: int, noise_clip: Optional[Tuple[float, float]] = None +) -> nn.Variable: """Sample multiple values from a gaussian distribution of given mean and variance. The returned variable will have an additional axis in the middle as follows (batch_size, num_samples, dimension) @@ -63,13 +62,12 @@ def sample_gaussian_multiple(mean: nn.Variable, nn.Variable: Sampled values from gaussian distribution of given mean and variance """ if not (mean.shape == ln_var.shape): - raise ValueError('mean and ln_var has different shape') + raise ValueError("mean and ln_var has different shape") batch_size = mean.shape[0] data_shape = mean.shape[1:] mean = NF.reshape(mean, shape=(batch_size, 1, *data_shape)) - stddev = NF.reshape(NF.exp(ln_var * 0.5), - shape=(batch_size, 1, *data_shape)) + stddev = NF.reshape(NF.exp(ln_var * 0.5), shape=(batch_size, 1, *data_shape)) output_shape = (batch_size, num_samples, *data_shape) @@ -110,9 +108,9 @@ def repeat(x: nn.Variable, repeats: int, axis: int) -> nn.Variable: assert isinstance(repeats, int) assert axis is not None assert axis < len(x.shape) - reshape_size = (*x.shape[0:axis+1], 1, *x.shape[axis+1:]) - repeater_size = (*x.shape[0:axis+1], repeats, *x.shape[axis+1:]) - final_size = (*x.shape[0:axis], x.shape[axis] * repeats, *x.shape[axis+1:]) + reshape_size = (*x.shape[0 : axis + 1], 1, *x.shape[axis + 1 :]) + repeater_size = (*x.shape[0 : axis + 1], repeats, *x.shape[axis + 1 :]) + final_size = (*x.shape[0:axis], x.shape[axis] * repeats, *x.shape[axis + 1 :]) x = NF.reshape(x=x, shape=reshape_size) x = NF.broadcast(x, repeater_size) return NF.reshape(x, final_size) @@ -228,7 +226,7 @@ def minimum_n(variables: Sequence[nn.Variable]) -> nn.Variable: nn.Variable: Minimum value among the list of variables """ if len(variables) < 1: - raise ValueError('Variables must have at least 1 variable') + raise ValueError("Variables must have at least 1 variable") if len(variables) == 1: return variables[0] if len(variables) == 2: @@ -240,10 +238,15 @@ def minimum_n(variables: Sequence[nn.Variable]) -> nn.Variable: return minimum -def gaussian_cross_entropy_method(objective_function: Callable[[nn.Variable], nn.Variable], - init_mean: Union[nn.Variable, np.ndarray], init_var: Union[nn.Variable, np.ndarray], - sample_size: int = 500, num_elites: int = 10, - num_iterations: int = 5, alpha: float = 0.25) -> Tuple[nn.Variable, nn.Variable]: +def gaussian_cross_entropy_method( + objective_function: Callable[[nn.Variable], nn.Variable], + init_mean: Union[nn.Variable, np.ndarray], + init_var: Union[nn.Variable, np.ndarray], + sample_size: int = 500, + num_elites: int = 10, + num_iterations: int = 5, + alpha: float = 0.25, +) -> Tuple[nn.Variable, nn.Variable]: """Optimize objective function with respect to input using cross entropy method using gaussian distribution. Candidates are sampled from a gaussian distribution :math:`\\mathcal{N}(mean,\\,variance)` @@ -341,7 +344,7 @@ def objective_function(time_seq_action): # new_mean.shape = (batch_size, 1, gaussian_dimension) new_mean = NF.mean(elites, axis=1, keepdims=True) # new_var.shape = (batch_size, 1, gaussian_dimension) - new_var = NF.mean((elites - new_mean)**2, axis=1, keepdims=True) + new_var = NF.mean((elites - new_mean) ** 2, axis=1, keepdims=True) mean = alpha * mean + (1 - alpha) * new_mean.reshape((batch_size, gaussian_dimension)) var = alpha * var + (1 - alpha) * new_var.reshape((batch_size, gaussian_dimension)) @@ -349,10 +352,12 @@ def objective_function(time_seq_action): return mean, top -def random_shooting_method(objective_function: Callable[[nn.Variable], nn.Variable], - upper_bound: np.ndarray, - lower_bound: np.ndarray, - sample_size: int = 500) -> nn.Variable: +def random_shooting_method( + objective_function: Callable[[nn.Variable], nn.Variable], + upper_bound: np.ndarray, + lower_bound: np.ndarray, + sample_size: int = 500, +) -> nn.Variable: """Optimize objective function with respect to the variable using random shooting method. Candidates are sampled from a uniform distribution :math:`x \\sim U(lower\\:bound, upper\\:bound)`. @@ -481,6 +486,7 @@ def triangular_matrix(diagonal: nn.Variable, non_diagonal: Optional[nn.Variable] Returns: nn.Variable: lower triangular matrix constructed from given variables. """ + def _flat_tri_indices(batch_size, matrix_dim, upper): matrix_size = matrix_dim * matrix_dim @@ -502,8 +508,8 @@ def _flat_tri_indices(batch_size, matrix_dim, upper): scatter_indices = _flat_tri_indices(batch_size, matrix_dim=diagonal_size, upper=upper) matrix_size = diagonal_size * diagonal_size - non_diagonal_part = NF.reshape(non_diagonal, shape=(batch_size * non_diagonal_size, )) - non_diagonal_part = NF.scatter_nd(non_diagonal_part, scatter_indices, shape=(batch_size * matrix_size, )) + non_diagonal_part = NF.reshape(non_diagonal, shape=(batch_size * non_diagonal_size,)) + non_diagonal_part = NF.scatter_nd(non_diagonal_part, scatter_indices, shape=(batch_size * matrix_size,)) non_diagonal_part = NF.reshape(non_diagonal_part, shape=(batch_size, diagonal_size, diagonal_size)) return diagonal_part + non_diagonal_part @@ -556,7 +562,7 @@ def pytorch_equivalent_gather(x: nn.Variable, indices: nn.Variable, axis: int) - nn.Variable: gathered (in pytorch's style) variable. """ assert x.shape[:axis] == indices.shape[:axis] - assert x.shape[axis+1:] == indices.shape[axis+1:] + assert x.shape[axis + 1 :] == indices.shape[axis + 1 :] if axis != len(x.shape) - 1: x = swapaxes(x, axis, len(x.shape) - 1) indices = swapaxes(indices, axis, len(indices.shape) - 1) @@ -596,7 +602,7 @@ def concat_interleave(variables: Sequence[nn.Variable], axis: int) -> nn.Variabl indices = np.swapaxes(indices, axis, len(indices.shape) - 1) original_size = axis_size // variable_num for i in range(axis_size): - item_index = (i // variable_num) + item_index = i // variable_num var_index = i % variable_num data_index = var_index * original_size + item_index indices[..., i] = data_index @@ -622,10 +628,9 @@ def swapaxes(x: nn.Variable, axis1: int, axis2: int) -> nn.Variable: return NF.transpose(x, axes=axes) -def normalize(x: nn.Variable, - mean: nn.Variable, - std: nn.Variable, - value_clip: Optional[Tuple[float, float]] = None) -> nn.Variable: +def normalize( + x: nn.Variable, mean: nn.Variable, std: nn.Variable, value_clip: Optional[Tuple[float, float]] = None +) -> nn.Variable: """Normalize the given variable. Args: @@ -643,10 +648,9 @@ def normalize(x: nn.Variable, return normalized -def unnormalize(x: nn.Variable, - mean: nn.Variable, - std: nn.Variable, - value_clip: Optional[Tuple[float, float]] = None) -> nn.Variable: +def unnormalize( + x: nn.Variable, mean: nn.Variable, std: nn.Variable, value_clip: Optional[Tuple[float, float]] = None +) -> nn.Variable: """Unnormalize the given variable. Args: diff --git a/nnabla_rl/hook.py b/nnabla_rl/hook.py index a615a5f4..8941cb92 100644 --- a/nnabla_rl/hook.py +++ b/nnabla_rl/hook.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,13 +32,13 @@ class Hook(metaclass=ABCMeta): def __init__(self, timing: int = 1000): self._timing = timing - def __call__(self, algorithm: 'Algorithm'): + def __call__(self, algorithm: "Algorithm"): if algorithm.iteration_num % self._timing != 0: return self.on_hook_called(algorithm) @abstractmethod - def on_hook_called(self, algorithm: 'Algorithm'): + def on_hook_called(self, algorithm: "Algorithm"): """Called every "timing" iteration which is set on Hook's instance creation. Will run additional periodical operation (see each class' documentation) during the training. @@ -49,7 +49,7 @@ def on_hook_called(self, algorithm: 'Algorithm'): """ raise NotImplementedError - def setup(self, algorithm: 'Algorithm', total_iterations: int): + def setup(self, algorithm: "Algorithm", total_iterations: int): """Called before the training starts. Args: @@ -59,7 +59,7 @@ def setup(self, algorithm: 'Algorithm', total_iterations: int): """ pass - def teardown(self, algorithm: 'Algorithm', total_iterations: int): + def teardown(self, algorithm: "Algorithm", total_iterations: int): """Called after the training ends. Args: diff --git a/nnabla_rl/hooks/computational_graph_hook.py b/nnabla_rl/hooks/computational_graph_hook.py index a10eabf2..c562723f 100644 --- a/nnabla_rl/hooks/computational_graph_hook.py +++ b/nnabla_rl/hooks/computational_graph_hook.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -72,6 +72,6 @@ def on_hook_called(self, algorithm): contents = {"networks": networks, "executors": executors} save(path, contents) - logger.info(f'Training computational graphs have been saved to {path}.') + logger.info(f"Training computational graphs have been saved to {path}.") self._saved = True diff --git a/nnabla_rl/hooks/epoch_num_hook.py b/nnabla_rl/hooks/epoch_num_hook.py index c0191ca6..2086a90f 100644 --- a/nnabla_rl/hooks/epoch_num_hook.py +++ b/nnabla_rl/hooks/epoch_num_hook.py @@ -1,5 +1,4 @@ - -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/nnabla_rl/hooks/evaluation_hook.py b/nnabla_rl/hooks/evaluation_hook.py index 639e1d98..cd133077 100644 --- a/nnabla_rl/hooks/evaluation_hook.py +++ b/nnabla_rl/hooks/evaluation_hook.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -42,27 +42,29 @@ def __init__(self, env, evaluator=EpisodicEvaluator(), timing=1000, writer=None) def on_hook_called(self, algorithm): iteration_num = algorithm.iteration_num - logger.info( - 'Starting evaluation at iteration {}.'.format(iteration_num)) + logger.info("Starting evaluation at iteration {}.".format(iteration_num)) returns = self._evaluator(algorithm, self._env) mean = np.mean(returns) std_dev = np.std(returns) median = np.median(returns) - logger.info('Evaluation results at iteration {}. mean: {} +/- std: {}, median: {}'.format( - iteration_num, mean, std_dev, median)) + logger.info( + "Evaluation results at iteration {}. mean: {} +/- std: {}, median: {}".format( + iteration_num, mean, std_dev, median + ) + ) if self._writer is not None: minimum = np.min(returns) maximum = np.max(returns) # From python 3.6 or above, the dictionary preserves insertion order scalar_results = {} - scalar_results['mean'] = mean - scalar_results['std_dev'] = std_dev - scalar_results['min'] = minimum - scalar_results['max'] = maximum - scalar_results['median'] = median + scalar_results["mean"] = mean + scalar_results["std_dev"] = std_dev + scalar_results["min"] = minimum + scalar_results["max"] = maximum + scalar_results["median"] = median self._writer.write_scalar(algorithm.iteration_num, scalar_results) histogram_results = {} - histogram_results['returns'] = returns + histogram_results["returns"] = returns self._writer.write_histogram(algorithm.iteration_num, histogram_results) diff --git a/nnabla_rl/hooks/iteration_state_hook.py b/nnabla_rl/hooks/iteration_state_hook.py index 5d4444cc..f6aba1e5 100644 --- a/nnabla_rl/hooks/iteration_state_hook.py +++ b/nnabla_rl/hooks/iteration_state_hook.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -34,19 +34,18 @@ def __init__(self, writer=None, timing=1000): self._writer = writer def on_hook_called(self, algorithm): - logger.info('Iteration state at iteration {}'.format( - algorithm.iteration_num)) + logger.info("Iteration state at iteration {}".format(algorithm.iteration_num)) latest_iteration_state = algorithm.latest_iteration_state - if 'scalar' in latest_iteration_state: - logger.info(pprint.pformat(latest_iteration_state['scalar'])) + if "scalar" in latest_iteration_state: + logger.info(pprint.pformat(latest_iteration_state["scalar"])) if self._writer is not None: for key, value in latest_iteration_state.items(): - if key == 'scalar': + if key == "scalar": self._writer.write_scalar(algorithm.iteration_num, value) - if key == 'histogram': + if key == "histogram": self._writer.write_histogram(algorithm.iteration_num, value) - if key == 'image': + if key == "image": self._writer.write_image(algorithm.iteration_num, value) diff --git a/nnabla_rl/hooks/time_measuring_hook.py b/nnabla_rl/hooks/time_measuring_hook.py index bdaf9618..afc02b8c 100644 --- a/nnabla_rl/hooks/time_measuring_hook.py +++ b/nnabla_rl/hooks/time_measuring_hook.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,7 +33,6 @@ def __init__(self, timing=1): def on_hook_called(self, algorithm): current_time = time.time() - logger.info("time spent since previous hook: {} seconds".format( - current_time - self._prev_time)) + logger.info("time spent since previous hook: {} seconds".format(current_time - self._prev_time)) self._prev_time = current_time diff --git a/nnabla_rl/initializers.py b/nnabla_rl/initializers.py index 2053c479..3924eac2 100644 --- a/nnabla_rl/initializers.py +++ b/nnabla_rl/initializers.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ import nnabla.initializer as NI -def HeNormal(inmaps, outmaps, kernel=(1, 1), factor=2.0, mode='fan_in'): +def HeNormal(inmaps, outmaps, kernel=(1, 1), factor=2.0, mode="fan_in"): """Create Weight initializer proposed by He et al. (Normal distribution version) @@ -34,19 +34,17 @@ def HeNormal(inmaps, outmaps, kernel=(1, 1), factor=2.0, mode='fan_in'): Raises: NotImplementedError: mode other than 'fan_in' or 'fan_out' is given """ - if mode == 'fan_in': - s = calc_normal_std_he_forward( - inmaps, outmaps, kernel, factor) - elif mode == 'fan_out': - s = calc_normal_std_he_backward( - inmaps, outmaps, kernel, factor) + if mode == "fan_in": + s = calc_normal_std_he_forward(inmaps, outmaps, kernel, factor) + elif mode == "fan_out": + s = calc_normal_std_he_backward(inmaps, outmaps, kernel, factor) else: - raise NotImplementedError('Unknown init mode: {}'.format(mode)) + raise NotImplementedError("Unknown init mode: {}".format(mode)) return NI.NormalInitializer(s) -def LeCunNormal(inmaps, outmaps, kernel=(1, 1), factor=1.0, mode='fan_in'): +def LeCunNormal(inmaps, outmaps, kernel=(1, 1), factor=1.0, mode="fan_in"): """Create Weight initializer proposed in LeCun 98, Efficient Backprop (Normal distribution version) @@ -62,15 +60,15 @@ def LeCunNormal(inmaps, outmaps, kernel=(1, 1), factor=1.0, mode='fan_in'): Raises: NotImplementedError: mode other than 'fan_in' is given """ - if mode == 'fan_in': + if mode == "fan_in": s = calc_normal_std_he_forward(inmaps, outmaps, kernel, factor) else: - raise NotImplementedError('Unknown init mode: {}'.format(mode)) + raise NotImplementedError("Unknown init mode: {}".format(mode)) return NI.NormalInitializer(s) -def HeUniform(inmaps, outmaps, kernel=(1, 1), factor=2.0, mode='fan_in', rng=None): +def HeUniform(inmaps, outmaps, kernel=(1, 1), factor=2.0, mode="fan_in", rng=None): """Create Weight initializer proposed by He et al. (Uniform distribution version) @@ -86,12 +84,12 @@ def HeUniform(inmaps, outmaps, kernel=(1, 1), factor=2.0, mode='fan_in', rng=Non Raises: NotImplementedError: mode other than 'fan_in' or 'fan_out' is given """ - if mode == 'fan_in': + if mode == "fan_in": lim = calc_uniform_lim_he_forward(inmaps, outmaps, kernel, factor) - elif mode == 'fan_out': + elif mode == "fan_out": lim = calc_uniform_lim_he_backward(inmaps, outmaps, kernel, factor) else: - raise NotImplementedError('Unknown init mode: {}'.format(mode)) + raise NotImplementedError("Unknown init mode: {}".format(mode)) return NI.UniformInitializer(lim=(-lim, lim), rng=rng) @@ -102,7 +100,7 @@ def GlorotUniform(inmaps, outmaps, kernel=(1, 1), rng=None): class NormcInitializer(NI.BaseInitializer): - ''' Create Normc initializer + """Create Normc initializer See: https://github.com/openai/baselines/blob/master/baselines/common/tf_util.py Initializes the parameter which normalized along 'axis' dimension. @@ -114,7 +112,7 @@ class NormcInitializer(NI.BaseInitializer): When None, nnabla's default random nunmber generator will be used. Returns: NormcInitializer : weight initialzier - ''' + """ def __init__(self, std=1.0, axis=0, rng=None): if rng is None: diff --git a/nnabla_rl/logger.py b/nnabla_rl/logger.py index 2ea9d2a5..c05a07df 100644 --- a/nnabla_rl/logger.py +++ b/nnabla_rl/logger.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ def level(self): return self.logger.level -logger = TqdmAdapter(logging.getLogger('nnabla_rl'), {}) +logger = TqdmAdapter(logging.getLogger("nnabla_rl"), {}) logger.disabled = False diff --git a/nnabla_rl/model_trainers/decision_transformer/__init__.py b/nnabla_rl/model_trainers/decision_transformer/__init__.py index 46e57247..77c8bc38 100644 --- a/nnabla_rl/model_trainers/decision_transformer/__init__.py +++ b/nnabla_rl/model_trainers/decision_transformer/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,5 +13,8 @@ # limitations under the License. from nnabla_rl.model_trainers.decision_transformer.decision_transformer_trainer import ( # noqa - StochasticDecisionTransformerTrainer, StochasticDecisionTransformerTrainerConfig, - DeterministicDecisionTransformerTrainer, DeterministicDecisionTransformerTrainerConfig) + StochasticDecisionTransformerTrainer, + StochasticDecisionTransformerTrainerConfig, + DeterministicDecisionTransformerTrainer, + DeterministicDecisionTransformerTrainerConfig, +) diff --git a/nnabla_rl/model_trainers/decision_transformer/decision_transformer_trainer.py b/nnabla_rl/model_trainers/decision_transformer/decision_transformer_trainer.py index 097fa7c8..6eb6aa47 100644 --- a/nnabla_rl/model_trainers/decision_transformer/decision_transformer_trainer.py +++ b/nnabla_rl/model_trainers/decision_transformer/decision_transformer_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -46,33 +46,38 @@ class StochasticDecisionTransformerTrainerConfig(DecisionTransformerTrainerConfi class DecisionTransformerTrainer(ModelTrainer): """Decision transformer trainer for Stochastic environment.""" + # type declarations to type check with mypy # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar # See https://mypy.readthedocs.io/en/stable/class_basics.html for details _config: DecisionTransformerTrainerConfig _pi_loss: nn.Variable - def __init__(self, - models: Union[DecisionTransformerModel, Sequence[DecisionTransformerModel]], - solvers: Dict[str, nn.solver.Solver], - env_info: EnvironmentInfo, - wd_solvers: Optional[Dict[str, nn.solver.Solver]], - config: DecisionTransformerTrainerConfig): + def __init__( + self, + models: Union[DecisionTransformerModel, Sequence[DecisionTransformerModel]], + solvers: Dict[str, nn.solver.Solver], + env_info: EnvironmentInfo, + wd_solvers: Optional[Dict[str, nn.solver.Solver]], + config: DecisionTransformerTrainerConfig, + ): self._wd_solvers = {} if wd_solvers is None else wd_solvers super(DecisionTransformerTrainer, self).__init__(models, solvers, env_info, config) - def _update_model(self, - models: Sequence[Model], - solvers: Dict[str, nn.solver.Solver], - batch: TrainingBatch, - training_variables: TrainingVariables, - **kwargs) -> Dict[str, np.ndarray]: + def _update_model( + self, + models: Sequence[Model], + solvers: Dict[str, nn.solver.Solver], + batch: TrainingBatch, + training_variables: TrainingVariables, + **kwargs, + ) -> Dict[str, np.ndarray]: for t, b in zip(training_variables, batch): set_data_to_variable(t.s_current, b.s_current) set_data_to_variable(t.a_current, b.a_current) - set_data_to_variable(t.extra['timesteps'], b.extra['timesteps']) - set_data_to_variable(t.extra['rtg'], b.extra['rtg']) - set_data_to_variable(t.extra['target'], b.extra['target']) + set_data_to_variable(t.extra["timesteps"], b.extra["timesteps"]) + set_data_to_variable(t.extra["rtg"], b.extra["rtg"]) + set_data_to_variable(t.extra["target"], b.extra["target"]) # update model for solver in solvers.values(): @@ -87,7 +92,7 @@ def _update_model(self, wd_solver.update() trainer_state = {} - trainer_state['loss'] = self._pi_loss.d.copy() + trainer_state["loss"] = self._pi_loss.d.copy() return trainer_state def _setup_training_variables(self, batch_size) -> TrainingVariables: @@ -99,14 +104,15 @@ def _setup_training_variables(self, batch_size) -> TrainingVariables: rtg_var = create_variable(batch_size, (self._config.context_length, 1)) extra = {} - extra['target'] = target_var - extra['timesteps'] = timesteps_var - extra['rtg'] = rtg_var + extra["target"] = target_var + extra["timesteps"] = timesteps_var + extra["rtg"] = rtg_var return TrainingVariables(batch_size, s_current_var, a_current_var, extra=extra) def _setup_solver(self): def _should_decay(param_key): - return 'affine/W' in param_key or 'conv/W' in param_key + return "affine/W" in param_key or "conv/W" in param_key + for model in self._models: if model.scope_name in self._wd_solvers.keys(): solver = self._solvers[model.scope_name] @@ -151,12 +157,14 @@ class StochasticDecisionTransformerTrainer(DecisionTransformerTrainer): _config: StochasticDecisionTransformerTrainerConfig _pi_loss: nn.Variable - def __init__(self, - models: Union[StochasticDecisionTransformer, Sequence[StochasticDecisionTransformer]], - solvers: Dict[str, nn.solver.Solver], - env_info: EnvironmentInfo, - wd_solvers: Optional[Dict[str, nn.solver.Solver]] = None, - config: StochasticDecisionTransformerTrainerConfig = StochasticDecisionTransformerTrainerConfig()): + def __init__( + self, + models: Union[StochasticDecisionTransformer, Sequence[StochasticDecisionTransformer]], + solvers: Dict[str, nn.solver.Solver], + env_info: EnvironmentInfo, + wd_solvers: Optional[Dict[str, nn.solver.Solver]] = None, + config: StochasticDecisionTransformerTrainerConfig = StochasticDecisionTransformerTrainerConfig(), + ): super(StochasticDecisionTransformerTrainer, self).__init__(models, solvers, env_info, wd_solvers, config) def _build_training_graph(self, models: Sequence[Model], training_variables: TrainingVariables): @@ -165,9 +173,9 @@ def _build_training_graph(self, models: Sequence[Model], training_variables: Tra for policy in models: s = training_variables.s_current a = training_variables.a_current - rtg = training_variables.extra['rtg'] - timesteps = training_variables.extra['timesteps'] - target = training_variables.extra['target'] + rtg = training_variables.extra["rtg"] + timesteps = training_variables.extra["timesteps"] + target = training_variables.extra["target"] distribution = policy.pi(s, a, rtg, timesteps) # This loss calculation should be same as cross entropy loss loss = -distribution.log_prob(target) @@ -183,13 +191,14 @@ class DeterministicDecisionTransformerTrainer(DecisionTransformerTrainer): _config: DeterministicDecisionTransformerTrainerConfig _pi_loss: nn.Variable - def __init__(self, - models: Union[DeterministicDecisionTransformer, Sequence[DeterministicDecisionTransformer]], - solvers: Dict[str, nn.solver.Solver], - env_info: EnvironmentInfo, - wd_solvers: Optional[Dict[str, nn.solver.Solver]] = None, - config: DeterministicDecisionTransformerTrainerConfig = - DeterministicDecisionTransformerTrainerConfig()): + def __init__( + self, + models: Union[DeterministicDecisionTransformer, Sequence[DeterministicDecisionTransformer]], + solvers: Dict[str, nn.solver.Solver], + env_info: EnvironmentInfo, + wd_solvers: Optional[Dict[str, nn.solver.Solver]] = None, + config: DeterministicDecisionTransformerTrainerConfig = DeterministicDecisionTransformerTrainerConfig(), + ): super(DeterministicDecisionTransformerTrainer, self).__init__(models, solvers, env_info, wd_solvers, config) def _build_training_graph(self, models: Sequence[Model], training_variables: TrainingVariables): @@ -198,9 +207,9 @@ def _build_training_graph(self, models: Sequence[Model], training_variables: Tra for policy in models: s = training_variables.s_current a = training_variables.a_current - rtg = training_variables.extra['rtg'] - timesteps = training_variables.extra['timesteps'] - target = training_variables.extra['target'] + rtg = training_variables.extra["rtg"] + timesteps = training_variables.extra["timesteps"] + target = training_variables.extra["target"] actions = policy.pi(s, a, rtg, timesteps) loss = RF.mean_squared_error(actions, target) self._pi_loss += loss diff --git a/nnabla_rl/model_trainers/dynamics/__init__.py b/nnabla_rl/model_trainers/dynamics/__init__.py index 9e9c2d91..66fc24fd 100644 --- a/nnabla_rl/model_trainers/dynamics/__init__.py +++ b/nnabla_rl/model_trainers/dynamics/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,4 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nnabla_rl.model_trainers.dynamics.mppi_dynamics_trainer import MPPIDynamicsTrainerConfig, MPPIDynamicsTrainer # noqa +from nnabla_rl.model_trainers.dynamics.mppi_dynamics_trainer import ( # noqa + MPPIDynamicsTrainerConfig, + MPPIDynamicsTrainer, +) diff --git a/nnabla_rl/model_trainers/dynamics/mppi_dynamics_trainer.py b/nnabla_rl/model_trainers/dynamics/mppi_dynamics_trainer.py index 34955e1a..67f07ae3 100644 --- a/nnabla_rl/model_trainers/dynamics/mppi_dynamics_trainer.py +++ b/nnabla_rl/model_trainers/dynamics/mppi_dynamics_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,8 +19,14 @@ import nnabla as nn import nnabla_rl.functions as RF -from nnabla_rl.model_trainers.model_trainer import (LossIntegration, ModelTrainer, TrainerConfig, TrainingBatch, - TrainingVariables, rnn_support) +from nnabla_rl.model_trainers.model_trainer import ( + LossIntegration, + ModelTrainer, + TrainerConfig, + TrainingBatch, + TrainingVariables, + rnn_support, +) from nnabla_rl.models import DeterministicDynamics, Model from nnabla_rl.utils.data import convert_to_list_if_not_list, set_data_to_variable from nnabla_rl.utils.misc import create_variable, create_variables @@ -32,7 +38,7 @@ class MPPIDynamicsTrainerConfig(TrainerConfig): def __post_init__(self): super().__post_init__() - self._assert_positive(self.dt, 'dt') + self._assert_positive(self.dt, "dt") class MPPIDynamicsTrainer(ModelTrainer): @@ -43,23 +49,27 @@ class MPPIDynamicsTrainer(ModelTrainer): _loss: nn.Variable _prev_dynamics_rnn_states: Dict[str, Dict[str, nn.Variable]] - def __init__(self, - models: Union[DeterministicDynamics, Sequence[DeterministicDynamics]], - solvers: Dict[str, nn.solver.Solver], - env_info, - config: MPPIDynamicsTrainerConfig = MPPIDynamicsTrainerConfig()): + def __init__( + self, + models: Union[DeterministicDynamics, Sequence[DeterministicDynamics]], + solvers: Dict[str, nn.solver.Solver], + env_info, + config: MPPIDynamicsTrainerConfig = MPPIDynamicsTrainerConfig(), + ): self._prev_dynamics_rnn_states = {} super(MPPIDynamicsTrainer, self).__init__(models, solvers, env_info, config) def support_rnn(self) -> bool: return True - def _update_model(self, - models: Iterable[Model], - solvers: Dict[str, nn.solver.Solver], - batch: TrainingBatch, - training_variables: TrainingVariables, - **kwargs) -> Dict[str, np.ndarray]: + def _update_model( + self, + models: Iterable[Model], + solvers: Dict[str, nn.solver.Solver], + batch: TrainingBatch, + training_variables: TrainingVariables, + **kwargs, + ) -> Dict[str, np.ndarray]: for t, b in zip(training_variables, batch): set_data_to_variable(t.s_current, b.s_current) set_data_to_variable(t.a_current, b.a_current) @@ -88,11 +98,10 @@ def _update_model(self, solver.update() trainer_state: Dict[str, np.ndarray] = {} - trainer_state['dynamics_loss'] = self._loss.d.copy() + trainer_state["dynamics_loss"] = self._loss.d.copy() return trainer_state - def _build_training_graph(self, models: Union[Model, Sequence[Model]], - training_variables: TrainingVariables): + def _build_training_graph(self, models: Union[Model, Sequence[Model]], training_variables: TrainingVariables): models = convert_to_list_if_not_list(models) models = cast(Sequence[DeterministicDynamics], models) @@ -104,10 +113,7 @@ def _build_training_graph(self, models: Union[Model, Sequence[Model]], ignore_loss = is_burn_in_steps or (is_intermediate_steps and ignore_intermediate_loss) self._build_one_step_graph(models, variables, ignore_loss=ignore_loss) - def _build_one_step_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables, - ignore_loss: bool): + def _build_one_step_graph(self, models: Sequence[Model], training_variables: TrainingVariables, ignore_loss: bool): models = cast(Sequence[DeterministicDynamics], models) train_rnn_states = training_variables.rnn_states for model in models: @@ -117,8 +123,8 @@ def _build_one_step_graph(self, s_current = cast(nn.Variable, training_variables.s_current) s_next = cast(nn.Variable, training_variables.s_next) - q_dot = s_current[:, self._env_info.state_dim // 2:] - q_dot_next = s_next[:, self._env_info.state_dim // 2:] + q_dot = s_current[:, self._env_info.state_dim // 2 :] + q_dot_next = s_next[:, self._env_info.state_dim // 2 :] target_a = (q_dot_next - q_dot) / self._config.dt target_a.need_grad = False loss = RF.mean_squared_error(predicted_a, target_a) @@ -136,9 +142,14 @@ def _setup_training_variables(self, batch_size): rnn_state_variables = create_variables(batch_size, model.internal_state_shapes()) rnn_states[model.scope_name] = rnn_state_variables - return TrainingVariables(batch_size, s_current_var, a_current_var, s_next=s_next_var, - non_terminal=non_terminal_var, - rnn_states=rnn_states) + return TrainingVariables( + batch_size, + s_current_var, + a_current_var, + s_next=s_next_var, + non_terminal=non_terminal_var, + rnn_states=rnn_states, + ) @property def loss_variables(self) -> Dict[str, nn.Variable]: diff --git a/nnabla_rl/model_trainers/encoder/__init__.py b/nnabla_rl/model_trainers/encoder/__init__.py index c1b6af53..354eef9d 100644 --- a/nnabla_rl/model_trainers/encoder/__init__.py +++ b/nnabla_rl/model_trainers/encoder/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nnabla_rl.model_trainers.encoder.hyar_vae_trainer import ( # noqa - HyARVAETrainer, HyARVAETrainerConfig) +from nnabla_rl.model_trainers.encoder.hyar_vae_trainer import HyARVAETrainer, HyARVAETrainerConfig # noqa from nnabla_rl.model_trainers.encoder.kld_variational_auto_encoder_trainer import ( # noqa - KLDVariationalAutoEncoderTrainer, KLDVariationalAutoEncoderTrainerConfig) + KLDVariationalAutoEncoderTrainer, + KLDVariationalAutoEncoderTrainerConfig, +) diff --git a/nnabla_rl/model_trainers/encoder/hyar_vae_trainer.py b/nnabla_rl/model_trainers/encoder/hyar_vae_trainer.py index 0062cb59..80704510 100644 --- a/nnabla_rl/model_trainers/encoder/hyar_vae_trainer.py +++ b/nnabla_rl/model_trainers/encoder/hyar_vae_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,8 +23,13 @@ import nnabla_rl.functions as RNF from nnabla_rl.distributions import Gaussian from nnabla_rl.environments.environment_info import EnvironmentInfo -from nnabla_rl.model_trainers.model_trainer import (LossIntegration, ModelTrainer, TrainerConfig, TrainingBatch, - TrainingVariables) +from nnabla_rl.model_trainers.model_trainer import ( + LossIntegration, + ModelTrainer, + TrainerConfig, + TrainingBatch, + TrainingVariables, +) from nnabla_rl.models import HyARVAE, Model from nnabla_rl.utils.data import set_data_to_variable from nnabla_rl.utils.misc import create_variable @@ -42,19 +47,23 @@ class HyARVAETrainer(ModelTrainer): _config: HyARVAETrainerConfig _encoder_loss: nn.Variable # Training loss/output - def __init__(self, - models: Union[HyARVAE, Sequence[HyARVAE]], - solvers: Dict[str, nn.solver.Solver], - env_info: EnvironmentInfo, - config: HyARVAETrainerConfig = HyARVAETrainerConfig()): + def __init__( + self, + models: Union[HyARVAE, Sequence[HyARVAE]], + solvers: Dict[str, nn.solver.Solver], + env_info: EnvironmentInfo, + config: HyARVAETrainerConfig = HyARVAETrainerConfig(), + ): super().__init__(models, solvers, env_info, config) - def _update_model(self, - models: Iterable[Model], - solvers: Dict[str, nn.solver.Solver], - batch: TrainingBatch, - training_variables: TrainingVariables, - **kwargs) -> Dict[str, np.ndarray]: + def _update_model( + self, + models: Iterable[Model], + solvers: Dict[str, nn.solver.Solver], + batch: TrainingBatch, + training_variables: TrainingVariables, + **kwargs, + ) -> Dict[str, np.ndarray]: for t, b in zip(training_variables, batch): set_data_to_variable(t.s_current, b.s_current) set_data_to_variable(t.a_current, b.a_current) @@ -69,10 +78,10 @@ def _update_model(self, solver.update() trainer_state = {} - trainer_state['encoder_loss'] = self._encoder_loss.d.copy() - trainer_state['kl_loss'] = self._kl_loss.d.copy() - trainer_state['reconstruction_loss'] = self._reconstruction_loss.d.copy() - trainer_state['dyn_loss'] = self._dyn_loss.d.copy() + trainer_state["encoder_loss"] = self._encoder_loss.d.copy() + trainer_state["kl_loss"] = self._kl_loss.d.copy() + trainer_state["reconstruction_loss"] = self._reconstruction_loss.d.copy() + trainer_state["dyn_loss"] = self._dyn_loss.d.copy() return trainer_state def _build_training_graph(self, models: Union[Model, Sequence[Model]], training_variables: TrainingVariables): @@ -85,10 +94,7 @@ def _build_training_graph(self, models: Union[Model, Sequence[Model]], training_ ignore_loss = is_burn_in_steps or (is_intermediate_steps and ignore_intermediate_loss) self._build_one_step_graph(models, variables, ignore_loss=ignore_loss) - def _build_one_step_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables, - ignore_loss: bool): + def _build_one_step_graph(self, models: Sequence[Model], training_variables: TrainingVariables, ignore_loss: bool): batch_size = training_variables.batch_size models = cast(Sequence[HyARVAE], models) @@ -96,14 +102,14 @@ def _build_one_step_graph(self, action1, action2 = training_variables.a_current action_space = cast(gym.spaces.Tuple, self._env_info.action_space) xk = action1 if isinstance(action_space[0], gym.spaces.Box) else action2 - latent_distribution, (xk_tilde, ds_tilde) = vae.encode_and_decode(x=xk, - state=training_variables.s_current, - action=training_variables.a_current) + latent_distribution, (xk_tilde, ds_tilde) = vae.encode_and_decode( + x=xk, state=training_variables.s_current, action=training_variables.a_current + ) latent_shape = (batch_size, latent_distribution.ndim) target_latent_distribution = Gaussian( mean=nn.Variable.from_numpy_array(np.zeros(shape=latent_shape, dtype=np.float32)), - ln_var=nn.Variable.from_numpy_array(np.zeros(shape=latent_shape, dtype=np.float32)) + ln_var=nn.Variable.from_numpy_array(np.zeros(shape=latent_shape, dtype=np.float32)), ) reconstruction_loss = RNF.mean_squared_error(xk, xk_tilde) diff --git a/nnabla_rl/model_trainers/encoder/kld_variational_auto_encoder_trainer.py b/nnabla_rl/model_trainers/encoder/kld_variational_auto_encoder_trainer.py index 21f3090f..129e931e 100644 --- a/nnabla_rl/model_trainers/encoder/kld_variational_auto_encoder_trainer.py +++ b/nnabla_rl/model_trainers/encoder/kld_variational_auto_encoder_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,8 +22,13 @@ import nnabla_rl.functions as RNF from nnabla_rl.distributions import Gaussian from nnabla_rl.environments.environment_info import EnvironmentInfo -from nnabla_rl.model_trainers.model_trainer import (LossIntegration, ModelTrainer, TrainerConfig, TrainingBatch, - TrainingVariables) +from nnabla_rl.model_trainers.model_trainer import ( + LossIntegration, + ModelTrainer, + TrainerConfig, + TrainingBatch, + TrainingVariables, +) from nnabla_rl.models import Model, VariationalAutoEncoder from nnabla_rl.utils.data import set_data_to_variable from nnabla_rl.utils.misc import create_variable @@ -41,19 +46,23 @@ class KLDVariationalAutoEncoderTrainer(ModelTrainer): _config: KLDVariationalAutoEncoderTrainerConfig _encoder_loss: nn.Variable # Training loss/output - def __init__(self, - models: Union[VariationalAutoEncoder, Sequence[VariationalAutoEncoder]], - solvers: Dict[str, nn.solver.Solver], - env_info: EnvironmentInfo, - config: KLDVariationalAutoEncoderTrainerConfig = KLDVariationalAutoEncoderTrainerConfig()): + def __init__( + self, + models: Union[VariationalAutoEncoder, Sequence[VariationalAutoEncoder]], + solvers: Dict[str, nn.solver.Solver], + env_info: EnvironmentInfo, + config: KLDVariationalAutoEncoderTrainerConfig = KLDVariationalAutoEncoderTrainerConfig(), + ): super(KLDVariationalAutoEncoderTrainer, self).__init__(models, solvers, env_info, config) - def _update_model(self, - models: Iterable[Model], - solvers: Dict[str, nn.solver.Solver], - batch: TrainingBatch, - training_variables: TrainingVariables, - **kwargs) -> Dict[str, np.ndarray]: + def _update_model( + self, + models: Iterable[Model], + solvers: Dict[str, nn.solver.Solver], + batch: TrainingBatch, + training_variables: TrainingVariables, + **kwargs, + ) -> Dict[str, np.ndarray]: for t, b in zip(training_variables, batch): set_data_to_variable(t.s_current, b.s_current) set_data_to_variable(t.a_current, b.a_current) @@ -67,7 +76,7 @@ def _update_model(self, solver.update() trainer_state = {} - trainer_state['encoder_loss'] = self._encoder_loss.d.copy() + trainer_state["encoder_loss"] = self._encoder_loss.d.copy() return trainer_state def _build_training_graph(self, models: Union[Model, Sequence[Model]], training_variables: TrainingVariables): @@ -80,21 +89,19 @@ def _build_training_graph(self, models: Union[Model, Sequence[Model]], training_ ignore_loss = is_burn_in_steps or (is_intermediate_steps and ignore_intermediate_loss) self._build_one_step_graph(models, variables, ignore_loss=ignore_loss) - def _build_one_step_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables, - ignore_loss: bool): + def _build_one_step_graph(self, models: Sequence[Model], training_variables: TrainingVariables, ignore_loss: bool): batch_size = training_variables.batch_size models = cast(Sequence[VariationalAutoEncoder], models) for vae in models: - latent_distribution, reconstructed_action = vae.encode_and_decode(training_variables.s_current, - action=training_variables.a_current) + latent_distribution, reconstructed_action = vae.encode_and_decode( + training_variables.s_current, action=training_variables.a_current + ) latent_shape = (batch_size, latent_distribution.ndim) target_latent_distribution = Gaussian( mean=nn.Variable.from_numpy_array(np.zeros(shape=latent_shape, dtype=np.float32)), - ln_var=nn.Variable.from_numpy_array(np.zeros(shape=latent_shape, dtype=np.float32)) + ln_var=nn.Variable.from_numpy_array(np.zeros(shape=latent_shape, dtype=np.float32)), ) reconstruction_loss = RNF.mean_squared_error(training_variables.a_current, reconstructed_action) diff --git a/nnabla_rl/model_trainers/hybrid/__init__.py b/nnabla_rl/model_trainers/hybrid/__init__.py index 4dfee5c3..41eecfba 100644 --- a/nnabla_rl/model_trainers/hybrid/__init__.py +++ b/nnabla_rl/model_trainers/hybrid/__init__.py @@ -12,4 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nnabla_rl.model_trainers.hybrid.srsac_actor_critic_trainer import SRSACActorCriticTrainer, SRSACActorCriticTrainerConfig # noqa +from nnabla_rl.model_trainers.hybrid.srsac_actor_critic_trainer import ( # noqa + SRSACActorCriticTrainer, + SRSACActorCriticTrainerConfig, +) diff --git a/nnabla_rl/model_trainers/hybrid/srsac_actor_critic_trainer.py b/nnabla_rl/model_trainers/hybrid/srsac_actor_critic_trainer.py index 4ae0dae4..b5458a0f 100644 --- a/nnabla_rl/model_trainers/hybrid/srsac_actor_critic_trainer.py +++ b/nnabla_rl/model_trainers/hybrid/srsac_actor_critic_trainer.py @@ -42,6 +42,7 @@ def __post_init__(self): class SRSACActorCriticTrainer(ModelTrainer): """Efficient implementation of SAC style training that trains a policy and a q-function in parallel.""" + # type declarations to type check with mypy # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar # See https://mypy.readthedocs.io/en/stable/class_basics.html for details @@ -56,20 +57,22 @@ class SRSACActorCriticTrainer(ModelTrainer): _temperature_losses: Sequence[nn.Variable] _config: SRSACActorCriticTrainerConfig - def __init__(self, - pi: StochasticPolicy, - pi_solvers: Dict[str, nn.solver.Solver], - q_functions: Tuple[QFunction, QFunction], - q_solvers: Dict[str, nn.solver.Solver], - target_q_functions: Tuple[QFunction, QFunction], - temperature: AdjustableTemperature, - temperature_solver: Optional[nn.solver.Solver], - env_info: EnvironmentInfo, - config: SRSACActorCriticTrainerConfig = SRSACActorCriticTrainerConfig()): + def __init__( + self, + pi: StochasticPolicy, + pi_solvers: Dict[str, nn.solver.Solver], + q_functions: Tuple[QFunction, QFunction], + q_solvers: Dict[str, nn.solver.Solver], + target_q_functions: Tuple[QFunction, QFunction], + temperature: AdjustableTemperature, + temperature_solver: Optional[nn.solver.Solver], + env_info: EnvironmentInfo, + config: SRSACActorCriticTrainerConfig = SRSACActorCriticTrainerConfig(), + ): if len(q_functions) != 2: - raise ValueError('Two q functions should be provided') + raise ValueError("Two q functions should be provided") if not config.fixed_temperature and temperature_solver is None: - raise ValueError('Please set solver for temperature model') + raise ValueError("Please set solver for temperature model") self._pi_solver = pi_solvers[pi.scope_name] self._q_functions = q_functions self._q_solvers = [q_solvers[q_function.scope_name] for q_function in q_functions] @@ -90,12 +93,14 @@ def support_rnn(self) -> bool: def _total_timesteps(self) -> int: return self._config.replay_ratio - def _update_model(self, - models: Sequence[Model], - solvers: Dict[str, nn.solver.Solver], - batch: TrainingBatch, - training_variables: TrainingVariables, - **kwargs) -> Dict[str, np.ndarray]: + def _update_model( + self, + models: Sequence[Model], + solvers: Dict[str, nn.solver.Solver], + batch: TrainingBatch, + training_variables: TrainingVariables, + **kwargs, + ) -> Dict[str, np.ndarray]: for t, b in zip(training_variables, batch): set_data_to_variable(t.s_current, b.s_current) set_data_to_variable(t.a_current, b.a_current) @@ -113,10 +118,10 @@ def _update_model(self, self._pi_training(pi_loss, self._pi_solver, temperature_loss, self._temperature_solver) trainer_state = {} - trainer_state['pi_loss'] = self._pi_losses[-1].d.copy() - trainer_state['td_errors'] = self._td_errors[-1].d.copy() - trainer_state['q_loss'] = self._q_losses[-1].d.copy() - trainer_state['temperature_loss'] = self._temperature_losses[-1].d.copy() + trainer_state["pi_loss"] = self._pi_losses[-1].d.copy() + trainer_state["td_errors"] = self._td_errors[-1].d.copy() + trainer_state["q_loss"] = self._q_losses[-1].d.copy() + trainer_state["temperature_loss"] = self._temperature_losses[-1].d.copy() return trainer_state def _q_training(self, q_loss, q_solvers): @@ -145,9 +150,7 @@ def get_temperature(self) -> nn.Variable: # Will return exponentiated log temperature. To keep temperature always positive return self._temperature() - def _build_training_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables): + def _build_training_graph(self, models: Sequence[Model], training_variables: TrainingVariables): self._q_losses = [] self._td_errors = [] self._pi_losses = [] @@ -159,9 +162,7 @@ def _build_training_graph(self, self._pi_losses.append(pi_loss) self._temperature_losses.append(temperature_loss) - def _build_one_step_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables): + def _build_one_step_graph(self, models: Sequence[Model], training_variables: TrainingVariables): policy = cast(StochasticPolicy, models[0]) q_functions = cast(Tuple[QFunction, QFunction], models[1:]) target_q_functions = self._target_q_functions @@ -174,11 +175,13 @@ def _build_one_step_graph(self, pi_loss, temperature_loss = self._build_policy_training_graph(policy, q_functions, training_variables) return q_loss, td_error, pi_loss, temperature_loss - def _build_q_training_graph(self, - q_functions: Sequence[QFunction], - target_q_functions: Sequence[QFunction], - target_policy: StochasticPolicy, - training_variables: TrainingVariables) -> Tuple[nn.Variable, nn.Variable]: + def _build_q_training_graph( + self, + q_functions: Sequence[QFunction], + target_q_functions: Sequence[QFunction], + target_policy: StochasticPolicy, + training_variables: TrainingVariables, + ) -> Tuple[nn.Variable, nn.Variable]: # NOTE: Target q value depends on underlying implementation target_q = self._compute_q_target(target_q_functions, target_policy, training_variables) target_q.need_grad = False @@ -189,15 +192,14 @@ def _build_q_training_graph(self, q_loss += loss # FIXME: using the last q function's td error for prioritized replay. Is this fine? - td_error = extra['td_error'] + td_error = extra["td_error"] td_error.persistent = True return q_loss, td_error - def _compute_squared_td_loss(self, - model: QFunction, - target_q: nn.Variable, - training_variables: TrainingVariables) -> Tuple[nn.Variable, Dict[str, nn.Variable]]: + def _compute_squared_td_loss( + self, model: QFunction, target_q: nn.Variable, training_variables: TrainingVariables + ) -> Tuple[nn.Variable, Dict[str, nn.Variable]]: s_current = training_variables.s_current a_current = training_variables.a_current @@ -206,14 +208,16 @@ def _compute_squared_td_loss(self, squared_td_error = training_variables.weight * NF.pow_scalar(td_error, 2.0) q_loss = NF.mean(squared_td_error) - extra = {'td_error': td_error} + extra = {"td_error": td_error} return q_loss, extra - def _compute_q_target(self, - target_q_functions: Sequence[QFunction], - target_policy: StochasticPolicy, - training_variables: TrainingVariables, - **kwargs) -> nn.Variable: + def _compute_q_target( + self, + target_q_functions: Sequence[QFunction], + target_policy: StochasticPolicy, + training_variables: TrainingVariables, + **kwargs, + ) -> nn.Variable: gamma = training_variables.gamma reward = training_variables.reward non_terminal = training_variables.non_terminal @@ -230,9 +234,9 @@ def _compute_q_target(self, target_q = RF.minimum_n(q_values) return reward + gamma * non_terminal * (target_q - self.get_temperature() * log_pi) - def _build_policy_training_graph(self, policy: StochasticPolicy, - q_functions: Sequence[QFunction], - training_variables: TrainingVariables) -> Tuple[nn.Variable, nn.Variable]: + def _build_policy_training_graph( + self, policy: StochasticPolicy, q_functions: Sequence[QFunction], training_variables: TrainingVariables + ) -> Tuple[nn.Variable, nn.Variable]: # Actor optimization graph policy_distribution = policy.pi(training_variables.s_current) action_var, log_pi = policy_distribution.sample_and_compute_log_prob() @@ -258,14 +262,16 @@ def _setup_training_variables(self, batch_size): non_terminal_var = create_variable(batch_size, 1) weight_var = create_variable(batch_size, 1) - training_variables = TrainingVariables(batch_size=batch_size, - s_current=s_current_var, - a_current=a_current_var, - reward=reward_var, - gamma=gamma_var, - non_terminal=non_terminal_var, - s_next=s_next_var, - weight=weight_var) + training_variables = TrainingVariables( + batch_size=batch_size, + s_current=s_current_var, + a_current=a_current_var, + reward=reward_var, + gamma=gamma_var, + non_terminal=non_terminal_var, + s_next=s_next_var, + weight=weight_var, + ) return training_variables def _setup_solver(self): @@ -275,6 +281,8 @@ def _setup_solver(self): @property def loss_variables(self) -> Dict[str, nn.Variable]: - return {"pi_loss": self._pi_losses[-1], - "q_loss": self._q_losses[-1], - "temperature_loss": self._temperature_losses[-1]} + return { + "pi_loss": self._pi_losses[-1], + "q_loss": self._q_losses[-1], + "temperature_loss": self._temperature_losses[-1], + } diff --git a/nnabla_rl/model_trainers/model_trainer.py b/nnabla_rl/model_trainers/model_trainer.py index f05a91e6..ed79840b 100644 --- a/nnabla_rl/model_trainers/model_trainer.py +++ b/nnabla_rl/model_trainers/model_trainer.py @@ -30,19 +30,23 @@ @contextlib.contextmanager -def rnn_support(model: Model, - prev_rnn_states: Dict[str, Dict[str, nn.Variable]], - train_rnn_states: Dict[str, Dict[str, nn.Variable]], - training_variables: 'TrainingVariables', - config: 'TrainerConfig'): +def rnn_support( + model: Model, + prev_rnn_states: Dict[str, Dict[str, nn.Variable]], + train_rnn_states: Dict[str, Dict[str, nn.Variable]], + training_variables: "TrainingVariables", + config: "TrainerConfig", +): def stop_backprop(rnn_states): for value in rnn_states.values(): value.need_grad = False + try: if model.is_recurrent(): scope_name = model.scope_name internal_states = retrieve_internal_states( - scope_name, prev_rnn_states, train_rnn_states, training_variables, config.reset_on_terminal) + scope_name, prev_rnn_states, train_rnn_states, training_variables, config.reset_on_terminal + ) model.set_internal_states(internal_states) yield finally: @@ -54,13 +58,14 @@ def stop_backprop(rnn_states): class LossIntegration(Enum): - ALL_TIMESTEPS = 1, 'Computed loss is summed over all timesteps' - LAST_TIMESTEP_ONLY = 2, 'Only the last timestep\'s loss is used.' + ALL_TIMESTEPS = 1, "Computed loss is summed over all timesteps" + LAST_TIMESTEP_ONLY = 2, "Only the last timestep's loss is used." @dataclass class TrainerConfig(Configuration): """Configuration class for ModelTrainer.""" + unroll_steps: int = 1 burn_in_steps: int = 0 reset_on_terminal: bool = True # Reset internal rnn state to given state if previous state is terminal. @@ -68,11 +73,11 @@ class TrainerConfig(Configuration): def __post_init__(self): super(TrainerConfig, self).__post_init__() - self._assert_positive(self.unroll_steps, 'unroll_steps') - self._assert_positive_or_zero(self.burn_in_steps, 'burn_in_steps') + self._assert_positive(self.unroll_steps, "unroll_steps") + self._assert_positive_or_zero(self.burn_in_steps, "burn_in_steps") -class TrainingBatch(): +class TrainingBatch: """Mini-Batch class for train. Args: @@ -89,6 +94,7 @@ class TrainingBatch(): the mini-batch for next step (used in n-step learning) rnn_states (Dict[str, Dict[str, np.array]]): the rnn internal state values """ + batch_size: int s_current: Union[np.ndarray, Tuple[np.ndarray, ...]] a_current: np.ndarray @@ -99,21 +105,23 @@ class TrainingBatch(): weight: np.ndarray extra: Dict[str, np.ndarray] # Used in n-step/rnn learning - next_step_batch: Optional['TrainingBatch'] + next_step_batch: Optional["TrainingBatch"] rnn_states: Dict[str, Dict[str, np.ndarray]] - def __init__(self, - batch_size: int, - s_current: Optional[Union[np.ndarray, Tuple[np.ndarray, ...]]] = None, - a_current: Optional[np.ndarray] = None, - reward: Optional[np.ndarray] = None, - gamma: Optional[float] = None, - non_terminal: Optional[np.ndarray] = None, - s_next: Optional[Union[np.ndarray, Tuple[np.ndarray, ...]]] = None, - weight: Optional[np.ndarray] = None, - extra: Optional[Dict[str, np.ndarray]] = None, - next_step_batch: Optional['TrainingBatch'] = None, - rnn_states: Optional[Dict[str, Dict[str, np.ndarray]]] = None): + def __init__( + self, + batch_size: int, + s_current: Optional[Union[np.ndarray, Tuple[np.ndarray, ...]]] = None, + a_current: Optional[np.ndarray] = None, + reward: Optional[np.ndarray] = None, + gamma: Optional[float] = None, + non_terminal: Optional[np.ndarray] = None, + s_next: Optional[Union[np.ndarray, Tuple[np.ndarray, ...]]] = None, + weight: Optional[np.ndarray] = None, + extra: Optional[Dict[str, np.ndarray]] = None, + next_step_batch: Optional["TrainingBatch"] = None, + rnn_states: Optional[Dict[str, Dict[str, np.ndarray]]] = None, + ): assert 0 < batch_size self.batch_size = batch_size if s_current is not None: @@ -159,7 +167,7 @@ def __len__(self): return num_steps -class TrainingVariables(): +class TrainingVariables: batch_size: int s_current: Union[nn.Variable, Tuple[nn.Variable, ...]] a_current: nn.Variable @@ -172,21 +180,23 @@ class TrainingVariables(): rnn_states: Dict[str, Dict[str, nn.Variable]] # Used in rnn learning - _next_step_variables: Optional['TrainingVariables'] - _prev_step_variables: Optional['TrainingVariables'] - - def __init__(self, - batch_size: int, - s_current: Optional[Union[nn.Variable, Tuple[nn.Variable, ...]]] = None, - a_current: Optional[nn.Variable] = None, - reward: Optional[nn.Variable] = None, - gamma: Optional[nn.Variable] = None, - non_terminal: Optional[nn.Variable] = None, - s_next: Optional[Union[nn.Variable, Tuple[nn.Variable, ...]]] = None, - weight: Optional[nn.Variable] = None, - extra: Optional[Dict[str, nn.Variable]] = None, - next_step_variables: Optional["TrainingVariables"] = None, - rnn_states: Optional[Dict[str, Dict[str, nn.Variable]]] = None): + _next_step_variables: Optional["TrainingVariables"] + _prev_step_variables: Optional["TrainingVariables"] + + def __init__( + self, + batch_size: int, + s_current: Optional[Union[nn.Variable, Tuple[nn.Variable, ...]]] = None, + a_current: Optional[nn.Variable] = None, + reward: Optional[nn.Variable] = None, + gamma: Optional[nn.Variable] = None, + non_terminal: Optional[nn.Variable] = None, + s_next: Optional[Union[nn.Variable, Tuple[nn.Variable, ...]]] = None, + weight: Optional[nn.Variable] = None, + extra: Optional[Dict[str, nn.Variable]] = None, + next_step_variables: Optional["TrainingVariables"] = None, + rnn_states: Optional[Dict[str, Dict[str, nn.Variable]]] = None, + ): assert 0 < batch_size self.batch_size = batch_size if s_current is not None: @@ -327,11 +337,13 @@ class ModelTrainer(metaclass=ABCMeta): _train_count: int _training_variables: TrainingVariables - def __init__(self, - models: Union[Model, Sequence[Model]], - solvers: Dict[str, nn.solver.Solver], - env_info: EnvironmentInfo, - config: TrainerConfig): + def __init__( + self, + models: Union[Model, Sequence[Model]], + solvers: Dict[str, nn.solver.Solver], + env_info: EnvironmentInfo, + config: TrainerConfig, + ): self._env_info = env_info self._config = config @@ -340,7 +352,7 @@ def __init__(self, self._models = convert_to_list_if_not_list(models) self._assert_no_duplicate_model(self._models) if self._need_rnn_support(self._models) and not self.support_rnn(): - raise NotImplementedError(f'{self.__name__} does not support RNN models!') + raise NotImplementedError(f"{self.__name__} does not support RNN models!") self._solvers = solvers # Initially create training variables with batch_size 1. @@ -365,7 +377,7 @@ def __name__(self): def train(self, batch: TrainingBatch, **kwargs) -> Dict[str, np.ndarray]: if self._models is None: - raise RuntimeError('Call setup_training() first. Model is not set!') + raise RuntimeError("Call setup_training() first. Model is not set!") self._train_count += 1 batch = self._setup_batch(batch) @@ -411,18 +423,18 @@ def _setup_batch(self, training_batch: TrainingBatch) -> TrainingBatch: return training_batch @abstractmethod - def _update_model(self, - models: Sequence[Model], - solvers: Dict[str, nn.solver.Solver], - batch: TrainingBatch, - training_variables: TrainingVariables, - **kwargs) -> Dict[str, np.ndarray]: + def _update_model( + self, + models: Sequence[Model], + solvers: Dict[str, nn.solver.Solver], + batch: TrainingBatch, + training_variables: TrainingVariables, + **kwargs, + ) -> Dict[str, np.ndarray]: raise NotImplementedError @abstractmethod - def _build_training_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables): + def _build_training_graph(self, models: Sequence[Model], training_variables: TrainingVariables): raise NotImplementedError @abstractmethod @@ -439,9 +451,11 @@ def _setup_solver(self): def _assert_variable_length_equals_total_timesteps(self): total_timesptes = self._total_timesteps() if len(self._training_variables) != total_timesptes: - raise RuntimeError(f'Training variables length and rnn unroll + burn-in steps does not match!. \ + raise RuntimeError( + f"Training variables length and rnn unroll + burn-in steps does not match!. \ {len(self._training_variables)} != {total_timesptes}. \ - Check that the training method supports recurrent networks.') + Check that the training method supports recurrent networks." + ) @classmethod def _assert_no_duplicate_model(cls, models): diff --git a/nnabla_rl/model_trainers/perturbator/__init__.py b/nnabla_rl/model_trainers/perturbator/__init__.py index 59957070..8ae7cb7d 100644 --- a/nnabla_rl/model_trainers/perturbator/__init__.py +++ b/nnabla_rl/model_trainers/perturbator/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,4 +13,6 @@ # limitations under the License. from nnabla_rl.model_trainers.perturbator.bcq_perturbator_trainer import ( # noqa - BCQPerturbatorTrainer, BCQPerturbatorTrainerConfig) + BCQPerturbatorTrainer, + BCQPerturbatorTrainerConfig, +) diff --git a/nnabla_rl/model_trainers/perturbator/bcq_perturbator_trainer.py b/nnabla_rl/model_trainers/perturbator/bcq_perturbator_trainer.py index eb632b85..90c94132 100644 --- a/nnabla_rl/model_trainers/perturbator/bcq_perturbator_trainer.py +++ b/nnabla_rl/model_trainers/perturbator/bcq_perturbator_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,8 +18,13 @@ import nnabla as nn import nnabla.functions as NF from nnabla_rl.environments.environment_info import EnvironmentInfo -from nnabla_rl.model_trainers.model_trainer import (LossIntegration, ModelTrainer, TrainerConfig, TrainingBatch, - TrainingVariables) +from nnabla_rl.model_trainers.model_trainer import ( + LossIntegration, + ModelTrainer, + TrainerConfig, + TrainingBatch, + TrainingVariables, +) from nnabla_rl.models import Model, Perturbator, QFunction, VariationalAutoEncoder from nnabla_rl.utils.data import set_data_to_variable from nnabla_rl.utils.misc import create_variable @@ -27,10 +32,11 @@ @dataclass class BCQPerturbatorTrainerConfig(TrainerConfig): - ''' + """ Args: phi(float): action perturbator noise coefficient - ''' + """ + phi: float = 0.05 @@ -43,23 +49,27 @@ class BCQPerturbatorTrainer(ModelTrainer): _vae: VariationalAutoEncoder _perturbator_loss: nn.Variable - def __init__(self, - models: Union[Perturbator, Sequence[Perturbator]], - solvers: Dict[str, nn.solver.Solver], - q_function: QFunction, - vae: VariationalAutoEncoder, - env_info: EnvironmentInfo, - config: BCQPerturbatorTrainerConfig = BCQPerturbatorTrainerConfig()): + def __init__( + self, + models: Union[Perturbator, Sequence[Perturbator]], + solvers: Dict[str, nn.solver.Solver], + q_function: QFunction, + vae: VariationalAutoEncoder, + env_info: EnvironmentInfo, + config: BCQPerturbatorTrainerConfig = BCQPerturbatorTrainerConfig(), + ): self._q_function = q_function self._vae = vae super(BCQPerturbatorTrainer, self).__init__(models, solvers, env_info, config) - def _update_model(self, - models: Sequence[Model], - solvers: Dict[str, nn.solver.Solver], - batch: TrainingBatch, - training_variables: TrainingVariables, - **kwargs) -> Dict: + def _update_model( + self, + models: Sequence[Model], + solvers: Dict[str, nn.solver.Solver], + batch: TrainingBatch, + training_variables: TrainingVariables, + **kwargs, + ) -> Dict: for t, b in zip(training_variables, batch): set_data_to_variable(t.s_current, b.s_current) @@ -72,7 +82,7 @@ def _update_model(self, solver.update() trainer_state = {} - trainer_state['perturbator_loss'] = float(self._perturbator_loss.d.copy()) + trainer_state["perturbator_loss"] = float(self._perturbator_loss.d.copy()) return trainer_state def _build_training_graph(self, models: Union[Model, Sequence[Model]], training_variables: TrainingVariables): @@ -85,10 +95,7 @@ def _build_training_graph(self, models: Union[Model, Sequence[Model]], training_ ignore_loss = is_burn_in_steps or (is_intermediate_steps and ignore_intermediate_loss) self._build_one_step_graph(models, variables, ignore_loss=ignore_loss) - def _build_one_step_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables, - ignore_loss: bool): + def _build_one_step_graph(self, models: Sequence[Model], training_variables: TrainingVariables, ignore_loss: bool): assert training_variables.s_current is not None batch_size = training_variables.batch_size diff --git a/nnabla_rl/model_trainers/policy/__init__.py b/nnabla_rl/model_trainers/policy/__init__.py index 9f0f0454..09a79070 100644 --- a/nnabla_rl/model_trainers/policy/__init__.py +++ b/nnabla_rl/model_trainers/policy/__init__.py @@ -21,7 +21,16 @@ from nnabla_rl.model_trainers.policy.hyar_policy_trainer import HyARPolicyTrainer, HyARPolicyTrainerConfig # noqa from nnabla_rl.model_trainers.policy.ppo_policy_trainer import PPOPolicyTrainer, PPOPolicyTrainerConfig # noqa from nnabla_rl.model_trainers.policy.soft_policy_trainer import SoftPolicyTrainer, SoftPolicyTrainerConfig # noqa -from nnabla_rl.model_trainers.policy.reinforce_policy_trainer import REINFORCEPolicyTrainer, REINFORCEPolicyTrainerConfig # noqa +from nnabla_rl.model_trainers.policy.reinforce_policy_trainer import ( # noqa + REINFORCEPolicyTrainer, + REINFORCEPolicyTrainerConfig, +) from nnabla_rl.model_trainers.policy.trpo_policy_trainer import TRPOPolicyTrainer, TRPOPolicyTrainerConfig # noqa -from nnabla_rl.model_trainers.policy.xql_forward_policy_trainer import XQLForwardPolicyTrainer, XQLForwardPolicyTrainerConfig # noqa -from nnabla_rl.model_trainers.policy.xql_reverse_policy_trainer import XQLReversePolicyTrainer, XQLReversePolicyTrainerConfig # noqa +from nnabla_rl.model_trainers.policy.xql_forward_policy_trainer import ( # noqa + XQLForwardPolicyTrainer, + XQLForwardPolicyTrainerConfig, +) +from nnabla_rl.model_trainers.policy.xql_reverse_policy_trainer import ( # noqa + XQLReversePolicyTrainer, + XQLReversePolicyTrainerConfig, +) diff --git a/nnabla_rl/model_trainers/policy/a2c_policy_trainer.py b/nnabla_rl/model_trainers/policy/a2c_policy_trainer.py index 0d36c3ef..dc59711d 100644 --- a/nnabla_rl/model_trainers/policy/a2c_policy_trainer.py +++ b/nnabla_rl/model_trainers/policy/a2c_policy_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,8 +20,13 @@ import nnabla as nn import nnabla.functions as NF from nnabla_rl.environments.environment_info import EnvironmentInfo -from nnabla_rl.model_trainers.model_trainer import (LossIntegration, ModelTrainer, TrainerConfig, TrainingBatch, - TrainingVariables) +from nnabla_rl.model_trainers.model_trainer import ( + LossIntegration, + ModelTrainer, + TrainerConfig, + TrainingBatch, + TrainingVariables, +) from nnabla_rl.models import Model, StochasticPolicy from nnabla_rl.utils.data import set_data_to_variable from nnabla_rl.utils.misc import create_variable @@ -34,29 +39,34 @@ class A2CPolicyTrainerConfig(TrainerConfig): class A2CPolicyTrainer(ModelTrainer): """Advantaged Actor Critic style Policy Trainer.""" + # type declarations to type check with mypy # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar # See https://mypy.readthedocs.io/en/stable/class_basics.html for details _config: A2CPolicyTrainerConfig _pi_loss: nn.Variable - def __init__(self, - models: Union[StochasticPolicy, Sequence[StochasticPolicy]], - solvers: Dict[str, nn.solver.Solver], - env_info: EnvironmentInfo, - config: A2CPolicyTrainerConfig = A2CPolicyTrainerConfig()): + def __init__( + self, + models: Union[StochasticPolicy, Sequence[StochasticPolicy]], + solvers: Dict[str, nn.solver.Solver], + env_info: EnvironmentInfo, + config: A2CPolicyTrainerConfig = A2CPolicyTrainerConfig(), + ): super(A2CPolicyTrainer, self).__init__(models, solvers, env_info, config) - def _update_model(self, - models: Sequence[Model], - solvers: Dict[str, nn.solver.Solver], - batch: TrainingBatch, - training_variables: TrainingVariables, - **kwargs) -> Dict[str, np.ndarray]: + def _update_model( + self, + models: Sequence[Model], + solvers: Dict[str, nn.solver.Solver], + batch: TrainingBatch, + training_variables: TrainingVariables, + **kwargs, + ) -> Dict[str, np.ndarray]: for t, b in zip(training_variables, batch): set_data_to_variable(t.s_current, b.s_current) set_data_to_variable(t.a_current, b.a_current) - set_data_to_variable(t.extra['advantage'], b.extra['advantage']) + set_data_to_variable(t.extra["advantage"], b.extra["advantage"]) # update model for solver in solvers.values(): @@ -67,7 +77,7 @@ def _update_model(self, solver.update() trainer_state = {} - trainer_state['pi_loss'] = self._pi_loss.d.copy() + trainer_state["pi_loss"] = self._pi_loss.d.copy() return trainer_state def _build_training_graph(self, models: Sequence[Model], training_variables: TrainingVariables): @@ -80,16 +90,13 @@ def _build_training_graph(self, models: Sequence[Model], training_variables: Tra ignore_loss = is_burn_in_steps or (is_intermediate_steps and ignore_intermediate_loss) self._build_one_step_graph(models, variables, ignore_loss=ignore_loss) - def _build_one_step_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables, - ignore_loss: bool): + def _build_one_step_graph(self, models: Sequence[Model], training_variables: TrainingVariables, ignore_loss: bool): models = cast(Sequence[StochasticPolicy], models) for policy in models: distribution = policy.pi(training_variables.s_current) log_prob = distribution.log_prob(training_variables.a_current) entropy = distribution.entropy() - advantage = training_variables.extra['advantage'] + advantage = training_variables.extra["advantage"] self._pi_loss += NF.mean(-advantage * log_prob - self._config.entropy_coefficient * entropy) @@ -100,7 +107,7 @@ def _setup_training_variables(self, batch_size): advantage_var = create_variable(batch_size, 1) extra = {} - extra['advantage'] = advantage_var + extra["advantage"] = advantage_var return TrainingVariables(batch_size, s_current_var, a_current_var, extra=extra) @property diff --git a/nnabla_rl/model_trainers/policy/amp_policy_trainer.py b/nnabla_rl/model_trainers/policy/amp_policy_trainer.py index 61e27e25..a1328088 100644 --- a/nnabla_rl/model_trainers/policy/amp_policy_trainer.py +++ b/nnabla_rl/model_trainers/policy/amp_policy_trainer.py @@ -23,8 +23,13 @@ from nnabla_rl.distributions.distribution import Distribution from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.functions import compute_std, normalize -from nnabla_rl.model_trainers.model_trainer import (LossIntegration, ModelTrainer, TrainerConfig, TrainingBatch, - TrainingVariables) +from nnabla_rl.model_trainers.model_trainer import ( + LossIntegration, + ModelTrainer, + TrainerConfig, + TrainingBatch, + TrainingVariables, +) from nnabla_rl.models import Model, StochasticPolicy from nnabla_rl.utils.data import add_batch_dimension, set_data_to_variable from nnabla_rl.utils.misc import create_variable @@ -49,11 +54,13 @@ class AMPPolicyTrainer(ModelTrainer): _config: AMPPolicyTrainerConfig _pi_loss: nn.Variable - def __init__(self, - models: Union[StochasticPolicy, Sequence[StochasticPolicy]], - solvers: Dict[str, nn.solver.Solver], - env_info: EnvironmentInfo, - config: AMPPolicyTrainerConfig = AMPPolicyTrainerConfig()): + def __init__( + self, + models: Union[StochasticPolicy, Sequence[StochasticPolicy]], + solvers: Dict[str, nn.solver.Solver], + env_info: EnvironmentInfo, + config: AMPPolicyTrainerConfig = AMPPolicyTrainerConfig(), + ): self._action_mean = None self._action_std = None @@ -61,17 +68,20 @@ def __init__(self, action_mean = add_batch_dimension(np.array(config.action_mean, dtype=np.float32)) self._action_mean = nn.Variable.from_numpy_array(action_mean) action_var = add_batch_dimension(np.array(config.action_var, dtype=np.float32)) - self._action_std = compute_std(nn.Variable.from_numpy_array(action_var), - epsilon=0.0, mode_for_floating_point_error="max") + self._action_std = compute_std( + nn.Variable.from_numpy_array(action_var), epsilon=0.0, mode_for_floating_point_error="max" + ) super(AMPPolicyTrainer, self).__init__(models, solvers, env_info, config) - def _update_model(self, - models: Sequence[Model], - solvers: Dict[str, nn.solver.Solver], - batch: TrainingBatch, - training_variables: TrainingVariables, - **kwargs) -> Dict[str, np.ndarray]: + def _update_model( + self, + models: Sequence[Model], + solvers: Dict[str, nn.solver.Solver], + batch: TrainingBatch, + training_variables: TrainingVariables, + **kwargs, + ) -> Dict[str, np.ndarray]: for t, b in zip(training_variables, batch): set_data_to_variable(t.s_current, b.s_current) set_data_to_variable(t.a_current, b.a_current) @@ -140,11 +150,12 @@ def _clip_loss(self, training_variables: TrainingVariables, distribution: Distri clipped_ratio = NF.clip_by_value(probability_ratio, 1 - self._config.epsilon, 1 + self._config.epsilon) advantage = training_variables.extra["advantage"] lower_bounds = NF.minimum2(probability_ratio * advantage, clipped_ratio * advantage) - clip_loss = - NF.mean(lower_bounds) + clip_loss = -NF.mean(lower_bounds) return clip_loss - def _bound_loss(self, mean: nn.Variable, bound_min: nn.Variable, bound_max: nn.Variable, axis: int = -1 - ) -> nn.Variable: + def _bound_loss( + self, mean: nn.Variable, bound_min: nn.Variable, bound_max: nn.Variable, axis: int = -1 + ) -> nn.Variable: violation_min = NF.minimum_scalar(mean - bound_min, 0.0) violation_max = NF.maximum_scalar(mean - bound_max, 0.0) violation = NF.sum((violation_min**2), axis=axis) + NF.sum((violation_max**2), axis=axis) diff --git a/nnabla_rl/model_trainers/policy/bear_policy_trainer.py b/nnabla_rl/model_trainers/policy/bear_policy_trainer.py index 89d0172b..29cde6bb 100644 --- a/nnabla_rl/model_trainers/policy/bear_policy_trainer.py +++ b/nnabla_rl/model_trainers/policy/bear_policy_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,8 +21,13 @@ import nnabla.functions as NF import nnabla_rl.functions as RF from nnabla_rl.environments.environment_info import EnvironmentInfo -from nnabla_rl.model_trainers.model_trainer import (LossIntegration, ModelTrainer, TrainerConfig, TrainingBatch, - TrainingVariables) +from nnabla_rl.model_trainers.model_trainer import ( + LossIntegration, + ModelTrainer, + TrainerConfig, + TrainingBatch, + TrainingVariables, +) from nnabla_rl.models import Model, QFunction, StochasticPolicy, VariationalAutoEncoder from nnabla_rl.utils.data import set_data_to_variable from nnabla_rl.utils.misc import create_variable @@ -35,8 +40,7 @@ class AdjustableLagrangeMultiplier(Model): _log_lagrange: nn.Variable def __init__(self, scope_name, initial_value=None): - super(AdjustableLagrangeMultiplier, self).__init__( - scope_name=scope_name) + super(AdjustableLagrangeMultiplier, self).__init__(scope_name=scope_name) if initial_value: initial_value = np.log(initial_value) else: @@ -44,9 +48,9 @@ def __init__(self, scope_name, initial_value=None): initializer = np.reshape(initial_value, newshape=(1, 1)) with nn.parameter_scope(scope_name): - self._log_lagrange = \ - nn.parameter.get_parameter_or_create( - name='log_lagrange', shape=(1, 1), initializer=initializer) + self._log_lagrange = nn.parameter.get_parameter_or_create( + name="log_lagrange", shape=(1, 1), initializer=initializer + ) # Dummy call. Just for initializing the parameters self() @@ -65,18 +69,19 @@ def value(self): class BEARPolicyTrainerConfig(TrainerConfig): num_mmd_actions: int = 10 mmd_sigma: float = 20.0 - mmd_type: str = 'gaussian' + mmd_type: str = "gaussian" epsilon: float = 0.05 fix_lagrange_multiplier: bool = False warmup_iterations: int = 20000 def __post_init__(self): - self._assert_one_of(self.mmd_type, ['gaussian', 'laplacian'], 'mmd_type') + self._assert_one_of(self.mmd_type, ["gaussian", "laplacian"], "mmd_type") class BEARPolicyTrainer(ModelTrainer): """Bootstrapping Error Accumulation Reduction (BEAR) style Policy Trainer.""" + # type declarations to type check with mypy # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar # See https://mypy.readthedocs.io/en/stable/class_basics.html for details @@ -91,15 +96,17 @@ class BEARPolicyTrainer(ModelTrainer): _lagrange: AdjustableLagrangeMultiplier _lagrange_solver: Optional[nn.solver.Solver] - def __init__(self, - models: Union[StochasticPolicy, Sequence[StochasticPolicy]], - solvers: Dict[str, nn.solver.Solver], - q_ensembles: Sequence[QFunction], - vae: VariationalAutoEncoder, - lagrange_multiplier: AdjustableLagrangeMultiplier, - lagrange_solver: Optional[nn.solver.Solver], - env_info: EnvironmentInfo, - config: BEARPolicyTrainerConfig = BEARPolicyTrainerConfig()): + def __init__( + self, + models: Union[StochasticPolicy, Sequence[StochasticPolicy]], + solvers: Dict[str, nn.solver.Solver], + q_ensembles: Sequence[QFunction], + vae: VariationalAutoEncoder, + lagrange_multiplier: AdjustableLagrangeMultiplier, + lagrange_solver: Optional[nn.solver.Solver], + env_info: EnvironmentInfo, + config: BEARPolicyTrainerConfig = BEARPolicyTrainerConfig(), + ): self._q_ensembles = q_ensembles self._vae = vae @@ -108,12 +115,14 @@ def __init__(self, self._lagrange_solver = lagrange_solver super(BEARPolicyTrainer, self).__init__(models, solvers, env_info, config) - def _update_model(self, - models: Sequence[Model], - solvers: Dict[str, nn.solver.Solver], - batch: TrainingBatch, - training_variables: TrainingVariables, - **kwargs) -> Dict[str, np.ndarray]: + def _update_model( + self, + models: Sequence[Model], + solvers: Dict[str, nn.solver.Solver], + batch: TrainingBatch, + training_variables: TrainingVariables, + **kwargs, + ) -> Dict[str, np.ndarray]: for t, b in zip(training_variables, batch): set_data_to_variable(t.s_current, b.s_current) @@ -136,14 +145,13 @@ def _update_model(self, self._lagrange.clip(-5.0, 10.0) trainer_state = {} - trainer_state['pi_loss'] = self._pi_loss.d.copy() + trainer_state["pi_loss"] = self._pi_loss.d.copy() return trainer_state def _repeat_state(self, s_var: nn.Variable, batch_size: int) -> nn.Variable: s_hat = RF.expand_dims(s_var, axis=0) s_hat = RF.repeat(s_hat, repeats=self._config.num_mmd_actions, axis=0) - s_hat = NF.reshape(s_hat, shape=(batch_size * self._config.num_mmd_actions, - s_var.shape[-1])) + s_hat = NF.reshape(s_hat, shape=(batch_size * self._config.num_mmd_actions, s_var.shape[-1])) return s_hat def _build_training_graph(self, models: Sequence[Model], training_variables: TrainingVariables): @@ -158,29 +166,26 @@ def _build_training_graph(self, models: Sequence[Model], training_variables: Tra ignore_loss = is_burn_in_steps or (is_intermediate_steps and ignore_intermediate_loss) self._build_one_step_graph(models, variables, ignore_loss=ignore_loss) - def _build_one_step_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables, - ignore_loss: bool): + def _build_one_step_graph(self, models: Sequence[Model], training_variables: TrainingVariables, ignore_loss: bool): models = cast(Sequence[StochasticPolicy], models) batch_size = training_variables.batch_size for policy in models: - sampled_actions = self._vae.decode_multiple(z=None, - decode_num=self._config.num_mmd_actions, - state=training_variables.s_current) + sampled_actions = self._vae.decode_multiple( + z=None, decode_num=self._config.num_mmd_actions, state=training_variables.s_current + ) policy_distribution = policy.pi(training_variables.s_current) policy_actions = policy_distribution.sample_multiple( - num_samples=self._config.num_mmd_actions, noise_clip=(-0.5, 0.5)) + num_samples=self._config.num_mmd_actions, noise_clip=(-0.5, 0.5) + ) - if self._config.mmd_type == 'gaussian': + if self._config.mmd_type == "gaussian": mmd_loss = _compute_gaussian_mmd(sampled_actions, policy_actions, sigma=self._config.mmd_sigma) - elif self._config.mmd_type == 'laplacian': + elif self._config.mmd_type == "laplacian": mmd_loss = _compute_laplacian_mmd(sampled_actions, policy_actions, sigma=self._config.mmd_sigma) else: - raise ValueError( - 'Unknown mmd type: {}'.format(self._config.mmd_type)) + raise ValueError("Unknown mmd type: {}".format(self._config.mmd_type)) assert mmd_loss.shape == (batch_size, 1) if isinstance(training_variables.s_current, tuple): @@ -216,8 +221,9 @@ def _build_one_step_graph(self, self._pi_warmup_loss += 0.0 if ignore_loss else NF.mean(self._lagrange() * mmd_loss) # Must forward pi_loss before forwarding lagrange_loss - self._lagrange_loss += 0.0 if ignore_loss else - \ - NF.mean(-q_min + self._lagrange() * (mmd_loss - self._config.epsilon)) + self._lagrange_loss += ( + 0.0 if ignore_loss else -NF.mean(-q_min + self._lagrange() * (mmd_loss - self._config.epsilon)) + ) def _setup_training_variables(self, batch_size): # Training input variables @@ -250,7 +256,7 @@ def _compute_gaussian_mmd(samples1, samples2, sigma): last_axis = len(k_yy.shape) - 1 sum_k_yy = NF.sum(NF.exp(-NF.sum(k_yy**2, axis=last_axis, keepdims=True) / (2.0 * sigma)), axis=(1, 2)) - mmd_squared = (sum_k_xx / (n*n) - 2.0 * sum_k_xy / (m*n) + sum_k_yy / (m*m)) + mmd_squared = sum_k_xx / (n * n) - 2.0 * sum_k_xy / (m * n) + sum_k_yy / (m * m) # Add 1e-6 to avoid numerical instability return RF.sqrt(mmd_squared + 1e-6) @@ -271,6 +277,6 @@ def _compute_laplacian_mmd(samples1, samples2, sigma): last_axis = len(k_yy.shape) - 1 sum_k_yy = NF.sum(NF.exp(-NF.sum(NF.abs(k_yy), axis=last_axis, keepdims=True) / (2.0 * sigma)), axis=(1, 2)) - mmd_squared = (sum_k_xx / (n*n) - 2.0 * sum_k_xy / (m*n) + sum_k_yy / (m*m)) + mmd_squared = sum_k_xx / (n * n) - 2.0 * sum_k_xy / (m * n) + sum_k_yy / (m * m) # Add 1e-6 to avoid numerical instability return RF.sqrt(mmd_squared + 1e-6) diff --git a/nnabla_rl/model_trainers/policy/demme_policy_trainer.py b/nnabla_rl/model_trainers/policy/demme_policy_trainer.py index a31dc93b..996b3d21 100644 --- a/nnabla_rl/model_trainers/policy/demme_policy_trainer.py +++ b/nnabla_rl/model_trainers/policy/demme_policy_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,8 +21,14 @@ import nnabla.functions as NF import nnabla_rl.functions as RNF from nnabla_rl.environments.environment_info import EnvironmentInfo -from nnabla_rl.model_trainers.model_trainer import (LossIntegration, ModelTrainer, TrainerConfig, TrainingBatch, - TrainingVariables, rnn_support) +from nnabla_rl.model_trainers.model_trainer import ( + LossIntegration, + ModelTrainer, + TrainerConfig, + TrainingBatch, + TrainingVariables, + rnn_support, +) from nnabla_rl.models import Model, QFunction, StochasticPolicy from nnabla_rl.utils.data import convert_to_list_if_not_list, set_data_to_variable from nnabla_rl.utils.misc import create_variable, create_variables @@ -36,6 +42,7 @@ def __post_init__(self): class DEMMEPolicyTrainer(ModelTrainer): """DEMME Policy Gradient style Policy Trainer.""" + # type declarations to type check with mypy # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar # See https://mypy.readthedocs.io/en/stable/class_basics.html for details @@ -47,17 +54,19 @@ class DEMMEPolicyTrainer(ModelTrainer): _prev_q_rr_rnn_states: Dict[str, Dict[str, Dict[str, nn.Variable]]] _prev_q_re_rnn_states: Dict[str, Dict[str, Dict[str, nn.Variable]]] - def __init__(self, - models: Union[StochasticPolicy, Sequence[StochasticPolicy]], - solvers: Dict[str, nn.solver.Solver], - q_rr_functions: Sequence[QFunction], - q_re_functions: Sequence[QFunction], - env_info: EnvironmentInfo, - config: DEMMEPolicyTrainerConfig = DEMMEPolicyTrainerConfig()): + def __init__( + self, + models: Union[StochasticPolicy, Sequence[StochasticPolicy]], + solvers: Dict[str, nn.solver.Solver], + q_rr_functions: Sequence[QFunction], + q_re_functions: Sequence[QFunction], + env_info: EnvironmentInfo, + config: DEMMEPolicyTrainerConfig = DEMMEPolicyTrainerConfig(), + ): if len(q_rr_functions) < 2: - raise ValueError('Must provide at least 2 Qrr-functions for DEMME-training') + raise ValueError("Must provide at least 2 Qrr-functions for DEMME-training") if len(q_re_functions) < 2: - raise ValueError('Must provide at least 2 Qre-functions for DEMME-training') + raise ValueError("Must provide at least 2 Qre-functions for DEMME-training") self._q_rr_functions = q_rr_functions self._q_re_functions = q_re_functions @@ -73,12 +82,14 @@ def __init__(self, def support_rnn(self) -> bool: return True - def _update_model(self, - models: Sequence[Model], - solvers: Dict[str, nn.solver.Solver], - batch: TrainingBatch, - training_variables: TrainingVariables, - **kwargs) -> Dict[str, np.ndarray]: + def _update_model( + self, + models: Sequence[Model], + solvers: Dict[str, nn.solver.Solver], + batch: TrainingBatch, + training_variables: TrainingVariables, + **kwargs, + ) -> Dict[str, np.ndarray]: for t, b in zip(training_variables, batch): set_data_to_variable(t.s_current, b.s_current) set_data_to_variable(t.non_terminal, b.non_terminal) @@ -129,12 +140,10 @@ def _update_model(self, solver.update() trainer_state = {} - trainer_state['pi_loss'] = self._pi_loss.d.copy() + trainer_state["pi_loss"] = self._pi_loss.d.copy() return trainer_state - def _build_training_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables): + def _build_training_graph(self, models: Sequence[Model], training_variables: TrainingVariables): self._pi_loss = 0 ignore_intermediate_loss = self._config.loss_integration is LossIntegration.LAST_TIMESTEP_ONLY for step_index, variables in enumerate(training_variables): @@ -143,10 +152,7 @@ def _build_training_graph(self, ignore_loss = is_burn_in_steps or (is_intermediate_steps and ignore_intermediate_loss) self._build_one_step_graph(models, variables, ignore_loss=ignore_loss) - def _build_one_step_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables, - ignore_loss: bool): + def _build_one_step_graph(self, models: Sequence[Model], training_variables: TrainingVariables, ignore_loss: bool): train_rnn_states = training_variables.rnn_states for policy in models: assert isinstance(policy, StochasticPolicy) diff --git a/nnabla_rl/model_trainers/policy/dpg_policy_trainer.py b/nnabla_rl/model_trainers/policy/dpg_policy_trainer.py index aed1ca4d..7643ecd2 100644 --- a/nnabla_rl/model_trainers/policy/dpg_policy_trainer.py +++ b/nnabla_rl/model_trainers/policy/dpg_policy_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,8 +20,14 @@ import nnabla as nn import nnabla.functions as NF from nnabla_rl.environments.environment_info import EnvironmentInfo -from nnabla_rl.model_trainers.model_trainer import (LossIntegration, ModelTrainer, TrainerConfig, TrainingBatch, - TrainingVariables, rnn_support) +from nnabla_rl.model_trainers.model_trainer import ( + LossIntegration, + ModelTrainer, + TrainerConfig, + TrainingBatch, + TrainingVariables, + rnn_support, +) from nnabla_rl.models import DeterministicPolicy, Model, QFunction from nnabla_rl.utils.data import convert_to_list_if_not_list, set_data_to_variable from nnabla_rl.utils.misc import create_variable, create_variables @@ -34,6 +40,7 @@ class DPGPolicyTrainerConfig(TrainerConfig): class DPGPolicyTrainer(ModelTrainer): """Deterministic Policy Gradient (DPG) style Policy Trainer.""" + # type declarations to type check with mypy # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar # See https://mypy.readthedocs.io/en/stable/class_basics.html for details @@ -43,12 +50,14 @@ class DPGPolicyTrainer(ModelTrainer): _prev_policy_rnn_states: Dict[str, Dict[str, nn.Variable]] _prev_q_rnn_states: Dict[str, Dict[str, Dict[str, nn.Variable]]] - def __init__(self, - models: Union[DeterministicPolicy, Sequence[DeterministicPolicy]], - solvers: Dict[str, nn.solver.Solver], - q_function: QFunction, - env_info: EnvironmentInfo, - config: DPGPolicyTrainerConfig = DPGPolicyTrainerConfig()): + def __init__( + self, + models: Union[DeterministicPolicy, Sequence[DeterministicPolicy]], + solvers: Dict[str, nn.solver.Solver], + q_function: QFunction, + env_info: EnvironmentInfo, + config: DPGPolicyTrainerConfig = DPGPolicyTrainerConfig(), + ): self._q_function = q_function self._prev_policy_rnn_states = {} self._prev_q_rnn_states = {} @@ -60,12 +69,14 @@ def __init__(self, def support_rnn(self) -> bool: return True - def _update_model(self, - models: Sequence[Model], - solvers: Dict[str, nn.solver.Solver], - batch: TrainingBatch, - training_variables: TrainingVariables, - **kwargs) -> Dict[str, np.ndarray]: + def _update_model( + self, + models: Sequence[Model], + solvers: Dict[str, nn.solver.Solver], + batch: TrainingBatch, + training_variables: TrainingVariables, + **kwargs, + ) -> Dict[str, np.ndarray]: for t, b in zip(training_variables, batch): set_data_to_variable(t.s_current, b.s_current) set_data_to_variable(t.non_terminal, b.non_terminal) @@ -97,7 +108,7 @@ def _update_model(self, solver.update() trainer_state = {} - trainer_state['pi_loss'] = self._pi_loss.d.copy() + trainer_state["pi_loss"] = self._pi_loss.d.copy() return trainer_state def _build_training_graph(self, models: Sequence[Model], training_variables: TrainingVariables): @@ -110,10 +121,7 @@ def _build_training_graph(self, models: Sequence[Model], training_variables: Tra ignore_loss = is_burn_in_steps or (is_intermediate_steps and ignore_intermediate_loss) self._build_one_step_graph(models, variables, ignore_loss=ignore_loss) - def _build_one_step_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables, - ignore_loss: bool): + def _build_one_step_graph(self, models: Sequence[Model], training_variables: TrainingVariables, ignore_loss: bool): models = cast(Sequence[DeterministicPolicy], models) train_rnn_states = training_variables.rnn_states for policy in models: diff --git a/nnabla_rl/model_trainers/policy/her_policy_trainer.py b/nnabla_rl/model_trainers/policy/her_policy_trainer.py index c0ab0796..ceb48978 100644 --- a/nnabla_rl/model_trainers/policy/her_policy_trainer.py +++ b/nnabla_rl/model_trainers/policy/her_policy_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -35,20 +35,19 @@ class HERPolicyTrainer(DPGPolicyTrainer): _train_rnn_states: Dict[str, Dict[str, nn.Variable]] _prev_rnn_states: Dict[str, Dict[str, nn.Variable]] - def __init__(self, - models: Union[DeterministicPolicy, Sequence[DeterministicPolicy]], - solvers: Dict[str, nn.solver.Solver], - q_function: QFunction, - env_info: EnvironmentInfo, - config: HERPolicyTrainerConfig = HERPolicyTrainerConfig()): + def __init__( + self, + models: Union[DeterministicPolicy, Sequence[DeterministicPolicy]], + solvers: Dict[str, nn.solver.Solver], + q_function: QFunction, + env_info: EnvironmentInfo, + config: HERPolicyTrainerConfig = HERPolicyTrainerConfig(), + ): action_space = cast(gym.spaces.Box, env_info.action_space) self._max_action_value = float(action_space.high[0]) super(HERPolicyTrainer, self).__init__(models, solvers, q_function, env_info, config) - def _build_one_step_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables, - ignore_loss: bool): + def _build_one_step_graph(self, models: Sequence[Model], training_variables: TrainingVariables, ignore_loss: bool): models = cast(Sequence[DeterministicPolicy], models) train_rnn_states = training_variables.rnn_states for policy in models: @@ -62,5 +61,8 @@ def _build_one_step_graph(self, self._prev_q_rnn_states[policy.scope_name] = prev_rnn_states self._pi_loss += 0.0 if ignore_loss else -NF.mean(q) - self._pi_loss += 0.0 if ignore_loss else self._config.action_loss_coef \ - * NF.mean(NF.pow_scalar(action / self._max_action_value, 2.0)) + self._pi_loss += ( + 0.0 + if ignore_loss + else self._config.action_loss_coef * NF.mean(NF.pow_scalar(action / self._max_action_value, 2.0)) + ) diff --git a/nnabla_rl/model_trainers/policy/hyar_policy_trainer.py b/nnabla_rl/model_trainers/policy/hyar_policy_trainer.py index 10e14618..3bbbaf91 100644 --- a/nnabla_rl/model_trainers/policy/hyar_policy_trainer.py +++ b/nnabla_rl/model_trainers/policy/hyar_policy_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -40,20 +40,24 @@ class HyARPolicyTrainer(DPGPolicyTrainer): _config: HyARPolicyTrainerConfig _action_and_grads: Dict[str, List[Tuple[nn.Variable, nn.Variable]]] - def __init__(self, - models: Union[DeterministicPolicy, Sequence[DeterministicPolicy]], - solvers: Dict[str, nn.solver.Solver], - q_function: QFunction, - env_info: EnvironmentInfo, - config: HyARPolicyTrainerConfig = HyARPolicyTrainerConfig()): + def __init__( + self, + models: Union[DeterministicPolicy, Sequence[DeterministicPolicy]], + solvers: Dict[str, nn.solver.Solver], + q_function: QFunction, + env_info: EnvironmentInfo, + config: HyARPolicyTrainerConfig = HyARPolicyTrainerConfig(), + ): super().__init__(models, solvers, q_function, env_info, config) - def _update_model(self, - models: Sequence[Model], - solvers: Dict[str, nn.solver.Solver], - batch: TrainingBatch, - training_variables: TrainingVariables, - **kwargs) -> Dict[str, np.ndarray]: + def _update_model( + self, + models: Sequence[Model], + solvers: Dict[str, nn.solver.Solver], + batch: TrainingBatch, + training_variables: TrainingVariables, + **kwargs, + ) -> Dict[str, np.ndarray]: for t, b in zip(training_variables, batch): set_data_to_variable(t.s_current, b.s_current) set_data_to_variable(t.non_terminal, b.non_terminal) @@ -86,7 +90,7 @@ def _update_model(self, for solver in solvers.values(): solver.update() - trainer_state: Dict[str, Any] = {'pi_loss': 0} + trainer_state: Dict[str, Any] = {"pi_loss": 0} return trainer_state def _build_training_graph(self, models: Sequence[Model], training_variables: TrainingVariables): @@ -125,7 +129,7 @@ def _compute_q_grad_wrt_action(self, state: nn.Variable, action: nn.Variable, q_ def _invert_gradients(self, p: nn.Variable, grads: nn.Variable, p_min: nn.Variable, p_max: nn.Variable): increasing = NF.greater_equal_scalar(grads, val=0) decreasing = NF.less_scalar(grads, val=0) - p_range = (p_max - p_min) + p_range = p_max - p_min return grads * increasing * (p_max - p) / p_range + grads * decreasing * (p - p_min) / p_range @property diff --git a/nnabla_rl/model_trainers/policy/ppo_policy_trainer.py b/nnabla_rl/model_trainers/policy/ppo_policy_trainer.py index c4945c04..859ea803 100644 --- a/nnabla_rl/model_trainers/policy/ppo_policy_trainer.py +++ b/nnabla_rl/model_trainers/policy/ppo_policy_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,8 +20,13 @@ import nnabla as nn import nnabla.functions as NF from nnabla_rl.environments.environment_info import EnvironmentInfo -from nnabla_rl.model_trainers.model_trainer import (LossIntegration, ModelTrainer, TrainerConfig, TrainingBatch, - TrainingVariables) +from nnabla_rl.model_trainers.model_trainer import ( + LossIntegration, + ModelTrainer, + TrainerConfig, + TrainingBatch, + TrainingVariables, +) from nnabla_rl.models import Model, StochasticPolicy from nnabla_rl.utils.data import set_data_to_variable from nnabla_rl.utils.misc import create_variable @@ -35,31 +40,36 @@ class PPOPolicyTrainerConfig(TrainerConfig): class PPOPolicyTrainer(ModelTrainer): """Proximal Policy Optimization (PPO) style Policy Trainer.""" + # type declarations to type check with mypy # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar # See https://mypy.readthedocs.io/en/stable/class_basics.html for details _config: PPOPolicyTrainerConfig _pi_loss: nn.Variable - def __init__(self, - models: Union[StochasticPolicy, Sequence[StochasticPolicy]], - solvers: Dict[str, nn.solver.Solver], - env_info: EnvironmentInfo, - config: PPOPolicyTrainerConfig = PPOPolicyTrainerConfig()): + def __init__( + self, + models: Union[StochasticPolicy, Sequence[StochasticPolicy]], + solvers: Dict[str, nn.solver.Solver], + env_info: EnvironmentInfo, + config: PPOPolicyTrainerConfig = PPOPolicyTrainerConfig(), + ): super(PPOPolicyTrainer, self).__init__(models, solvers, env_info, config) - def _update_model(self, - models: Sequence[Model], - solvers: Dict[str, nn.solver.Solver], - batch: TrainingBatch, - training_variables: TrainingVariables, - **kwargs) -> Dict[str, np.ndarray]: + def _update_model( + self, + models: Sequence[Model], + solvers: Dict[str, nn.solver.Solver], + batch: TrainingBatch, + training_variables: TrainingVariables, + **kwargs, + ) -> Dict[str, np.ndarray]: for t, b in zip(training_variables, batch): set_data_to_variable(t.s_current, b.s_current) set_data_to_variable(t.a_current, b.a_current) - set_data_to_variable(t.extra['log_prob'], b.extra['log_prob']) - set_data_to_variable(t.extra['advantage'], b.extra['advantage']) - set_data_to_variable(t.extra['alpha'], b.extra['alpha']) + set_data_to_variable(t.extra["log_prob"], b.extra["log_prob"]) + set_data_to_variable(t.extra["advantage"], b.extra["advantage"]) + set_data_to_variable(t.extra["alpha"], b.extra["alpha"]) # update model for solver in solvers.values(): @@ -70,7 +80,7 @@ def _update_model(self, solver.update() trainer_state = {} - trainer_state['pi_loss'] = self._pi_loss.d.copy() + trainer_state["pi_loss"] = self._pi_loss.d.copy() return trainer_state def _build_training_graph(self, models: Sequence[Model], training_variables: TrainingVariables): @@ -83,21 +93,18 @@ def _build_training_graph(self, models: Sequence[Model], training_variables: Tra ignore_loss = is_burn_in_steps or (is_intermediate_steps and ignore_intermediate_loss) self._build_one_step_graph(models, variables, ignore_loss=ignore_loss) - def _build_one_step_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables, - ignore_loss: bool): + def _build_one_step_graph(self, models: Sequence[Model], training_variables: TrainingVariables, ignore_loss: bool): models = cast(Sequence[StochasticPolicy], models) for policy in models: distribution = policy.pi(training_variables.s_current) log_prob_new = distribution.log_prob(training_variables.a_current) - log_prob_old = training_variables.extra['log_prob'] + log_prob_old = training_variables.extra["log_prob"] probability_ratio = NF.exp(log_prob_new - log_prob_old) - alpha = training_variables.extra['alpha'] - clipped_ratio = NF.clip_by_value(probability_ratio, - 1 - self._config.epsilon * alpha, - 1 + self._config.epsilon * alpha) - advantage = training_variables.extra['advantage'] + alpha = training_variables.extra["alpha"] + clipped_ratio = NF.clip_by_value( + probability_ratio, 1 - self._config.epsilon * alpha, 1 + self._config.epsilon * alpha + ) + advantage = training_variables.extra["advantage"] lower_bounds = NF.minimum2(probability_ratio * advantage, clipped_ratio * advantage) clip_loss = NF.mean(lower_bounds) @@ -115,9 +122,9 @@ def _setup_training_variables(self, batch_size) -> TrainingVariables: alpha_var = create_variable(batch_size, 1) extra = {} - extra['log_prob'] = log_prob_var - extra['advantage'] = advantage_var - extra['alpha'] = alpha_var + extra["log_prob"] = log_prob_var + extra["advantage"] = advantage_var + extra["alpha"] = alpha_var return TrainingVariables(batch_size, s_current_var, a_current_var, extra=extra) @property diff --git a/nnabla_rl/model_trainers/policy/reinforce_policy_trainer.py b/nnabla_rl/model_trainers/policy/reinforce_policy_trainer.py index 5162ed1c..481f7e75 100644 --- a/nnabla_rl/model_trainers/policy/reinforce_policy_trainer.py +++ b/nnabla_rl/model_trainers/policy/reinforce_policy_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -34,36 +34,41 @@ class REINFORCEPolicyTrainerConfig(SPGPolicyTrainerConfig): class REINFORCEPolicyTrainer(SPGPolicyTrainer): """REINFORCE style Stochastic Policy Trainer.""" + # type declarations to type check with mypy # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar # See https://mypy.readthedocs.io/en/stable/class_basics.html for details _config: REINFORCEPolicyTrainerConfig - def __init__(self, - models: Union[StochasticPolicy, Sequence[StochasticPolicy]], - solvers: Dict[str, nn.solver.Solver], - env_info: EnvironmentInfo, - config: REINFORCEPolicyTrainerConfig = REINFORCEPolicyTrainerConfig()): + def __init__( + self, + models: Union[StochasticPolicy, Sequence[StochasticPolicy]], + solvers: Dict[str, nn.solver.Solver], + env_info: EnvironmentInfo, + config: REINFORCEPolicyTrainerConfig = REINFORCEPolicyTrainerConfig(), + ): super(REINFORCEPolicyTrainer, self).__init__(models, solvers, env_info, config) - def _update_model(self, - models: Sequence[Model], - solvers: Dict[str, nn.solver.Solver], - batch: TrainingBatch, - training_variables: TrainingVariables, - **kwargs) -> Dict[str, np.ndarray]: + def _update_model( + self, + models: Sequence[Model], + solvers: Dict[str, nn.solver.Solver], + batch: TrainingBatch, + training_variables: TrainingVariables, + **kwargs, + ) -> Dict[str, np.ndarray]: for t, b in zip(training_variables, batch): - set_data_to_variable(t.extra['target_return'], b.extra['target_return']) + set_data_to_variable(t.extra["target_return"], b.extra["target_return"]) return super()._update_model(models, solvers, batch, training_variables, **kwargs) def _compute_target(self, training_variables: TrainingVariables) -> nn.Variable: - return training_variables.extra['target_return'] + return training_variables.extra["target_return"] def _setup_training_variables(self, batch_size) -> TrainingVariables: training_variables = super()._setup_training_variables(batch_size) extra = {} - extra['target_return'] = create_variable(batch_size, 1) + extra["target_return"] = create_variable(batch_size, 1) training_variables.extra.update(extra) return training_variables diff --git a/nnabla_rl/model_trainers/policy/soft_policy_trainer.py b/nnabla_rl/model_trainers/policy/soft_policy_trainer.py index 54a72b51..34dbe574 100644 --- a/nnabla_rl/model_trainers/policy/soft_policy_trainer.py +++ b/nnabla_rl/model_trainers/policy/soft_policy_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,8 +21,14 @@ import nnabla.functions as NF import nnabla_rl.functions as RNF from nnabla_rl.environments.environment_info import EnvironmentInfo -from nnabla_rl.model_trainers.model_trainer import (LossIntegration, ModelTrainer, TrainerConfig, TrainingBatch, - TrainingVariables, rnn_support) +from nnabla_rl.model_trainers.model_trainer import ( + LossIntegration, + ModelTrainer, + TrainerConfig, + TrainingBatch, + TrainingVariables, + rnn_support, +) from nnabla_rl.models import Model, QFunction, StochasticPolicy from nnabla_rl.utils.data import convert_to_list_if_not_list, set_data_to_variable from nnabla_rl.utils.misc import create_variable, create_variables @@ -38,9 +44,9 @@ def __init__(self, scope_name, initial_value=None): initializer = np.reshape(initial_value, newshape=(1, 1)) with nn.parameter_scope(scope_name): - self._log_temperature = nn.parameter.get_parameter_or_create(name='log_temperature', - shape=(1, 1), - initializer=initializer) + self._log_temperature = nn.parameter.get_parameter_or_create( + name="log_temperature", shape=(1, 1), initializer=initializer + ) def __call__(self): return NF.exp(self._log_temperature) @@ -57,6 +63,7 @@ def __post_init__(self): class SoftPolicyTrainer(ModelTrainer): """Soft Policy Gradient style Policy Trainer.""" + # type declarations to type check with mypy # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar # See https://mypy.readthedocs.io/en/stable/class_basics.html for details @@ -69,19 +76,21 @@ class SoftPolicyTrainer(ModelTrainer): _prev_policy_rnn_states: Dict[str, Dict[str, nn.Variable]] _prev_q_rnn_states: Dict[str, Dict[str, Dict[str, nn.Variable]]] - def __init__(self, - models: Union[StochasticPolicy, Sequence[StochasticPolicy]], - solvers: Dict[str, nn.solver.Solver], - q_functions: Sequence[QFunction], - temperature: AdjustableTemperature, - temperature_solver: Optional[nn.solver.Solver], - env_info: EnvironmentInfo, - config: SoftPolicyTrainerConfig = SoftPolicyTrainerConfig()): + def __init__( + self, + models: Union[StochasticPolicy, Sequence[StochasticPolicy]], + solvers: Dict[str, nn.solver.Solver], + q_functions: Sequence[QFunction], + temperature: AdjustableTemperature, + temperature_solver: Optional[nn.solver.Solver], + env_info: EnvironmentInfo, + config: SoftPolicyTrainerConfig = SoftPolicyTrainerConfig(), + ): if len(q_functions) < 2: - raise ValueError('Must provide at least 2 Q-functions for soft-training') + raise ValueError("Must provide at least 2 Q-functions for soft-training") self._q_functions = q_functions if not config.fixed_temperature and temperature_solver is None: - raise ValueError('Please set solver for temperature model') + raise ValueError("Please set solver for temperature model") self._temperature = temperature self._temperature_solver = temperature_solver @@ -97,12 +106,14 @@ def __init__(self, def support_rnn(self) -> bool: return True - def _update_model(self, - models: Sequence[Model], - solvers: Dict[str, nn.solver.Solver], - batch: TrainingBatch, - training_variables: TrainingVariables, - **kwargs) -> Dict[str, np.ndarray]: + def _update_model( + self, + models: Sequence[Model], + solvers: Dict[str, nn.solver.Solver], + batch: TrainingBatch, + training_variables: TrainingVariables, + **kwargs, + ) -> Dict[str, np.ndarray]: for t, b in zip(training_variables, batch): set_data_to_variable(t.s_current, b.s_current) set_data_to_variable(t.non_terminal, b.non_terminal) @@ -149,16 +160,14 @@ def _update_model(self, self._temperature_solver.update() trainer_state = {} - trainer_state['pi_loss'] = self._pi_loss.d.copy() + trainer_state["pi_loss"] = self._pi_loss.d.copy() return trainer_state def get_temperature(self) -> nn.Variable: # Will return exponentiated log temperature. To keep temperature always positive return self._temperature() - def _build_training_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables): + def _build_training_graph(self, models: Sequence[Model], training_variables: TrainingVariables): self._pi_loss = 0 ignore_intermediate_loss = self._config.loss_integration is LossIntegration.LAST_TIMESTEP_ONLY for step_index, variables in enumerate(training_variables): @@ -167,10 +176,7 @@ def _build_training_graph(self, ignore_loss = is_burn_in_steps or (is_intermediate_steps and ignore_intermediate_loss) self._build_one_step_graph(models, variables, ignore_loss=ignore_loss) - def _build_one_step_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables, - ignore_loss: bool): + def _build_one_step_graph(self, models: Sequence[Model], training_variables: TrainingVariables, ignore_loss: bool): train_rnn_states = training_variables.rnn_states for policy in models: assert isinstance(policy, StochasticPolicy) @@ -191,8 +197,7 @@ def _build_one_step_graph(self, if not self._config.fixed_temperature: assert isinstance(log_pi, nn.Variable) log_pi_unlinked = log_pi.get_unlinked_variable() - self._temperature_loss = -NF.mean(self.get_temperature() * - (log_pi_unlinked + self._config.target_entropy)) + self._temperature_loss = -NF.mean(self.get_temperature() * (log_pi_unlinked + self._config.target_entropy)) def _setup_training_variables(self, batch_size): # Training input variables diff --git a/nnabla_rl/model_trainers/policy/spg_policy_trainer.py b/nnabla_rl/model_trainers/policy/spg_policy_trainer.py index fdbf5f34..f5e32225 100644 --- a/nnabla_rl/model_trainers/policy/spg_policy_trainer.py +++ b/nnabla_rl/model_trainers/policy/spg_policy_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,8 +20,13 @@ import nnabla as nn import nnabla.functions as NF from nnabla_rl.environments.environment_info import EnvironmentInfo -from nnabla_rl.model_trainers.model_trainer import (LossIntegration, ModelTrainer, TrainerConfig, TrainingBatch, - TrainingVariables) +from nnabla_rl.model_trainers.model_trainer import ( + LossIntegration, + ModelTrainer, + TrainerConfig, + TrainingBatch, + TrainingVariables, +) from nnabla_rl.models import Model, StochasticPolicy from nnabla_rl.utils.data import set_data_to_variable from nnabla_rl.utils.misc import create_variable @@ -36,25 +41,30 @@ class SPGPolicyTrainerConfig(TrainerConfig): class SPGPolicyTrainer(ModelTrainer): """Stochastic Policy Gradient (SPG) style Policy Trainer Stochastic Policy Gradient is widely known as 'Policy Gradient algorithm'.""" + # type declarations to type check with mypy # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar # See https://mypy.readthedocs.io/en/stable/class_basics.html for details _config: SPGPolicyTrainerConfig _pi_loss: nn.Variable - def __init__(self, - models: Union[StochasticPolicy, Sequence[StochasticPolicy]], - solvers: Dict[str, nn.solver.Solver], - env_info: EnvironmentInfo, - config: SPGPolicyTrainerConfig = SPGPolicyTrainerConfig()): + def __init__( + self, + models: Union[StochasticPolicy, Sequence[StochasticPolicy]], + solvers: Dict[str, nn.solver.Solver], + env_info: EnvironmentInfo, + config: SPGPolicyTrainerConfig = SPGPolicyTrainerConfig(), + ): super(SPGPolicyTrainer, self).__init__(models, solvers, env_info, config) - def _update_model(self, - models: Sequence[Model], - solvers: Dict[str, nn.solver.Solver], - batch: TrainingBatch, - training_variables: TrainingVariables, - **kwargs) -> Dict[str, np.ndarray]: + def _update_model( + self, + models: Sequence[Model], + solvers: Dict[str, nn.solver.Solver], + batch: TrainingBatch, + training_variables: TrainingVariables, + **kwargs, + ) -> Dict[str, np.ndarray]: for t, b in zip(training_variables, batch): set_data_to_variable(t.s_current, b.s_current) set_data_to_variable(t.a_current, b.a_current) @@ -71,7 +81,7 @@ def _update_model(self, solver.update() trainer_state = {} - trainer_state['pi_loss'] = self._pi_loss.d.copy() + trainer_state["pi_loss"] = self._pi_loss.d.copy() return trainer_state def _build_training_graph(self, models: Sequence[Model], training_variables: TrainingVariables): @@ -84,10 +94,7 @@ def _build_training_graph(self, models: Sequence[Model], training_variables: Tra ignore_loss = is_burn_in_steps or (is_intermediate_steps and ignore_intermediate_loss) self._build_one_step_graph(models, variables, ignore_loss=ignore_loss) - def _build_one_step_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables, - ignore_loss: bool): + def _build_one_step_graph(self, models: Sequence[Model], training_variables: TrainingVariables, ignore_loss: bool): models = cast(Sequence[StochasticPolicy], models) # Actor optimization graph target_value = self._compute_target(training_variables) @@ -96,10 +103,9 @@ def _build_one_step_graph(self, for policy in models: self._pi_loss += 0.0 if ignore_loss else self._compute_loss(policy, target_value, training_variables) - def _compute_loss(self, - model: StochasticPolicy, - target_value: nn.Variable, - training_variables: TrainingVariables) -> nn.Variable: + def _compute_loss( + self, model: StochasticPolicy, target_value: nn.Variable, training_variables: TrainingVariables + ) -> nn.Variable: distribution = model.pi(training_variables.s_current) log_prob = distribution.log_prob(training_variables.a_current) return NF.sum(-log_prob * target_value) * self._config.pi_loss_scalar @@ -111,9 +117,7 @@ def _setup_training_variables(self, batch_size) -> TrainingVariables: # Training input variables s_current_var = create_variable(batch_size, self._env_info.state_shape) a_current_var = create_variable(batch_size, self._env_info.action_shape) - return TrainingVariables(batch_size, - s_current_var, - a_current_var) + return TrainingVariables(batch_size, s_current_var, a_current_var) @property def loss_variables(self) -> Dict[str, nn.Variable]: diff --git a/nnabla_rl/model_trainers/policy/trpo_policy_trainer.py b/nnabla_rl/model_trainers/policy/trpo_policy_trainer.py index a149c8ea..86777c48 100644 --- a/nnabla_rl/model_trainers/policy/trpo_policy_trainer.py +++ b/nnabla_rl/model_trainers/policy/trpo_policy_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -81,8 +81,7 @@ def _update_network_params_by_flat_params(params, new_flat_params): for param in params.values(): param_shape = param.shape param_numbers = len(param.d.flatten()) - new_param = new_flat_params[total_param_numbers:total_param_numbers + - param_numbers].reshape(param_shape) + new_param = new_flat_params[total_param_numbers : total_param_numbers + param_numbers].reshape(param_shape) param.d = new_param total_param_numbers += param_numbers assert total_param_numbers == len(new_flat_params) @@ -98,10 +97,10 @@ class TRPOPolicyTrainerConfig(TrainerConfig): backtrack_coefficient: float = 0.5 def __post_init__(self): - self._assert_positive(self.sigma_kl_divergence_constraint, 'sigma_kl_divergence_constraint') - self._assert_positive(self.maximum_backtrack_numbers, 'maximum_backtrack_numbers') - self._assert_positive(self.conjugate_gradient_damping, 'conjugate_gradient_damping') - self._assert_positive(self.conjugate_gradient_iterations, 'conjugate_gradient_iterations') + self._assert_positive(self.sigma_kl_divergence_constraint, "sigma_kl_divergence_constraint") + self._assert_positive(self.maximum_backtrack_numbers, "maximum_backtrack_numbers") + self._assert_positive(self.conjugate_gradient_damping, "conjugate_gradient_damping") + self._assert_positive(self.conjugate_gradient_iterations, "conjugate_gradient_iterations") class TRPOPolicyTrainer(ModelTrainer): @@ -115,31 +114,33 @@ class TRPOPolicyTrainer(ModelTrainer): _kl_divergence_flat_grads: nn.Variable _old_policy: StochasticPolicy - def __init__(self, - model: StochasticPolicy, - env_info: EnvironmentInfo, - config: TRPOPolicyTrainerConfig = TRPOPolicyTrainerConfig()): + def __init__( + self, + model: StochasticPolicy, + env_info: EnvironmentInfo, + config: TRPOPolicyTrainerConfig = TRPOPolicyTrainerConfig(), + ): super(TRPOPolicyTrainer, self).__init__(model, {}, env_info, config) - def _update_model(self, - models: Sequence[Model], - solvers: Dict[str, nn.solver.Solver], - batch: TrainingBatch, - training_variables: TrainingVariables, - **kwargs) -> Dict[str, np.ndarray]: + def _update_model( + self, + models: Sequence[Model], + solvers: Dict[str, nn.solver.Solver], + batch: TrainingBatch, + training_variables: TrainingVariables, + **kwargs, + ) -> Dict[str, np.ndarray]: s = batch.s_current a = batch.a_current - advantage = batch.extra['advantage'] + advantage = batch.extra["advantage"] policy = models[0] old_policy = self._old_policy - full_step_params_update = self._compute_full_step_params_update( - policy, s, a, advantage, training_variables) + full_step_params_update = self._compute_full_step_params_update(policy, s, a, advantage, training_variables) - self._linesearch_and_update_params(policy, s, a, advantage, - full_step_params_update, training_variables) + self._linesearch_and_update_params(policy, s, a, advantage, full_step_params_update, training_variables) copy_network_parameters(policy.get_parameters(), old_policy.get_parameters(), tau=1.0) @@ -158,11 +159,11 @@ def _gpu_batch_size(self, batch_size): def _build_training_graph(self, models: Sequence[Model], training_variables: TrainingVariables): if len(models) != 1: - raise RuntimeError('TRPO training only support 1 model for training') + raise RuntimeError("TRPO training only support 1 model for training") models = cast(Sequence[StochasticPolicy], models) policy = models[0] - if not hasattr(self, '_old_policy'): - self._old_policy = policy.deepcopy('old_policy') + if not hasattr(self, "_old_policy"): + self._old_policy = policy.deepcopy("old_policy") old_policy = self._old_policy # policy learning @@ -180,31 +181,35 @@ def _build_training_graph(self, models: Sequence[Model], training_variables: Tra old_log_prob = old_distribution.log_prob(training_variables.a_current) prob_ratio = NF.exp(log_prob - old_log_prob) - advantage = training_variables.extra['advantage'] - self._approximate_return = NF.mean(prob_ratio*advantage) + advantage = training_variables.extra["advantage"] + self._approximate_return = NF.mean(prob_ratio * advantage) _approximate_return_grads = nn.grad([self._approximate_return], policy.get_parameters().values()) self._approximate_return_flat_grads = NF.concatenate( - *[grad.reshape((-1,)) for grad in _approximate_return_grads]) + *[grad.reshape((-1,)) for grad in _approximate_return_grads] + ) self._approximate_return_flat_grads.need_grad = True copy_network_parameters(policy.get_parameters(), old_policy.get_parameters(), tau=1.0) def _compute_full_step_params_update(self, policy, s_batch, a_batch, adv_batch, training_variables): _, _, approximate_return_flat_grads = self._forward_all_variables( - s_batch, a_batch, adv_batch, training_variables) + s_batch, a_batch, adv_batch, training_variables + ) def fisher_vector_product_wrapper(step_direction): - return self._fisher_vector_product(policy, s_batch, a_batch, - step_direction, training_variables) + return self._fisher_vector_product(policy, s_batch, a_batch, step_direction, training_variables) step_direction = conjugate_gradient( - fisher_vector_product_wrapper, approximate_return_flat_grads, - max_iterations=self._config.conjugate_gradient_iterations) + fisher_vector_product_wrapper, + approximate_return_flat_grads, + max_iterations=self._config.conjugate_gradient_iterations, + ) fisher_vector_product = self._fisher_vector_product( - policy, s_batch, a_batch, step_direction, training_variables) + policy, s_batch, a_batch, step_direction, training_variables + ) sAs = float(np.dot(step_direction, fisher_vector_product)) # adding 1e-8 to avoid computational error @@ -220,34 +225,35 @@ def _fisher_vector_product(self, policy, s_batch, a_batch, vector, training_vari for block_index in range(total_blocks): start_idx = block_index * gpu_batch_size - set_data_to_variable(training_variables.s_current, s_batch[start_idx:start_idx+gpu_batch_size]) - set_data_to_variable(training_variables.a_current, a_batch[start_idx:start_idx+gpu_batch_size]) + set_data_to_variable(training_variables.s_current, s_batch[start_idx : start_idx + gpu_batch_size]) + set_data_to_variable(training_variables.a_current, a_batch[start_idx : start_idx + gpu_batch_size]) for param in policy.get_parameters().values(): param.grad.zero() self._kl_divergence_flat_grads.forward() - hessian_vector_product = _hessian_vector_product(self._kl_divergence_flat_grads, - policy.get_parameters().values(), - vector) + hessian_vector_product = _hessian_vector_product( + self._kl_divergence_flat_grads, policy.get_parameters().values(), vector + ) hessian_multiplied_vector = hessian_vector_product + self._config.conjugate_gradient_damping * vector sum_hessian_multiplied_vector += hessian_multiplied_vector return sum_hessian_multiplied_vector / total_blocks def _linesearch_and_update_params( - self, policy, s_batch, a_batch, adv_batch, full_step_params_update, training_variables): + self, policy, s_batch, a_batch, adv_batch, full_step_params_update, training_variables + ): current_flat_params = _concat_network_params_in_ndarray(policy.get_parameters()) - current_approximate_return, _, _ = self._forward_all_variables( - s_batch, a_batch, adv_batch, training_variables) + current_approximate_return, _, _ = self._forward_all_variables(s_batch, a_batch, adv_batch, training_variables) - for step_size in self._config.backtrack_coefficient**np.arange(self._config.maximum_backtrack_numbers): + for step_size in self._config.backtrack_coefficient ** np.arange(self._config.maximum_backtrack_numbers): new_flat_params = current_flat_params + step_size * full_step_params_update _update_network_params_by_flat_params(policy.get_parameters(), new_flat_params) approximate_return, kl_divergence, _ = self._forward_all_variables( - s_batch, a_batch, adv_batch, training_variables) + s_batch, a_batch, adv_batch, training_variables + ) - improved = approximate_return - current_approximate_return > 0. + improved = approximate_return - current_approximate_return > 0.0 is_in_kl_divergence_constraint = kl_divergence < self._config.sigma_kl_divergence_constraint if improved and is_in_kl_divergence_constraint: @@ -271,13 +277,11 @@ def _forward_all_variables(self, s_batch, a_batch, adv_batch, training_variables for block_index in range(total_blocks): start_idx = block_index * gpu_batch_size - set_data_to_variable(training_variables.s_current, s_batch[start_idx:start_idx+gpu_batch_size]) - set_data_to_variable(training_variables.a_current, a_batch[start_idx:start_idx+gpu_batch_size]) - training_variables.extra['advantage'].d = adv_batch[start_idx:start_idx+gpu_batch_size] + set_data_to_variable(training_variables.s_current, s_batch[start_idx : start_idx + gpu_batch_size]) + set_data_to_variable(training_variables.a_current, a_batch[start_idx : start_idx + gpu_batch_size]) + training_variables.extra["advantage"].d = adv_batch[start_idx : start_idx + gpu_batch_size] - nn.forward_all([self._approximate_return, - self._kl_divergence, - self._approximate_return_flat_grads]) + nn.forward_all([self._approximate_return, self._kl_divergence, self._approximate_return_flat_grads]) sum_approximate_return += float(self._approximate_return.d) sum_kl_divergence += float(self._kl_divergence.d) @@ -294,7 +298,7 @@ def _setup_training_variables(self, batch_size: int) -> TrainingVariables: a_current_var = create_variable(gpu_batch_size, self._env_info.action_shape) advantage_var = create_variable(gpu_batch_size, 1) extra = {} - extra['advantage'] = advantage_var + extra["advantage"] = advantage_var return TrainingVariables(gpu_batch_size, s_current_var, a_current_var, extra=extra) @property diff --git a/nnabla_rl/model_trainers/policy/xql_forward_policy_trainer.py b/nnabla_rl/model_trainers/policy/xql_forward_policy_trainer.py index 28bb666c..58c8eea8 100644 --- a/nnabla_rl/model_trainers/policy/xql_forward_policy_trainer.py +++ b/nnabla_rl/model_trainers/policy/xql_forward_policy_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,8 +21,14 @@ import nnabla.functions as NF import nnabla_rl.functions as RNF from nnabla_rl.environments.environment_info import EnvironmentInfo -from nnabla_rl.model_trainers.model_trainer import (LossIntegration, ModelTrainer, TrainerConfig, TrainingBatch, - TrainingVariables, rnn_support) +from nnabla_rl.model_trainers.model_trainer import ( + LossIntegration, + ModelTrainer, + TrainerConfig, + TrainingBatch, + TrainingVariables, + rnn_support, +) from nnabla_rl.models import Model, QFunction, StochasticPolicy, VFunction from nnabla_rl.utils.data import convert_to_list_if_not_list, set_data_to_variable from nnabla_rl.utils.misc import create_variable, create_variables @@ -39,6 +45,7 @@ def __post_init__(self): class XQLForwardPolicyTrainer(ModelTrainer): """EXtreme Q-learning style (w/ forward KL-divergence) Policy Trainer.""" + # type declarations to type check with mypy # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar # See https://mypy.readthedocs.io/en/stable/class_basics.html for details @@ -50,15 +57,17 @@ class XQLForwardPolicyTrainer(ModelTrainer): _prev_q_rnn_states: Dict[str, Dict[str, Dict[str, nn.Variable]]] _prev_v_rnn_states: Dict[str, Dict[str, Dict[str, nn.Variable]]] - def __init__(self, - models: Union[StochasticPolicy, Sequence[StochasticPolicy]], - solvers: Dict[str, nn.solver.Solver], - q_functions: Sequence[QFunction], - v_function: VFunction, - env_info: EnvironmentInfo, - config: XQLForwardPolicyTrainerConfig = XQLForwardPolicyTrainerConfig()): + def __init__( + self, + models: Union[StochasticPolicy, Sequence[StochasticPolicy]], + solvers: Dict[str, nn.solver.Solver], + q_functions: Sequence[QFunction], + v_function: VFunction, + env_info: EnvironmentInfo, + config: XQLForwardPolicyTrainerConfig = XQLForwardPolicyTrainerConfig(), + ): if len(q_functions) < 2: - raise ValueError('Must provide at least 2 Q-functions for training') + raise ValueError("Must provide at least 2 Q-functions for training") self._q_functions = q_functions self._v_function = v_function @@ -73,12 +82,14 @@ def __init__(self, def support_rnn(self) -> bool: return True - def _update_model(self, - models: Sequence[Model], - solvers: Dict[str, nn.solver.Solver], - batch: TrainingBatch, - training_variables: TrainingVariables, - **kwargs) -> Dict[str, np.ndarray]: + def _update_model( + self, + models: Sequence[Model], + solvers: Dict[str, nn.solver.Solver], + batch: TrainingBatch, + training_variables: TrainingVariables, + **kwargs, + ) -> Dict[str, np.ndarray]: for t, b in zip(training_variables, batch): set_data_to_variable(t.s_current, b.s_current) set_data_to_variable(t.a_current, b.a_current) @@ -122,12 +133,10 @@ def _update_model(self, for solver in solvers.values(): solver.update() trainer_state = {} - trainer_state['pi_loss'] = self._pi_loss.d.copy() + trainer_state["pi_loss"] = self._pi_loss.d.copy() return trainer_state - def _build_training_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables): + def _build_training_graph(self, models: Sequence[Model], training_variables: TrainingVariables): self._pi_loss = 0 ignore_intermediate_loss = self._config.loss_integration is LossIntegration.LAST_TIMESTEP_ONLY for step_index, variables in enumerate(training_variables): @@ -136,10 +145,7 @@ def _build_training_graph(self, ignore_loss = is_burn_in_steps or (is_intermediate_steps and ignore_intermediate_loss) self._pi_loss += self._build_one_step_graph(models, variables, ignore_loss=ignore_loss) - def _build_one_step_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables, - ignore_loss: bool): + def _build_one_step_graph(self, models: Sequence[Model], training_variables: TrainingVariables, ignore_loss: bool): train_rnn_states = training_variables.rnn_states for policy in models: assert isinstance(policy, StochasticPolicy) @@ -192,11 +198,9 @@ def _setup_training_variables(self, batch_size): rnn_state_variables = create_variables(batch_size, shapes) rnn_states[self._v_function.scope_name] = rnn_state_variables - return TrainingVariables(batch_size, - s_current_var, - a_current_var, - non_terminal=non_terminal_var, - rnn_states=rnn_states) + return TrainingVariables( + batch_size, s_current_var, a_current_var, non_terminal=non_terminal_var, rnn_states=rnn_states + ) @property def loss_variables(self) -> Dict[str, nn.Variable]: diff --git a/nnabla_rl/model_trainers/policy/xql_reverse_policy_trainer.py b/nnabla_rl/model_trainers/policy/xql_reverse_policy_trainer.py index f13184b4..316bb03c 100644 --- a/nnabla_rl/model_trainers/policy/xql_reverse_policy_trainer.py +++ b/nnabla_rl/model_trainers/policy/xql_reverse_policy_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,8 +20,14 @@ import nnabla as nn import nnabla_rl.functions as RNF from nnabla_rl.environments.environment_info import EnvironmentInfo -from nnabla_rl.model_trainers.model_trainer import (LossIntegration, ModelTrainer, TrainerConfig, TrainingBatch, - TrainingVariables, rnn_support) +from nnabla_rl.model_trainers.model_trainer import ( + LossIntegration, + ModelTrainer, + TrainerConfig, + TrainingBatch, + TrainingVariables, + rnn_support, +) from nnabla_rl.models import Model, QFunction, StochasticPolicy from nnabla_rl.utils.data import convert_to_list_if_not_list, set_data_to_variable from nnabla_rl.utils.misc import create_variable, create_variables @@ -37,6 +43,7 @@ def __post_init__(self): class XQLReversePolicyTrainer(ModelTrainer): """EXtreme Q-learning style (w/ reverse KL-divergence) Policy Trainer.""" + # type declarations to type check with mypy # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar # See https://mypy.readthedocs.io/en/stable/class_basics.html for details @@ -47,15 +54,17 @@ class XQLReversePolicyTrainer(ModelTrainer): _prev_behavior_policy_rnn_states: Dict[str, Dict[str, nn.Variable]] _prev_q_rnn_states: Dict[str, Dict[str, Dict[str, nn.Variable]]] - def __init__(self, - models: Union[StochasticPolicy, Sequence[StochasticPolicy]], - solvers: Dict[str, nn.solver.Solver], - behavior_policy: StochasticPolicy, - q_functions: Sequence[QFunction], - env_info: EnvironmentInfo, - config: XQLReversePolicyTrainerConfig = XQLReversePolicyTrainerConfig()): + def __init__( + self, + models: Union[StochasticPolicy, Sequence[StochasticPolicy]], + solvers: Dict[str, nn.solver.Solver], + behavior_policy: StochasticPolicy, + q_functions: Sequence[QFunction], + env_info: EnvironmentInfo, + config: XQLReversePolicyTrainerConfig = XQLReversePolicyTrainerConfig(), + ): if len(q_functions) < 2: - raise ValueError('Must provide at least 2 Q-functions for training') + raise ValueError("Must provide at least 2 Q-functions for training") self._behavior_policy = behavior_policy self._q_functions = q_functions @@ -70,12 +79,14 @@ def __init__(self, def support_rnn(self) -> bool: return True - def _update_model(self, - models: Sequence[Model], - solvers: Dict[str, nn.solver.Solver], - batch: TrainingBatch, - training_variables: TrainingVariables, - **kwargs) -> Dict[str, np.ndarray]: + def _update_model( + self, + models: Sequence[Model], + solvers: Dict[str, nn.solver.Solver], + batch: TrainingBatch, + training_variables: TrainingVariables, + **kwargs, + ) -> Dict[str, np.ndarray]: for t, b in zip(training_variables, batch): set_data_to_variable(t.s_current, b.s_current) set_data_to_variable(t.a_current, b.a_current) @@ -119,12 +130,10 @@ def _update_model(self, for solver in solvers.values(): solver.update() trainer_state = {} - trainer_state['pi_loss'] = self._pi_loss.d.copy() + trainer_state["pi_loss"] = self._pi_loss.d.copy() return trainer_state - def _build_training_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables): + def _build_training_graph(self, models: Sequence[Model], training_variables: TrainingVariables): self._pi_loss = 0 ignore_intermediate_loss = self._config.loss_integration is LossIntegration.LAST_TIMESTEP_ONLY for step_index, variables in enumerate(training_variables): @@ -133,10 +142,7 @@ def _build_training_graph(self, ignore_loss = is_burn_in_steps or (is_intermediate_steps and ignore_intermediate_loss) self._pi_loss += self._build_one_step_graph(models, variables, ignore_loss=ignore_loss) - def _build_one_step_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables, - ignore_loss: bool): + def _build_one_step_graph(self, models: Sequence[Model], training_variables: TrainingVariables, ignore_loss: bool): train_rnn_states = training_variables.rnn_states for policy in models: assert isinstance(policy, StochasticPolicy) @@ -147,11 +153,9 @@ def _build_one_step_graph(self, action, log_pi = policy_distribution.sample_and_compute_log_prob() prev_rnn_states = self._prev_behavior_policy_rnn_states[policy.scope_name] - with rnn_support(self._behavior_policy, - prev_rnn_states, - train_rnn_states, - training_variables, - self._config): + with rnn_support( + self._behavior_policy, prev_rnn_states, train_rnn_states, training_variables, self._config + ): behavior_distribution = self._behavior_policy.pi(training_variables.s_current) log_mu = behavior_distribution.log_prob(action) @@ -186,10 +190,7 @@ def _setup_training_variables(self, batch_size): rnn_state_variables = create_variables(batch_size, shapes) rnn_states[self._behavior_policy.scope_name] = rnn_state_variables - return TrainingVariables(batch_size, - s_current_var, - non_terminal=non_terminal_var, - rnn_states=rnn_states) + return TrainingVariables(batch_size, s_current_var, non_terminal=non_terminal_var, rnn_states=rnn_states) @property def loss_variables(self) -> Dict[str, nn.Variable]: diff --git a/nnabla_rl/model_trainers/q_value/__init__.py b/nnabla_rl/model_trainers/q_value/__init__.py index a7a12c71..46fea33a 100644 --- a/nnabla_rl/model_trainers/q_value/__init__.py +++ b/nnabla_rl/model_trainers/q_value/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,39 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nnabla_rl.model_trainers.q_value.bcq_q_trainer import ( # noqa - BCQQTrainer, BCQQTrainerConfig) +from nnabla_rl.model_trainers.q_value.bcq_q_trainer import BCQQTrainer, BCQQTrainerConfig # noqa from nnabla_rl.model_trainers.q_value.categorical_dqn_q_trainer import ( # noqa - CategoricalDQNQTrainer, CategoricalDQNQTrainerConfig) + CategoricalDQNQTrainer, + CategoricalDQNQTrainerConfig, +) from nnabla_rl.model_trainers.q_value.categorical_ddqn_q_trainer import ( # noqa - CategoricalDDQNQTrainer, CategoricalDDQNQTrainerConfig) + CategoricalDDQNQTrainer, + CategoricalDDQNQTrainerConfig, +) from nnabla_rl.model_trainers.q_value.clipped_double_q_trainer import ( # noqa - ClippedDoubleQTrainer, ClippedDoubleQTrainerConfig) -from nnabla_rl.model_trainers.q_value.ddpg_q_trainer import ( # noqa - DDPGQTrainer, DDPGQTrainerConfig) -from nnabla_rl.model_trainers.q_value.ddqn_q_trainer import ( # noqa - DDQNQTrainer, DDQNQTrainerConfig) -from nnabla_rl.model_trainers.q_value.dqn_q_trainer import ( # noqa - DQNQTrainer, DQNQTrainerConfig) -from nnabla_rl.model_trainers.q_value.her_q_trainer import ( # noqa - HERQTrainer, HERQTrainerConfig) -from nnabla_rl.model_trainers.q_value.hyar_q_trainer import ( # noqa - HyARQTrainer, HyARQTrainerConfig) -from nnabla_rl.model_trainers.q_value.iqn_q_trainer import ( # noqa - IQNQTrainer, IQNQTrainerConfig) + ClippedDoubleQTrainer, + ClippedDoubleQTrainerConfig, +) +from nnabla_rl.model_trainers.q_value.ddpg_q_trainer import DDPGQTrainer, DDPGQTrainerConfig # noqa +from nnabla_rl.model_trainers.q_value.ddqn_q_trainer import DDQNQTrainer, DDQNQTrainerConfig # noqa +from nnabla_rl.model_trainers.q_value.dqn_q_trainer import DQNQTrainer, DQNQTrainerConfig # noqa +from nnabla_rl.model_trainers.q_value.her_q_trainer import HERQTrainer, HERQTrainerConfig # noqa +from nnabla_rl.model_trainers.q_value.hyar_q_trainer import HyARQTrainer, HyARQTrainerConfig # noqa +from nnabla_rl.model_trainers.q_value.iqn_q_trainer import IQNQTrainer, IQNQTrainerConfig # noqa from nnabla_rl.model_trainers.q_value.munchausen_rl_q_trainer import ( # noqa - MunchausenIQNQTrainer, MunchausenIQNQTrainerConfig, MunchausenDQNQTrainer, MunchausenDQNQTrainerConfig) -from nnabla_rl.model_trainers.q_value.qrdqn_q_trainer import ( # noqa - QRDQNQTrainer, QRDQNQTrainerConfig) -from nnabla_rl.model_trainers.q_value.qrsac_q_trainer import ( # noqa - QRSACQTrainer, QRSACQTrainerConfig) -from nnabla_rl.model_trainers.q_value.redq_q_trainer import ( # noqa - REDQQTrainer, REDQQTrainerConfig) -from nnabla_rl.model_trainers.q_value.soft_q_trainer import ( # noqa - SoftQTrainer, SoftQTrainerConfig) -from nnabla_rl.model_trainers.q_value.soft_q_decomposition_trainer import ( # noqa - SoftQDTrainer, SoftQDTrainerConfig) -from nnabla_rl.model_trainers.q_value.td3_q_trainer import ( # noqa - TD3QTrainer, TD3QTrainerConfig) -from nnabla_rl.model_trainers.q_value.v_targeted_q_trainer import ( # noqa - VTargetedQTrainer, VTargetedQTrainerConfig) + MunchausenIQNQTrainer, + MunchausenIQNQTrainerConfig, + MunchausenDQNQTrainer, + MunchausenDQNQTrainerConfig, +) +from nnabla_rl.model_trainers.q_value.qrdqn_q_trainer import QRDQNQTrainer, QRDQNQTrainerConfig # noqa +from nnabla_rl.model_trainers.q_value.qrsac_q_trainer import QRSACQTrainer, QRSACQTrainerConfig # noqa +from nnabla_rl.model_trainers.q_value.redq_q_trainer import REDQQTrainer, REDQQTrainerConfig # noqa +from nnabla_rl.model_trainers.q_value.soft_q_trainer import SoftQTrainer, SoftQTrainerConfig # noqa +from nnabla_rl.model_trainers.q_value.soft_q_decomposition_trainer import SoftQDTrainer, SoftQDTrainerConfig # noqa +from nnabla_rl.model_trainers.q_value.td3_q_trainer import TD3QTrainer, TD3QTrainerConfig # noqa +from nnabla_rl.model_trainers.q_value.v_targeted_q_trainer import VTargetedQTrainer, VTargetedQTrainerConfig # noqa diff --git a/nnabla_rl/model_trainers/q_value/bcq_q_trainer.py b/nnabla_rl/model_trainers/q_value/bcq_q_trainer.py index d81b532d..b6fd417c 100644 --- a/nnabla_rl/model_trainers/q_value/bcq_q_trainer.py +++ b/nnabla_rl/model_trainers/q_value/bcq_q_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,8 +20,10 @@ import nnabla_rl.functions as RNF from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.model_trainers.model_trainer import TrainingVariables, rnn_support -from nnabla_rl.model_trainers.q_value.squared_td_q_function_trainer import (SquaredTDQFunctionTrainer, - SquaredTDQFunctionTrainerConfig) +from nnabla_rl.model_trainers.q_value.squared_td_q_function_trainer import ( + SquaredTDQFunctionTrainer, + SquaredTDQFunctionTrainerConfig, +) from nnabla_rl.models import DeterministicPolicy, QFunction from nnabla_rl.utils.data import convert_to_list_if_not_list from nnabla_rl.utils.misc import create_variables @@ -42,13 +44,15 @@ class BCQQTrainer(SquaredTDQFunctionTrainer): _prev_target_rnn_states: Dict[str, Dict[str, nn.Variable]] _prev_q_rnn_states: Dict[str, Dict[str, nn.Variable]] - def __init__(self, - train_functions: Union[QFunction, Sequence[QFunction]], - solvers: Dict[str, nn.solver.Solver], - target_functions: Union[QFunction, Sequence[QFunction]], - target_policy: DeterministicPolicy, - env_info: EnvironmentInfo, - config: BCQQTrainerConfig = BCQQTrainerConfig()): + def __init__( + self, + train_functions: Union[QFunction, Sequence[QFunction]], + solvers: Dict[str, nn.solver.Solver], + target_functions: Union[QFunction, Sequence[QFunction]], + target_policy: DeterministicPolicy, + env_info: EnvironmentInfo, + config: BCQQTrainerConfig = BCQQTrainerConfig(), + ): self._target_functions = convert_to_list_if_not_list(target_functions) self._assert_no_duplicate_model(self._target_functions) self._target_policy = target_policy @@ -84,8 +88,9 @@ def _compute_target(self, training_variables: TrainingVariables) -> nn.Variable: num_q_ensembles = len(self._target_functions) assert isinstance(q_values, nn.Variable) assert q_values.shape == (num_q_ensembles, batch_size * self._config.num_action_samples, 1) - weighted_q_minmax = self._config.lmb * NF.min(q_values, axis=0) + \ - (1.0 - self._config.lmb) * NF.max(q_values, axis=0) + weighted_q_minmax = self._config.lmb * NF.min(q_values, axis=0) + (1.0 - self._config.lmb) * NF.max( + q_values, axis=0 + ) assert weighted_q_minmax.shape == (batch_size * self._config.num_action_samples, 1) next_q_value = NF.max(NF.reshape(weighted_q_minmax, shape=(batch_size, -1)), axis=1, keepdims=True) diff --git a/nnabla_rl/model_trainers/q_value/categorical_ddqn_q_trainer.py b/nnabla_rl/model_trainers/q_value/categorical_ddqn_q_trainer.py index 7cd141de..116d98e9 100644 --- a/nnabla_rl/model_trainers/q_value/categorical_ddqn_q_trainer.py +++ b/nnabla_rl/model_trainers/q_value/categorical_ddqn_q_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,8 +21,10 @@ import nnabla.functions as NF from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.model_trainers.model_trainer import TrainingVariables, rnn_support -from nnabla_rl.model_trainers.q_value.categorical_dqn_q_trainer import (CategoricalDQNQTrainer, - CategoricalDQNQTrainerConfig) +from nnabla_rl.model_trainers.q_value.categorical_dqn_q_trainer import ( + CategoricalDQNQTrainer, + CategoricalDQNQTrainerConfig, +) from nnabla_rl.models import ValueDistributionFunction from nnabla_rl.utils.misc import create_variables @@ -40,12 +42,14 @@ class CategoricalDDQNQTrainer(CategoricalDQNQTrainer): _prev_train_rnn_states: Dict[str, Dict[str, nn.Variable]] _prev_target_rnn_states: Dict[str, Dict[str, nn.Variable]] - def __init__(self, - train_function: ValueDistributionFunction, - solvers: Dict[str, nn.solver.Solver], - target_function: ValueDistributionFunction, - env_info: EnvironmentInfo, - config: CategoricalDDQNQTrainerConfig = CategoricalDDQNQTrainerConfig()): + def __init__( + self, + train_function: ValueDistributionFunction, + solvers: Dict[str, nn.solver.Solver], + target_function: ValueDistributionFunction, + env_info: EnvironmentInfo, + config: CategoricalDDQNQTrainerConfig = CategoricalDDQNQTrainerConfig(), + ): self._train_function = train_function self._target_function = target_function self._prev_train_rnn_states = {} @@ -68,7 +72,7 @@ def _compute_target(self, training_variables: TrainingVariables, **kwargs) -> nn prev_rnn_states = self._prev_train_rnn_states train_rnn_states = training_variables.rnn_states - with rnn_support(self._train_function, prev_rnn_states, train_rnn_states, training_variables, self._config): + with rnn_support(self._train_function, prev_rnn_states, train_rnn_states, training_variables, self._config): a_next = self._train_function.as_q_function().argmax_q(s_next) prev_rnn_states = self._prev_target_rnn_states diff --git a/nnabla_rl/model_trainers/q_value/categorical_dqn_q_trainer.py b/nnabla_rl/model_trainers/q_value/categorical_dqn_q_trainer.py index 86dcd13d..e85a989d 100644 --- a/nnabla_rl/model_trainers/q_value/categorical_dqn_q_trainer.py +++ b/nnabla_rl/model_trainers/q_value/categorical_dqn_q_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,7 +22,9 @@ from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.model_trainers.model_trainer import TrainingVariables, rnn_support from nnabla_rl.model_trainers.q_value.value_distribution_function_trainer import ( - ValueDistributionFunctionTrainer, ValueDistributionFunctionTrainerConfig) + ValueDistributionFunctionTrainer, + ValueDistributionFunctionTrainerConfig, +) from nnabla_rl.models import ValueDistributionFunction from nnabla_rl.utils.misc import create_variables @@ -39,12 +41,14 @@ class CategoricalDQNQTrainer(ValueDistributionFunctionTrainer): _target_function: ValueDistributionFunction _prev_target_rnn_states: Dict[str, Dict[str, nn.Variable]] - def __init__(self, - train_functions: Union[ValueDistributionFunction, Sequence[ValueDistributionFunction]], - solvers: Dict[str, nn.solver.Solver], - target_function: ValueDistributionFunction, - env_info: EnvironmentInfo, - config: CategoricalDQNQTrainerConfig = CategoricalDQNQTrainerConfig()): + def __init__( + self, + train_functions: Union[ValueDistributionFunction, Sequence[ValueDistributionFunction]], + solvers: Dict[str, nn.solver.Solver], + target_function: ValueDistributionFunction, + env_info: EnvironmentInfo, + config: CategoricalDQNQTrainerConfig = CategoricalDQNQTrainerConfig(), + ): self._target_function = target_function self._prev_target_rnn_states = {} super(CategoricalDQNQTrainer, self).__init__(train_functions, solvers, env_info, config) @@ -100,7 +104,7 @@ def _compute_projection(self, Tz, pj, N, v_max, v_min): result_upper = NF.scatter_add(mi, ml_indices, pj * (upper - bj), axis=-1) result_lower = NF.scatter_add(mi, mu_indices, pj * (bj - lower), axis=-1) - return (result_upper + result_lower) + return result_upper + result_lower def _setup_training_variables(self, batch_size: int) -> TrainingVariables: training_variables = super()._setup_training_variables(batch_size) diff --git a/nnabla_rl/model_trainers/q_value/clipped_double_q_trainer.py b/nnabla_rl/model_trainers/q_value/clipped_double_q_trainer.py index f905e804..36483480 100644 --- a/nnabla_rl/model_trainers/q_value/clipped_double_q_trainer.py +++ b/nnabla_rl/model_trainers/q_value/clipped_double_q_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,8 +19,10 @@ import nnabla_rl.functions as RNF from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.model_trainers.model_trainer import TrainingVariables, rnn_support -from nnabla_rl.model_trainers.q_value.squared_td_q_function_trainer import (SquaredTDQFunctionTrainer, - SquaredTDQFunctionTrainerConfig) +from nnabla_rl.model_trainers.q_value.squared_td_q_function_trainer import ( + SquaredTDQFunctionTrainer, + SquaredTDQFunctionTrainerConfig, +) from nnabla_rl.models import QFunction from nnabla_rl.utils.data import convert_to_list_if_not_list from nnabla_rl.utils.misc import create_variables @@ -39,14 +41,16 @@ class ClippedDoubleQTrainer(SquaredTDQFunctionTrainer): _prev_q0_rnn_states: Dict[str, Dict[str, nn.Variable]] _prev_q_rnn_states: Dict[str, Dict[str, nn.Variable]] - def __init__(self, - train_functions: Union[QFunction, Sequence[QFunction]], - solvers: Dict[str, nn.solver.Solver], - target_functions: Sequence[QFunction], - env_info: EnvironmentInfo, - config: ClippedDoubleQTrainerConfig = ClippedDoubleQTrainerConfig()): + def __init__( + self, + train_functions: Union[QFunction, Sequence[QFunction]], + solvers: Dict[str, nn.solver.Solver], + target_functions: Sequence[QFunction], + env_info: EnvironmentInfo, + config: ClippedDoubleQTrainerConfig = ClippedDoubleQTrainerConfig(), + ): if len(target_functions) < 2: - raise ValueError('Must have at least 2 target functions for training') + raise ValueError("Must have at least 2 target functions for training") self._target_functions = convert_to_list_if_not_list(target_functions) self._assert_no_duplicate_model(self._target_functions) self._prev_q0_rnn_states = {} diff --git a/nnabla_rl/model_trainers/q_value/ddpg_q_trainer.py b/nnabla_rl/model_trainers/q_value/ddpg_q_trainer.py index 903aa407..b25423c6 100644 --- a/nnabla_rl/model_trainers/q_value/ddpg_q_trainer.py +++ b/nnabla_rl/model_trainers/q_value/ddpg_q_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,8 +19,10 @@ import nnabla_rl.functions as RNF from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.model_trainers.model_trainer import TrainingVariables, rnn_support -from nnabla_rl.model_trainers.q_value.squared_td_q_function_trainer import (SquaredTDQFunctionTrainer, - SquaredTDQFunctionTrainerConfig) +from nnabla_rl.model_trainers.q_value.squared_td_q_function_trainer import ( + SquaredTDQFunctionTrainer, + SquaredTDQFunctionTrainerConfig, +) from nnabla_rl.models import DeterministicPolicy, QFunction from nnabla_rl.utils.data import convert_to_list_if_not_list from nnabla_rl.utils.misc import create_variables @@ -40,13 +42,15 @@ class DDPGQTrainer(SquaredTDQFunctionTrainer): _prev_target_rnn_states: Dict[str, Dict[str, nn.Variable]] _prev_q_rnn_states: Dict[str, Dict[str, nn.Variable]] - def __init__(self, - train_functions: Union[QFunction, Sequence[QFunction]], - solvers: Dict[str, nn.solver.Solver], - target_functions: Union[QFunction, Sequence[QFunction]], - target_policy: DeterministicPolicy, - env_info: EnvironmentInfo, - config: DDPGQTrainerConfig = DDPGQTrainerConfig()): + def __init__( + self, + train_functions: Union[QFunction, Sequence[QFunction]], + solvers: Dict[str, nn.solver.Solver], + target_functions: Union[QFunction, Sequence[QFunction]], + target_policy: DeterministicPolicy, + env_info: EnvironmentInfo, + config: DDPGQTrainerConfig = DDPGQTrainerConfig(), + ): self._target_policy = target_policy self._target_functions = convert_to_list_if_not_list(target_functions) self._assert_no_duplicate_model(self._target_functions) diff --git a/nnabla_rl/model_trainers/q_value/ddqn_q_trainer.py b/nnabla_rl/model_trainers/q_value/ddqn_q_trainer.py index 02cd8e72..0aafd1e1 100644 --- a/nnabla_rl/model_trainers/q_value/ddqn_q_trainer.py +++ b/nnabla_rl/model_trainers/q_value/ddqn_q_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,8 +18,10 @@ import nnabla as nn from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.model_trainers.model_trainer import TrainingVariables, rnn_support -from nnabla_rl.model_trainers.q_value.squared_td_q_function_trainer import (SquaredTDQFunctionTrainer, - SquaredTDQFunctionTrainerConfig) +from nnabla_rl.model_trainers.q_value.squared_td_q_function_trainer import ( + SquaredTDQFunctionTrainer, + SquaredTDQFunctionTrainerConfig, +) from nnabla_rl.models import QFunction from nnabla_rl.utils.misc import create_variables @@ -38,12 +40,14 @@ class DDQNQTrainer(SquaredTDQFunctionTrainer): _prev_train_rnn_states: Dict[str, Dict[str, nn.Variable]] _prev_target_rnn_states: Dict[str, Dict[str, nn.Variable]] - def __init__(self, - train_function: QFunction, - solvers: Dict[str, nn.solver.Solver], - target_function: QFunction, - env_info: EnvironmentInfo, - config: DDQNQTrainerConfig = DDQNQTrainerConfig()): + def __init__( + self, + train_function: QFunction, + solvers: Dict[str, nn.solver.Solver], + target_function: QFunction, + env_info: EnvironmentInfo, + config: DDQNQTrainerConfig = DDQNQTrainerConfig(), + ): self._train_function = train_function self._target_function = target_function self._prev_train_rnn_states = {} @@ -69,7 +73,7 @@ def _compute_target(self, training_variables: TrainingVariables, **kwargs) -> nn a_next = self._train_function.argmax_q(s_next) prev_rnn_states = self._prev_target_rnn_states - with rnn_support(self._target_function, prev_rnn_states, train_rnn_states, training_variables, self._config): + with rnn_support(self._target_function, prev_rnn_states, train_rnn_states, training_variables, self._config): double_q_target = self._target_function.q(s_next, a_next) return reward + gamma * non_terminal * double_q_target diff --git a/nnabla_rl/model_trainers/q_value/dqn_q_trainer.py b/nnabla_rl/model_trainers/q_value/dqn_q_trainer.py index db12390d..9edbead6 100644 --- a/nnabla_rl/model_trainers/q_value/dqn_q_trainer.py +++ b/nnabla_rl/model_trainers/q_value/dqn_q_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,8 +18,10 @@ import nnabla as nn from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.model_trainers.model_trainer import TrainingVariables, rnn_support -from nnabla_rl.model_trainers.q_value.squared_td_q_function_trainer import (SquaredTDQFunctionTrainer, - SquaredTDQFunctionTrainerConfig) +from nnabla_rl.model_trainers.q_value.squared_td_q_function_trainer import ( + SquaredTDQFunctionTrainer, + SquaredTDQFunctionTrainerConfig, +) from nnabla_rl.models import QFunction from nnabla_rl.utils.misc import create_variables @@ -36,12 +38,14 @@ class DQNQTrainer(SquaredTDQFunctionTrainer): _target_function: QFunction _prev_target_rnn_states: Dict[str, Dict[str, nn.Variable]] - def __init__(self, - train_functions: Union[QFunction, Sequence[QFunction]], - solvers: Dict[str, nn.solver.Solver], - target_function: QFunction, - env_info: EnvironmentInfo, - config: DQNQTrainerConfig = DQNQTrainerConfig()): + def __init__( + self, + train_functions: Union[QFunction, Sequence[QFunction]], + solvers: Dict[str, nn.solver.Solver], + target_function: QFunction, + env_info: EnvironmentInfo, + config: DQNQTrainerConfig = DQNQTrainerConfig(), + ): self._target_function = target_function self._prev_target_rnn_states = {} super(DQNQTrainer, self).__init__(train_functions, solvers, env_info, config) diff --git a/nnabla_rl/model_trainers/q_value/her_q_trainer.py b/nnabla_rl/model_trainers/q_value/her_q_trainer.py index 7ce3b83c..dc1d9120 100644 --- a/nnabla_rl/model_trainers/q_value/her_q_trainer.py +++ b/nnabla_rl/model_trainers/q_value/her_q_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,8 +20,10 @@ import nnabla_rl.functions as RNF from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.model_trainers.model_trainer import TrainingVariables, rnn_support -from nnabla_rl.model_trainers.q_value.squared_td_q_function_trainer import (SquaredTDQFunctionTrainer, - SquaredTDQFunctionTrainerConfig) +from nnabla_rl.model_trainers.q_value.squared_td_q_function_trainer import ( + SquaredTDQFunctionTrainer, + SquaredTDQFunctionTrainerConfig, +) from nnabla_rl.models import DeterministicPolicy, QFunction from nnabla_rl.utils.data import convert_to_list_if_not_list from nnabla_rl.utils.misc import create_variables @@ -42,13 +44,15 @@ class HERQTrainer(SquaredTDQFunctionTrainer): _prev_target_rnn_states: Dict[str, Dict[str, nn.Variable]] _prev_q_rnn_states: Dict[str, Dict[str, nn.Variable]] - def __init__(self, - train_functions: Union[QFunction, Sequence[QFunction]], - solvers: Dict[str, nn.solver.Solver], - target_functions: Union[QFunction, Sequence[QFunction]], - target_policy: DeterministicPolicy, - env_info: EnvironmentInfo, - config: HERQTrainerConfig = HERQTrainerConfig()): + def __init__( + self, + train_functions: Union[QFunction, Sequence[QFunction]], + solvers: Dict[str, nn.solver.Solver], + target_functions: Union[QFunction, Sequence[QFunction]], + target_policy: DeterministicPolicy, + env_info: EnvironmentInfo, + config: HERQTrainerConfig = HERQTrainerConfig(), + ): self._target_policy = target_policy self._target_functions = convert_to_list_if_not_list(target_functions) self._prev_target_rnn_states = {} diff --git a/nnabla_rl/model_trainers/q_value/hyar_q_trainer.py b/nnabla_rl/model_trainers/q_value/hyar_q_trainer.py index 89726ae2..ff9da65c 100644 --- a/nnabla_rl/model_trainers/q_value/hyar_q_trainer.py +++ b/nnabla_rl/model_trainers/q_value/hyar_q_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -48,14 +48,16 @@ class HyARQTrainer(TD3QTrainer): _target_policy: DeterministicPolicy _config: HyARQTrainerConfig - def __init__(self, - train_functions: Union[QFunction, Sequence[QFunction]], - solvers: Dict[str, nn.solver.Solver], - target_functions: Union[QFunction, Sequence[QFunction]], - target_policy: DeterministicPolicy, - vae: HyARVAE, - env_info: EnvironmentInfo, - config: HyARQTrainerConfig = HyARQTrainerConfig()): + def __init__( + self, + train_functions: Union[QFunction, Sequence[QFunction]], + solvers: Dict[str, nn.solver.Solver], + target_functions: Union[QFunction, Sequence[QFunction]], + target_policy: DeterministicPolicy, + vae: HyARVAE, + env_info: EnvironmentInfo, + config: HyARQTrainerConfig = HyARQTrainerConfig(), + ): self._vae = vae super().__init__(train_functions, solvers, target_functions, target_policy, env_info, config) @@ -79,33 +81,36 @@ def _compute_target(self, training_variables: TrainingVariables, **kwargs) -> nn def _compute_noisy_action(self, state): a_next_var = self._target_policy.pi(state) - epsilon = NF.clip_by_value(NF.randn(sigma=self._config.train_action_noise_sigma, shape=a_next_var.shape), - min=-self._config.train_action_noise_abs, - max=self._config.train_action_noise_abs) + epsilon = NF.clip_by_value( + NF.randn(sigma=self._config.train_action_noise_sigma, shape=a_next_var.shape), + min=-self._config.train_action_noise_abs, + max=self._config.train_action_noise_abs, + ) a_tilde_var = a_next_var + epsilon a_tilde_var = NF.clip_by_value(a_tilde_var, self._config.noisy_action_min, self._config.noisy_action_max) return a_tilde_var - def _update_model(self, - models: Sequence[Model], - solvers: Dict[str, Any], - batch: TrainingBatch, - training_variables: TrainingVariables, - **kwargs): + def _update_model( + self, + models: Sequence[Model], + solvers: Dict[str, Any], + batch: TrainingBatch, + training_variables: TrainingVariables, + **kwargs, + ): for t, b in zip(training_variables, batch): - set_data_to_variable(t.extra['e'], b.extra['e']) - set_data_to_variable(t.extra['z'], b.extra['z']) - set_data_to_variable(t.extra['c_rate'], b.extra['c_rate']) - set_data_to_variable(t.extra['ds_rate'], b.extra['ds_rate']) + set_data_to_variable(t.extra["e"], b.extra["e"]) + set_data_to_variable(t.extra["z"], b.extra["z"]) + set_data_to_variable(t.extra["c_rate"], b.extra["c_rate"]) + set_data_to_variable(t.extra["ds_rate"], b.extra["ds_rate"]) result = super()._update_model(models, solvers, batch, training_variables, **kwargs) return result - def _compute_loss(self, - model: QFunction, - target_q: nn.Variable, - training_variables: TrainingVariables) -> Tuple[nn.Variable, Dict[str, nn.Variable]]: - e = training_variables.extra['e'] - z = training_variables.extra['z'] + def _compute_loss( + self, model: QFunction, target_q: nn.Variable, training_variables: TrainingVariables + ) -> Tuple[nn.Variable, Dict[str, nn.Variable]]: + e = training_variables.extra["e"] + z = training_variables.extra["z"] e, z = self._reweight_action(e, z, training_variables) latent_action = NF.concatenate(e, z) @@ -116,38 +121,39 @@ def _compute_loss(self, td_error = target_q - q q_loss = 0 - if self._config.loss_type == 'squared': + if self._config.loss_type == "squared": squared_td_error = training_variables.weight * NF.pow_scalar(td_error, 2.0) else: raise RuntimeError - if self._config.reduction_method == 'mean': + if self._config.reduction_method == "mean": q_loss += self._config.q_loss_scalar * NF.mean(squared_td_error) else: raise RuntimeError - extra = {'td_error': td_error} + extra = {"td_error": td_error} return q_loss, extra def _reweight_action(self, e, z, training_variables: TrainingVariables): - c_rate = training_variables.extra['c_rate'] - ds_rate = training_variables.extra['ds_rate'] + c_rate = training_variables.extra["c_rate"] + ds_rate = training_variables.extra["ds_rate"] s_current = training_variables.s_current s_next = training_variables.s_next action1, action2 = training_variables.a_current action_space = cast(gym.spaces.Tuple, self._env_info.action_space) - a_continuous, a_discrete = (action1, action2) if isinstance( - action_space[0], gym.spaces.Box) else (action2, action1) + a_continuous, a_discrete = ( + (action1, action2) if isinstance(action_space[0], gym.spaces.Box) else (action2, action1) + ) a_discrete_emb = self._vae.encode_discrete_action(a_discrete) - a_discrete_emb = NF.clip_by_value(a_discrete_emb, - self._config.embed_action_min, - self._config.embed_action_max) - noise = NF.clip_by_value(NF.randn(shape=a_discrete_emb.shape) * self._config.embed_action_noise_sigma, - -self._config.embed_action_noise_abs, - self._config.embed_action_noise_abs) - a_discrete_emb_with_noise = NF.clip_by_value(a_discrete_emb + noise, - self._config.embed_action_min, - self._config.embed_action_max) + a_discrete_emb = NF.clip_by_value(a_discrete_emb, self._config.embed_action_min, self._config.embed_action_max) + noise = NF.clip_by_value( + NF.randn(shape=a_discrete_emb.shape) * self._config.embed_action_noise_sigma, + -self._config.embed_action_noise_abs, + self._config.embed_action_noise_abs, + ) + a_discrete_emb_with_noise = NF.clip_by_value( + a_discrete_emb + noise, self._config.embed_action_min, self._config.embed_action_max + ) a_discrete_new = a_discrete a_discrete_old = self._vae.decode_discrete_action(e) @@ -181,10 +187,10 @@ def _setup_training_variables(self, batch_size: int) -> TrainingVariables: training_variables = super()._setup_training_variables(batch_size) extras = {} - extras['e'] = create_variable(batch_size, (self._config.embed_dim, )) - extras['z'] = create_variable(batch_size, (self._config.latent_dim, )) - extras['ds_rate'] = create_variable(1, (1, )) - extras['c_rate'] = create_variable(1, ((self._config.latent_dim, ), (self._config.latent_dim, ))) + extras["e"] = create_variable(batch_size, (self._config.embed_dim,)) + extras["z"] = create_variable(batch_size, (self._config.latent_dim,)) + extras["ds_rate"] = create_variable(1, (1,)) + extras["c_rate"] = create_variable(1, ((self._config.latent_dim,), (self._config.latent_dim,))) training_variables.extra.update(extras) return training_variables diff --git a/nnabla_rl/model_trainers/q_value/iqn_q_trainer.py b/nnabla_rl/model_trainers/q_value/iqn_q_trainer.py index 829a1d17..297b93da 100644 --- a/nnabla_rl/model_trainers/q_value/iqn_q_trainer.py +++ b/nnabla_rl/model_trainers/q_value/iqn_q_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,7 +19,9 @@ from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.model_trainers.model_trainer import TrainingVariables, rnn_support from nnabla_rl.model_trainers.q_value.state_action_quantile_function_trainer import ( - StateActionQuantileFunctionTrainer, StateActionQuantileFunctionTrainerConfig) + StateActionQuantileFunctionTrainer, + StateActionQuantileFunctionTrainerConfig, +) from nnabla_rl.models import StateActionQuantileFunction from nnabla_rl.utils.misc import create_variables @@ -37,12 +39,14 @@ class IQNQTrainer(StateActionQuantileFunctionTrainer): _prev_a_star_rnn_states: Dict[str, Dict[str, nn.Variable]] _prev_z_tau_j_rnn_states: Dict[str, Dict[str, nn.Variable]] - def __init__(self, - train_functions: Union[StateActionQuantileFunction, Sequence[StateActionQuantileFunction]], - solvers: Dict[str, nn.solver.Solver], - target_function: StateActionQuantileFunction, - env_info: EnvironmentInfo, - config: IQNQTrainerConfig = IQNQTrainerConfig()): + def __init__( + self, + train_functions: Union[StateActionQuantileFunction, Sequence[StateActionQuantileFunction]], + solvers: Dict[str, nn.solver.Solver], + target_function: StateActionQuantileFunction, + env_info: EnvironmentInfo, + config: IQNQTrainerConfig = IQNQTrainerConfig(), + ): self._target_function = target_function self._prev_a_star_rnn_states = {} self._prev_z_tau_j_rnn_states = {} diff --git a/nnabla_rl/model_trainers/q_value/multi_step_trainer.py b/nnabla_rl/model_trainers/q_value/multi_step_trainer.py index e6509691..98c8835f 100644 --- a/nnabla_rl/model_trainers/q_value/multi_step_trainer.py +++ b/nnabla_rl/model_trainers/q_value/multi_step_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,11 +26,12 @@ @dataclass class MultiStepTrainerConfig(TrainerConfig): """Configuration class for ModelTrainer.""" + num_steps: int = 1 def __post_init__(self): super(MultiStepTrainerConfig, self).__post_init__() - self._assert_positive(self.num_steps, 'num_steps') + self._assert_positive(self.num_steps, "num_steps") class MultiStepTrainer(ModelTrainer): @@ -39,11 +40,13 @@ class MultiStepTrainer(ModelTrainer): # See https://mypy.readthedocs.io/en/stable/class_basics.html for details _config: MultiStepTrainerConfig - def __init__(self, - models: Union[Model, Sequence[Model]], - solvers: Dict[str, nn.solver.Solver], - env_info: EnvironmentInfo, - config: MultiStepTrainerConfig): + def __init__( + self, + models: Union[Model, Sequence[Model]], + solvers: Dict[str, nn.solver.Solver], + env_info: EnvironmentInfo, + config: MultiStepTrainerConfig, + ): super(MultiStepTrainer, self).__init__(models, solvers, env_info, config) def _setup_batch(self, training_batch: TrainingBatch) -> TrainingBatch: @@ -74,16 +77,18 @@ def _setup_batch(self, training_batch: TrainingBatch) -> TrainingBatch: continue last_batch = training_batch_list[training_batch_length - 1 - (i - self._config.num_steps + 1)] - n_step_batch = TrainingBatch(batch_size=batch.batch_size, - s_current=batch.s_current, - a_current=batch.a_current, - reward=np.sum(n_step_reward, axis=1, keepdims=True), - gamma=np.prod(n_step_gamma, axis=1, keepdims=True), - non_terminal=np.prod(n_step_non_terminal, axis=1, keepdims=True), - s_next=last_batch.s_next, - weight=batch.weight, - extra=batch.extra, - next_step_batch=next_step_batch) + n_step_batch = TrainingBatch( + batch_size=batch.batch_size, + s_current=batch.s_current, + a_current=batch.a_current, + reward=np.sum(n_step_reward, axis=1, keepdims=True), + gamma=np.prod(n_step_gamma, axis=1, keepdims=True), + non_terminal=np.prod(n_step_non_terminal, axis=1, keepdims=True), + s_next=last_batch.s_next, + weight=batch.weight, + extra=batch.extra, + next_step_batch=next_step_batch, + ) next_step_batch = n_step_batch return cast(TrainingBatch, n_step_batch) diff --git a/nnabla_rl/model_trainers/q_value/munchausen_rl_q_trainer.py b/nnabla_rl/model_trainers/q_value/munchausen_rl_q_trainer.py index dfdeaa4d..6ced2f10 100644 --- a/nnabla_rl/model_trainers/q_value/munchausen_rl_q_trainer.py +++ b/nnabla_rl/model_trainers/q_value/munchausen_rl_q_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,10 +20,14 @@ import nnabla_rl.functions as RF from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.model_trainers.model_trainer import TrainingVariables, rnn_support -from nnabla_rl.model_trainers.q_value.squared_td_q_function_trainer import (SquaredTDQFunctionTrainer, - SquaredTDQFunctionTrainerConfig) +from nnabla_rl.model_trainers.q_value.squared_td_q_function_trainer import ( + SquaredTDQFunctionTrainer, + SquaredTDQFunctionTrainerConfig, +) from nnabla_rl.model_trainers.q_value.state_action_quantile_function_trainer import ( - StateActionQuantileFunctionTrainer, StateActionQuantileFunctionTrainerConfig) + StateActionQuantileFunctionTrainer, + StateActionQuantileFunctionTrainerConfig, +) from nnabla_rl.models import QFunction, StateActionQuantileFunction from nnabla_rl.utils.misc import create_variables @@ -33,14 +37,12 @@ def _pi(q_values: nn.Variable, max_q: nn.Variable, tau: float): def _all_tau_log_pi(q_values: nn.Variable, max_q: nn.Variable, tau: float): - logsumexp = tau * NF.log(NF.sum(x=NF.exp((q_values - max_q) / tau), - axis=(q_values.ndim - 1), keepdims=True)) + logsumexp = tau * NF.log(NF.sum(x=NF.exp((q_values - max_q) / tau), axis=(q_values.ndim - 1), keepdims=True)) return q_values - max_q - logsumexp def _tau_log_pi(q_k: nn.Variable, q_values: nn.Variable, max_q: nn.Variable, tau: float): - logsumexp = tau * NF.log(NF.sum(x=NF.exp((q_values - max_q) / tau), - axis=(q_values.ndim - 1), keepdims=True)) + logsumexp = tau * NF.log(NF.sum(x=NF.exp((q_values - max_q) / tau), axis=(q_values.ndim - 1), keepdims=True)) return q_k - max_q - logsumexp @@ -64,12 +66,14 @@ class MunchausenDQNQTrainer(SquaredTDQFunctionTrainer): _prev_all_current_q_rnn_states: Dict[str, Dict[str, nn.Variable]] _prev_max_current_q_rnn_states: Dict[str, Dict[str, nn.Variable]] - def __init__(self, - train_functions: Union[QFunction, Sequence[QFunction]], - solvers: Dict[str, nn.solver.Solver], - target_function: QFunction, - env_info: EnvironmentInfo, - config: MunchausenDQNQTrainerConfig = MunchausenDQNQTrainerConfig()): + def __init__( + self, + train_functions: Union[QFunction, Sequence[QFunction]], + solvers: Dict[str, nn.solver.Solver], + target_function: QFunction, + env_info: EnvironmentInfo, + config: MunchausenDQNQTrainerConfig = MunchausenDQNQTrainerConfig(), + ): self._target_function = target_function self._prev_all_next_q_rnn_states = {} self._prev_max_next_q_rnn_states = {} @@ -103,7 +107,7 @@ def _compute_target(self, training_variables: TrainingVariables) -> nn.Variable: all_tau_log_pi = _all_tau_log_pi(all_next_q, max_next_q, self._config.tau) assert pi.shape == all_next_q.shape assert pi.shape == all_tau_log_pi.shape - soft_q_target = NF.sum(pi * (all_next_q - all_tau_log_pi), axis=(pi.ndim - 1), keepdims=True) + soft_q_target = NF.sum(pi * (all_next_q - all_tau_log_pi), axis=(pi.ndim - 1), keepdims=True) prev_rnn_states = self._prev_current_q_rnn_states with rnn_support(self._target_function, prev_rnn_states, train_rnn_states, training_variables, self._config): @@ -151,12 +155,14 @@ class MunchausenIQNQTrainer(StateActionQuantileFunctionTrainer): _prev_next_quantile_rnn_states: Dict[str, Dict[str, nn.Variable]] _prev_current_quantile_rnn_states: Dict[str, Dict[str, nn.Variable]] - def __init__(self, - train_functions: Union[StateActionQuantileFunction, Sequence[StateActionQuantileFunction]], - solvers: Dict[str, nn.solver.Solver], - target_function: StateActionQuantileFunction, - env_info: EnvironmentInfo, - config: MunchausenIQNQTrainerConfig = MunchausenIQNQTrainerConfig()): + def __init__( + self, + train_functions: Union[StateActionQuantileFunction, Sequence[StateActionQuantileFunction]], + solvers: Dict[str, nn.solver.Solver], + target_function: StateActionQuantileFunction, + env_info: EnvironmentInfo, + config: MunchausenIQNQTrainerConfig = MunchausenIQNQTrainerConfig(), + ): self._target_function = target_function self._prev_next_quantile_rnn_states = {} self._prev_current_quantile_rnn_states = {} diff --git a/nnabla_rl/model_trainers/q_value/qrdqn_q_trainer.py b/nnabla_rl/model_trainers/q_value/qrdqn_q_trainer.py index 1a3e28e3..d36deb42 100644 --- a/nnabla_rl/model_trainers/q_value/qrdqn_q_trainer.py +++ b/nnabla_rl/model_trainers/q_value/qrdqn_q_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,7 +19,9 @@ from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.model_trainers.model_trainer import TrainingVariables, rnn_support from nnabla_rl.model_trainers.q_value.quantile_distribution_function_trainer import ( - QuantileDistributionFunctionTrainer, QuantileDistributionFunctionTrainerConfig) + QuantileDistributionFunctionTrainer, + QuantileDistributionFunctionTrainerConfig, +) from nnabla_rl.models import QuantileDistributionFunction from nnabla_rl.utils.misc import create_variables @@ -36,12 +38,14 @@ class QRDQNQTrainer(QuantileDistributionFunctionTrainer): _target_function: QuantileDistributionFunction _prev_target_rnn_states: Dict[str, Dict[str, nn.Variable]] - def __init__(self, - train_functions: Union[QuantileDistributionFunction, Sequence[QuantileDistributionFunction]], - solvers: Dict[str, nn.solver.Solver], - target_function: QuantileDistributionFunction, - env_info: EnvironmentInfo, - config: QRDQNQTrainerConfig = QRDQNQTrainerConfig()): + def __init__( + self, + train_functions: Union[QuantileDistributionFunction, Sequence[QuantileDistributionFunction]], + solvers: Dict[str, nn.solver.Solver], + target_function: QuantileDistributionFunction, + env_info: EnvironmentInfo, + config: QRDQNQTrainerConfig = QRDQNQTrainerConfig(), + ): self._target_function = target_function self._prev_target_rnn_states = {} super(QRDQNQTrainer, self).__init__(train_functions, solvers, env_info, config) diff --git a/nnabla_rl/model_trainers/q_value/qrsac_q_trainer.py b/nnabla_rl/model_trainers/q_value/qrsac_q_trainer.py index 9c9d3fc3..cd439252 100644 --- a/nnabla_rl/model_trainers/q_value/qrsac_q_trainer.py +++ b/nnabla_rl/model_trainers/q_value/qrsac_q_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,7 +21,9 @@ from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.model_trainers.model_trainer import TrainingVariables, rnn_support from nnabla_rl.model_trainers.q_value.quantile_distribution_function_trainer import ( - QuantileDistributionFunctionTrainer, QuantileDistributionFunctionTrainerConfig) + QuantileDistributionFunctionTrainer, + QuantileDistributionFunctionTrainerConfig, +) from nnabla_rl.models import QuantileDistributionFunction from nnabla_rl.models.policy import StochasticPolicy from nnabla_rl.utils.data import convert_to_list_if_not_list @@ -42,14 +44,16 @@ class QRSACQTrainer(QuantileDistributionFunctionTrainer): _prev_quantile_rnn_states: Dict[str, Dict[str, nn.Variable]] _prev_q_rnn_states: Dict[str, Dict[str, nn.Variable]] - def __init__(self, - train_functions: Union[QuantileDistributionFunction, Sequence[QuantileDistributionFunction]], - solvers: Dict[str, nn.solver.Solver], - target_functions: Union[QuantileDistributionFunction, Sequence[QuantileDistributionFunction]], - target_policy: StochasticPolicy, - temperature: nn.Variable, - env_info: EnvironmentInfo, - config: QRSACQTrainerConfig = QRSACQTrainerConfig()): + def __init__( + self, + train_functions: Union[QuantileDistributionFunction, Sequence[QuantileDistributionFunction]], + solvers: Dict[str, nn.solver.Solver], + target_functions: Union[QuantileDistributionFunction, Sequence[QuantileDistributionFunction]], + target_policy: StochasticPolicy, + temperature: nn.Variable, + env_info: EnvironmentInfo, + config: QRSACQTrainerConfig = QRSACQTrainerConfig(), + ): self._target_functions = convert_to_list_if_not_list(target_functions) self._assert_no_duplicate_model(self._target_functions) self._target_policy = target_policy diff --git a/nnabla_rl/model_trainers/q_value/quantile_distribution_function_trainer.py b/nnabla_rl/model_trainers/q_value/quantile_distribution_function_trainer.py index d22d3337..14e2bae3 100644 --- a/nnabla_rl/model_trainers/q_value/quantile_distribution_function_trainer.py +++ b/nnabla_rl/model_trainers/q_value/quantile_distribution_function_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -44,11 +44,13 @@ class QuantileDistributionFunctionTrainer(MultiStepTrainer): _quantile_huber_loss: nn.Variable _prev_rnn_states: Dict[str, Dict[str, nn.Variable]] - def __init__(self, - models: Union[QuantileDistributionFunction, Sequence[QuantileDistributionFunction]], - solvers: Dict[str, nn.solver.Solver], - env_info: EnvironmentInfo, - config: QuantileDistributionFunctionTrainerConfig = QuantileDistributionFunctionTrainerConfig()): + def __init__( + self, + models: Union[QuantileDistributionFunction, Sequence[QuantileDistributionFunction]], + solvers: Dict[str, nn.solver.Solver], + env_info: EnvironmentInfo, + config: QuantileDistributionFunctionTrainerConfig = QuantileDistributionFunctionTrainerConfig(), + ): if config.kappa == 0.0: logger.info("kappa is set to 0.0. Quantile regression loss will be used for training") else: @@ -59,12 +61,14 @@ def __init__(self, self._prev_rnn_states = {} super(QuantileDistributionFunctionTrainer, self).__init__(models, solvers, env_info, config) - def _update_model(self, - models: Sequence[Model], - solvers: Dict[str, nn.solver.Solver], - batch: TrainingBatch, - training_variables: TrainingVariables, - **kwargs) -> Dict[str, np.ndarray]: + def _update_model( + self, + models: Sequence[Model], + solvers: Dict[str, nn.solver.Solver], + batch: TrainingBatch, + training_variables: TrainingVariables, + **kwargs, + ) -> Dict[str, np.ndarray]: for t, b in zip(training_variables, batch): set_data_to_variable(t.s_current, b.s_current) set_data_to_variable(t.a_current, b.a_current) @@ -94,12 +98,10 @@ def _update_model(self, solver.update() trainer_state = {} - trainer_state['q_loss'] = self._quantile_huber_loss.d.copy() + trainer_state["q_loss"] = self._quantile_huber_loss.d.copy() return trainer_state - def _build_training_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables): + def _build_training_graph(self, models: Sequence[Model], training_variables: TrainingVariables): self._quantile_huber_loss = 0 ignore_intermediate_loss = self._config.loss_integration is LossIntegration.LAST_TIMESTEP_ONLY for step_index, variables in enumerate(training_variables): @@ -108,10 +110,7 @@ def _build_training_graph(self, ignore_loss = is_burn_in_steps or (is_intermediate_steps and ignore_intermediate_loss) self._build_one_step_graph(models, variables, ignore_loss=ignore_loss) - def _build_one_step_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables, - ignore_loss: bool): + def _build_one_step_graph(self, models: Sequence[Model], training_variables: TrainingVariables, ignore_loss: bool): models = cast(Sequence[QuantileDistributionFunction], models) # Ttheta_j is the target quantile distribution @@ -129,10 +128,9 @@ def _build_one_step_graph(self, def _compute_target(self, training_variables: TrainingVariables) -> nn.Variable: raise NotImplementedError - def _compute_loss(self, - model: QuantileDistributionFunction, - target: nn.Variable, - training_variables: TrainingVariables) -> Tuple[nn.Variable, Dict[str, nn.Variable]]: + def _compute_loss( + self, model: QuantileDistributionFunction, target: nn.Variable, training_variables: TrainingVariables + ) -> Tuple[nn.Variable, Dict[str, nn.Variable]]: batch_size = training_variables.batch_size Ttheta_i = model.quantiles(training_variables.s_current, training_variables.a_current) Ttheta_i = RF.expand_dims(Ttheta_i, axis=2) @@ -145,9 +143,7 @@ def _compute_loss(self, # NOTE: target is same as Ttheta_j in the paper quantile_huber_loss = RF.quantile_huber_loss(target, Ttheta_i, self._config.kappa, tau_hat) - assert quantile_huber_loss.shape == (batch_size, - self._config.num_quantiles, - self._config.num_quantiles) + assert quantile_huber_loss.shape == (batch_size, self._config.num_quantiles, self._config.num_quantiles) quantile_huber_loss = NF.mean(quantile_huber_loss, axis=2) quantile_huber_loss = NF.sum(quantile_huber_loss, axis=1) @@ -168,21 +164,25 @@ def _setup_training_variables(self, batch_size) -> TrainingVariables: rnn_state_variables = create_variables(batch_size, model.internal_state_shapes()) rnn_states[model.scope_name] = rnn_state_variables - training_variables = TrainingVariables(batch_size=batch_size, - s_current=s_current_var, - a_current=a_current_var, - reward=reward_var, - gamma=gamma_var, - non_terminal=non_terminal_var, - s_next=s_next_var, - rnn_states=rnn_states) + training_variables = TrainingVariables( + batch_size=batch_size, + s_current=s_current_var, + a_current=a_current_var, + reward=reward_var, + gamma=gamma_var, + non_terminal=non_terminal_var, + s_next=s_next_var, + rnn_states=rnn_states, + ) return training_variables @staticmethod def _precompute_tau_hat(num_quantiles): - tau_hat = [(tau_prev + tau_i) / num_quantiles / 2.0 - for tau_prev, tau_i in zip(range(0, num_quantiles), range(1, num_quantiles+1))] + tau_hat = [ + (tau_prev + tau_i) / num_quantiles / 2.0 + for tau_prev, tau_i in zip(range(0, num_quantiles), range(1, num_quantiles + 1)) + ] return np.array(tau_hat, dtype=np.float32) @property diff --git a/nnabla_rl/model_trainers/q_value/redq_q_trainer.py b/nnabla_rl/model_trainers/q_value/redq_q_trainer.py index f7f197de..98697697 100644 --- a/nnabla_rl/model_trainers/q_value/redq_q_trainer.py +++ b/nnabla_rl/model_trainers/q_value/redq_q_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,8 +19,10 @@ import nnabla.functions as NF from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.model_trainers.model_trainer import TrainingVariables, rnn_support -from nnabla_rl.model_trainers.q_value.squared_td_q_function_trainer import (SquaredTDQFunctionTrainer, - SquaredTDQFunctionTrainerConfig) +from nnabla_rl.model_trainers.q_value.squared_td_q_function_trainer import ( + SquaredTDQFunctionTrainer, + SquaredTDQFunctionTrainerConfig, +) from nnabla_rl.models import QFunction, StochasticPolicy from nnabla_rl.utils.data import convert_to_list_if_not_list from nnabla_rl.utils.misc import create_variables @@ -42,14 +44,16 @@ class REDQQTrainer(SquaredTDQFunctionTrainer): _prev_target_rnn_states: Dict[str, Dict[str, nn.Variable]] _prev_q_rnn_states: Dict[str, Dict[str, nn.Variable]] - def __init__(self, - train_functions: Union[QFunction, Sequence[QFunction]], - solvers: Dict[str, nn.solver.Solver], - target_functions: Union[QFunction, Sequence[QFunction]], - target_policy: StochasticPolicy, - temperature: nn.Variable, - env_info: EnvironmentInfo, - config: REDQQTrainerConfig = REDQQTrainerConfig()): + def __init__( + self, + train_functions: Union[QFunction, Sequence[QFunction]], + solvers: Dict[str, nn.solver.Solver], + target_functions: Union[QFunction, Sequence[QFunction]], + target_policy: StochasticPolicy, + temperature: nn.Variable, + env_info: EnvironmentInfo, + config: REDQQTrainerConfig = REDQQTrainerConfig(), + ): self._target_functions = convert_to_list_if_not_list(target_functions) self._N = len(self._target_functions) self._assert_no_duplicate_model(self._target_functions) diff --git a/nnabla_rl/model_trainers/q_value/soft_q_decomposition_trainer.py b/nnabla_rl/model_trainers/q_value/soft_q_decomposition_trainer.py index 5c2d72eb..e8d936d4 100644 --- a/nnabla_rl/model_trainers/q_value/soft_q_decomposition_trainer.py +++ b/nnabla_rl/model_trainers/q_value/soft_q_decomposition_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -34,14 +34,16 @@ class SoftQDTrainerConfig(SoftQTrainerConfig): class SoftQDTrainer(SoftQTrainer): _target_functions: Sequence[FactoredContinuousQFunction] - def __init__(self, - train_functions: Union[FactoredContinuousQFunction, Sequence[FactoredContinuousQFunction]], - solvers: Dict[str, nn.solver.Solver], - target_functions: Union[FactoredContinuousQFunction, Sequence[FactoredContinuousQFunction]], - target_policy: StochasticPolicy, - temperature: nn.Variable, - env_info: EnvironmentInfo, - config: SoftQDTrainerConfig = SoftQDTrainerConfig()): + def __init__( + self, + train_functions: Union[FactoredContinuousQFunction, Sequence[FactoredContinuousQFunction]], + solvers: Dict[str, nn.solver.Solver], + target_functions: Union[FactoredContinuousQFunction, Sequence[FactoredContinuousQFunction]], + target_policy: StochasticPolicy, + temperature: nn.Variable, + env_info: EnvironmentInfo, + config: SoftQDTrainerConfig = SoftQDTrainerConfig(), + ): super().__init__( train_functions=train_functions, solvers=solvers, @@ -52,10 +54,9 @@ def __init__(self, config=config, ) - def _compute_loss(self, - model: QFunction, - target_q: nn.Variable, - training_variables: TrainingVariables) -> Tuple[nn.Variable, Dict[str, nn.Variable]]: + def _compute_loss( + self, model: QFunction, target_q: nn.Variable, training_variables: TrainingVariables + ) -> Tuple[nn.Variable, Dict[str, nn.Variable]]: assert isinstance(model, FactoredContinuousQFunction) s_current = training_variables.s_current @@ -72,14 +73,14 @@ def _compute_loss(self, maximum = nn.Variable.from_numpy_array(np.full(td_error.shape, clip_max)) td_error = NF.clip_grad_by_value(td_error, minimum, maximum) squared_td_error = training_variables.weight * NF.pow_scalar(td_error, 2.0) - if self._config.reduction_method == 'mean': + if self._config.reduction_method == "mean": q_loss += self._config.q_loss_scalar * NF.mean(squared_td_error) - elif self._config.reduction_method == 'sum': + elif self._config.reduction_method == "sum": q_loss += self._config.q_loss_scalar * NF.sum(squared_td_error) else: raise RuntimeError - extra = {'td_error': td_error} + extra = {"td_error": td_error} return q_loss, extra def _compute_target(self, training_variables: TrainingVariables, **kwargs) -> nn.Variable: diff --git a/nnabla_rl/model_trainers/q_value/soft_q_trainer.py b/nnabla_rl/model_trainers/q_value/soft_q_trainer.py index 7bbf705f..9d161d89 100644 --- a/nnabla_rl/model_trainers/q_value/soft_q_trainer.py +++ b/nnabla_rl/model_trainers/q_value/soft_q_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,8 +19,10 @@ import nnabla_rl.functions as RF from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.model_trainers.model_trainer import TrainingVariables, rnn_support -from nnabla_rl.model_trainers.q_value.squared_td_q_function_trainer import (SquaredTDQFunctionTrainer, - SquaredTDQFunctionTrainerConfig) +from nnabla_rl.model_trainers.q_value.squared_td_q_function_trainer import ( + SquaredTDQFunctionTrainer, + SquaredTDQFunctionTrainerConfig, +) from nnabla_rl.models import QFunction, StochasticPolicy from nnabla_rl.utils.data import convert_to_list_if_not_list from nnabla_rl.utils.misc import create_variables @@ -42,14 +44,16 @@ class SoftQTrainer(SquaredTDQFunctionTrainer): _prev_target_rnn_states: Dict[str, Dict[str, nn.Variable]] _prev_q_rnn_states: Dict[str, Dict[str, nn.Variable]] - def __init__(self, - train_functions: Union[QFunction, Sequence[QFunction]], - solvers: Dict[str, nn.solver.Solver], - target_functions: Union[QFunction, Sequence[QFunction]], - target_policy: StochasticPolicy, - temperature: nn.Variable, - env_info: EnvironmentInfo, - config: SoftQTrainerConfig = SoftQTrainerConfig()): + def __init__( + self, + train_functions: Union[QFunction, Sequence[QFunction]], + solvers: Dict[str, nn.solver.Solver], + target_functions: Union[QFunction, Sequence[QFunction]], + target_policy: StochasticPolicy, + temperature: nn.Variable, + env_info: EnvironmentInfo, + config: SoftQTrainerConfig = SoftQTrainerConfig(), + ): self._target_functions = convert_to_list_if_not_list(target_functions) self._assert_no_duplicate_model(self._target_functions) self._target_policy = target_policy diff --git a/nnabla_rl/model_trainers/q_value/squared_td_q_function_trainer.py b/nnabla_rl/model_trainers/q_value/squared_td_q_function_trainer.py index a1b4f9dd..acd628f1 100644 --- a/nnabla_rl/model_trainers/q_value/squared_td_q_function_trainer.py +++ b/nnabla_rl/model_trainers/q_value/squared_td_q_function_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,19 +29,19 @@ @dataclass class SquaredTDQFunctionTrainerConfig(MultiStepTrainerConfig): - loss_type: str = 'squared' - reduction_method: str = 'mean' + loss_type: str = "squared" + reduction_method: str = "mean" grad_clip: Optional[tuple] = None q_loss_scalar: float = 1.0 reward_dimension: int = 1 huber_delta: Optional[float] = None def __post_init__(self): - self._assert_one_of(self.loss_type, ['squared', 'huber'], 'loss_type') - self._assert_one_of(self.reduction_method, ['sum', 'mean'], 'reduction_method') + self._assert_one_of(self.loss_type, ["squared", "huber"], "loss_type") + self._assert_one_of(self.reduction_method, ["sum", "mean"], "reduction_method") if self.grad_clip is not None: - self._assert_ascending_order(self.grad_clip, 'grad_clip') - self._assert_length(self.grad_clip, 2, 'grad_clip') + self._assert_ascending_order(self.grad_clip, "grad_clip") + self._assert_length(self.grad_clip, 2, "grad_clip") class SquaredTDQFunctionTrainer(MultiStepTrainer): @@ -51,20 +51,24 @@ class SquaredTDQFunctionTrainer(MultiStepTrainer): _q_loss: nn.Variable _prev_rnn_states: Dict[str, Dict[str, nn.Variable]] - def __init__(self, - models: Union[QFunction, Sequence[QFunction]], - solvers: Dict[str, nn.solver.Solver], - env_info: EnvironmentInfo, - config: SquaredTDQFunctionTrainerConfig = SquaredTDQFunctionTrainerConfig()): + def __init__( + self, + models: Union[QFunction, Sequence[QFunction]], + solvers: Dict[str, nn.solver.Solver], + env_info: EnvironmentInfo, + config: SquaredTDQFunctionTrainerConfig = SquaredTDQFunctionTrainerConfig(), + ): self._prev_rnn_states = {} super(SquaredTDQFunctionTrainer, self).__init__(models, solvers, env_info, config) - def _update_model(self, - models: Sequence[Model], - solvers: Dict[str, nn.solver.Solver], - batch: TrainingBatch, - training_variables: TrainingVariables, - **kwargs) -> Dict[str, np.ndarray]: + def _update_model( + self, + models: Sequence[Model], + solvers: Dict[str, nn.solver.Solver], + batch: TrainingBatch, + training_variables: TrainingVariables, + **kwargs, + ) -> Dict[str, np.ndarray]: for t, b in zip(training_variables, batch): set_data_to_variable(t.s_current, b.s_current) set_data_to_variable(t.a_current, b.a_current) @@ -96,13 +100,11 @@ def _update_model(self, q_solver.update() trainer_state = {} - trainer_state['q_loss'] = self._q_loss.d.copy() - trainer_state['td_errors'] = self._td_error.d.copy() + trainer_state["q_loss"] = self._q_loss.d.copy() + trainer_state["td_errors"] = self._td_error.d.copy() return trainer_state - def _build_training_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables): + def _build_training_graph(self, models: Sequence[Model], training_variables: TrainingVariables): self._q_loss = 0 ignore_intermediate_loss = self._config.loss_integration is LossIntegration.LAST_TIMESTEP_ONLY for step_index, variables in enumerate(training_variables): @@ -111,10 +113,7 @@ def _build_training_graph(self, ignore_loss = is_burn_in_steps or (is_intermediate_steps and ignore_intermediate_loss) self._build_one_step_graph(models, variables, ignore_loss=ignore_loss) - def _build_one_step_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables, - ignore_loss: bool): + def _build_one_step_graph(self, models: Sequence[Model], training_variables: TrainingVariables, ignore_loss: bool): models = cast(Sequence[QFunction], models) # NOTE: Target q value depends on underlying implementation @@ -129,16 +128,15 @@ def _build_one_step_graph(self, self._q_loss += 0.0 if ignore_loss else q_loss # FIXME: using the last q function's td error for prioritized replay. Is this fine? - self._td_error = extra['td_error'] + self._td_error = extra["td_error"] self._td_error.persistent = True def _compute_target(self, training_variables: TrainingVariables) -> nn.Variable: raise NotImplementedError - def _compute_loss(self, - model: QFunction, - target_q: nn.Variable, - training_variables: TrainingVariables) -> Tuple[nn.Variable, Dict[str, nn.Variable]]: + def _compute_loss( + self, model: QFunction, target_q: nn.Variable, training_variables: TrainingVariables + ) -> Tuple[nn.Variable, Dict[str, nn.Variable]]: s_current = training_variables.s_current a_current = training_variables.a_current @@ -152,21 +150,21 @@ def _compute_loss(self, minimum = nn.Variable.from_numpy_array(np.full(td_error.shape, clip_min)) maximum = nn.Variable.from_numpy_array(np.full(td_error.shape, clip_max)) td_error = NF.clip_grad_by_value(td_error, minimum, maximum) - if self._config.loss_type == 'squared': + if self._config.loss_type == "squared": squared_td_error = training_variables.weight * NF.pow_scalar(td_error, 2.0) - elif self._config.loss_type == 'huber': + elif self._config.loss_type == "huber": zero = nn.Variable.from_numpy_array(np.zeros(shape=td_error.shape)) squared_td_error = training_variables.weight * NF.huber_loss(td_error, zero, delta=self._config.huber_delta) else: raise RuntimeError - if self._config.reduction_method == 'mean': + if self._config.reduction_method == "mean": q_loss += self._config.q_loss_scalar * NF.mean(squared_td_error) - elif self._config.reduction_method == 'sum': + elif self._config.reduction_method == "sum": q_loss += self._config.q_loss_scalar * NF.sum(squared_td_error) else: raise RuntimeError - extra = {'td_error': td_error} + extra = {"td_error": td_error} return q_loss, extra def _setup_training_variables(self, batch_size) -> TrainingVariables: @@ -185,15 +183,17 @@ def _setup_training_variables(self, batch_size) -> TrainingVariables: rnn_state_variables = create_variables(batch_size, model.internal_state_shapes()) rnn_states[model.scope_name] = rnn_state_variables - training_variables = TrainingVariables(batch_size=batch_size, - s_current=s_current_var, - a_current=a_current_var, - reward=reward_var, - gamma=gamma_var, - non_terminal=non_terminal_var, - s_next=s_next_var, - weight=weight_var, - rnn_states=rnn_states) + training_variables = TrainingVariables( + batch_size=batch_size, + s_current=s_current_var, + a_current=a_current_var, + reward=reward_var, + gamma=gamma_var, + non_terminal=non_terminal_var, + s_next=s_next_var, + weight=weight_var, + rnn_states=rnn_states, + ) return training_variables @property diff --git a/nnabla_rl/model_trainers/q_value/state_action_quantile_function_trainer.py b/nnabla_rl/model_trainers/q_value/state_action_quantile_function_trainer.py index 11873ea9..a6bafbd3 100644 --- a/nnabla_rl/model_trainers/q_value/state_action_quantile_function_trainer.py +++ b/nnabla_rl/model_trainers/q_value/state_action_quantile_function_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -44,20 +44,24 @@ class StateActionQuantileFunctionTrainer(MultiStepTrainer): _quantile_huber_loss: nn.Variable _prev_rnn_states: Dict[str, Dict[str, nn.Variable]] - def __init__(self, - models: Union[StateActionQuantileFunction, Sequence[StateActionQuantileFunction]], - solvers: Dict[str, nn.solver.Solver], - env_info: EnvironmentInfo, - config: StateActionQuantileFunctionTrainerConfig = StateActionQuantileFunctionTrainerConfig()): + def __init__( + self, + models: Union[StateActionQuantileFunction, Sequence[StateActionQuantileFunction]], + solvers: Dict[str, nn.solver.Solver], + env_info: EnvironmentInfo, + config: StateActionQuantileFunctionTrainerConfig = StateActionQuantileFunctionTrainerConfig(), + ): self._prev_rnn_states = {} super(StateActionQuantileFunctionTrainer, self).__init__(models, solvers, env_info, config) - def _update_model(self, - models: Sequence[Model], - solvers: Dict[str, nn.solver.Solver], - batch: TrainingBatch, - training_variables: TrainingVariables, - **kwargs) -> Dict[str, np.ndarray]: + def _update_model( + self, + models: Sequence[Model], + solvers: Dict[str, nn.solver.Solver], + batch: TrainingBatch, + training_variables: TrainingVariables, + **kwargs, + ) -> Dict[str, np.ndarray]: for t, b in zip(training_variables, batch): set_data_to_variable(t.s_current, b.s_current) set_data_to_variable(t.a_current, b.a_current) @@ -87,12 +91,10 @@ def _update_model(self, solver.update() trainer_state = {} - trainer_state['q_loss'] = self._quantile_huber_loss.d.copy() + trainer_state["q_loss"] = self._quantile_huber_loss.d.copy() return trainer_state - def _build_training_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables): + def _build_training_graph(self, models: Sequence[Model], training_variables: TrainingVariables): self._quantile_huber_loss = 0 ignore_intermediate_loss = self._config.loss_integration is LossIntegration.LAST_TIMESTEP_ONLY for step_index, variables in enumerate(training_variables): @@ -101,10 +103,7 @@ def _build_training_graph(self, ignore_loss = is_burn_in_steps or (is_intermediate_steps and ignore_intermediate_loss) self._build_one_step_graph(models, variables, ignore_loss=ignore_loss) - def _build_one_step_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables, - ignore_loss: bool): + def _build_one_step_graph(self, models: Sequence[Model], training_variables: TrainingVariables, ignore_loss: bool): models = cast(Sequence[StateActionQuantileFunction], models) batch_size = training_variables.batch_size @@ -124,16 +123,13 @@ def _build_one_step_graph(self, def _compute_target(self, training_variables: TrainingVariables): raise NotImplementedError - def _compute_loss(self, - model: StateActionQuantileFunction, - target: nn.Variable, - training_variables: TrainingVariables) -> nn.Variable: + def _compute_loss( + self, model: StateActionQuantileFunction, target: nn.Variable, training_variables: TrainingVariables + ) -> nn.Variable: batch_size = training_variables.batch_size tau_i = model.sample_tau(shape=(batch_size, self._config.N)) - Z_tau_i = model.quantile_values(training_variables.s_current, - training_variables.a_current, - tau_i) + Z_tau_i = model.quantile_values(training_variables.s_current, training_variables.a_current, tau_i) Z_tau_i = RF.expand_dims(Z_tau_i, axis=2) tau_i = RF.expand_dims(tau_i, axis=2) assert Z_tau_i.shape == (batch_size, self._config.N, 1) @@ -160,14 +156,16 @@ def _setup_training_variables(self, batch_size) -> TrainingVariables: rnn_state_variables = create_variables(batch_size, model.internal_state_shapes()) rnn_states[model.scope_name] = rnn_state_variables - training_variables = TrainingVariables(batch_size=batch_size, - s_current=s_current_var, - a_current=a_current_var, - reward=reward_var, - gamma=gamma_var, - non_terminal=non_terminal_var, - s_next=s_next_var, - rnn_states=rnn_states) + training_variables = TrainingVariables( + batch_size=batch_size, + s_current=s_current_var, + a_current=a_current_var, + reward=reward_var, + gamma=gamma_var, + non_terminal=non_terminal_var, + s_next=s_next_var, + rnn_states=rnn_states, + ) return training_variables diff --git a/nnabla_rl/model_trainers/q_value/td3_q_trainer.py b/nnabla_rl/model_trainers/q_value/td3_q_trainer.py index 137132e2..6d3fe337 100644 --- a/nnabla_rl/model_trainers/q_value/td3_q_trainer.py +++ b/nnabla_rl/model_trainers/q_value/td3_q_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,8 +20,10 @@ import nnabla_rl.functions as RF from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.model_trainers.model_trainer import TrainingVariables, rnn_support -from nnabla_rl.model_trainers.q_value.squared_td_q_function_trainer import (SquaredTDQFunctionTrainer, - SquaredTDQFunctionTrainerConfig) +from nnabla_rl.model_trainers.q_value.squared_td_q_function_trainer import ( + SquaredTDQFunctionTrainer, + SquaredTDQFunctionTrainerConfig, +) from nnabla_rl.models import DeterministicPolicy, QFunction from nnabla_rl.utils.data import convert_to_list_if_not_list from nnabla_rl.utils.misc import create_variables @@ -43,13 +45,15 @@ class TD3QTrainer(SquaredTDQFunctionTrainer): _prev_target_rnn_states: Dict[str, Dict[str, nn.Variable]] _prev_q_rnn_states: Dict[str, Dict[str, nn.Variable]] - def __init__(self, - train_functions: Union[QFunction, Sequence[QFunction]], - solvers: Dict[str, nn.solver.Solver], - target_functions: Union[QFunction, Sequence[QFunction]], - target_policy: DeterministicPolicy, - env_info: EnvironmentInfo, - config: TD3QTrainerConfig = TD3QTrainerConfig()): + def __init__( + self, + train_functions: Union[QFunction, Sequence[QFunction]], + solvers: Dict[str, nn.solver.Solver], + target_functions: Union[QFunction, Sequence[QFunction]], + target_policy: DeterministicPolicy, + env_info: EnvironmentInfo, + config: TD3QTrainerConfig = TD3QTrainerConfig(), + ): self._target_policy = target_policy self._target_functions = convert_to_list_if_not_list(target_functions) self._assert_no_duplicate_model(self._target_functions) @@ -84,10 +88,11 @@ def _compute_target(self, training_variables: TrainingVariables, **kwargs) -> nn def _compute_noisy_action(self, state): a_next_var = self._target_policy.pi(state) - epsilon = NF.clip_by_value(NF.randn(sigma=self._config.train_action_noise_sigma, - shape=a_next_var.shape), - min=-self._config.train_action_noise_abs, - max=self._config.train_action_noise_abs) + epsilon = NF.clip_by_value( + NF.randn(sigma=self._config.train_action_noise_sigma, shape=a_next_var.shape), + min=-self._config.train_action_noise_abs, + max=self._config.train_action_noise_abs, + ) a_tilde_var = a_next_var + epsilon return a_tilde_var diff --git a/nnabla_rl/model_trainers/q_value/v_targeted_q_trainer.py b/nnabla_rl/model_trainers/q_value/v_targeted_q_trainer.py index a25016ac..e5ac50ab 100644 --- a/nnabla_rl/model_trainers/q_value/v_targeted_q_trainer.py +++ b/nnabla_rl/model_trainers/q_value/v_targeted_q_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,8 +19,10 @@ import nnabla_rl.functions as RNF from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.model_trainers.model_trainer import TrainingVariables, rnn_support -from nnabla_rl.model_trainers.q_value.squared_td_q_function_trainer import (SquaredTDQFunctionTrainer, - SquaredTDQFunctionTrainerConfig) +from nnabla_rl.model_trainers.q_value.squared_td_q_function_trainer import ( + SquaredTDQFunctionTrainer, + SquaredTDQFunctionTrainerConfig, +) from nnabla_rl.models import QFunction, VFunction from nnabla_rl.utils.data import convert_to_list_if_not_list from nnabla_rl.utils.misc import create_variables @@ -35,6 +37,7 @@ class VTargetedQTrainerConfig(SquaredTDQFunctionTrainerConfig): :math:`target=\\gamma\\times V(s_{t+1})`.\ Used in Disentangled MME.\ """ + pure_exploration: bool = False @@ -46,12 +49,14 @@ class VTargetedQTrainer(SquaredTDQFunctionTrainer): _target_v_rnn_states: Dict[str, Dict[str, nn.Variable]] _config: VTargetedQTrainerConfig - def __init__(self, - train_functions: Union[QFunction, Sequence[QFunction]], - solvers: Dict[str, nn.solver.Solver], - target_functions: Union[VFunction, Sequence[VFunction]], - env_info: EnvironmentInfo, - config: VTargetedQTrainerConfig = VTargetedQTrainerConfig()): + def __init__( + self, + train_functions: Union[QFunction, Sequence[QFunction]], + solvers: Dict[str, nn.solver.Solver], + target_functions: Union[VFunction, Sequence[VFunction]], + env_info: EnvironmentInfo, + config: VTargetedQTrainerConfig = VTargetedQTrainerConfig(), + ): self._target_functions = convert_to_list_if_not_list(target_functions) self._assert_no_duplicate_model(self._target_functions) self._target_v_rnn_states = {} diff --git a/nnabla_rl/model_trainers/q_value/value_distribution_function_trainer.py b/nnabla_rl/model_trainers/q_value/value_distribution_function_trainer.py index 69e5091b..7212e247 100644 --- a/nnabla_rl/model_trainers/q_value/value_distribution_function_trainer.py +++ b/nnabla_rl/model_trainers/q_value/value_distribution_function_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,13 +29,13 @@ @dataclass class ValueDistributionFunctionTrainerConfig(MultiStepTrainerConfig): - reduction_method: str = 'mean' + reduction_method: str = "mean" v_min: float = -10.0 v_max: float = 10.0 num_atoms: int = 51 def __post_init__(self): - self._assert_one_of(self.reduction_method, ['sum', 'mean'], 'reduction_method') + self._assert_one_of(self.reduction_method, ["sum", "mean"], "reduction_method") return super().__post_init__() @@ -50,20 +50,24 @@ class ValueDistributionFunctionTrainer(MultiStepTrainer): _cross_entropy_loss: nn.Variable _prev_rnn_states: Dict[str, Dict[str, nn.Variable]] - def __init__(self, - models: Union[ValueDistributionFunction, Sequence[ValueDistributionFunction]], - solvers: Dict[str, nn.solver.Solver], - env_info: EnvironmentInfo, - config: ValueDistributionFunctionTrainerConfig): + def __init__( + self, + models: Union[ValueDistributionFunction, Sequence[ValueDistributionFunction]], + solvers: Dict[str, nn.solver.Solver], + env_info: EnvironmentInfo, + config: ValueDistributionFunctionTrainerConfig, + ): self._prev_rnn_states = {} super(ValueDistributionFunctionTrainer, self).__init__(models, solvers, env_info, config) - def _update_model(self, - models: Sequence[Model], - solvers: Dict[str, nn.solver.Solver], - batch: TrainingBatch, - training_variables: TrainingVariables, - **kwargs) -> Dict[str, np.ndarray]: + def _update_model( + self, + models: Sequence[Model], + solvers: Dict[str, nn.solver.Solver], + batch: TrainingBatch, + training_variables: TrainingVariables, + **kwargs, + ) -> Dict[str, np.ndarray]: for t, b in zip(training_variables, batch): set_data_to_variable(t.s_current, b.s_current) set_data_to_variable(t.a_current, b.a_current) @@ -96,13 +100,11 @@ def _update_model(self, # Kullbuck Leibler divergence is not actually the td_error itself # but is used for prioritizing the replay buffer and we save it as 'td_error' for convenience # See: https://arxiv.org/pdf/1710.02298.pdf - trainer_state['td_errors'] = self._kl_loss.d.copy() - trainer_state['cross_entropy_loss'] = float(self._cross_entropy_loss.d.copy()) + trainer_state["td_errors"] = self._kl_loss.d.copy() + trainer_state["cross_entropy_loss"] = float(self._cross_entropy_loss.d.copy()) return trainer_state - def _build_training_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables): + def _build_training_graph(self, models: Sequence[Model], training_variables: TrainingVariables): self._cross_entropy_loss = 0 ignore_intermediate_loss = self._config.loss_integration is LossIntegration.LAST_TIMESTEP_ONLY for step_index, variables in enumerate(training_variables): @@ -111,10 +113,7 @@ def _build_training_graph(self, ignore_loss = is_burn_in_steps or (is_intermediate_steps and ignore_intermediate_loss) self._build_one_step_graph(models, variables, ignore_loss=ignore_loss) - def _build_one_step_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables, - ignore_loss: bool): + def _build_one_step_graph(self, models: Sequence[Model], training_variables: TrainingVariables, ignore_loss: bool): models = cast(Sequence[ValueDistributionFunction], models) # Computing the target probabilities @@ -131,17 +130,16 @@ def _build_one_step_graph(self, # for prioritized experience replay # See: https://arxiv.org/pdf/1710.02298.pdf # keep kl_loss only for the last model for prioritized replay - kl_loss = extra_info['kl_loss'] + kl_loss = extra_info["kl_loss"] self._kl_loss = kl_loss self._kl_loss.persistent = True def _compute_target(self, training_variables: TrainingVariables) -> nn.Variable: raise NotImplementedError - def _compute_loss(self, - model: ValueDistributionFunction, - target: nn.Variable, - training_variables: TrainingVariables) -> Tuple[nn.Variable, Dict[str, nn.Variable]]: + def _compute_loss( + self, model: ValueDistributionFunction, target: nn.Variable, training_variables: TrainingVariables + ) -> Tuple[nn.Variable, Dict[str, nn.Variable]]: batch_size = training_variables.batch_size atom_probabilities = model.probs(training_variables.s_current, training_variables.a_current) atom_probabilities = NF.clip_by_value(atom_probabilities, 1e-10, 1.0) @@ -149,13 +147,13 @@ def _compute_loss(self, assert cross_entropy.shape == (batch_size, self._config.num_atoms) kl_loss = -NF.sum(cross_entropy, axis=1, keepdims=True) - if self._config.reduction_method == 'mean': + if self._config.reduction_method == "mean": loss = NF.mean(kl_loss * training_variables.weight) - elif self._config.reduction_method == 'sum': + elif self._config.reduction_method == "sum": loss = NF.sum(kl_loss * training_variables.weight) else: raise RuntimeError - extra = {'kl_loss': kl_loss} + extra = {"kl_loss": kl_loss} return loss, extra def _setup_training_variables(self, batch_size) -> TrainingVariables: @@ -174,15 +172,17 @@ def _setup_training_variables(self, batch_size) -> TrainingVariables: rnn_state_variables = create_variables(batch_size, model.internal_state_shapes()) rnn_states[model.scope_name] = rnn_state_variables - training_variables = TrainingVariables(batch_size=batch_size, - s_current=s_current_var, - a_current=a_current_var, - reward=reward_var, - gamma=gamma_var, - non_terminal=non_terminal_var, - s_next=s_next_var, - weight=weight_var, - rnn_states=rnn_states) + training_variables = TrainingVariables( + batch_size=batch_size, + s_current=s_current_var, + a_current=a_current_var, + reward=reward_var, + gamma=gamma_var, + non_terminal=non_terminal_var, + s_next=s_next_var, + weight=weight_var, + rnn_states=rnn_states, + ) return training_variables @property diff --git a/nnabla_rl/model_trainers/reward/__init__.py b/nnabla_rl/model_trainers/reward/__init__.py index 5000c078..d7fd42c7 100644 --- a/nnabla_rl/model_trainers/reward/__init__.py +++ b/nnabla_rl/model_trainers/reward/__init__.py @@ -13,6 +13,10 @@ # limitations under the License. from nnabla_rl.model_trainers.reward.gail_reward_function_trainer import ( # noqa - GAILRewardFunctionTrainer, GAILRewardFunctionTrainerConfig) + GAILRewardFunctionTrainer, + GAILRewardFunctionTrainerConfig, +) from nnabla_rl.model_trainers.reward.amp_reward_function_trainer import ( # noqa - AMPRewardFunctionTrainer, AMPRewardFunctionTrainerConfig) + AMPRewardFunctionTrainer, + AMPRewardFunctionTrainerConfig, +) diff --git a/nnabla_rl/model_trainers/reward/amp_reward_function_trainer.py b/nnabla_rl/model_trainers/reward/amp_reward_function_trainer.py index 47f3461e..f23f9000 100644 --- a/nnabla_rl/model_trainers/reward/amp_reward_function_trainer.py +++ b/nnabla_rl/model_trainers/reward/amp_reward_function_trainer.py @@ -20,8 +20,13 @@ import nnabla as nn import nnabla.functions as NF from nnabla_rl.environments.environment_info import EnvironmentInfo -from nnabla_rl.model_trainers.model_trainer import (LossIntegration, ModelTrainer, TrainerConfig, TrainingBatch, - TrainingVariables) +from nnabla_rl.model_trainers.model_trainer import ( + LossIntegration, + ModelTrainer, + TrainerConfig, + TrainingBatch, + TrainingVariables, +) from nnabla_rl.models import Model, RewardFunction from nnabla_rl.preprocessors.preprocessor import Preprocessor from nnabla_rl.utils.data import convert_to_list_if_not_list, set_data_to_variable @@ -50,21 +55,25 @@ class AMPRewardFunctionTrainer(ModelTrainer): _grad_penalty_loss: nn.Variable _regularization_loss: nn.Variable - def __init__(self, - models: Union[RewardFunction, Sequence[RewardFunction]], - solvers: Dict[str, nn.solver.Solver], - env_info: EnvironmentInfo, - state_preprocessor: Optional[Preprocessor] = None, - config: AMPRewardFunctionTrainerConfig = AMPRewardFunctionTrainerConfig()): + def __init__( + self, + models: Union[RewardFunction, Sequence[RewardFunction]], + solvers: Dict[str, nn.solver.Solver], + env_info: EnvironmentInfo, + state_preprocessor: Optional[Preprocessor] = None, + config: AMPRewardFunctionTrainerConfig = AMPRewardFunctionTrainerConfig(), + ): self._state_preprocessor = state_preprocessor super(AMPRewardFunctionTrainer, self).__init__(models, solvers, env_info, config) - def _update_model(self, - models: Iterable[Model], - solvers: Dict[str, nn.solver.Solver], - batch: TrainingBatch, - training_variables: TrainingVariables, - **kwargs) -> Dict[str, np.ndarray]: + def _update_model( + self, + models: Iterable[Model], + solvers: Dict[str, nn.solver.Solver], + batch: TrainingBatch, + training_variables: TrainingVariables, + **kwargs, + ) -> Dict[str, np.ndarray]: for t, b in zip(training_variables, batch): for key in batch.extra.keys(): set_data_to_variable(t.extra[key], b.extra[key]) @@ -107,10 +116,12 @@ def _build_training_graph( ignore_loss = is_burn_in_steps or (is_intermediate_steps and ignore_intermediate_loss) self._build_one_step_graph(models, variables, ignore_loss=ignore_loss) - self._total_reward_loss = (self._binary_regression_loss - + self._grad_penalty_loss - + self._regularization_loss - + self._extra_regularization_loss) + self._total_reward_loss = ( + self._binary_regression_loss + + self._grad_penalty_loss + + self._regularization_loss + + self._extra_regularization_loss + ) # To check all loss is built assert isinstance(self._binary_regression_loss, nn.Variable) @@ -154,12 +165,14 @@ def _setup_training_variables(self, batch_size): a_current_agent_var = create_variable(batch_size, self._env_info.action_shape) a_current_expert_var = create_variable(batch_size, self._env_info.action_shape) - variables = {"s_current_expert": s_current_expert_var, - "a_current_expert": a_current_expert_var, - "s_next_expert": s_next_expert_var, - "s_current_agent": s_current_agent_var, - "a_current_agent": a_current_agent_var, - "s_next_agent": s_next_agent_var} + variables = { + "s_current_expert": s_current_expert_var, + "a_current_expert": a_current_expert_var, + "s_next_expert": s_next_expert_var, + "s_current_agent": s_current_agent_var, + "a_current_agent": a_current_agent_var, + "s_next_agent": s_next_agent_var, + } training_variables = TrainingVariables(batch_size, extra=variables) @@ -172,8 +185,13 @@ def _build_adversarial_loss(self, model: RewardFunction, training_variables: Tra _apply_need_grad_true(s_n_expert) logits_real, logits_fake = self._compute_logits( - model, s_expert, training_variables.extra["a_current_expert"], s_n_expert, - s_agent, training_variables.extra["a_current_agent"], s_n_agent, + model, + s_expert, + training_variables.extra["a_current_expert"], + s_n_expert, + s_agent, + training_variables.extra["a_current_agent"], + s_n_agent, ) real_loss = 0.5 * NF.mean((logits_real - 1.0) ** 2) fake_loss = 0.5 * NF.mean((logits_fake + 1.0) ** 2) @@ -186,8 +204,9 @@ def _build_adversarial_loss(self, model: RewardFunction, training_variables: Tra next_state_grads = self._compute_gradient_wrt_state(logits_real, s_n_expert) next_state_grad_penalty = self._compute_gradient_penalty(next_state_grads) - grad_penalty_loss = self._config.gradient_penelty_coefficient * \ - (current_state_grad_penalty + next_state_grad_penalty) + grad_penalty_loss = self._config.gradient_penelty_coefficient * ( + current_state_grad_penalty + next_state_grad_penalty + ) return binary_regression_loss, grad_penalty_loss @@ -228,8 +247,9 @@ def _build_extra_regularization_penalty(self, model: RewardFunction): extra_regularization_loss = 0.0 model_params = model.get_parameters() for variable_name in self._config.extra_regularization_variable_names: - extra_regularization_loss += self._config.extra_regularization_coefficient * \ - 0.5 * NF.sum(model_params[variable_name] ** 2) + extra_regularization_loss += ( + self._config.extra_regularization_coefficient * 0.5 * NF.sum(model_params[variable_name] ** 2) + ) return extra_regularization_loss def _compute_gradient_wrt_state( diff --git a/nnabla_rl/model_trainers/reward/gail_reward_function_trainer.py b/nnabla_rl/model_trainers/reward/gail_reward_function_trainer.py index 5ed89cc8..4c61fd3f 100644 --- a/nnabla_rl/model_trainers/reward/gail_reward_function_trainer.py +++ b/nnabla_rl/model_trainers/reward/gail_reward_function_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,8 +19,13 @@ import nnabla as nn import nnabla.functions as NF -from nnabla_rl.model_trainers.model_trainer import (LossIntegration, ModelTrainer, TrainerConfig, TrainingBatch, - TrainingVariables) +from nnabla_rl.model_trainers.model_trainer import ( + LossIntegration, + ModelTrainer, + TrainerConfig, + TrainingBatch, + TrainingVariables, +) from nnabla_rl.models import Model, RewardFunction from nnabla_rl.utils.data import convert_to_list_if_not_list, set_data_to_variable from nnabla_rl.utils.misc import create_variable @@ -43,19 +48,23 @@ class GAILRewardFunctionTrainer(ModelTrainer): _config: GAILRewardFunctionTrainerConfig _binary_classification_loss: nn.Variable - def __init__(self, - models: Union[RewardFunction, Sequence[RewardFunction]], - solvers: Dict[str, nn.solver.Solver], - env_info, - config=GAILRewardFunctionTrainerConfig()): + def __init__( + self, + models: Union[RewardFunction, Sequence[RewardFunction]], + solvers: Dict[str, nn.solver.Solver], + env_info, + config=GAILRewardFunctionTrainerConfig(), + ): super(GAILRewardFunctionTrainer, self).__init__(models, solvers, env_info, config) - def _update_model(self, - models: Iterable[Model], - solvers: Dict[str, nn.solver.Solver], - batch: TrainingBatch, - training_variables: TrainingVariables, - **kwargs) -> Dict[str, np.ndarray]: + def _update_model( + self, + models: Iterable[Model], + solvers: Dict[str, nn.solver.Solver], + batch: TrainingBatch, + training_variables: TrainingVariables, + **kwargs, + ) -> Dict[str, np.ndarray]: for t, b in zip(training_variables, batch): for key in batch.extra.keys(): set_data_to_variable(t.extra[key], b.extra[key]) @@ -69,11 +78,10 @@ def _update_model(self, solver.update() trainer_state: Dict[str, np.ndarray] = {} - trainer_state['reward_loss'] = self._binary_classification_loss.d.copy() + trainer_state["reward_loss"] = self._binary_classification_loss.d.copy() return trainer_state - def _build_training_graph(self, models: Union[Model, Sequence[Model]], - training_variables: TrainingVariables): + def _build_training_graph(self, models: Union[Model, Sequence[Model]], training_variables: TrainingVariables): models = convert_to_list_if_not_list(models) models = cast(Sequence[RewardFunction], models) @@ -85,26 +93,27 @@ def _build_training_graph(self, models: Union[Model, Sequence[Model]], ignore_loss = is_burn_in_steps or (is_intermediate_steps and ignore_intermediate_loss) self._build_one_step_graph(models, variables, ignore_loss=ignore_loss) - def _build_one_step_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables, - ignore_loss: bool): + def _build_one_step_graph(self, models: Sequence[Model], training_variables: TrainingVariables, ignore_loss: bool): models = cast(Sequence[RewardFunction], models) for model in models: # fake path - logits_fake = model.r(training_variables.extra['s_current_agent'], - training_variables.extra['a_current_agent'], - training_variables.extra['s_next_agent']) + logits_fake = model.r( + training_variables.extra["s_current_agent"], + training_variables.extra["a_current_agent"], + training_variables.extra["s_next_agent"], + ) fake_loss = NF.mean(NF.sigmoid_cross_entropy(logits_fake, NF.constant(0, logits_fake.shape))) # real path - logits_real = model.r(training_variables.extra['s_current_expert'], - training_variables.extra['a_current_expert'], - training_variables.extra['s_next_expert']) + logits_real = model.r( + training_variables.extra["s_current_expert"], + training_variables.extra["a_current_expert"], + training_variables.extra["s_next_expert"], + ) real_loss = NF.mean(NF.sigmoid_cross_entropy(logits_real, NF.constant(1, logits_real.shape))) # entropy loss logits = NF.concatenate(logits_fake, logits_real, axis=0) - entropy = NF.mean((1. - NF.sigmoid(logits)) * logits - NF.log_sigmoid(logits)) - entropy_loss = - self._config.entropy_coef * entropy # maximize + entropy = NF.mean((1.0 - NF.sigmoid(logits)) * logits - NF.log_sigmoid(logits)) + entropy_loss = -self._config.entropy_coef * entropy # maximize self._binary_classification_loss += 0.0 if ignore_loss else fake_loss + real_loss + entropy_loss def _setup_training_variables(self, batch_size): @@ -115,12 +124,14 @@ def _setup_training_variables(self, batch_size): a_current_agent_var = create_variable(batch_size, self._env_info.action_shape) a_current_expert_var = create_variable(batch_size, self._env_info.action_shape) - variables = {'s_current_expert': s_current_expert_var, - 'a_current_expert': a_current_expert_var, - 's_next_expert': s_next_expert_var, - 's_current_agent': s_current_agent_var, - 'a_current_agent': a_current_agent_var, - 's_next_agent': s_next_agent_var} + variables = { + "s_current_expert": s_current_expert_var, + "a_current_expert": a_current_expert_var, + "s_next_expert": s_next_expert_var, + "s_current_agent": s_current_agent_var, + "a_current_agent": a_current_agent_var, + "s_next_agent": s_next_agent_var, + } return TrainingVariables(batch_size, extra=variables) diff --git a/nnabla_rl/model_trainers/v_value/__init__.py b/nnabla_rl/model_trainers/v_value/__init__.py index ce777c09..f3defa84 100644 --- a/nnabla_rl/model_trainers/v_value/__init__.py +++ b/nnabla_rl/model_trainers/v_value/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,13 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nnabla_rl.model_trainers.v_value.demme_v_trainer import ( # noqa - DEMMEVTrainer, DEMMEVTrainerConfig) -from nnabla_rl.model_trainers.v_value.xql_v_trainer import ( # noqa - XQLVTrainer, XQLVTrainerConfig) -from nnabla_rl.model_trainers.v_value.mme_v_trainer import ( # noqa - MMEVTrainer, MMEVTrainerConfig) -from nnabla_rl.model_trainers.v_value.monte_carlo_v_trainer import ( # noqa - MonteCarloVTrainer, MonteCarloVTrainerConfig) -from nnabla_rl.model_trainers.v_value.soft_v_trainer import ( # noqa - SoftVTrainer, SoftVTrainerConfig) +from nnabla_rl.model_trainers.v_value.demme_v_trainer import DEMMEVTrainer, DEMMEVTrainerConfig # noqa +from nnabla_rl.model_trainers.v_value.xql_v_trainer import XQLVTrainer, XQLVTrainerConfig # noqa +from nnabla_rl.model_trainers.v_value.mme_v_trainer import MMEVTrainer, MMEVTrainerConfig # noqa +from nnabla_rl.model_trainers.v_value.monte_carlo_v_trainer import MonteCarloVTrainer, MonteCarloVTrainerConfig # noqa +from nnabla_rl.model_trainers.v_value.soft_v_trainer import SoftVTrainer, SoftVTrainerConfig # noqa diff --git a/nnabla_rl/model_trainers/v_value/demme_v_trainer.py b/nnabla_rl/model_trainers/v_value/demme_v_trainer.py index 67f231d1..1fe4e50b 100644 --- a/nnabla_rl/model_trainers/v_value/demme_v_trainer.py +++ b/nnabla_rl/model_trainers/v_value/demme_v_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,8 +19,10 @@ import nnabla_rl.functions as RNF from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.model_trainers.model_trainer import TrainingVariables, rnn_support -from nnabla_rl.model_trainers.v_value.squared_td_v_function_trainer import (SquaredTDVFunctionTrainer, - SquaredTDVFunctionTrainerConfig) +from nnabla_rl.model_trainers.v_value.squared_td_v_function_trainer import ( + SquaredTDVFunctionTrainer, + SquaredTDVFunctionTrainerConfig, +) from nnabla_rl.models import QFunction, StochasticPolicy, VFunction from nnabla_rl.utils.data import convert_to_list_if_not_list from nnabla_rl.utils.misc import create_variables @@ -40,13 +42,15 @@ class DEMMEVTrainer(SquaredTDVFunctionTrainer): _prev_target_rnn_states: Dict[str, Dict[str, nn.Variable]] _prev_q_rnn_states: Dict[str, Dict[str, nn.Variable]] - def __init__(self, - train_functions: Union[VFunction, Sequence[VFunction]], - solvers: Dict[str, nn.solver.Solver], - target_functions: Union[QFunction, Sequence[QFunction]], - target_policy: StochasticPolicy, - env_info: EnvironmentInfo, - config: DEMMEVTrainerConfig = DEMMEVTrainerConfig()): + def __init__( + self, + train_functions: Union[VFunction, Sequence[VFunction]], + solvers: Dict[str, nn.solver.Solver], + target_functions: Union[QFunction, Sequence[QFunction]], + target_policy: StochasticPolicy, + env_info: EnvironmentInfo, + config: DEMMEVTrainerConfig = DEMMEVTrainerConfig(), + ): self._target_functions = convert_to_list_if_not_list(target_functions) self._target_policy = target_policy self._prev_target_rnn_states = {} diff --git a/nnabla_rl/model_trainers/v_value/extreme_v_function_trainer.py b/nnabla_rl/model_trainers/v_value/extreme_v_function_trainer.py index f9f6fef1..98a1f188 100644 --- a/nnabla_rl/model_trainers/v_value/extreme_v_function_trainer.py +++ b/nnabla_rl/model_trainers/v_value/extreme_v_function_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ class ExtremeVFunctionTrainerConfig(VFunctionTrainerConfig): def __post_init__(self): super(ExtremeVFunctionTrainerConfig, self).__post_init__() - self._assert_positive(self.beta, 'beta') + self._assert_positive(self.beta, "beta") class ExtremeVFunctionTrainer(VFunctionTrainer): @@ -40,18 +40,19 @@ class ExtremeVFunctionTrainer(VFunctionTrainer): _config: ExtremeVFunctionTrainerConfig _prev_rnn_states: Dict[str, Dict[str, nn.Variable]] - def __init__(self, - models: Union[VFunction, Sequence[VFunction]], - solvers: Dict[str, nn.solver.Solver], - env_info: EnvironmentInfo, - config: ExtremeVFunctionTrainerConfig = ExtremeVFunctionTrainerConfig()): + def __init__( + self, + models: Union[VFunction, Sequence[VFunction]], + solvers: Dict[str, nn.solver.Solver], + env_info: EnvironmentInfo, + config: ExtremeVFunctionTrainerConfig = ExtremeVFunctionTrainerConfig(), + ): self._prev_rnn_states = {} super(ExtremeVFunctionTrainer, self).__init__(models, solvers, env_info, config) - def _compute_loss(self, - model: VFunction, - target_value: nn.Variable, - training_variables: TrainingVariables) -> nn.Variable: + def _compute_loss( + self, model: VFunction, target_value: nn.Variable, training_variables: TrainingVariables + ) -> nn.Variable: prev_rnn_states = self._prev_rnn_states train_rnn_states = training_variables.rnn_states with rnn_support(model, prev_rnn_states, train_rnn_states, training_variables, self._config): @@ -65,4 +66,4 @@ def _compute_loss(self, # original code seems to rescale the gumbel loss by max(z) # i.e. exp(z) / max(z) - z / max(z) # - NF.exp(-max_z) <- this term exists in the original code but this should take no effect - return NF.exp(z-max_z) - z * NF.exp(-max_z) - NF.exp(-max_z) + return NF.exp(z - max_z) - z * NF.exp(-max_z) - NF.exp(-max_z) diff --git a/nnabla_rl/model_trainers/v_value/mme_v_trainer.py b/nnabla_rl/model_trainers/v_value/mme_v_trainer.py index fdcd99ac..b2aa1c48 100644 --- a/nnabla_rl/model_trainers/v_value/mme_v_trainer.py +++ b/nnabla_rl/model_trainers/v_value/mme_v_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,8 +21,10 @@ from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.model_trainers.model_trainer import TrainingVariables, rnn_support from nnabla_rl.model_trainers.policy.soft_policy_trainer import AdjustableTemperature -from nnabla_rl.model_trainers.v_value.squared_td_v_function_trainer import (SquaredTDVFunctionTrainer, - SquaredTDVFunctionTrainerConfig) +from nnabla_rl.model_trainers.v_value.squared_td_v_function_trainer import ( + SquaredTDVFunctionTrainer, + SquaredTDVFunctionTrainerConfig, +) from nnabla_rl.models import QFunction, StochasticPolicy, VFunction from nnabla_rl.utils.data import convert_to_list_if_not_list from nnabla_rl.utils.misc import create_variables @@ -42,14 +44,16 @@ class MMEVTrainer(SquaredTDVFunctionTrainer): _prev_target_rnn_states: Dict[str, Dict[str, nn.Variable]] _prev_q_rnn_states: Dict[str, Dict[str, nn.Variable]] - def __init__(self, - train_functions: Union[VFunction, Sequence[VFunction]], - solvers: Dict[str, nn.solver.Solver], - target_functions: Union[QFunction, Sequence[QFunction]], - target_policy: StochasticPolicy, - temperature: AdjustableTemperature, # value temperature alpha_q - env_info: EnvironmentInfo, - config: MMEVTrainerConfig = MMEVTrainerConfig()): + def __init__( + self, + train_functions: Union[VFunction, Sequence[VFunction]], + solvers: Dict[str, nn.solver.Solver], + target_functions: Union[QFunction, Sequence[QFunction]], + target_policy: StochasticPolicy, + temperature: AdjustableTemperature, # value temperature alpha_q + env_info: EnvironmentInfo, + config: MMEVTrainerConfig = MMEVTrainerConfig(), + ): self._target_functions = convert_to_list_if_not_list(target_functions) self._target_policy = target_policy self._temperature = temperature diff --git a/nnabla_rl/model_trainers/v_value/monte_carlo_v_trainer.py b/nnabla_rl/model_trainers/v_value/monte_carlo_v_trainer.py index d4f30afe..1f1e3395 100644 --- a/nnabla_rl/model_trainers/v_value/monte_carlo_v_trainer.py +++ b/nnabla_rl/model_trainers/v_value/monte_carlo_v_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,8 +20,10 @@ import nnabla as nn from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.model_trainers.model_trainer import TrainingBatch, TrainingVariables -from nnabla_rl.model_trainers.v_value.squared_td_v_function_trainer import (SquaredTDVFunctionTrainer, - SquaredTDVFunctionTrainerConfig) +from nnabla_rl.model_trainers.v_value.squared_td_v_function_trainer import ( + SquaredTDVFunctionTrainer, + SquaredTDVFunctionTrainerConfig, +) from nnabla_rl.models import VFunction from nnabla_rl.models.model import Model from nnabla_rl.utils.data import set_data_to_variable @@ -34,31 +36,35 @@ class MonteCarloVTrainerConfig(SquaredTDVFunctionTrainerConfig): class MonteCarloVTrainer(SquaredTDVFunctionTrainer): - def __init__(self, - train_functions: Union[VFunction, Sequence[VFunction]], - solvers: Dict[str, nn.solver.Solver], - env_info: EnvironmentInfo, - config: MonteCarloVTrainerConfig = MonteCarloVTrainerConfig()): + def __init__( + self, + train_functions: Union[VFunction, Sequence[VFunction]], + solvers: Dict[str, nn.solver.Solver], + env_info: EnvironmentInfo, + config: MonteCarloVTrainerConfig = MonteCarloVTrainerConfig(), + ): super(MonteCarloVTrainer, self).__init__(train_functions, solvers, env_info, config) - def _update_model(self, - models: Sequence[Model], - solvers: Dict[str, nn.solver.Solver], - batch: TrainingBatch, - training_variables: TrainingVariables, - **kwargs) -> Dict[str, np.ndarray]: + def _update_model( + self, + models: Sequence[Model], + solvers: Dict[str, nn.solver.Solver], + batch: TrainingBatch, + training_variables: TrainingVariables, + **kwargs, + ) -> Dict[str, np.ndarray]: for t, b in zip(training_variables, batch): - set_data_to_variable(t.extra['v_target'], b.extra['v_target']) + set_data_to_variable(t.extra["v_target"], b.extra["v_target"]) return super()._update_model(models, solvers, batch, training_variables, **kwargs) def _compute_target(self, training_variables: TrainingVariables, **kwargs) -> nn.Variable: - return training_variables.extra['v_target'] + return training_variables.extra["v_target"] def _setup_training_variables(self, batch_size) -> TrainingVariables: training_variables = super()._setup_training_variables(batch_size) extra = {} - extra['v_target'] = create_variable(batch_size, 1) + extra["v_target"] = create_variable(batch_size, 1) training_variables.extra.update(extra) return training_variables diff --git a/nnabla_rl/model_trainers/v_value/soft_v_trainer.py b/nnabla_rl/model_trainers/v_value/soft_v_trainer.py index d6d04d49..0ef21c48 100644 --- a/nnabla_rl/model_trainers/v_value/soft_v_trainer.py +++ b/nnabla_rl/model_trainers/v_value/soft_v_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,8 +19,10 @@ import nnabla_rl.functions as RNF from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.model_trainers.model_trainer import TrainingVariables, rnn_support -from nnabla_rl.model_trainers.v_value.squared_td_v_function_trainer import (SquaredTDVFunctionTrainer, - SquaredTDVFunctionTrainerConfig) +from nnabla_rl.model_trainers.v_value.squared_td_v_function_trainer import ( + SquaredTDVFunctionTrainer, + SquaredTDVFunctionTrainerConfig, +) from nnabla_rl.models import QFunction, StochasticPolicy, VFunction from nnabla_rl.utils.data import convert_to_list_if_not_list from nnabla_rl.utils.misc import create_variables @@ -40,13 +42,15 @@ class SoftVTrainer(SquaredTDVFunctionTrainer): _prev_target_rnn_states: Dict[str, Dict[str, nn.Variable]] _prev_q_rnn_states: Dict[str, Dict[str, nn.Variable]] - def __init__(self, - train_functions: Union[VFunction, Sequence[VFunction]], - solvers: Dict[str, nn.solver.Solver], - target_functions: Union[QFunction, Sequence[QFunction]], - target_policy: StochasticPolicy, - env_info: EnvironmentInfo, - config: SoftVTrainerConfig = SoftVTrainerConfig()): + def __init__( + self, + train_functions: Union[VFunction, Sequence[VFunction]], + solvers: Dict[str, nn.solver.Solver], + target_functions: Union[QFunction, Sequence[QFunction]], + target_policy: StochasticPolicy, + env_info: EnvironmentInfo, + config: SoftVTrainerConfig = SoftVTrainerConfig(), + ): self._target_functions = convert_to_list_if_not_list(target_functions) self._target_policy = target_policy self._prev_target_rnn_states = {} diff --git a/nnabla_rl/model_trainers/v_value/squared_td_v_function_trainer.py b/nnabla_rl/model_trainers/v_value/squared_td_v_function_trainer.py index 07d91476..dfeccab3 100644 --- a/nnabla_rl/model_trainers/v_value/squared_td_v_function_trainer.py +++ b/nnabla_rl/model_trainers/v_value/squared_td_v_function_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -35,18 +35,19 @@ class SquaredTDVFunctionTrainer(VFunctionTrainer): _config: SquaredTDVFunctionTrainerConfig _prev_rnn_states: Dict[str, Dict[str, nn.Variable]] - def __init__(self, - models: Union[VFunction, Sequence[VFunction]], - solvers: Dict[str, nn.solver.Solver], - env_info: EnvironmentInfo, - config: SquaredTDVFunctionTrainerConfig = SquaredTDVFunctionTrainerConfig()): + def __init__( + self, + models: Union[VFunction, Sequence[VFunction]], + solvers: Dict[str, nn.solver.Solver], + env_info: EnvironmentInfo, + config: SquaredTDVFunctionTrainerConfig = SquaredTDVFunctionTrainerConfig(), + ): self._prev_rnn_states = {} super(SquaredTDVFunctionTrainer, self).__init__(models, solvers, env_info, config) - def _compute_loss(self, - model: VFunction, - target_value: nn.Variable, - training_variables: TrainingVariables) -> nn.Variable: + def _compute_loss( + self, model: VFunction, target_value: nn.Variable, training_variables: TrainingVariables + ) -> nn.Variable: prev_rnn_states = self._prev_rnn_states train_rnn_states = training_variables.rnn_states with rnn_support(model, prev_rnn_states, train_rnn_states, training_variables, self._config): diff --git a/nnabla_rl/model_trainers/v_value/v_function_trainer.py b/nnabla_rl/model_trainers/v_value/v_function_trainer.py index 66af905f..279de65b 100644 --- a/nnabla_rl/model_trainers/v_value/v_function_trainer.py +++ b/nnabla_rl/model_trainers/v_value/v_function_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,8 +21,13 @@ import nnabla as nn import nnabla.functions as NF from nnabla_rl.environments.environment_info import EnvironmentInfo -from nnabla_rl.model_trainers.model_trainer import (LossIntegration, ModelTrainer, TrainerConfig, TrainingBatch, - TrainingVariables) +from nnabla_rl.model_trainers.model_trainer import ( + LossIntegration, + ModelTrainer, + TrainerConfig, + TrainingBatch, + TrainingVariables, +) from nnabla_rl.models import Model, VFunction from nnabla_rl.utils.data import set_data_to_variable from nnabla_rl.utils.misc import create_variable, create_variables @@ -30,12 +35,12 @@ @dataclass class VFunctionTrainerConfig(TrainerConfig): - reduction_method: str = 'mean' + reduction_method: str = "mean" v_loss_scalar: float = 1.0 def __post_init__(self): super(VFunctionTrainerConfig, self).__post_init__() - self._assert_one_of(self.reduction_method, ['sum', 'mean'], 'reduction_method') + self._assert_one_of(self.reduction_method, ["sum", "mean"], "reduction_method") class VFunctionTrainer(ModelTrainer): @@ -45,19 +50,23 @@ class VFunctionTrainer(ModelTrainer): _config: VFunctionTrainerConfig _v_loss: nn.Variable # Training loss/output - def __init__(self, - models: Union[VFunction, Sequence[VFunction]], - solvers: Dict[str, nn.solver.Solver], - env_info: EnvironmentInfo, - config: VFunctionTrainerConfig = VFunctionTrainerConfig()): + def __init__( + self, + models: Union[VFunction, Sequence[VFunction]], + solvers: Dict[str, nn.solver.Solver], + env_info: EnvironmentInfo, + config: VFunctionTrainerConfig = VFunctionTrainerConfig(), + ): super(VFunctionTrainer, self).__init__(models, solvers, env_info, config) - def _update_model(self, - models: Sequence[Model], - solvers: Dict[str, nn.solver.Solver], - batch: TrainingBatch, - training_variables: TrainingVariables, - **kwargs) -> Dict[str, np.ndarray]: + def _update_model( + self, + models: Sequence[Model], + solvers: Dict[str, nn.solver.Solver], + batch: TrainingBatch, + training_variables: TrainingVariables, + **kwargs, + ) -> Dict[str, np.ndarray]: for t, b in zip(training_variables, batch): set_data_to_variable(t.s_current, b.s_current) if self.support_rnn() and self._config.reset_on_terminal and self._need_rnn_support(models): @@ -85,7 +94,7 @@ def _update_model(self, solver.update() trainer_state: Dict[str, np.ndarray] = {} - trainer_state['v_loss'] = self._v_loss.d.copy() + trainer_state["v_loss"] = self._v_loss.d.copy() return trainer_state def _build_training_graph(self, models: Union[Model, Sequence[Model]], training_variables: TrainingVariables): @@ -98,10 +107,9 @@ def _build_training_graph(self, models: Union[Model, Sequence[Model]], training_ ignore_loss = is_burn_in_steps or (is_intermediate_steps and ignore_intermediate_loss) self._v_loss += self._build_one_step_graph(models, variables, ignore_loss=ignore_loss) - def _build_one_step_graph(self, - models: Sequence[Model], - training_variables: TrainingVariables, - ignore_loss: bool) -> nn.Variable: + def _build_one_step_graph( + self, models: Sequence[Model], training_variables: TrainingVariables, ignore_loss: bool + ) -> nn.Variable: # value function learning target_v = self._compute_target(training_variables) target_v.need_grad = False @@ -110,9 +118,9 @@ def _build_one_step_graph(self, for v_function in models: v_function = cast(VFunction, v_function) v_loss = self._compute_loss(v_function, target_v, training_variables) - if self._config.reduction_method == 'mean': + if self._config.reduction_method == "mean": v_loss = self._config.v_loss_scalar * NF.mean(v_loss) - elif self._config.reduction_method == 'sum': + elif self._config.reduction_method == "sum": v_loss = self._config.v_loss_scalar * NF.sum(v_loss) else: raise RuntimeError @@ -120,15 +128,13 @@ def _build_one_step_graph(self, return loss @abstractmethod - def _compute_loss(self, - v_function: VFunction, - target_v: nn.Variable, - traininge_variables: TrainingVariables) -> nn.Variable: + def _compute_loss( + self, v_function: VFunction, target_v: nn.Variable, traininge_variables: TrainingVariables + ) -> nn.Variable: raise NotImplementedError @abstractmethod - def _compute_target(self, - training_variables: TrainingVariables) -> nn.Variable: + def _compute_target(self, training_variables: TrainingVariables) -> nn.Variable: raise NotImplementedError def _setup_training_variables(self, batch_size) -> TrainingVariables: @@ -143,10 +149,9 @@ def _setup_training_variables(self, batch_size) -> TrainingVariables: rnn_state_variables = create_variables(batch_size, model.internal_state_shapes()) rnn_states[model.scope_name] = rnn_state_variables - training_variables = TrainingVariables(batch_size=batch_size, - s_current=s_current_var, - non_terminal=non_terminal_var, - rnn_states=rnn_states) + training_variables = TrainingVariables( + batch_size=batch_size, s_current=s_current_var, non_terminal=non_terminal_var, rnn_states=rnn_states + ) return training_variables @property diff --git a/nnabla_rl/model_trainers/v_value/xql_v_trainer.py b/nnabla_rl/model_trainers/v_value/xql_v_trainer.py index 90d88e12..53ebc353 100644 --- a/nnabla_rl/model_trainers/v_value/xql_v_trainer.py +++ b/nnabla_rl/model_trainers/v_value/xql_v_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,8 +21,10 @@ import nnabla_rl.functions as RNF from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.model_trainers.model_trainer import TrainingBatch, TrainingVariables, rnn_support -from nnabla_rl.model_trainers.v_value.extreme_v_function_trainer import (ExtremeVFunctionTrainer, - ExtremeVFunctionTrainerConfig) +from nnabla_rl.model_trainers.v_value.extreme_v_function_trainer import ( + ExtremeVFunctionTrainer, + ExtremeVFunctionTrainerConfig, +) from nnabla_rl.models import Model, QFunction, StochasticPolicy, VFunction from nnabla_rl.utils.data import convert_to_list_if_not_list, set_data_to_variable from nnabla_rl.utils.misc import create_variable @@ -41,13 +43,15 @@ class XQLVTrainer(ExtremeVFunctionTrainer): _prev_q_rnn_states: Dict[str, Dict[str, nn.Variable]] _prev_pi_rnn_states: Dict[str, Dict[str, nn.Variable]] - def __init__(self, - train_functions: Union[VFunction, Sequence[VFunction]], - solvers: Dict[str, nn.solver.Solver], - target_functions: Union[QFunction, Sequence[QFunction]], - env_info: EnvironmentInfo, - target_policy: Optional[StochasticPolicy] = None, - config: XQLVTrainerConfig = XQLVTrainerConfig()): + def __init__( + self, + train_functions: Union[VFunction, Sequence[VFunction]], + solvers: Dict[str, nn.solver.Solver], + target_functions: Union[QFunction, Sequence[QFunction]], + env_info: EnvironmentInfo, + target_policy: Optional[StochasticPolicy] = None, + config: XQLVTrainerConfig = XQLVTrainerConfig(), + ): self._target_policy = target_policy self._prev_pi_rnn_states = {} @@ -59,12 +63,14 @@ def __init__(self, def support_rnn(self) -> bool: return True - def _update_model(self, - models: Sequence[Model], - solvers: Dict[str, nn.solver.Solver], - batch: TrainingBatch, - training_variables: TrainingVariables, - **kwargs) -> Dict[str, np.ndarray]: + def _update_model( + self, + models: Sequence[Model], + solvers: Dict[str, nn.solver.Solver], + batch: TrainingBatch, + training_variables: TrainingVariables, + **kwargs, + ) -> Dict[str, np.ndarray]: for t, b in zip(training_variables, batch): set_data_to_variable(t.a_current, b.a_current) diff --git a/nnabla_rl/models/__init__.py b/nnabla_rl/models/__init__.py index 2d455cb7..a0978129 100644 --- a/nnabla_rl/models/__init__.py +++ b/nnabla_rl/models/__init__.py @@ -13,29 +13,48 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nnabla_rl.models.decision_transformer import (DecisionTransformer, # noqa - DeterministicDecisionTransformer, - StochasticDecisionTransformer) -from nnabla_rl.models.distributional_function import (ValueDistributionFunction, # noqa - DiscreteValueDistributionFunction, - ContinuousValueDistributionFunction) -from nnabla_rl.models.distributional_function import (QuantileDistributionFunction, # noqa - DiscreteQuantileDistributionFunction, - ContinuousQuantileDistributionFunction) -from nnabla_rl.models.distributional_function import (StateActionQuantileFunction, # noqa - DiscreteStateActionQuantileFunction, - ContinuousStateActionQuantileFunction) +from nnabla_rl.models.decision_transformer import ( # noqa + DecisionTransformer, + DeterministicDecisionTransformer, + StochasticDecisionTransformer, +) +from nnabla_rl.models.distributional_function import ( # noqa + ValueDistributionFunction, + DiscreteValueDistributionFunction, + ContinuousValueDistributionFunction, +) +from nnabla_rl.models.distributional_function import ( # noqa + QuantileDistributionFunction, + DiscreteQuantileDistributionFunction, + ContinuousQuantileDistributionFunction, +) +from nnabla_rl.models.distributional_function import ( # noqa + StateActionQuantileFunction, + DiscreteStateActionQuantileFunction, + ContinuousStateActionQuantileFunction, +) from nnabla_rl.models.dynamics import Dynamics, DeterministicDynamics, StochasticDynamics # noqa from nnabla_rl.models.model import Model # noqa from nnabla_rl.models.perturbator import Perturbator # noqa from nnabla_rl.models.policy import Policy, DeterministicPolicy, StochasticPolicy # noqa -from nnabla_rl.models.q_function import QFunction, DiscreteQFunction, ContinuousQFunction, FactoredContinuousQFunction # noqa +from nnabla_rl.models.q_function import ( # noqa + QFunction, + DiscreteQFunction, + ContinuousQFunction, + FactoredContinuousQFunction, +) from nnabla_rl.models.v_function import VFunction # noqa from nnabla_rl.models.reward_function import RewardFunction # noqa from nnabla_rl.models.encoder import Encoder, VariationalAutoEncoder # noqa from nnabla_rl.models.mujoco.policies import TD3Policy, SACPolicy, BEARPolicy, TRPOPolicy # noqa -from nnabla_rl.models.mujoco.q_functions import TD3QFunction, SACQFunction, SACDQFunction, HERQFunction, XQLQFunction # noqa +from nnabla_rl.models.mujoco.q_functions import ( # noqa + TD3QFunction, + SACQFunction, + SACDQFunction, + HERQFunction, + XQLQFunction, +) from nnabla_rl.models.mujoco.decision_transformers import MujocoDecisionTransformer # noqa from nnabla_rl.models.mujoco.distributional_functions import QRSACQuantileDistributionFunction # noqa from nnabla_rl.models.mujoco.v_functions import SACVFunction, TRPOVFunction, ATRPOVFunction # noqa @@ -58,12 +77,14 @@ from nnabla_rl.models.atari.v_functions import PPOVFunction as PPOAtariVFunction # noqa from nnabla_rl.models.atari.v_functions import A3CVFunction # noqa from nnabla_rl.models.atari.shared_functions import PPOSharedFunctionHead, A3CSharedFunctionHead # noqa -from nnabla_rl.models.atari.distributional_functions import (C51ValueDistributionFunction, # noqa - RainbowValueDistributionFunction, - RainbowNoDuelValueDistributionFunction, - RainbowNoNoisyValueDistributionFunction, - QRDQNQuantileDistributionFunction, - IQNQuantileFunction) +from nnabla_rl.models.atari.distributional_functions import ( # noqa + C51ValueDistributionFunction, + RainbowValueDistributionFunction, + RainbowNoDuelValueDistributionFunction, + RainbowNoNoisyValueDistributionFunction, + QRDQNQuantileDistributionFunction, + IQNQuantileFunction, +) from nnabla_rl.models.atari.policies import ICML2015TRPOPolicy as ICML2015TRPOAtariPolicy # noqa from nnabla_rl.models.pybullet.q_functions import ICRA2018QtOptQFunction # noqa diff --git a/nnabla_rl/models/atari/__init__.py b/nnabla_rl/models/atari/__init__.py index 8a5dfccc..d3f66d40 100644 --- a/nnabla_rl/models/atari/__init__.py +++ b/nnabla_rl/models/atari/__init__.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/nnabla_rl/models/atari/decision_transformers.py b/nnabla_rl/models/atari/decision_transformers.py index 0e874501..c73a9cb6 100644 --- a/nnabla_rl/models/atari/decision_transformers.py +++ b/nnabla_rl/models/atari/decision_transformers.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,13 +30,15 @@ class AtariDecisionTransformer(StochasticDecisionTransformer): - def __init__(self, - scope_name: str, - action_num: int, - max_timestep: int, - context_length: int, - num_heads: int = 8, - embedding_dim: int = 128): + def __init__( + self, + scope_name: str, + action_num: int, + max_timestep: int, + context_length: int, + num_heads: int = 8, + embedding_dim: int = 128, + ): super().__init__(scope_name, num_heads, embedding_dim) self._action_num = action_num self._max_timestep = max_timestep @@ -68,13 +70,13 @@ def pi(self, s: nn.Variable, a: nn.Variable, rtg: nn.Variable, t: nn.Variable) - x = NF.dropout(x, p=dropout) attention_mask = create_attention_mask(block_size, block_size) for i in range(self._attention_layers): - with nn.parameter_scope(f'attention_block{i}'): + with nn.parameter_scope(f"attention_block{i}"): x = self._attention_block(x, attention_mask=attention_mask) - with nn.parameter_scope('layer_norm'): + with nn.parameter_scope("layer_norm"): fix_parameters = rl.is_eval_scope() # 0.003 is almost sqrt(10^-5) x = NPF.layer_normalization(x, batch_axis=(0, 1), fix_parameters=fix_parameters, eps=0.003) - with nn.parameter_scope('affine'): + with nn.parameter_scope("affine"): logits = NPF.affine(x, n_outmaps=self._action_num, with_bias=False, base_axis=2) # Use predictions from state embeddings logits = logits[:, 1::3, :] @@ -99,10 +101,9 @@ def _embed_state(self, s: nn.Variable) -> nn.Variable: def _embed_action(self, a: nn.Variable) -> nn.Variable: with nn.parameter_scope("action_embedding"): with nn.parameter_scope("embed"): - embedding = NPF.embed(a, - n_inputs=self._action_num, - n_features=self._embedding_dim, - initializer=NI.NormalInitializer(0.02)) + embedding = NPF.embed( + a, n_inputs=self._action_num, n_features=self._embedding_dim, initializer=NI.NormalInitializer(0.02) + ) embedding = NF.reshape(embedding, shape=(embedding.shape[0], embedding.shape[1], -1)) return NF.tanh(embedding) @@ -120,9 +121,9 @@ def _embed_t(self, timesteps: nn.Variable, block_size, context_length: int, T: i self._timesteps = timesteps # global: position embedding for timestep number - global_position_embedding = get_parameter_or_create('global_position_embedding', - shape=(1, T, self._embedding_dim), - initializer=NI.ConstantInitializer(0)) + global_position_embedding = get_parameter_or_create( + "global_position_embedding", shape=(1, T, self._embedding_dim), initializer=NI.ConstantInitializer(0) + ) # use same position embedding for all data in batch # (1, T, embedding_dim) -> (batch_size, T, embedding_dim) global_position_embedding = RF.repeat(global_position_embedding, repeats=batch_size, axis=0) @@ -134,37 +135,41 @@ def _embed_t(self, timesteps: nn.Variable, block_size, context_length: int, T: i # block: position embedding for block's position # block_size changes depending on the input # block_size <= context_length * 3 - block_position_embedding = get_parameter_or_create('block_position_embedding', - shape=(1, context_length * 3, self._embedding_dim), - initializer=NI.ConstantInitializer(0))[:, :block_size, :] + block_position_embedding = get_parameter_or_create( + "block_position_embedding", + shape=(1, context_length * 3, self._embedding_dim), + initializer=NI.ConstantInitializer(0), + )[:, :block_size, :] return global_position_embedding + block_position_embedding def _attention_block(self, x: nn.Variable, attention_mask=None) -> nn.Variable: - with nn.parameter_scope('layer_norm1'): + with nn.parameter_scope("layer_norm1"): fix_parameters = rl.is_eval_scope() normalized_x1 = NPF.layer_normalization(x, batch_axis=(0, 1), fix_parameters=fix_parameters, eps=0.003) - with nn.parameter_scope('causal_self_attention'): + with nn.parameter_scope("causal_self_attention"): attention_dropout = None if rl.is_eval_scope() else 0.1 output_dropout = None if rl.is_eval_scope() else 0.1 - x = x + RPF.causal_self_attention(normalized_x1, - embed_dim=self._embedding_dim, - num_heads=self._num_heads, - mask=attention_mask, - attention_dropout=attention_dropout, - output_dropout=output_dropout) - with nn.parameter_scope('layer_norm2'): + x = x + RPF.causal_self_attention( + normalized_x1, + embed_dim=self._embedding_dim, + num_heads=self._num_heads, + mask=attention_mask, + attention_dropout=attention_dropout, + output_dropout=output_dropout, + ) + with nn.parameter_scope("layer_norm2"): fix_parameters = rl.is_eval_scope() normalized_x2 = NPF.layer_normalization(x, batch_axis=(0, 1), fix_parameters=fix_parameters, eps=0.003) - with nn.parameter_scope('mlp'): + with nn.parameter_scope("mlp"): block_dropout = None if rl.is_eval_scope() else 0.1 x = x + self._block_mlp(normalized_x2, block_dropout) return x def _block_mlp(self, x: nn.Variable, dropout: Optional[float] = None) -> nn.Variable: - with nn.parameter_scope('linear1'): + with nn.parameter_scope("linear1"): x = NPF.affine(x, n_outmaps=4 * self._embedding_dim, base_axis=2, w_init=NI.NormalInitializer(0.02)) x = NF.gelu(x) - with nn.parameter_scope('linear2'): + with nn.parameter_scope("linear2"): x = NPF.affine(x, n_outmaps=self._embedding_dim, base_axis=2, w_init=NI.NormalInitializer(0.02)) if dropout is not None: x = NF.dropout(x, p=dropout) diff --git a/nnabla_rl/models/atari/distributional_functions.py b/nnabla_rl/models/atari/distributional_functions.py index fa801ad3..338f1a10 100644 --- a/nnabla_rl/models/atari/distributional_functions.py +++ b/nnabla_rl/models/atari/distributional_functions.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,8 +22,11 @@ import nnabla.parametric_functions as NPF import nnabla_rl.functions as RF import nnabla_rl.parametric_functions as RPF -from nnabla_rl.models import (DiscreteQuantileDistributionFunction, DiscreteStateActionQuantileFunction, - DiscreteValueDistributionFunction) +from nnabla_rl.models import ( + DiscreteQuantileDistributionFunction, + DiscreteStateActionQuantileFunction, + DiscreteValueDistributionFunction, +) class C51ValueDistributionFunction(DiscreteValueDistributionFunction): @@ -44,8 +47,7 @@ def all_probs(self, s: nn.Variable) -> nn.Variable: h = NPF.affine(h, n_outmaps=512) h = NF.relu(x=h) with nn.parameter_scope("affine2"): - h = NPF.affine( - h, n_outmaps=self._n_action * self._n_atom) + h = NPF.affine(h, n_outmaps=self._n_action * self._n_atom) h = NF.reshape(h, (-1, self._n_action, self._n_atom)) assert h.shape == (batch_size, self._n_action, self._n_atom) return NF.softmax(h, axis=2) @@ -168,8 +170,7 @@ def all_quantiles(self, s: nn.Variable) -> nn.Variable: h = NF.relu(x=h) with nn.parameter_scope("affine2"): h = NPF.affine(h, n_outmaps=self._n_action * self._n_quantile) - quantiles = NF.reshape( - h, (-1, self._n_action, self._n_quantile)) + quantiles = NF.reshape(h, (-1, self._n_action, self._n_quantile)) assert quantiles.shape == (batch_size, self._n_action, self._n_quantile) return quantiles @@ -180,8 +181,14 @@ class IQNQuantileFunction(DiscreteStateActionQuantileFunction): # See https://mypy.readthedocs.io/en/stable/class_basics.html for details _embedding_dim: int - def __init__(self, scope_name: str, n_action: int, embedding_dim: int, K: int, - risk_measure_function: Callable[[nn.Variable], nn.Variable]): + def __init__( + self, + scope_name: str, + n_action: int, + embedding_dim: int, + K: int, + risk_measure_function: Callable[[nn.Variable], nn.Variable], + ): super(IQNQuantileFunction, self).__init__(scope_name, n_action, K, risk_measure_function) self._embedding_dim = embedding_dim diff --git a/nnabla_rl/models/atari/policies.py b/nnabla_rl/models/atari/policies.py index 6b40a364..da763e63 100644 --- a/nnabla_rl/models/atari/policies.py +++ b/nnabla_rl/models/atari/policies.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -46,8 +46,7 @@ def pi(self, s: nn.Variable) -> Distribution: h = self._hidden(s) with nn.parameter_scope(self.scope_name): with nn.parameter_scope("linear_pi"): - z = NPF.affine(h, n_outmaps=self._action_dim, - w_init=RI.NormcInitializer(std=0.01)) + z = NPF.affine(h, n_outmaps=self._action_dim, w_init=RI.NormcInitializer(std=0.01)) return D.Softmax(z=z) def _hidden(self, s: nn.Variable) -> nn.Variable: @@ -74,11 +73,9 @@ def pi(self, s: nn.Variable) -> Distribution: batch_size = s.shape[0] with nn.parameter_scope(self.scope_name): with nn.parameter_scope("conv1"): - h = NF.tanh(NPF.convolution( - s, 16, (4, 4), stride=(2, 2))) + h = NF.tanh(NPF.convolution(s, 16, (4, 4), stride=(2, 2))) with nn.parameter_scope("conv2"): - h = NF.tanh(NPF.convolution( - h, 16, (4, 4), pad=(1, 1), stride=(2, 2))) + h = NF.tanh(NPF.convolution(h, 16, (4, 4), pad=(1, 1), stride=(2, 2))) h = NF.reshape(h, (batch_size, -1), inplace=False) with nn.parameter_scope("affine1"): h = NF.tanh(NPF.affine(h, 20)) diff --git a/nnabla_rl/models/atari/q_functions.py b/nnabla_rl/models/atari/q_functions.py index 9bb8937b..39a65669 100644 --- a/nnabla_rl/models/atari/q_functions.py +++ b/nnabla_rl/models/atari/q_functions.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,14 +23,14 @@ class DQNQFunction(DiscreteQFunction): - ''' + """ Q function proposed by DeepMind in DQN paper for atari environment. See: https://deepmind.com/research/publications/human-level-control-through-deep-reinforcement-learning Args: scope_name (str): the scope name n_action (int): the number of discrete action - ''' + """ # type declarations to type check with mypy # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar @@ -46,49 +46,39 @@ def all_q(self, s: nn.Variable) -> nn.Variable: with nn.parameter_scope(self.scope_name): with nn.parameter_scope("conv1"): - h = NF.relu(NPF.convolution(s, 32, (8, 8), stride=(4, 4), - w_init=RI.HeNormal(s.shape[1], - 32, - kernel=(8, 8)) - )) + h = NF.relu( + NPF.convolution(s, 32, (8, 8), stride=(4, 4), w_init=RI.HeNormal(s.shape[1], 32, kernel=(8, 8))) + ) with nn.parameter_scope("conv2"): - h = NF.relu(NPF.convolution(h, 64, (4, 4), stride=(2, 2), - w_init=RI.HeNormal(h.shape[1], - 64, - kernel=(4, 4)) - )) + h = NF.relu( + NPF.convolution(h, 64, (4, 4), stride=(2, 2), w_init=RI.HeNormal(h.shape[1], 64, kernel=(4, 4))) + ) with nn.parameter_scope("conv3"): - h = NF.relu(NPF.convolution(h, 64, (3, 3), stride=(1, 1), - w_init=RI.HeNormal(h.shape[1], - 64, - kernel=(3, 3)) - )) + h = NF.relu( + NPF.convolution(h, 64, (3, 3), stride=(1, 1), w_init=RI.HeNormal(h.shape[1], 64, kernel=(3, 3))) + ) h = NF.reshape(h, (-1, 3136)) with nn.parameter_scope("affine1"): - h = NF.relu(NPF.affine(h, 512, - w_init=RI.HeNormal(h.shape[1], 512) - )) + h = NF.relu(NPF.affine(h, 512, w_init=RI.HeNormal(h.shape[1], 512))) with nn.parameter_scope("affine2"): - h = NPF.affine(h, self._n_action, - w_init=RI.HeNormal(h.shape[1], self._n_action) - ) + h = NPF.affine(h, self._n_action, w_init=RI.HeNormal(h.shape[1], self._n_action)) return h class DRQNQFunction(DiscreteQFunction): - ''' + """ Q function with LSTM layer proposed by M. Hausknecht et al. used in DRQN paper for atari environment. See: https://arxiv.org/pdf/1507.06527.pdf Args: scope_name (str): the scope name n_action (int): the number of discrete action - ''' + """ # type declarations to type check with mypy # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar @@ -115,25 +105,19 @@ def all_q(self, s: nn.Variable) -> nn.Variable: with nn.parameter_scope(self.scope_name): with nn.parameter_scope("conv1"): - h = NF.relu(NPF.convolution(s, 32, (8, 8), stride=(4, 4), - w_init=RI.HeNormal(s.shape[1], - 32, - kernel=(8, 8)) - )) + h = NF.relu( + NPF.convolution(s, 32, (8, 8), stride=(4, 4), w_init=RI.HeNormal(s.shape[1], 32, kernel=(8, 8))) + ) with nn.parameter_scope("conv2"): - h = NF.relu(NPF.convolution(h, 64, (4, 4), stride=(2, 2), - w_init=RI.HeNormal(h.shape[1], - 64, - kernel=(4, 4)) - )) + h = NF.relu( + NPF.convolution(h, 64, (4, 4), stride=(2, 2), w_init=RI.HeNormal(h.shape[1], 64, kernel=(4, 4))) + ) with nn.parameter_scope("conv3"): - h = NF.relu(NPF.convolution(h, 64, (3, 3), stride=(1, 1), - w_init=RI.HeNormal(h.shape[1], - 64, - kernel=(3, 3)) - )) + h = NF.relu( + NPF.convolution(h, 64, (3, 3), stride=(1, 1), w_init=RI.HeNormal(h.shape[1], 64, kernel=(3, 3))) + ) h = NF.reshape(h, (-1, 3136)) @@ -147,8 +131,7 @@ def all_q(self, s: nn.Variable) -> nn.Variable: h = self._h with nn.parameter_scope("affine2"): - h = NPF.affine(h, self._n_action, - w_init=RI.HeNormal(h.shape[1], self._n_action)) + h = NPF.affine(h, self._n_action, w_init=RI.HeNormal(h.shape[1], self._n_action)) return h def is_recurrent(self) -> bool: @@ -156,14 +139,14 @@ def is_recurrent(self) -> bool: def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: shapes: Dict[str, nn.Variable] = {} - shapes['lstm_hidden'] = (self._lstm_state_size, ) - shapes['lstm_cell'] = (self._lstm_state_size, ) + shapes["lstm_hidden"] = (self._lstm_state_size,) + shapes["lstm_cell"] = (self._lstm_state_size,) return shapes def get_internal_states(self) -> Dict[str, nn.Variable]: states: Dict[str, nn.Variable] = {} - states['lstm_hidden'] = self._h - states['lstm_cell'] = self._c + states["lstm_hidden"] = self._h + states["lstm_cell"] = self._c return states def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): @@ -173,8 +156,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): if self._c is not None: self._c.data.zero() else: - self._h = states['lstm_hidden'] - self._c = states['lstm_cell'] + self._h = states["lstm_hidden"] + self._c = states["lstm_cell"] def _create_internal_states(self, batch_size): self._h = nn.Variable((batch_size, self._lstm_state_size)) diff --git a/nnabla_rl/models/atari/shared_functions.py b/nnabla_rl/models/atari/shared_functions.py index 78081a13..7e9a5b50 100644 --- a/nnabla_rl/models/atari/shared_functions.py +++ b/nnabla_rl/models/atari/shared_functions.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,21 +32,17 @@ def __call__(self, s): with nn.parameter_scope(self.scope_name): with nn.parameter_scope("conv1"): - h = NPF.convolution(s, outmaps=32, kernel=(8, 8), stride=(4, 4), - w_init=RI.NormcInitializer(std=1.0)) + h = NPF.convolution(s, outmaps=32, kernel=(8, 8), stride=(4, 4), w_init=RI.NormcInitializer(std=1.0)) h = NF.relu(x=h) with nn.parameter_scope("conv2"): - h = NPF.convolution(h, outmaps=64, kernel=(4, 4), stride=(2, 2), - w_init=RI.NormcInitializer(std=1.0)) + h = NPF.convolution(h, outmaps=64, kernel=(4, 4), stride=(2, 2), w_init=RI.NormcInitializer(std=1.0)) h = NF.relu(x=h) with nn.parameter_scope("conv3"): - h = NPF.convolution(h, outmaps=64, kernel=(3, 3), stride=(1, 1), - w_init=RI.NormcInitializer(std=1.0)) + h = NPF.convolution(h, outmaps=64, kernel=(3, 3), stride=(1, 1), w_init=RI.NormcInitializer(std=1.0)) h = NF.relu(x=h) h = NF.reshape(h, shape=(batch_size, -1)) with nn.parameter_scope("linear1"): - h = NPF.affine(h, n_outmaps=512, - w_init=RI.NormcInitializer(std=1.0)) + h = NPF.affine(h, n_outmaps=512, w_init=RI.NormcInitializer(std=1.0)) h = NF.relu(x=h) return h @@ -62,16 +58,13 @@ def __call__(self, s): with nn.parameter_scope(self.scope_name): with nn.parameter_scope("conv1"): - h = NPF.convolution(s, outmaps=16, kernel=(8, 8), stride=(4, 4), - w_init=RI.NormcInitializer(std=1.0)) + h = NPF.convolution(s, outmaps=16, kernel=(8, 8), stride=(4, 4), w_init=RI.NormcInitializer(std=1.0)) h = NF.relu(x=h) with nn.parameter_scope("conv2"): - h = NPF.convolution(h, outmaps=32, kernel=(4, 4), stride=(2, 2), - w_init=RI.NormcInitializer(std=1.0)) + h = NPF.convolution(h, outmaps=32, kernel=(4, 4), stride=(2, 2), w_init=RI.NormcInitializer(std=1.0)) h = NF.relu(x=h) h = NF.reshape(h, shape=(batch_size, -1)) with nn.parameter_scope("linear1"): - h = NPF.affine(h, n_outmaps=256, - w_init=RI.NormcInitializer(std=1.0)) + h = NPF.affine(h, n_outmaps=256, w_init=RI.NormcInitializer(std=1.0)) h = NF.relu(x=h) return h diff --git a/nnabla_rl/models/atari/v_functions.py b/nnabla_rl/models/atari/v_functions.py index 1b75e520..0ffd16ae 100644 --- a/nnabla_rl/models/atari/v_functions.py +++ b/nnabla_rl/models/atari/v_functions.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -41,8 +41,7 @@ def v(self, s: nn.Variable) -> nn.Variable: h = self._hidden(s) with nn.parameter_scope(self.scope_name): with nn.parameter_scope("linear_v"): - v = NPF.affine(h, n_outmaps=1, - w_init=RI.NormcInitializer(std=0.01)) + v = NPF.affine(h, n_outmaps=1, w_init=RI.NormcInitializer(std=0.01)) return v def _hidden(self, s: nn.Variable) -> nn.Variable: diff --git a/nnabla_rl/models/classic_control/dynamics.py b/nnabla_rl/models/classic_control/dynamics.py index 4f6ba7fa..02fa6989 100644 --- a/nnabla_rl/models/classic_control/dynamics.py +++ b/nnabla_rl/models/classic_control/dynamics.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,7 +36,7 @@ def __init__(self, scope_name: str, dt: float): def next_state(self, x: nn.Variable, u: nn.Variable) -> nn.Variable: assert x.shape[-1] % 2 == 0 # must have even number of states (state and its time derivative pairs) a = self.acceleration(x, u) - time_derivatives = NF.concatenate(x[:, x.shape[-1] // 2:], a, axis=len(a.shape)-1) + time_derivatives = NF.concatenate(x[:, x.shape[-1] // 2 :], a, axis=len(a.shape) - 1) return x + time_derivatives * self._dt def acceleration(self, x: nn.Variable, u: nn.Variable) -> nn.Variable: diff --git a/nnabla_rl/models/classic_control/policies.py b/nnabla_rl/models/classic_control/policies.py index 90f92b07..ab080ae4 100644 --- a/nnabla_rl/models/classic_control/policies.py +++ b/nnabla_rl/models/classic_control/policies.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -43,14 +43,11 @@ def __init__(self, scope_name: str, action_dim: int): def pi(self, s: nn.Variable) -> nn.Variable: with nn.parameter_scope(self.scope_name): - h = NPF.affine(s, n_outmaps=200, name="linear1", - w_init=RI.HeNormal(s.shape[1], 200)) + h = NPF.affine(s, n_outmaps=200, name="linear1", w_init=RI.HeNormal(s.shape[1], 200)) h = NF.leaky_relu(h) - h = NPF.affine(h, n_outmaps=200, name="linear2", - w_init=RI.HeNormal(s.shape[1], 200)) + h = NPF.affine(h, n_outmaps=200, name="linear2", w_init=RI.HeNormal(s.shape[1], 200)) h = NF.leaky_relu(h) - z = NPF.affine(h, n_outmaps=self._action_dim, - name="linear3", w_init=RI.LeCunNormal(s.shape[1], 200)) + z = NPF.affine(h, n_outmaps=self._action_dim, name="linear3", w_init=RI.LeCunNormal(s.shape[1], 200)) return D.Softmax(z=z) @@ -79,14 +76,11 @@ def __init__(self, scope_name: str, action_dim: int, fixed_ln_var: Union[np.ndar def pi(self, s: nn.Variable) -> nn.Variable: batch_size = s.shape[0] with nn.parameter_scope(self.scope_name): - h = NPF.affine(s, n_outmaps=200, name="linear1", - w_init=RI.HeNormal(s.shape[1], 200)) + h = NPF.affine(s, n_outmaps=200, name="linear1", w_init=RI.HeNormal(s.shape[1], 200)) h = NF.leaky_relu(h) - h = NPF.affine(h, n_outmaps=200, name="linear2", - w_init=RI.HeNormal(s.shape[1], 200)) + h = NPF.affine(h, n_outmaps=200, name="linear2", w_init=RI.HeNormal(s.shape[1], 200)) h = NF.leaky_relu(h) - z = NPF.affine(h, n_outmaps=self._action_dim, - name="linear3", w_init=RI.HeNormal(s.shape[1], 200)) + z = NPF.affine(h, n_outmaps=self._action_dim, name="linear3", w_init=RI.HeNormal(s.shape[1], 200)) ln_var = nn.Variable.from_numpy_array(np.tile(self._fixed_ln_var, (batch_size, 1))) return D.Gaussian(z, ln_var) diff --git a/nnabla_rl/models/decision_transformer.py b/nnabla_rl/models/decision_transformer.py index 0cf7a086..43829ed1 100644 --- a/nnabla_rl/models/decision_transformer.py +++ b/nnabla_rl/models/decision_transformer.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -48,7 +48,7 @@ def pi(self, s: nn.Variable, a: nn.Variable, rtg: nn.Variable, t: nn.Variable) - class StochasticDecisionTransformer(DecisionTransformer): @abstractmethod - def pi(self, s: nn.Variable, a: nn.Variable, rtg: nn.Variable, t: nn.Variable) -> Distribution: + def pi(self, s: nn.Variable, a: nn.Variable, rtg: nn.Variable, t: nn.Variable) -> Distribution: """Compute action distribution for given state, action, and return to go (rtg) diff --git a/nnabla_rl/models/distributional_function.py b/nnabla_rl/models/distributional_function.py index 4deb1cf3..f5ffe852 100644 --- a/nnabla_rl/models/distributional_function.py +++ b/nnabla_rl/models/distributional_function.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -115,6 +115,7 @@ def _compute_z(self, n_atom: int, v_min: float, v_max: float) -> nn.Variable: class DiscreteValueDistributionFunction(ValueDistributionFunction): """Base value distribution class for discrete action envs.""" + @abstractmethod def all_probs(self, s: nn.Variable) -> nn.Variable: raise NotImplementedError @@ -130,9 +131,9 @@ def max_q_probs(self, s: nn.Variable) -> nn.Variable: def as_q_function(self) -> QFunction: class Wrapper(QFunction): - _value_distribution_function: 'DiscreteValueDistributionFunction' + _value_distribution_function: "DiscreteValueDistributionFunction" - def __init__(self, value_distribution_function: 'DiscreteValueDistributionFunction'): + def __init__(self, value_distribution_function: "DiscreteValueDistributionFunction"): super(Wrapper, self).__init__(value_distribution_function.scope_name) self._value_distribution_function = value_distribution_function @@ -203,6 +204,7 @@ def _to_one_hot(self, a: nn.Variable) -> nn.Variable: class ContinuousValueDistributionFunction(ValueDistributionFunction): """Base value distribution class for continuous action envs.""" + pass @@ -301,9 +303,9 @@ def max_q_quantiles(self, s: nn.Variable) -> nn.Variable: def as_q_function(self) -> QFunction: class Wrapper(QFunction): - _quantile_distribution_function: 'DiscreteQuantileDistributionFunction' + _quantile_distribution_function: "DiscreteQuantileDistributionFunction" - def __init__(self, quantile_distribution_function: 'DiscreteQuantileDistributionFunction'): + def __init__(self, quantile_distribution_function: "DiscreteQuantileDistributionFunction"): super(Wrapper, self).__init__(quantile_distribution_function.scope_name) self._quantile_distribution_function = quantile_distribution_function @@ -370,9 +372,9 @@ def _to_one_hot(self, a: nn.Variable) -> nn.Variable: class ContinuousQuantileDistributionFunction(QuantileDistributionFunction): def as_q_function(self) -> QFunction: class Wrapper(QFunction): - _quantile_distribution_function: 'ContinuousQuantileDistributionFunction' + _quantile_distribution_function: "ContinuousQuantileDistributionFunction" - def __init__(self, quantile_distribution_function: 'ContinuousQuantileDistributionFunction'): + def __init__(self, quantile_distribution_function: "ContinuousQuantileDistributionFunction"): super(Wrapper, self).__init__(quantile_distribution_function.scope_name) self._quantile_distribution_function = quantile_distribution_function @@ -427,11 +429,13 @@ class StateActionQuantileFunction(Model, metaclass=ABCMeta): _K: int # _risk_measure_funciton: Callable[[nn.Variable], nn.Variable] - def __init__(self, - scope_name: str, - n_action: int, - K: int, - risk_measure_function: Callable[[nn.Variable], nn.Variable] = risk_neutral_measure): + def __init__( + self, + scope_name: str, + n_action: int, + K: int, + risk_measure_function: Callable[[nn.Variable], nn.Variable] = risk_neutral_measure, + ): super(StateActionQuantileFunction, self).__init__(scope_name) self._n_action = n_action self._K = K @@ -514,7 +518,7 @@ def quantile_values(self, s: nn.Variable, a: nn.Variable, tau: nn.Variable) -> n def max_q_quantile_values(self, s: nn.Variable, tau: nn.Variable) -> nn.Variable: if self.is_recurrent(): - raise RuntimeError('max_q_quantile_values should be reimplemented in inherited class to support RNN layers') + raise RuntimeError("max_q_quantile_values should be reimplemented in inherited class to support RNN layers") batch_size = s.shape[0] tau_k = self._sample_risk_measured_tau(shape=(batch_size, self._K)) @@ -534,10 +538,11 @@ def as_q_function(self) -> QFunction: nnabla_rl.models.q_function.QFunction: QFunction instance which computes the q-values based on the return_samples. """ + class Wrapper(QFunction): - _quantile_function: 'DiscreteStateActionQuantileFunction' + _quantile_function: "DiscreteStateActionQuantileFunction" - def __init__(self, quantile_function: 'DiscreteStateActionQuantileFunction'): + def __init__(self, quantile_function: "DiscreteStateActionQuantileFunction"): super(Wrapper, self).__init__(quantile_function.scope_name) self._quantile_function = quantile_function diff --git a/nnabla_rl/models/dynamics.py b/nnabla_rl/models/dynamics.py index 3f6fccaa..cdfc09b8 100644 --- a/nnabla_rl/models/dynamics.py +++ b/nnabla_rl/models/dynamics.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,6 +32,7 @@ class DeterministicDynamics(Dynamics, metaclass=ABCMeta): This dynamics returns next state for given state and control input (action). """ + @abstractmethod def next_state(self, x: nn.Variable, u: nn.Variable) -> nn.Variable: """Next_state. @@ -64,6 +65,7 @@ class StochasticDynamics(Dynamics, metaclass=ABCMeta): This dynamics returns the probability distribution of next state for given state and control input (action). """ + @abstractmethod def next_state(self, x: nn.Variable, u: nn.Variable) -> Distribution: """Next_state. diff --git a/nnabla_rl/models/hybrid_env/encoders.py b/nnabla_rl/models/hybrid_env/encoders.py index b674b785..b5279200 100644 --- a/nnabla_rl/models/hybrid_env/encoders.py +++ b/nnabla_rl/models/hybrid_env/encoders.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,12 +32,7 @@ class HyARVAE(VariationalAutoEncoder): See: https://arxiv.org/abs/2109.05490 """ - def __init__(self, - scope_name: str, - state_dim, - action_dim, - encode_dim, - embed_dim): + def __init__(self, scope_name: str, state_dim, action_dim, encode_dim, embed_dim): super().__init__(scope_name) self._state_dim = state_dim self._action_dim = action_dim @@ -55,85 +50,82 @@ def encode(self, x: nn.Variable, **kwargs) -> nn.Variable: return latent_distribution.sample() def encode_and_decode(self, x: nn.Variable, **kwargs) -> Tuple[Distribution, Any]: - if 'action' in kwargs: - (d_action, _) = kwargs['action'] + if "action" in kwargs: + (d_action, _) = kwargs["action"] e = self.encode_discrete_action(d_action) - elif 'e' in kwargs: - e = kwargs['e'] + elif "e" in kwargs: + e = kwargs["e"] else: raise NotImplementedError - latent_distribution = self.latent_distribution(x, e=e, state=kwargs['state']) + latent_distribution = self.latent_distribution(x, e=e, state=kwargs["state"]) z = latent_distribution.sample() - reconstructed = self.decode(z, e=e, state=kwargs['state']) + reconstructed = self.decode(z, e=e, state=kwargs["state"]) return latent_distribution, reconstructed def decode(self, z: Any, **kwargs) -> nn.Variable: - state = kwargs['state'] - if 'action' in kwargs: - (d_action, _) = kwargs['action'] + state = kwargs["state"] + if "action" in kwargs: + (d_action, _) = kwargs["action"] action = self.encode_discrete_action(d_action) - elif 'e' in kwargs: - action = kwargs['e'] + elif "e" in kwargs: + action = kwargs["e"] else: raise NotImplementedError with nn.parameter_scope(self._scope_name): - with nn.parameter_scope('decoder'): + with nn.parameter_scope("decoder"): c = NF.concatenate(state, action) - linear1_init = RI.HeUniform(inmaps=c.shape[1], outmaps=256, factor=1/3) - c = NF.relu(NPF.affine(c, n_outmaps=256, name='linear1', w_init=linear1_init, b_init=linear1_init)) - linear2_init = RI.HeUniform(inmaps=z.shape[1], outmaps=256, factor=1/3) - z = NF.relu(NPF.affine(z, n_outmaps=256, name='linear2', w_init=linear2_init, b_init=linear2_init)) + linear1_init = RI.HeUniform(inmaps=c.shape[1], outmaps=256, factor=1 / 3) + c = NF.relu(NPF.affine(c, n_outmaps=256, name="linear1", w_init=linear1_init, b_init=linear1_init)) + linear2_init = RI.HeUniform(inmaps=z.shape[1], outmaps=256, factor=1 / 3) + z = NF.relu(NPF.affine(z, n_outmaps=256, name="linear2", w_init=linear2_init, b_init=linear2_init)) h = z * c - linear3_init = RI.HeUniform(inmaps=h.shape[1], outmaps=256, factor=1/3) - h = NF.relu(NPF.affine(h, n_outmaps=256, name='linear3', w_init=linear3_init, b_init=linear3_init)) - linear4_init = RI.HeUniform(inmaps=h.shape[1], outmaps=256, factor=1/3) - h = NF.relu(NPF.affine(h, n_outmaps=256, name='linear4', w_init=linear4_init, b_init=linear4_init)) - - linear5_init = RI.HeUniform(inmaps=h.shape[1], outmaps=256, factor=1/3) - ds = NF.relu(NPF.affine(h, n_outmaps=256, name='linear5', w_init=linear5_init, b_init=linear5_init)) - linear6_init = RI.HeUniform(inmaps=ds.shape[1], outmaps=self._state_dim, factor=1/3) - ds = NPF.affine(ds, n_outmaps=self._state_dim, name='linear6', w_init=linear6_init, b_init=linear6_init) - - linear7_init = RI.HeUniform(inmaps=h.shape[1], outmaps=self._latent_dim, factor=1/3) - x = NPF.affine(h, n_outmaps=self._latent_dim, name='linear7', w_init=linear7_init, b_init=linear7_init) + linear3_init = RI.HeUniform(inmaps=h.shape[1], outmaps=256, factor=1 / 3) + h = NF.relu(NPF.affine(h, n_outmaps=256, name="linear3", w_init=linear3_init, b_init=linear3_init)) + linear4_init = RI.HeUniform(inmaps=h.shape[1], outmaps=256, factor=1 / 3) + h = NF.relu(NPF.affine(h, n_outmaps=256, name="linear4", w_init=linear4_init, b_init=linear4_init)) + + linear5_init = RI.HeUniform(inmaps=h.shape[1], outmaps=256, factor=1 / 3) + ds = NF.relu(NPF.affine(h, n_outmaps=256, name="linear5", w_init=linear5_init, b_init=linear5_init)) + linear6_init = RI.HeUniform(inmaps=ds.shape[1], outmaps=self._state_dim, factor=1 / 3) + ds = NPF.affine(ds, n_outmaps=self._state_dim, name="linear6", w_init=linear6_init, b_init=linear6_init) + + linear7_init = RI.HeUniform(inmaps=h.shape[1], outmaps=self._latent_dim, factor=1 / 3) + x = NPF.affine(h, n_outmaps=self._latent_dim, name="linear7", w_init=linear7_init, b_init=linear7_init) return NF.tanh(x), NF.tanh(ds) def decode_multiple(self, z, decode_num: int, **kwargs): raise NotImplementedError def latent_distribution(self, x: nn.Variable, **kwargs) -> Distribution: - state = kwargs['state'] - if 'action' in kwargs: - (d_action, _) = kwargs['action'] + state = kwargs["state"] + if "action" in kwargs: + (d_action, _) = kwargs["action"] action = self.encode_discrete_action(d_action) - elif 'e' in kwargs: - action = kwargs['e'] + elif "e" in kwargs: + action = kwargs["e"] else: raise NotImplementedError with nn.parameter_scope(self._scope_name): - with nn.parameter_scope('encoder'): + with nn.parameter_scope("encoder"): c = NF.concatenate(state, action) - linear1_init = RI.HeUniform(inmaps=c.shape[1], outmaps=256, factor=1/3) - c = NF.relu(NPF.affine(c, n_outmaps=256, name='linear1', - w_init=linear1_init, b_init=linear1_init)) - linear2_init = RI.HeUniform(inmaps=x.shape[1], outmaps=256, factor=1/3) - x = NF.relu(NPF.affine(x, n_outmaps=256, name='linear2', - w_init=linear2_init, b_init=linear2_init)) + linear1_init = RI.HeUniform(inmaps=c.shape[1], outmaps=256, factor=1 / 3) + c = NF.relu(NPF.affine(c, n_outmaps=256, name="linear1", w_init=linear1_init, b_init=linear1_init)) + linear2_init = RI.HeUniform(inmaps=x.shape[1], outmaps=256, factor=1 / 3) + x = NF.relu(NPF.affine(x, n_outmaps=256, name="linear2", w_init=linear2_init, b_init=linear2_init)) h = x * c - linear3_init = RI.HeUniform(inmaps=h.shape[1], outmaps=256, factor=1/3) - h = NF.relu(NPF.affine(h, n_outmaps=256, name='linear3', - w_init=linear3_init, b_init=linear3_init)) - linear4_init = RI.HeUniform(inmaps=h.shape[1], outmaps=256, factor=1/3) - h = NF.relu(NPF.affine(h, n_outmaps=256, name='linear4', - w_init=linear4_init, b_init=linear4_init)) - linear5_init = RI.HeUniform(inmaps=h.shape[1], outmaps=self._encode_dim*2, factor=1/3) - h = NPF.affine(h, n_outmaps=self._encode_dim*2, name='linear5', - w_init=linear5_init, b_init=linear5_init) + linear3_init = RI.HeUniform(inmaps=h.shape[1], outmaps=256, factor=1 / 3) + h = NF.relu(NPF.affine(h, n_outmaps=256, name="linear3", w_init=linear3_init, b_init=linear3_init)) + linear4_init = RI.HeUniform(inmaps=h.shape[1], outmaps=256, factor=1 / 3) + h = NF.relu(NPF.affine(h, n_outmaps=256, name="linear4", w_init=linear4_init, b_init=linear4_init)) + linear5_init = RI.HeUniform(inmaps=h.shape[1], outmaps=self._encode_dim * 2, factor=1 / 3) + h = NPF.affine( + h, n_outmaps=self._encode_dim * 2, name="linear5", w_init=linear5_init, b_init=linear5_init + ) reshaped = NF.reshape(h, shape=(-1, 2, self._encode_dim)) mean, ln_var = NF.split(reshaped, axis=1) ln_var = NF.clip_by_value(ln_var, min=-8, max=30) @@ -143,10 +135,9 @@ def latent_distribution(self, x: nn.Variable, **kwargs) -> Distribution: def encode_discrete_action(self, action): with nn.parameter_scope(self.scope_name): with nn.parameter_scope("embed"): - embedding = NPF.embed(action, - n_inputs=self._class_num, - n_features=self._embed_dim, - initializer=NI.UniformInitializer()) + embedding = NPF.embed( + action, n_inputs=self._class_num, n_features=self._embed_dim, initializer=NI.UniformInitializer() + ) embedding = NF.reshape(embedding, shape=(-1, self._embed_dim)) embedding = NF.tanh(embedding) return embedding @@ -154,10 +145,12 @@ def encode_discrete_action(self, action): def decode_discrete_action(self, action_embedding): with nn.parameter_scope(self.scope_name): with nn.parameter_scope("embed"): - label_embedding = NPF.embed(self._labels, - n_inputs=self._class_num, - n_features=self._embed_dim, - initializer=NI.UniformInitializer()) + label_embedding = NPF.embed( + self._labels, + n_inputs=self._class_num, + n_features=self._embed_dim, + initializer=NI.UniformInitializer(), + ) label_embedding = NF.reshape(label_embedding, shape=(-1, self._class_num, self._embed_dim)) label_embedding = NF.tanh(label_embedding) @@ -168,7 +161,6 @@ def decode_discrete_action(self, action_embedding): @property def _labels(self) -> nn.Variable: - labels = np.array( - [label for label in range(self._class_num)], dtype=np.int32) + labels = np.array([label for label in range(self._class_num)], dtype=np.int32) labels = np.reshape(labels, newshape=(1, self._class_num)) return nn.Variable.from_numpy_array(labels) diff --git a/nnabla_rl/models/hybrid_env/policies.py b/nnabla_rl/models/hybrid_env/policies.py index c265cd29..93cfa071 100644 --- a/nnabla_rl/models/hybrid_env/policies.py +++ b/nnabla_rl/models/hybrid_env/policies.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ class HyARPolicy(DeterministicPolicy): in the HyAR paper. See: https://arxiv.org/abs/2109.05490 """ + # type declarations to type check with mypy # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar # See https://mypy.readthedocs.io/en/stable/class_basics.html for details @@ -37,12 +38,12 @@ def __init__(self, scope_name: str, action_dim: int, max_action_value: float): def pi(self, s: nn.Variable) -> nn.Variable: with nn.parameter_scope(self.scope_name): - linear1_init = RI.HeUniform(inmaps=s.shape[1], outmaps=256, factor=1/3) + linear1_init = RI.HeUniform(inmaps=s.shape[1], outmaps=256, factor=1 / 3) h = NPF.affine(s, n_outmaps=256, name="linear1", w_init=linear1_init, b_init=linear1_init) h = NF.relu(x=h) - linear2_init = RI.HeUniform(inmaps=h.shape[1], outmaps=256, factor=1/3) + linear2_init = RI.HeUniform(inmaps=h.shape[1], outmaps=256, factor=1 / 3) h = NPF.affine(h, n_outmaps=256, name="linear2", w_init=linear2_init, b_init=linear2_init) h = NF.relu(x=h) - linear3_init = RI.HeUniform(inmaps=h.shape[1], outmaps=self._action_dim, factor=1/3) + linear3_init = RI.HeUniform(inmaps=h.shape[1], outmaps=self._action_dim, factor=1 / 3) h = NPF.affine(h, n_outmaps=self._action_dim, name="linear3", w_init=linear3_init, b_init=linear3_init) return NF.tanh(h) * self._max_action_value diff --git a/nnabla_rl/models/hybrid_env/q_functions.py b/nnabla_rl/models/hybrid_env/q_functions.py index e2ecb9a5..8a59ef47 100644 --- a/nnabla_rl/models/hybrid_env/q_functions.py +++ b/nnabla_rl/models/hybrid_env/q_functions.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ class HyARQFunction(ContinuousQFunction): in the HyAR paper. See: https://arxiv.org/abs/2109.05490 """ + # type declarations to type check with mypy # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar # See https://mypy.readthedocs.io/en/stable/class_basics.html for details @@ -35,14 +36,11 @@ def __init__(self, scope_name: str): def q(self, s: nn.Variable, a: nn.Variable) -> nn.Variable: with nn.parameter_scope(self.scope_name): h = NF.concatenate(s, a) - linear1_init = RI.HeUniform(inmaps=h.shape[1], outmaps=256, factor=1/3) - h = NPF.affine(h, n_outmaps=256, name="linear1", - w_init=linear1_init, b_init=linear1_init) + linear1_init = RI.HeUniform(inmaps=h.shape[1], outmaps=256, factor=1 / 3) + h = NPF.affine(h, n_outmaps=256, name="linear1", w_init=linear1_init, b_init=linear1_init) h = NF.relu(x=h) - linear2_init = RI.HeUniform(inmaps=h.shape[1], outmaps=256, factor=1/3) - h = NPF.affine(h, n_outmaps=256, name="linear2", - w_init=linear2_init, b_init=linear2_init) + linear2_init = RI.HeUniform(inmaps=h.shape[1], outmaps=256, factor=1 / 3) + h = NPF.affine(h, n_outmaps=256, name="linear2", w_init=linear2_init, b_init=linear2_init) h = NF.relu(x=h) - linear3_init = RI.HeUniform(inmaps=h.shape[1], outmaps=1, factor=1/3) - return NPF.affine(h, n_outmaps=1, name="linear3", - w_init=linear3_init, b_init=linear3_init) + linear3_init = RI.HeUniform(inmaps=h.shape[1], outmaps=1, factor=1 / 3) + return NPF.affine(h, n_outmaps=1, name="linear3", w_init=linear3_init, b_init=linear3_init) diff --git a/nnabla_rl/models/model.py b/nnabla_rl/models/model.py index d02d7cd0..1d828ba9 100644 --- a/nnabla_rl/models/model.py +++ b/nnabla_rl/models/model.py @@ -20,7 +20,7 @@ import nnabla as nn from nnabla_rl.logger import logger -T = TypeVar('T', bound='Model') +T = TypeVar("T", bound="Model") class Model(object): @@ -153,7 +153,7 @@ def deepcopy(self: T, new_scope_name: str) -> T: Raises: ValueError: Given scope name is same as the model or already exists. """ - assert new_scope_name != self._scope_name, 'Can not use same scope_name!' + assert new_scope_name != self._scope_name, "Can not use same scope_name!" copied = copy.deepcopy(self) copied._scope_name = new_scope_name # copy current parameter if is already created @@ -161,9 +161,10 @@ def deepcopy(self: T, new_scope_name: str) -> T: with nn.parameter_scope(new_scope_name): for param_name, param in params.items(): if nn.parameter.get_parameter(param_name) is not None: - raise RuntimeError(f'Model with scope_name: {new_scope_name} already exists!!') + raise RuntimeError(f"Model with scope_name: {new_scope_name} already exists!!") logger.info( - f'copying param with name: {self.scope_name}/{param_name} ---> {new_scope_name}/{param_name}') + f"copying param with name: {self.scope_name}/{param_name} ---> {new_scope_name}/{param_name}" + ) nn.parameter.get_parameter_or_create(param_name, shape=param.shape, initializer=param.d) return copied diff --git a/nnabla_rl/models/mujoco/decision_transformers.py b/nnabla_rl/models/mujoco/decision_transformers.py index 703affed..486b875b 100644 --- a/nnabla_rl/models/mujoco/decision_transformers.py +++ b/nnabla_rl/models/mujoco/decision_transformers.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -28,13 +28,15 @@ class MujocoDecisionTransformer(DeterministicDecisionTransformer): - def __init__(self, - scope_name: str, - action_dim: int, - max_timestep: int, - context_length: int, - num_heads: int = 1, - embedding_dim: int = 128): + def __init__( + self, + scope_name: str, + action_dim: int, + max_timestep: int, + context_length: int, + num_heads: int = 1, + embedding_dim: int = 128, + ): super().__init__(scope_name, num_heads, embedding_dim) self._action_dim = action_dim self._max_timestep = max_timestep @@ -60,7 +62,7 @@ def pi(self, s: nn.Variable, a: nn.Variable, rtg: nn.Variable, t: nn.Variable) - position_embedding = self._embed_t(t, self._max_timestep + 1) x = token_embedding + position_embedding - with nn.parameter_scope('layer_norm'): + with nn.parameter_scope("layer_norm"): fix_parameters = rl.is_eval_scope() # 0.003 is almost sqrt(10^-5) x = NPF.layer_normalization(x, batch_axis=(0, 1), fix_parameters=fix_parameters, eps=0.003) @@ -68,12 +70,12 @@ def pi(self, s: nn.Variable, a: nn.Variable, rtg: nn.Variable, t: nn.Variable) - block_size = token_embedding.shape[1] attention_mask = create_attention_mask(block_size, block_size) for i in range(self._attention_layers): - with nn.parameter_scope(f'attention_block{i}'): + with nn.parameter_scope(f"attention_block{i}"): x = self._attention_block(x, attention_mask=attention_mask) # Use predictions from state embeddings x = x[:, 1::3, :] - with nn.parameter_scope('affine'): + with nn.parameter_scope("affine"): actions = NPF.affine(x, n_outmaps=self._action_dim, base_axis=2) return NF.tanh(actions) @@ -99,22 +101,24 @@ def _embed_t(self, timesteps: nn.Variable, T: int) -> nn.Variable: return NF.reshape(embedding, shape=(embedding.shape[0], embedding.shape[1], -1)) def _attention_block(self, x: nn.Variable, attention_mask=None) -> nn.Variable: - with nn.parameter_scope('layer_norm1'): + with nn.parameter_scope("layer_norm1"): fix_parameters = rl.is_eval_scope() normalized_x1 = NPF.layer_normalization(x, batch_axis=(0, 1), fix_parameters=fix_parameters, eps=0.003) - with nn.parameter_scope('causal_self_attention'): + with nn.parameter_scope("causal_self_attention"): attention_dropout = None if rl.is_eval_scope() else 0.1 output_dropout = None if rl.is_eval_scope() else 0.1 - x = x + RPF.causal_self_attention(normalized_x1, - embed_dim=self._embedding_dim, - num_heads=self._num_heads, - mask=attention_mask, - attention_dropout=attention_dropout, - output_dropout=output_dropout) - with nn.parameter_scope('layer_norm2'): + x = x + RPF.causal_self_attention( + normalized_x1, + embed_dim=self._embedding_dim, + num_heads=self._num_heads, + mask=attention_mask, + attention_dropout=attention_dropout, + output_dropout=output_dropout, + ) + with nn.parameter_scope("layer_norm2"): fix_parameters = rl.is_eval_scope() normalized_x2 = NPF.layer_normalization(x, batch_axis=(0, 1), fix_parameters=fix_parameters, eps=0.003) - with nn.parameter_scope('mlp'): + with nn.parameter_scope("mlp"): block_dropout = None if rl.is_eval_scope() else 0.1 x = x + self._block_mlp(normalized_x2, block_dropout) return x @@ -122,11 +126,11 @@ def _attention_block(self, x: nn.Variable, attention_mask=None) -> nn.Variable: def _block_mlp(self, x: nn.Variable, dropout: Optional[float] = None) -> nn.Variable: # NOTE: original code uses Conv1D operation defined below but it's same as affine layer. # https://github.com/huggingface/transformers/blob/main/src/transformers/pytorch_utils.py#L91 - with nn.parameter_scope('linear1'): + with nn.parameter_scope("linear1"): x = NPF.affine(x, n_outmaps=4 * self._embedding_dim, base_axis=2, w_init=NI.NormalInitializer(0.02)) # Relu is used for mujoco x = NF.relu(x) - with nn.parameter_scope('linear2'): + with nn.parameter_scope("linear2"): x = NPF.affine(x, n_outmaps=self._embedding_dim, base_axis=2, w_init=NI.NormalInitializer(0.02)) if dropout is not None: x = NF.dropout(x, p=dropout) diff --git a/nnabla_rl/models/mujoco/encoders.py b/nnabla_rl/models/mujoco/encoders.py index 80db6131..ef02e9aa 100644 --- a/nnabla_rl/models/mujoco/encoders.py +++ b/nnabla_rl/models/mujoco/encoders.py @@ -1,5 +1,5 @@ # Copyright 2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -39,14 +39,14 @@ def __init__(self, scope_name, state_dim, action_dim, latent_dim): self._latent_dim = latent_dim def encode_and_decode(self, x: nn.Variable, **kwargs) -> Tuple[Distribution, nn.Variable]: - ''' + """ Args: x (nn.Variable): encoder input. Returns: [Distribution, nn.Variable]: Reconstructed input and latent distribution - ''' - a = kwargs['action'] + """ + a = kwargs["action"] h = NF.concatenate(x, a) latent_distribution = self.latent_distribution(h) z = latent_distribution.sample() @@ -54,13 +54,13 @@ def encode_and_decode(self, x: nn.Variable, **kwargs) -> Tuple[Distribution, nn. return latent_distribution, reconstructed def encode(self, x: nn.Variable, **kwargs) -> nn.Variable: - a = kwargs['action'] + a = kwargs["action"] x = NF.concatenate(x, a) latent_distribution = self.latent_distribution(x) return latent_distribution.sample() def decode(self, z: nn.Variable, **kwargs) -> nn.Variable: - s = kwargs['state'] + s = kwargs["state"] if z is None: z = NF.randn(shape=(s.shape[0], self._latent_dim)) z = NF.clip_by_value(z, -0.5, 0.5) @@ -74,7 +74,7 @@ def decode(self, z: nn.Variable, **kwargs) -> nn.Variable: return h def decode_multiple(self, z: nn.Variable, decode_num: int, **kwargs) -> nn.Variable: - s = kwargs['state'] + s = kwargs["state"] if z is None: z = NF.randn(shape=(s.shape[0], decode_num, self._latent_dim)) z = NF.clip_by_value(z, -0.5, 0.5) @@ -100,7 +100,7 @@ def latent_distribution(self, x: nn.Variable, **kwargs) -> Distribution: h = NF.relu(x=h) h = NPF.affine(h, n_outmaps=750, name="linear2") h = NF.relu(x=h) - h = NPF.affine(h, n_outmaps=self._latent_dim*2, name="linear3") + h = NPF.affine(h, n_outmaps=self._latent_dim * 2, name="linear3") reshaped = NF.reshape(h, shape=(-1, 2, self._latent_dim)) mean, ln_var = NF.split(reshaped, axis=1) # Clip for numerical stability diff --git a/nnabla_rl/models/mujoco/policies.py b/nnabla_rl/models/mujoco/policies.py index 6f67af73..edda6f4c 100644 --- a/nnabla_rl/models/mujoco/policies.py +++ b/nnabla_rl/models/mujoco/policies.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -48,20 +48,14 @@ def __init__(self, scope_name: str, action_dim: int, max_action_value: float): def pi(self, s: nn.Variable) -> nn.Variable: with nn.parameter_scope(self.scope_name): - linear1_init = RI.HeUniform( - inmaps=s.shape[1], outmaps=400, factor=1/3) - h = NPF.affine(s, n_outmaps=400, name="linear1", - w_init=linear1_init, b_init=linear1_init) + linear1_init = RI.HeUniform(inmaps=s.shape[1], outmaps=400, factor=1 / 3) + h = NPF.affine(s, n_outmaps=400, name="linear1", w_init=linear1_init, b_init=linear1_init) h = NF.relu(x=h) - linear2_init = RI.HeUniform( - inmaps=400, outmaps=300, factor=1/3) - h = NPF.affine(h, n_outmaps=300, name="linear2", - w_init=linear2_init, b_init=linear2_init) + linear2_init = RI.HeUniform(inmaps=400, outmaps=300, factor=1 / 3) + h = NPF.affine(h, n_outmaps=300, name="linear2", w_init=linear2_init, b_init=linear2_init) h = NF.relu(x=h) - linear3_init = RI.HeUniform( - inmaps=300, outmaps=self._action_dim, factor=1/3) - h = NPF.affine(h, n_outmaps=self._action_dim, name="linear3", - w_init=linear3_init, b_init=linear3_init) + linear3_init = RI.HeUniform(inmaps=300, outmaps=self._action_dim, factor=1 / 3) + h = NPF.affine(h, n_outmaps=self._action_dim, name="linear3", w_init=linear3_init, b_init=linear3_init) return NF.tanh(h) * self._max_action_value @@ -80,12 +74,14 @@ class SACPolicy(StochasticPolicy): _min_log_sigma: float _max_log_sigma: float - def __init__(self, - scope_name: str, - action_dim: int, - clip_log_sigma: bool = True, - min_log_sigma: float = -20.0, - max_log_sigma: float = 2.0): + def __init__( + self, + scope_name: str, + action_dim: int, + clip_log_sigma: bool = True, + min_log_sigma: float = -20.0, + max_log_sigma: float = 2.0, + ): super(SACPolicy, self).__init__(scope_name) self._action_dim = action_dim self._clip_log_sigma = clip_log_sigma @@ -98,14 +94,13 @@ def pi(self, s: nn.Variable) -> Distribution: h = NF.relu(x=h) h = NPF.affine(h, n_outmaps=256, name="linear2") h = NF.relu(x=h) - h = NPF.affine(h, n_outmaps=self._action_dim*2, name="linear3") + h = NPF.affine(h, n_outmaps=self._action_dim * 2, name="linear3") reshaped = NF.reshape(h, shape=(-1, 2, self._action_dim)) mean, ln_sigma = NF.split(reshaped, axis=1) assert mean.shape == ln_sigma.shape assert mean.shape == (s.shape[0], self._action_dim) if self._clip_log_sigma: - ln_sigma = NF.clip_by_value( - ln_sigma, min=self._min_log_sigma, max=self._max_log_sigma) + ln_sigma = NF.clip_by_value(ln_sigma, min=self._min_log_sigma, max=self._max_log_sigma) ln_var = ln_sigma * 2.0 return D.SquashedGaussian(mean=mean, ln_var=ln_var) @@ -128,20 +123,14 @@ def __init__(self, scope_name: str, action_dim: int): def pi(self, s: nn.Variable) -> Distribution: with nn.parameter_scope(self.scope_name): - linear1_init = RI.HeUniform( - inmaps=s.shape[1], outmaps=400, factor=1/3) - h = NPF.affine(s, n_outmaps=400, name="linear1", - w_init=linear1_init, b_init=linear1_init) + linear1_init = RI.HeUniform(inmaps=s.shape[1], outmaps=400, factor=1 / 3) + h = NPF.affine(s, n_outmaps=400, name="linear1", w_init=linear1_init, b_init=linear1_init) h = NF.relu(x=h) - linear2_init = RI.HeUniform( - inmaps=400, outmaps=300, factor=1/3) - h = NPF.affine(h, n_outmaps=300, name="linear2", - w_init=linear2_init, b_init=linear2_init) + linear2_init = RI.HeUniform(inmaps=400, outmaps=300, factor=1 / 3) + h = NPF.affine(h, n_outmaps=300, name="linear2", w_init=linear2_init, b_init=linear2_init) h = NF.relu(x=h) - linear3_init = RI.HeUniform( - inmaps=300, outmaps=self._action_dim*2, factor=1/3) - h = NPF.affine(h, n_outmaps=self._action_dim*2, name="linear3", - w_init=linear3_init, b_init=linear3_init) + linear3_init = RI.HeUniform(inmaps=300, outmaps=self._action_dim * 2, factor=1 / 3) + h = NPF.affine(h, n_outmaps=self._action_dim * 2, name="linear3", w_init=linear3_init, b_init=linear3_init) reshaped = NF.reshape(h, shape=(-1, 2, self._action_dim)) mean, ln_var = NF.split(reshaped, axis=1) assert mean.shape == ln_var.shape @@ -168,18 +157,15 @@ def __init__(self, scope_name: str, action_dim: int): def pi(self, s: nn.Variable) -> Distribution: with nn.parameter_scope(self.scope_name): - h = NPF.affine(s, n_outmaps=64, name="linear1", - w_init=RI.NormcInitializer(std=1.0)) + h = NPF.affine(s, n_outmaps=64, name="linear1", w_init=RI.NormcInitializer(std=1.0)) h = NF.tanh(x=h) - h = NPF.affine(h, n_outmaps=64, name="linear2", - w_init=RI.NormcInitializer(std=1.0)) + h = NPF.affine(h, n_outmaps=64, name="linear2", w_init=RI.NormcInitializer(std=1.0)) h = NF.tanh(x=h) - mean = NPF.affine(h, n_outmaps=self._action_dim, name="linear3", - w_init=RI.NormcInitializer(std=0.01)) + mean = NPF.affine(h, n_outmaps=self._action_dim, name="linear3", w_init=RI.NormcInitializer(std=0.01)) ln_sigma = nn.parameter.get_parameter_or_create( - "ln_sigma", shape=(1, self._action_dim), initializer=NI.ConstantInitializer(0.)) - ln_var = NF.broadcast( - ln_sigma, (s.shape[0], self._action_dim)) * 2.0 + "ln_sigma", shape=(1, self._action_dim), initializer=NI.ConstantInitializer(0.0) + ) + ln_var = NF.broadcast(ln_sigma, (s.shape[0], self._action_dim)) * 2.0 assert mean.shape == ln_var.shape assert mean.shape == (s.shape[0], self._action_dim) return D.Gaussian(mean=mean, ln_var=ln_var) @@ -205,9 +191,8 @@ def pi(self, s: nn.Variable) -> Distribution: with nn.parameter_scope(self.scope_name): h = NPF.affine(s, n_outmaps=30, name="linear1") h = NF.tanh(x=h) - h = NPF.affine(h, n_outmaps=self._action_dim*2, name="linear2") - reshaped = NF.reshape( - h, shape=(-1, 2, self._action_dim), inplace=False) + h = NPF.affine(h, n_outmaps=self._action_dim * 2, name="linear2") + reshaped = NF.reshape(h, shape=(-1, 2, self._action_dim), inplace=False) mean, ln_sigma = NF.split(reshaped, axis=1) assert mean.shape == ln_sigma.shape assert mean.shape == (s.shape[0], self._action_dim) @@ -233,20 +218,19 @@ def __init__(self, scope_name: str, action_dim: int): def pi(self, s: nn.Variable) -> Distribution: with nn.parameter_scope(self.scope_name): - h = NPF.affine(s, n_outmaps=64, name="linear1", - w_init=NI.OrthogonalInitializer(np.sqrt(2.))) + h = NPF.affine(s, n_outmaps=64, name="linear1", w_init=NI.OrthogonalInitializer(np.sqrt(2.0))) h = NF.tanh(x=h) - h = NPF.affine(h, n_outmaps=64, name="linear2", - w_init=NI.OrthogonalInitializer(np.sqrt(2.))) + h = NPF.affine(h, n_outmaps=64, name="linear2", w_init=NI.OrthogonalInitializer(np.sqrt(2.0))) h = NF.tanh(x=h) - mean = NPF.affine(h, n_outmaps=self._action_dim, name="linear3", - w_init=NI.OrthogonalInitializer(np.sqrt(2.))) + mean = NPF.affine( + h, n_outmaps=self._action_dim, name="linear3", w_init=NI.OrthogonalInitializer(np.sqrt(2.0)) + ) assert mean.shape == (s.shape[0], self._action_dim) ln_sigma = get_parameter_or_create( - "ln_sigma", shape=(1, self._action_dim), initializer=NI.ConstantInitializer(0.)) - ln_var = NF.broadcast( - ln_sigma, (s.shape[0], self._action_dim)) * 2.0 + "ln_sigma", shape=(1, self._action_dim), initializer=NI.ConstantInitializer(0.0) + ) + ln_var = NF.broadcast(ln_sigma, (s.shape[0], self._action_dim)) * 2.0 return D.Gaussian(mean, ln_var) @@ -268,23 +252,35 @@ def __init__(self, scope_name: str, action_dim: int): def pi(self, s: nn.Variable) -> Distribution: with nn.parameter_scope(self.scope_name): - h = NPF.affine(s, n_outmaps=64, name="linear1", - w_init=RI.HeUniform(inmaps=64, outmaps=64, factor=1./3.), - b_init=RI.HeUniform(inmaps=64, outmaps=64, factor=1./3.)) + h = NPF.affine( + s, + n_outmaps=64, + name="linear1", + w_init=RI.HeUniform(inmaps=64, outmaps=64, factor=1.0 / 3.0), + b_init=RI.HeUniform(inmaps=64, outmaps=64, factor=1.0 / 3.0), + ) h = NF.tanh(x=h) - h = NPF.affine(h, n_outmaps=64, name="linear2", - w_init=RI.HeUniform(inmaps=64, outmaps=64, factor=1./3.), - b_init=RI.HeUniform(inmaps=64, outmaps=64, factor=1./3.)) + h = NPF.affine( + h, + n_outmaps=64, + name="linear2", + w_init=RI.HeUniform(inmaps=64, outmaps=64, factor=1.0 / 3.0), + b_init=RI.HeUniform(inmaps=64, outmaps=64, factor=1.0 / 3.0), + ) h = NF.tanh(x=h) - mean = NPF.affine(h, n_outmaps=self._action_dim, name="linear3", - w_init=RI.HeUniform(inmaps=64, outmaps=self._action_dim, factor=0.01/3.), - b_init=NI.ConstantInitializer(0.)) + mean = NPF.affine( + h, + n_outmaps=self._action_dim, + name="linear3", + w_init=RI.HeUniform(inmaps=64, outmaps=self._action_dim, factor=0.01 / 3.0), + b_init=NI.ConstantInitializer(0.0), + ) assert mean.shape == (s.shape[0], self._action_dim) ln_sigma = get_parameter_or_create( - "ln_sigma", shape=(1, self._action_dim), initializer=NI.ConstantInitializer(-0.5)) - ln_var = NF.broadcast( - ln_sigma, (s.shape[0], self._action_dim)) * 2.0 + "ln_sigma", shape=(1, self._action_dim), initializer=NI.ConstantInitializer(-0.5) + ) + ln_var = NF.broadcast(ln_sigma, (s.shape[0], self._action_dim)) * 2.0 return D.Gaussian(mean, ln_var) @@ -300,20 +296,17 @@ def __init__(self, scope_name: str, action_dim: str): def pi(self, s: nn.Variable) -> Distribution: with nn.parameter_scope(self.scope_name): - h = NPF.affine(s, n_outmaps=100, name="linear1", - w_init=RI.NormcInitializer(std=1.0)) + h = NPF.affine(s, n_outmaps=100, name="linear1", w_init=RI.NormcInitializer(std=1.0)) h = NF.tanh(x=h) - h = NPF.affine(h, n_outmaps=100, name="linear2", - w_init=RI.NormcInitializer(std=1.0)) + h = NPF.affine(h, n_outmaps=100, name="linear2", w_init=RI.NormcInitializer(std=1.0)) h = NF.tanh(x=h) - mean = NPF.affine(h, n_outmaps=self._action_dim, name="linear3", - w_init=RI.NormcInitializer(std=0.01)) + mean = NPF.affine(h, n_outmaps=self._action_dim, name="linear3", w_init=RI.NormcInitializer(std=0.01)) assert mean.shape == (s.shape[0], self._action_dim) ln_sigma = get_parameter_or_create( - "ln_sigma", shape=(1, self._action_dim), initializer=NI.ConstantInitializer(0.)) - ln_var = NF.broadcast( - ln_sigma, (s.shape[0], self._action_dim)) * 2.0 + "ln_sigma", shape=(1, self._action_dim), initializer=NI.ConstantInitializer(0.0) + ) + ln_var = NF.broadcast(ln_sigma, (s.shape[0], self._action_dim)) * 2.0 return D.Gaussian(mean, ln_var) @@ -329,16 +322,16 @@ def pi(self, s: Tuple[nn.Variable, nn.Variable, nn.Variable]) -> nn.Variable: with nn.parameter_scope(self.scope_name): h = NF.concatenate(obs, goal, axis=1) linear1_init = RI.GlorotUniform(inmaps=h.shape[1], outmaps=64) - h = NPF.affine(h, n_outmaps=64, name='linear1', w_init=linear1_init) + h = NPF.affine(h, n_outmaps=64, name="linear1", w_init=linear1_init) h = NF.relu(h) linear2_init = RI.GlorotUniform(inmaps=h.shape[1], outmaps=64) - h = NPF.affine(h, n_outmaps=64, name='linear2', w_init=linear2_init) + h = NPF.affine(h, n_outmaps=64, name="linear2", w_init=linear2_init) h = NF.relu(h) linear3_init = RI.GlorotUniform(inmaps=h.shape[1], outmaps=64) - h = NPF.affine(h, n_outmaps=64, name='linear3', w_init=linear3_init) + h = NPF.affine(h, n_outmaps=64, name="linear3", w_init=linear3_init) h = NF.relu(h) action_init = RI.GlorotUniform(inmaps=h.shape[1], outmaps=self._action_dim) - h = NPF.affine(h, n_outmaps=self._action_dim, name='action', w_init=action_init) + h = NPF.affine(h, n_outmaps=self._action_dim, name="action", w_init=action_init) return NF.tanh(h) * self._max_action_value @@ -357,12 +350,14 @@ class XQLPolicy(StochasticPolicy): _min_log_sigma: float _max_log_sigma: float - def __init__(self, - scope_name: str, - action_dim: int, - clip_log_sigma: bool = True, - min_log_sigma: float = -5.0, - max_log_sigma: float = 2.0): + def __init__( + self, + scope_name: str, + action_dim: int, + clip_log_sigma: bool = True, + min_log_sigma: float = -5.0, + max_log_sigma: float = 2.0, + ): super(XQLPolicy, self).__init__(scope_name) self._action_dim = action_dim self._clip_log_sigma = clip_log_sigma @@ -380,8 +375,9 @@ def pi(self, s: nn.Variable) -> Distribution: mean = NF.tanh(h) # create parameter with shape (1, self._action_dim) # because these parameters should be independent from states (and from batches also) - ln_sigma = get_parameter_or_create("ln_sigma", shape=(1, self._action_dim), - initializer=NI.ConstantInitializer(0.)) + ln_sigma = get_parameter_or_create( + "ln_sigma", shape=(1, self._action_dim), initializer=NI.ConstantInitializer(0.0) + ) ln_sigma = NF.broadcast(ln_sigma, (s.shape[0], self._action_dim)) assert mean.shape == ln_sigma.shape assert mean.shape == (s.shape[0], self._action_dim) diff --git a/nnabla_rl/models/mujoco/q_functions.py b/nnabla_rl/models/mujoco/q_functions.py index 27f0c85d..bed1b8b5 100644 --- a/nnabla_rl/models/mujoco/q_functions.py +++ b/nnabla_rl/models/mujoco/q_functions.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -45,24 +45,18 @@ def __init__(self, scope_name: str, optimal_policy: Optional[DeterministicPolicy def q(self, s: nn.Variable, a: nn.Variable) -> nn.Variable: with nn.parameter_scope(self.scope_name): h = NF.concatenate(s, a) - linear1_init = RI.HeUniform( - inmaps=h.shape[1], outmaps=400, factor=1/3) - h = NPF.affine(h, n_outmaps=400, name="linear1", - w_init=linear1_init, b_init=linear1_init) + linear1_init = RI.HeUniform(inmaps=h.shape[1], outmaps=400, factor=1 / 3) + h = NPF.affine(h, n_outmaps=400, name="linear1", w_init=linear1_init, b_init=linear1_init) h = NF.relu(x=h) - linear2_init = RI.HeUniform( - inmaps=400, outmaps=300, factor=1/3) - h = NPF.affine(h, n_outmaps=300, name="linear2", - w_init=linear2_init, b_init=linear2_init) + linear2_init = RI.HeUniform(inmaps=400, outmaps=300, factor=1 / 3) + h = NPF.affine(h, n_outmaps=300, name="linear2", w_init=linear2_init, b_init=linear2_init) h = NF.relu(x=h) - linear3_init = RI.HeUniform( - inmaps=300, outmaps=1, factor=1/3) - h = NPF.affine(h, n_outmaps=1, name="linear3", - w_init=linear3_init, b_init=linear3_init) + linear3_init = RI.HeUniform(inmaps=300, outmaps=1, factor=1 / 3) + h = NPF.affine(h, n_outmaps=1, name="linear3", w_init=linear3_init, b_init=linear3_init) return h def max_q(self, s: nn.Variable) -> nn.Variable: - assert self._optimal_policy, 'Optimal policy is not set!' + assert self._optimal_policy, "Optimal policy is not set!" optimal_action = self._optimal_policy.pi(s) return self.q(s, optimal_action) @@ -94,7 +88,7 @@ def q(self, s: nn.Variable, a: nn.Variable) -> nn.Variable: return h def max_q(self, s: nn.Variable) -> nn.Variable: - assert self._optimal_policy, 'Optimal policy is not set!' + assert self._optimal_policy, "Optimal policy is not set!" optimal_action = self._optimal_policy.pi(s) return self.q(s, optimal_action) @@ -131,7 +125,7 @@ def factored_q(self, s: nn.Variable, a: nn.Variable) -> nn.Variable: return h def max_q(self, s: nn.Variable) -> nn.Variable: - assert self._optimal_policy, 'Optimal policy is not set!' + assert self._optimal_policy, "Optimal policy is not set!" optimal_action = self._optimal_policy.pi(s) return self.q(s, optimal_action) @@ -150,20 +144,20 @@ def q(self, s: Tuple[nn.Variable, nn.Variable, nn.Variable], a: nn.Variable) -> with nn.parameter_scope(self.scope_name): h = NF.concatenate(obs, goal, a, axis=1) linear1_init = RI.GlorotUniform(inmaps=h.shape[1], outmaps=64) - h = NPF.affine(h, n_outmaps=64, name='linear1', w_init=linear1_init) + h = NPF.affine(h, n_outmaps=64, name="linear1", w_init=linear1_init) h = NF.relu(h) linear2_init = RI.GlorotUniform(inmaps=h.shape[1], outmaps=64) - h = NPF.affine(h, n_outmaps=64, name='linear2', w_init=linear2_init) + h = NPF.affine(h, n_outmaps=64, name="linear2", w_init=linear2_init) h = NF.relu(h) linear3_init = RI.GlorotUniform(inmaps=h.shape[1], outmaps=64) - h = NPF.affine(h, n_outmaps=64, name='linear3', w_init=linear3_init) + h = NPF.affine(h, n_outmaps=64, name="linear3", w_init=linear3_init) h = NF.relu(h) pred_q_init = RI.GlorotUniform(inmaps=h.shape[1], outmaps=1) - q = NPF.affine(h, n_outmaps=1, name='pred_q', w_init=pred_q_init) + q = NPF.affine(h, n_outmaps=1, name="pred_q", w_init=pred_q_init) return q def max_q(self, s: nn.Variable) -> nn.Variable: - assert self._optimal_policy, 'Optimal policy is not set!' + assert self._optimal_policy, "Optimal policy is not set!" optimal_action = self._optimal_policy.pi(s) return self.q(s, optimal_action) @@ -197,6 +191,6 @@ def q(self, s: nn.Variable, a: nn.Variable) -> nn.Variable: return h def max_q(self, s: nn.Variable) -> nn.Variable: - assert self._optimal_policy, 'Optimal policy is not set!' + assert self._optimal_policy, "Optimal policy is not set!" optimal_action = self._optimal_policy.pi(s) return self.q(s, optimal_action) diff --git a/nnabla_rl/models/mujoco/reward_functions.py b/nnabla_rl/models/mujoco/reward_functions.py index edba2647..da033f18 100644 --- a/nnabla_rl/models/mujoco/reward_functions.py +++ b/nnabla_rl/models/mujoco/reward_functions.py @@ -1,5 +1,5 @@ # Copyright 2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,19 +31,16 @@ def __init__(self, scope_name: str): super(GAILDiscriminator, self).__init__(scope_name) def r(self, s_current: nn.Variable, a_current: nn.Variable, s_next: nn.Variable) -> nn.Variable: - ''' + """ Notes: In gail, we don't use the next state. - ''' + """ h = NF.concatenate(s_current, a_current, axis=1) with nn.parameter_scope(self.scope_name): - h = NPF.affine(h, n_outmaps=100, name="linear1", - w_init=RI.GlorotUniform(h.shape[1], 100)) + h = NPF.affine(h, n_outmaps=100, name="linear1", w_init=RI.GlorotUniform(h.shape[1], 100)) h = NF.tanh(x=h) - h = NPF.affine(h, n_outmaps=100, name="linear2", - w_init=RI.GlorotUniform(h.shape[1], 100)) + h = NPF.affine(h, n_outmaps=100, name="linear2", w_init=RI.GlorotUniform(h.shape[1], 100)) h = NF.tanh(x=h) - h = NPF.affine(h, n_outmaps=1, name="linear3", - w_init=RI.GlorotUniform(h.shape[1], 1)) + h = NPF.affine(h, n_outmaps=1, name="linear3", w_init=RI.GlorotUniform(h.shape[1], 1)) return h diff --git a/nnabla_rl/models/mujoco/v_functions.py b/nnabla_rl/models/mujoco/v_functions.py index 00e8370a..8143a607 100644 --- a/nnabla_rl/models/mujoco/v_functions.py +++ b/nnabla_rl/models/mujoco/v_functions.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -49,14 +49,11 @@ class TRPOVFunction(VFunction): def v(self, s: nn.Variable) -> nn.Variable: with nn.parameter_scope(self.scope_name): - h = NPF.affine(s, n_outmaps=64, name="linear1", - w_init=NI.OrthogonalInitializer(np.sqrt(2.))) + h = NPF.affine(s, n_outmaps=64, name="linear1", w_init=NI.OrthogonalInitializer(np.sqrt(2.0))) h = NF.tanh(x=h) - h = NPF.affine(h, n_outmaps=64, name="linear2", - w_init=NI.OrthogonalInitializer(np.sqrt(2.))) + h = NPF.affine(h, n_outmaps=64, name="linear2", w_init=NI.OrthogonalInitializer(np.sqrt(2.0))) h = NF.tanh(x=h) - h = NPF.affine(h, n_outmaps=1, name="linear3", - w_init=NI.OrthogonalInitializer(np.sqrt(2.))) + h = NPF.affine(h, n_outmaps=1, name="linear3", w_init=NI.OrthogonalInitializer(np.sqrt(2.0))) return h @@ -71,16 +68,13 @@ class PPOVFunction(VFunction): def v(self, s: nn.Variable) -> nn.Variable: with nn.parameter_scope(self.scope_name): with nn.parameter_scope("linear1"): - h = NPF.affine(s, n_outmaps=64, - w_init=RI.NormcInitializer(std=1.0)) + h = NPF.affine(s, n_outmaps=64, w_init=RI.NormcInitializer(std=1.0)) h = NF.tanh(x=h) with nn.parameter_scope("linear2"): - h = NPF.affine(h, n_outmaps=64, - w_init=RI.NormcInitializer(std=1.0)) + h = NPF.affine(h, n_outmaps=64, w_init=RI.NormcInitializer(std=1.0)) h = NF.tanh(x=h) with nn.parameter_scope("linear_v"): - v = NPF.affine(h, n_outmaps=1, - w_init=RI.NormcInitializer(std=1.0)) + v = NPF.affine(h, n_outmaps=1, w_init=RI.NormcInitializer(std=1.0)) return v @@ -95,14 +89,11 @@ def __init__(self, scope_name: str): def v(self, s: nn.Variable) -> nn.Variable: with nn.parameter_scope(self.scope_name): - h = NPF.affine(s, n_outmaps=100, name="linear1", - w_init=RI.NormcInitializer(std=1.0)) + h = NPF.affine(s, n_outmaps=100, name="linear1", w_init=RI.NormcInitializer(std=1.0)) h = NF.tanh(x=h) - h = NPF.affine(h, n_outmaps=100, name="linear2", - w_init=RI.NormcInitializer(std=1.0)) + h = NPF.affine(h, n_outmaps=100, name="linear2", w_init=RI.NormcInitializer(std=1.0)) h = NF.tanh(x=h) - h = NPF.affine(h, n_outmaps=1, name="linear3", - w_init=RI.NormcInitializer(std=1.0)) + h = NPF.affine(h, n_outmaps=1, name="linear3", w_init=RI.NormcInitializer(std=1.0)) return h @@ -115,17 +106,29 @@ class ATRPOVFunction(VFunction): def v(self, s: nn.Variable) -> nn.Variable: with nn.parameter_scope(self.scope_name): - h = NPF.affine(s, n_outmaps=64, name="linear1", - w_init=RI.HeUniform(inmaps=64, outmaps=64, factor=1./3.), - b_init=RI.HeUniform(inmaps=64, outmaps=64, factor=1./3.)) + h = NPF.affine( + s, + n_outmaps=64, + name="linear1", + w_init=RI.HeUniform(inmaps=64, outmaps=64, factor=1.0 / 3.0), + b_init=RI.HeUniform(inmaps=64, outmaps=64, factor=1.0 / 3.0), + ) h = NF.tanh(x=h) - h = NPF.affine(h, n_outmaps=64, name="linear2", - w_init=RI.HeUniform(inmaps=64, outmaps=64, factor=1./3.), - b_init=RI.HeUniform(inmaps=64, outmaps=64, factor=1./3.)) + h = NPF.affine( + h, + n_outmaps=64, + name="linear2", + w_init=RI.HeUniform(inmaps=64, outmaps=64, factor=1.0 / 3.0), + b_init=RI.HeUniform(inmaps=64, outmaps=64, factor=1.0 / 3.0), + ) h = NF.tanh(x=h) - h = NPF.affine(h, n_outmaps=1, name="linear3", - w_init=RI.HeUniform(inmaps=64, outmaps=1, factor=0.01/3.), - b_init=NI.ConstantInitializer(0.)) + h = NPF.affine( + h, + n_outmaps=1, + name="linear3", + w_init=RI.HeUniform(inmaps=64, outmaps=1, factor=0.01 / 3.0), + b_init=NI.ConstantInitializer(0.0), + ) return h @@ -139,16 +142,16 @@ class XQLVFunction(VFunction): def v(self, s: nn.Variable) -> nn.Variable: w_init = NI.OrthogonalInitializer(np.sqrt(2.0)) with nn.parameter_scope(self.scope_name): - with nn.parameter_scope('linear1'): + with nn.parameter_scope("linear1"): h = NPF.affine(s, n_outmaps=256, w_init=w_init) - with nn.parameter_scope('layer_norm1'): + with nn.parameter_scope("layer_norm1"): h = NPF.layer_normalization(h, eps=1e-6) h = NF.relu(x=h) - with nn.parameter_scope('linear2'): + with nn.parameter_scope("linear2"): h = NPF.affine(h, n_outmaps=256, w_init=w_init) - with nn.parameter_scope('layer_norm2'): + with nn.parameter_scope("layer_norm2"): h = NPF.layer_normalization(h, eps=1e-6) h = NF.relu(x=h) - with nn.parameter_scope('linear3'): + with nn.parameter_scope("linear3"): h = NPF.affine(h, n_outmaps=1, w_init=w_init) return h diff --git a/nnabla_rl/models/policy.py b/nnabla_rl/models/policy.py index 72a77000..79bea520 100644 --- a/nnabla_rl/models/policy.py +++ b/nnabla_rl/models/policy.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,16 +30,17 @@ class DeterministicPolicy(Policy, metaclass=ABCMeta): This policy returns an action for the given state. """ + @abstractmethod def pi(self, s: nn.Variable) -> nn.Variable: - '''pi + """pi Args: state (nnabla.Variable): State variable Returns: nnabla.Variable : Action for the given state - ''' + """ raise NotImplementedError @@ -49,14 +50,15 @@ class StochasticPolicy(Policy, metaclass=ABCMeta): This policy returns a probability distribution of action for the given state. """ + @abstractmethod def pi(self, s: nn.Variable) -> Distribution: - '''pi + """pi Args: state (nnabla.Variable): State variable Returns: nnabla_rl.distributions.Distribution: Probability distribution of the action for the given state - ''' + """ raise NotImplementedError diff --git a/nnabla_rl/models/pybullet/policy.py b/nnabla_rl/models/pybullet/policy.py index 4ecef435..9bf4f0f4 100644 --- a/nnabla_rl/models/pybullet/policy.py +++ b/nnabla_rl/models/pybullet/policy.py @@ -51,24 +51,27 @@ def pi(self, s: Tuple[nn.Variable, nn.Variable, nn.Variable]) -> Distribution: s_for_pi_v, _, _ = s batch_size = s_for_pi_v.shape[0] with nn.parameter_scope(self.scope_name): - h = NPF.affine(s_for_pi_v, - n_outmaps=1024, - name="linear1", - w_init=RI.GlorotUniform(inmaps=s_for_pi_v.shape[1], outmaps=1024)) + h = NPF.affine( + s_for_pi_v, + n_outmaps=1024, + name="linear1", + w_init=RI.GlorotUniform(inmaps=s_for_pi_v.shape[1], outmaps=1024), + ) h = NF.relu(x=h) - h = NPF.affine(h, - n_outmaps=512, - name="linear2", - w_init=RI.GlorotUniform(inmaps=h.shape[1], outmaps=512)) + h = NPF.affine(h, n_outmaps=512, name="linear2", w_init=RI.GlorotUniform(inmaps=h.shape[1], outmaps=512)) h = NF.relu(x=h) - mean = NPF.affine(h, - n_outmaps=self._action_dim, - name="linear3_mean", - w_init=NI.UniformInitializer((-1.0 * self._output_layer_initializer_scale, - self._output_layer_initializer_scale)), - b_init=NI.ConstantInitializer(0.0)) + mean = NPF.affine( + h, + n_outmaps=self._action_dim, + name="linear3_mean", + w_init=NI.UniformInitializer( + (-1.0 * self._output_layer_initializer_scale, self._output_layer_initializer_scale) + ), + b_init=NI.ConstantInitializer(0.0), + ) ln_sigma = nn.Variable.from_numpy_array( - np.ones((batch_size, self._action_dim), dtype=np.float32) * np.log(0.05)) + np.ones((batch_size, self._action_dim), dtype=np.float32) * np.log(0.05) + ) ln_var = ln_sigma * 2.0 assert mean.shape == ln_var.shape assert mean.shape == (s_for_pi_v.shape[0], self._action_dim) @@ -95,17 +98,20 @@ def __init__(self, scope_name: str, action_dim: int, output_layer_initializer_sc assert output_layer_initializer_scale > 0.0, f"{output_layer_initializer_scale} should be larger than 0.0" self._output_layer_initializer_scale = output_layer_initializer_scale - def pi(self, s: Tuple[nn.Variable, nn.Variable, nn.Variable, nn.Variable, nn.Variable, nn.Variable, nn.Variable]) \ - -> Distribution: + def pi( + self, s: Tuple[nn.Variable, nn.Variable, nn.Variable, nn.Variable, nn.Variable, nn.Variable, nn.Variable] + ) -> Distribution: assert len(s) == 7 s_for_pi_v, _, _, goal, *_ = s batch_size = s_for_pi_v.shape[0] with nn.parameter_scope(self.scope_name): # gated block - g_z = NPF.affine(goal, - n_outmaps=128, - name="gate_initial_linear", - w_init=RI.GlorotUniform(inmaps=goal.shape[1], outmaps=128)) + g_z = NPF.affine( + goal, + n_outmaps=128, + name="gate_initial_linear", + w_init=RI.GlorotUniform(inmaps=goal.shape[1], outmaps=128), + ) g_z = NF.relu(g_z) h = NF.concatenate(s_for_pi_v, goal, axis=-1) @@ -113,17 +119,22 @@ def pi(self, s: Tuple[nn.Variable, nn.Variable, nn.Variable, nn.Variable, nn.Var # block1 h = NPF.affine(h, n_outmaps=1024, name="linear_1", w_init=RI.GlorotUniform(inmaps=h.shape[1], outmaps=1024)) - gate_h = NPF.affine(g_z, n_outmaps=64, name="gate_linear1", - w_init=RI.GlorotUniform(inmaps=g_z.shape[1], outmaps=64)) + gate_h = NPF.affine( + g_z, n_outmaps=64, name="gate_linear1", w_init=RI.GlorotUniform(inmaps=g_z.shape[1], outmaps=64) + ) gate_h = NF.relu(gate_h) - gate_h_bias = NPF.affine(gate_h, - n_outmaps=1024, - name="gate_bias_linear1", - w_init=RI.GlorotUniform(inmaps=gate_h.shape[1], outmaps=1024)) - gate_h_scale = NPF.affine(gate_h, - n_outmaps=1024, - name="gate_scale_linear1", - w_init=RI.GlorotUniform(inmaps=gate_h.shape[1], outmaps=1024)) + gate_h_bias = NPF.affine( + gate_h, + n_outmaps=1024, + name="gate_bias_linear1", + w_init=RI.GlorotUniform(inmaps=gate_h.shape[1], outmaps=1024), + ) + gate_h_scale = NPF.affine( + gate_h, + n_outmaps=1024, + name="gate_scale_linear1", + w_init=RI.GlorotUniform(inmaps=gate_h.shape[1], outmaps=1024), + ) h = h * 2.0 * NF.sigmoid(gate_h_scale) + gate_h_bias h = NF.relu(h) @@ -131,30 +142,39 @@ def pi(self, s: Tuple[nn.Variable, nn.Variable, nn.Variable, nn.Variable, nn.Var # block2 h = NPF.affine(h, n_outmaps=512, name="linear_2", w_init=RI.GlorotUniform(inmaps=h.shape[1], outmaps=512)) - gate_h = NPF.affine(g_z, n_outmaps=64, name="gate_linear2", - w_init=RI.GlorotUniform(inmaps=h.shape[1], outmaps=64)) + gate_h = NPF.affine( + g_z, n_outmaps=64, name="gate_linear2", w_init=RI.GlorotUniform(inmaps=h.shape[1], outmaps=64) + ) gate_h = NF.relu(gate_h) - gate_h_bias = NPF.affine(gate_h, - n_outmaps=512, - name="gate_bias_linear2", - w_init=RI.GlorotUniform(inmaps=gate_h.shape[1], outmaps=512)) - gate_h_scale = NPF.affine(gate_h, - n_outmaps=512, - name="gate_scale_linear2", - w_init=RI.GlorotUniform(inmaps=gate_h.shape[1], outmaps=512)) + gate_h_bias = NPF.affine( + gate_h, + n_outmaps=512, + name="gate_bias_linear2", + w_init=RI.GlorotUniform(inmaps=gate_h.shape[1], outmaps=512), + ) + gate_h_scale = NPF.affine( + gate_h, + n_outmaps=512, + name="gate_scale_linear2", + w_init=RI.GlorotUniform(inmaps=gate_h.shape[1], outmaps=512), + ) h = h * 2.0 * NF.sigmoid(gate_h_scale) + gate_h_bias h = NF.relu(h) # output block - mean = NPF.affine(h, - n_outmaps=self._action_dim, - name="linear3_mean", - w_init=NI.UniformInitializer((-1.0 * self._output_layer_initializer_scale, - self._output_layer_initializer_scale)), - b_init=NI.ConstantInitializer(0.0)) + mean = NPF.affine( + h, + n_outmaps=self._action_dim, + name="linear3_mean", + w_init=NI.UniformInitializer( + (-1.0 * self._output_layer_initializer_scale, self._output_layer_initializer_scale) + ), + b_init=NI.ConstantInitializer(0.0), + ) ln_sigma = nn.Variable.from_numpy_array( - np.ones((batch_size, self._action_dim), dtype=np.float32) * np.log(0.05)) + np.ones((batch_size, self._action_dim), dtype=np.float32) * np.log(0.05) + ) ln_var = ln_sigma * 2.0 assert mean.shape == ln_var.shape assert mean.shape == (batch_size, self._action_dim) diff --git a/nnabla_rl/models/pybullet/q_functions.py b/nnabla_rl/models/pybullet/q_functions.py index 98ccc942..bfe110ca 100644 --- a/nnabla_rl/models/pybullet/q_functions.py +++ b/nnabla_rl/models/pybullet/q_functions.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -42,15 +42,16 @@ def __init__( cem_num_elites: int = 50, cem_num_iterations: int = 100, cem_alpha: float = 0.0, - random_sample_size: int = 500 + random_sample_size: int = 500, ): super(ICRA2018QtOptQFunction, self).__init__(scope_name) self._action_high = action_high self._action_low = action_low self._cem_initial_mean_numpy = np.zeros(action_dim) if cem_initial_mean is None else np.array(cem_initial_mean) - self._cem_initial_variance_numpy = 0.5 * \ - np.ones(action_dim) if cem_initial_variance is None else np.array(cem_initial_variance) + self._cem_initial_variance_numpy = ( + 0.5 * np.ones(action_dim) if cem_initial_variance is None else np.array(cem_initial_variance) + ) self._cem_sample_size = cem_sample_size self._cem_num_elites = cem_num_elites self._cem_num_iterations = cem_num_iterations @@ -67,31 +68,31 @@ def q(self, s: Tuple[nn.Variable, nn.Variable], a: nn.Variable) -> nn.Variable: tiled_time_step = NF.reshape(tiled_time_step, (batch_size, 1, 7, 7)) with nn.parameter_scope(self.scope_name): - with nn.parameter_scope('state_conv1'): + with nn.parameter_scope("state_conv1"): h = NF.relu(NPF.convolution(image, 32, (3, 3), stride=(2, 2))) - with nn.parameter_scope('state_conv2'): + with nn.parameter_scope("state_conv2"): h = NF.relu(NPF.convolution(h, 32, (3, 3), stride=(2, 2))) - with nn.parameter_scope('state_conv3'): + with nn.parameter_scope("state_conv3"): h = NF.relu(NPF.convolution(h, 32, (3, 3), stride=(2, 2))) encoded_state = NF.concatenate(tiled_time_step, h, axis=1) - with nn.parameter_scope('action_affine1'): + with nn.parameter_scope("action_affine1"): h = NF.relu(NPF.affine(a, 33)) encoded_action = NF.reshape(h, (batch_size, 33, 1, 1)) h = encoded_state + encoded_action h = NF.reshape(h, (batch_size, -1)) - with nn.parameter_scope('affine1'): + with nn.parameter_scope("affine1"): h = NF.relu(NPF.affine(h, 32)) - with nn.parameter_scope('affine2'): + with nn.parameter_scope("affine2"): h = NF.relu(NPF.affine(h, 32)) - with nn.parameter_scope('affine3'): + with nn.parameter_scope("affine3"): q_value = NPF.affine(h, 1) return q_value @@ -106,7 +107,7 @@ def argmax_q(self, s: Tuple[nn.Variable, nn.Variable]) -> nn.Variable: def objective_function(a: nn.Variable) -> nn.Variable: batch_size, sample_size, action_dim = a.shape - a = a.reshape((batch_size*sample_size, action_dim)) + a = a.reshape((batch_size * sample_size, action_dim)) q_value = self.q(tiled_s, a) q_value = q_value.reshape((batch_size, sample_size, 1)) return q_value @@ -114,7 +115,8 @@ def objective_function(a: nn.Variable) -> nn.Variable: if is_eval_scope(): initial_mean_var = nn.Variable.from_numpy_array(np.tile(self._cem_initial_mean_numpy, (batch_size, 1))) initial_variance_var = nn.Variable.from_numpy_array( - np.tile(self._cem_initial_variance_numpy, (batch_size, 1))) + np.tile(self._cem_initial_variance_numpy, (batch_size, 1)) + ) optimized_action, _ = RF.gaussian_cross_entropy_method( objective_function, initial_mean_var, @@ -122,7 +124,7 @@ def objective_function(a: nn.Variable) -> nn.Variable: sample_size=self._cem_sample_size, num_elites=self._cem_num_elites, num_iterations=self._cem_num_iterations, - alpha=self._cem_alpha + alpha=self._cem_alpha, ) else: upper_bound = np.tile(self._action_high, (batch_size, 1)) @@ -131,15 +133,24 @@ def objective_function(a: nn.Variable) -> nn.Variable: objective_function, upper_bound=upper_bound, lower_bound=lower_bound, - sample_size=self._random_sample_size + sample_size=self._random_sample_size, ) return optimized_action def _tile_state(self, s: nn.Variable, tile_size: int) -> nn.Variable: - tile_reps = [tile_size, ] + [1, ] * len(s.shape) + tile_reps = [ + tile_size, + ] + [ + 1, + ] * len(s.shape) s = NF.tile(s, tile_reps) - transpose_reps = [1, 0, ] + list(range(len(s.shape)))[2:] + transpose_reps = [ + 1, + 0, + ] + list( + range(len(s.shape)) + )[2:] s = NF.transpose(s, transpose_reps) s = NF.reshape(s, (-1, *s.shape[2:])) return s diff --git a/nnabla_rl/models/pybullet/reward_functions.py b/nnabla_rl/models/pybullet/reward_functions.py index 5d168500..6b12d86a 100644 --- a/nnabla_rl/models/pybullet/reward_functions.py +++ b/nnabla_rl/models/pybullet/reward_functions.py @@ -34,8 +34,9 @@ def __init__(self, scope_name: str, output_layer_initializer_scale: float): assert output_layer_initializer_scale > 0.0, f"{output_layer_initializer_scale} should be larger than 0.0" self._output_layer_initializer_scale = output_layer_initializer_scale - def r(self, s_current: Tuple[nn.Variable, ...], a_current: nn.Variable, s_next: Tuple[nn.Variable, ...] - ) -> nn.Variable: + def r( + self, s_current: Tuple[nn.Variable, ...], a_current: nn.Variable, s_next: Tuple[nn.Variable, ...] + ) -> nn.Variable: assert len(s_current) == 7 or len(s_current) == 3 for _s in s_current[1:]: assert s_current[0].shape[0] == _s.shape[0] @@ -45,19 +46,21 @@ def r(self, s_current: Tuple[nn.Variable, ...], a_current: nn.Variable, s_next: # NOTE: s_for_reward has s and s_next. # See author's enviroment implmentation and our env wrapper implementation. with nn.parameter_scope(self.scope_name): - h = NPF.affine(s_for_reward, - n_outmaps=1024, - name="linear1", - w_init=RI.GlorotUniform(inmaps=s_for_reward.shape[1], outmaps=1024)) + h = NPF.affine( + s_for_reward, + n_outmaps=1024, + name="linear1", + w_init=RI.GlorotUniform(inmaps=s_for_reward.shape[1], outmaps=1024), + ) h = NF.relu(x=h) - h = NPF.affine(h, - n_outmaps=512, - name="linear2", - w_init=RI.GlorotUniform(inmaps=h.shape[1], outmaps=512)) + h = NPF.affine(h, n_outmaps=512, name="linear2", w_init=RI.GlorotUniform(inmaps=h.shape[1], outmaps=512)) h = NF.relu(x=h) - h = NPF.affine(h, - n_outmaps=1, - name="logits", - w_init=NI.UniformInitializer((-1.0 * self._output_layer_initializer_scale, - self._output_layer_initializer_scale))) + h = NPF.affine( + h, + n_outmaps=1, + name="logits", + w_init=NI.UniformInitializer( + (-1.0 * self._output_layer_initializer_scale, self._output_layer_initializer_scale) + ), + ) return h diff --git a/nnabla_rl/models/pybullet/v_functions.py b/nnabla_rl/models/pybullet/v_functions.py index dd438212..34278b9b 100644 --- a/nnabla_rl/models/pybullet/v_functions.py +++ b/nnabla_rl/models/pybullet/v_functions.py @@ -34,10 +34,12 @@ def v(self, s: nn.Variable) -> nn.Variable: assert len(s) == 3 s_for_pi_v, _, _ = s with nn.parameter_scope(self.scope_name): - h = NPF.affine(s_for_pi_v, - n_outmaps=1024, - name="linear1", - w_init=RI.GlorotUniform(inmaps=s_for_pi_v.shape[1], outmaps=1024)) + h = NPF.affine( + s_for_pi_v, + n_outmaps=1024, + name="linear1", + w_init=RI.GlorotUniform(inmaps=s_for_pi_v.shape[1], outmaps=1024), + ) h = NF.relu(x=h) h = NPF.affine(h, n_outmaps=512, name="linear2", w_init=RI.GlorotUniform(inmaps=h.shape[1], outmaps=512)) h = NF.relu(x=h) @@ -54,16 +56,19 @@ class AMPGatedVFunction(VFunction): def __init__(self, scope_name: str): super(AMPGatedVFunction, self).__init__(scope_name) - def v(self, s: Tuple[nn.Variable, nn.Variable, nn.Variable, nn.Variable, nn.Variable, nn.Variable, nn.Variable]) \ - -> nn.Variable: + def v( + self, s: Tuple[nn.Variable, nn.Variable, nn.Variable, nn.Variable, nn.Variable, nn.Variable, nn.Variable] + ) -> nn.Variable: assert len(s) == 7 s_for_pi_v, _, _, goal, *_ = s with nn.parameter_scope(self.scope_name): # gated block - g_z = NPF.affine(goal, - n_outmaps=128, - name="gate_initial_linear", - w_init=RI.GlorotUniform(inmaps=goal.shape[1], outmaps=128)) + g_z = NPF.affine( + goal, + n_outmaps=128, + name="gate_initial_linear", + w_init=RI.GlorotUniform(inmaps=goal.shape[1], outmaps=128), + ) g_z = NF.relu(g_z) h = NF.concatenate(s_for_pi_v, goal, axis=-1) @@ -71,17 +76,22 @@ def v(self, s: Tuple[nn.Variable, nn.Variable, nn.Variable, nn.Variable, nn.Vari # block1 h = NPF.affine(h, n_outmaps=1024, name="linear_1", w_init=RI.GlorotUniform(inmaps=h.shape[1], outmaps=1024)) - gate_h = NPF.affine(g_z, n_outmaps=64, name="gate_linear1", - w_init=RI.GlorotUniform(inmaps=g_z.shape[1], outmaps=64)) + gate_h = NPF.affine( + g_z, n_outmaps=64, name="gate_linear1", w_init=RI.GlorotUniform(inmaps=g_z.shape[1], outmaps=64) + ) gate_h = NF.relu(gate_h) - gate_h_bias = NPF.affine(gate_h, - n_outmaps=1024, - name="gate_bias_linear1", - w_init=RI.GlorotUniform(inmaps=gate_h.shape[1], outmaps=1024)) - gate_h_scale = NPF.affine(gate_h, - n_outmaps=1024, - name="gate_scale_linear1", - w_init=RI.GlorotUniform(inmaps=gate_h.shape[1], outmaps=1024)) + gate_h_bias = NPF.affine( + gate_h, + n_outmaps=1024, + name="gate_bias_linear1", + w_init=RI.GlorotUniform(inmaps=gate_h.shape[1], outmaps=1024), + ) + gate_h_scale = NPF.affine( + gate_h, + n_outmaps=1024, + name="gate_scale_linear1", + w_init=RI.GlorotUniform(inmaps=gate_h.shape[1], outmaps=1024), + ) h = h * 2.0 * NF.sigmoid(gate_h_scale) + gate_h_bias h = NF.relu(h) @@ -89,17 +99,22 @@ def v(self, s: Tuple[nn.Variable, nn.Variable, nn.Variable, nn.Variable, nn.Vari # block2 h = NPF.affine(h, n_outmaps=512, name="linear_2", w_init=RI.GlorotUniform(inmaps=h.shape[1], outmaps=512)) - gate_h = NPF.affine(g_z, n_outmaps=64, name="gate_linear2", - w_init=RI.GlorotUniform(inmaps=g_z.shape[1], outmaps=64)) + gate_h = NPF.affine( + g_z, n_outmaps=64, name="gate_linear2", w_init=RI.GlorotUniform(inmaps=g_z.shape[1], outmaps=64) + ) gate_h = NF.relu(gate_h) - gate_h_bias = NPF.affine(gate_h, - n_outmaps=512, - name="gate_bias_linear2", - w_init=RI.GlorotUniform(inmaps=gate_h.shape[1], outmaps=512)) - gate_h_scale = NPF.affine(gate_h, - n_outmaps=512, - name="gate_scale_linear2", - w_init=RI.GlorotUniform(inmaps=gate_h.shape[1], outmaps=512)) + gate_h_bias = NPF.affine( + gate_h, + n_outmaps=512, + name="gate_bias_linear2", + w_init=RI.GlorotUniform(inmaps=gate_h.shape[1], outmaps=512), + ) + gate_h_scale = NPF.affine( + gate_h, + n_outmaps=512, + name="gate_scale_linear2", + w_init=RI.GlorotUniform(inmaps=gate_h.shape[1], outmaps=512), + ) h = h * 2.0 * NF.sigmoid(gate_h_scale) + gate_h_bias h = NF.relu(h) diff --git a/nnabla_rl/models/q_function.py b/nnabla_rl/models/q_function.py index 54ddf7fc..58ec972d 100644 --- a/nnabla_rl/models/q_function.py +++ b/nnabla_rl/models/q_function.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ class QFunction(Model, metaclass=ABCMeta): """Base QFunction Class.""" + @abstractmethod def q(self, s: nn.Variable, a: nn.Variable) -> nn.Variable: """Compute Q-value for given state and action. @@ -72,15 +73,16 @@ def argmax_q(self, s: nn.Variable) -> nn.Variable: class DiscreteQFunction(QFunction): """Base QFunction Class for discrete action environment.""" + @abstractmethod def all_q(self, s: nn.Variable) -> nn.Variable: raise NotImplementedError def q(self, s: nn.Variable, a: nn.Variable) -> nn.Variable: q_values = self.all_q(s) - q_value = NF.sum(q_values * NF.one_hot(NF.reshape(a, (-1, 1), inplace=False), (q_values.shape[1],)), - axis=1, - keepdims=True) # get q value of a + q_value = NF.sum( + q_values * NF.one_hot(NF.reshape(a, (-1, 1), inplace=False), (q_values.shape[1],)), axis=1, keepdims=True + ) # get q value of a return q_value @@ -95,12 +97,14 @@ def argmax_q(self, s: nn.Variable) -> nn.Variable: class ContinuousQFunction(QFunction): """Base QFunction Class for continuous action environment.""" + pass class FactoredContinuousQFunction(ContinuousQFunction): """Base FactoredContinuousQFunction Class for continuous action environment.""" + @abstractmethod def factored_q(self, s: nn.Variable, a: nn.Variable) -> nn.Variable: """Compute factored Q-value for given state. diff --git a/nnabla_rl/numpy_model_trainers/distribution_parameters/gmm_parameter_trainer.py b/nnabla_rl/numpy_model_trainers/distribution_parameters/gmm_parameter_trainer.py index e2373720..84af1df1 100644 --- a/nnabla_rl/numpy_model_trainers/distribution_parameters/gmm_parameter_trainer.py +++ b/nnabla_rl/numpy_model_trainers/distribution_parameters/gmm_parameter_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -43,9 +43,7 @@ class GMMParameterTrainer(NumpyModelTrainer): _parameter: GMMParameter _config: GMMParameterTrainerConfig - def __init__(self, - parameter: GMMParameter, - config: GMMParameterTrainerConfig = GMMParameterTrainerConfig()): + def __init__(self, parameter: GMMParameter, config: GMMParameterTrainerConfig = GMMParameterTrainerConfig()): super(GMMParameterTrainer, self).__init__(config) self._parameter = parameter @@ -55,7 +53,7 @@ def update(self, data: np.ndarray) -> None: Args: data (np.ndarray): data, shape(num_data, dim) """ - prev_log_likelihood = -float('inf') + prev_log_likelihood = -float("inf") for _ in range(self._config.num_iterations_per_update): probs, responsibility = self._e_step(data) @@ -90,7 +88,7 @@ def _m_step(self, data: np.ndarray, responsibility: np.ndarray) -> None: def _has_converged(self, new_log_likelihood, prev_log_likelihood): if np.abs(new_log_likelihood - prev_log_likelihood) < self._config.threshold: - logger.debug('GMM converged before reaching max iterations') + logger.debug("GMM converged before reaching max iterations") return True else: return False diff --git a/nnabla_rl/numpy_models/cost_function.py b/nnabla_rl/numpy_models/cost_function.py index 420c7f5b..4f6157e8 100644 --- a/nnabla_rl/numpy_models/cost_function.py +++ b/nnabla_rl/numpy_models/cost_function.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,7 +26,7 @@ def __init__(self) -> None: @abstractmethod def evaluate( - self, x: np.ndarray, u: Optional[np.ndarray], t: int, final_state: bool = False, batched: bool = False + self, x: np.ndarray, u: Optional[np.ndarray], t: int, final_state: bool = False, batched: bool = False ) -> np.ndarray: """Evaluate cost for given state and action. @@ -94,7 +94,7 @@ def hessian( def __add__(self, o): if not isinstance(o, CostFunction): - raise ValueError('Only cost function can be added together') + raise ValueError("Only cost function can be added together") return SumCost([self, o]) diff --git a/nnabla_rl/numpy_models/distribution_parameters/gmm_parameter.py b/nnabla_rl/numpy_models/distribution_parameters/gmm_parameter.py index 91b312cd..496be02a 100644 --- a/nnabla_rl/numpy_models/distribution_parameters/gmm_parameter.py +++ b/nnabla_rl/numpy_models/distribution_parameters/gmm_parameter.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,19 +24,17 @@ class GMMParameter(DistributionParameter): _covarinces: np.ndarray _mixing_coefficients: np.ndarray - def __init__(self, means: np.ndarray, - covariances: np.ndarray, - mixing_coefficients: np.ndarray) -> None: + def __init__(self, means: np.ndarray, covariances: np.ndarray, mixing_coefficients: np.ndarray) -> None: super().__init__() self._num_classes, self._dim = means.shape assert (self._num_classes, self._dim, self._dim) == covariances.shape - assert (self._num_classes, ) == mixing_coefficients.shape + assert (self._num_classes,) == mixing_coefficients.shape self._means = means self._covariances = covariances self._mixing_coefficients = mixing_coefficients @staticmethod - def from_data(data: np.ndarray, num_classes: int) -> 'GMMParameter': + def from_data(data: np.ndarray, num_classes: int) -> "GMMParameter": """Create GMM from data by random class assignnment. Args: @@ -66,10 +64,9 @@ def from_data(data: np.ndarray, num_classes: int) -> 'GMMParameter': return GMMParameter(means, covariances, mixing_coefficients) - def update_parameter(self, # type: ignore - new_means: np.ndarray, - new_covariances: np.ndarray, - new_mixing_coefficients: np.ndarray) -> None: + def update_parameter( # type: ignore + self, new_means: np.ndarray, new_covariances: np.ndarray, new_mixing_coefficients: np.ndarray + ) -> None: self._means = new_means self._covariances = new_covariances self._mixing_coefficients = new_mixing_coefficients diff --git a/nnabla_rl/numpy_models/dynamics.py b/nnabla_rl/numpy_models/dynamics.py index e214e937..6108b12e 100644 --- a/nnabla_rl/numpy_models/dynamics.py +++ b/nnabla_rl/numpy_models/dynamics.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,8 +33,9 @@ def action_dim(self) -> int: raise NotImplementedError @abstractmethod - def next_state(self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False) \ - -> Tuple[np.ndarray, Dict[str, Any]]: + def next_state( + self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False + ) -> Tuple[np.ndarray, Dict[str, Any]]: """Predict next state. if the dynamics is probabilistic, will return the mean of the next_state. @@ -56,8 +57,7 @@ def next_state(self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False """ raise NotImplementedError - def gradient(self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False) \ - -> Tuple[np.ndarray, np.ndarray]: + def gradient(self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False) -> Tuple[np.ndarray, np.ndarray]: """Gradient of the dynamics with respect to the state and action. .. math:: @@ -82,8 +82,9 @@ def gradient(self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False) """ raise NotImplementedError - def hessian(self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False) \ - -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + def hessian( + self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Hessian of the dynamics with respect to the state and action. .. math:: diff --git a/nnabla_rl/parametric_functions.py b/nnabla_rl/parametric_functions.py index d080fe0c..d5d15f7d 100644 --- a/nnabla_rl/parametric_functions.py +++ b/nnabla_rl/parametric_functions.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,22 +27,24 @@ from nnabla_rl.initializers import HeUniform -def noisy_net(inp: nn.Variable, - n_outmap: int, - base_axis: int = 1, - w_init: Optional[Callable[[Tuple[int, ...]], np.ndarray]] = None, - b_init: Optional[Callable[[Tuple[int, ...]], np.ndarray]] = None, - noisy_w_init: Optional[Callable[[Tuple[int, ...]], np.ndarray]] = None, - noisy_b_init: Optional[Callable[[Tuple[int, ...]], np.ndarray]] = None, - fix_parameters: bool = False, - rng: Optional[np.random.RandomState] = None, - with_bias: bool = True, - with_noisy_bias: bool = True, - apply_w: Optional[Callable[[nn.Variable], nn.Variable]] = None, - apply_b: Optional[Callable[[nn.Variable], nn.Variable]] = None, - apply_noisy_w: Optional[Callable[[nn.Variable], nn.Variable]] = None, - apply_noisy_b: Optional[Callable[[nn.Variable], nn.Variable]] = None, - seed: int = -1) -> nn.Variable: +def noisy_net( + inp: nn.Variable, + n_outmap: int, + base_axis: int = 1, + w_init: Optional[Callable[[Tuple[int, ...]], np.ndarray]] = None, + b_init: Optional[Callable[[Tuple[int, ...]], np.ndarray]] = None, + noisy_w_init: Optional[Callable[[Tuple[int, ...]], np.ndarray]] = None, + noisy_b_init: Optional[Callable[[Tuple[int, ...]], np.ndarray]] = None, + fix_parameters: bool = False, + rng: Optional[np.random.RandomState] = None, + with_bias: bool = True, + with_noisy_bias: bool = True, + apply_w: Optional[Callable[[nn.Variable], nn.Variable]] = None, + apply_b: Optional[Callable[[nn.Variable], nn.Variable]] = None, + apply_noisy_w: Optional[Callable[[nn.Variable], nn.Variable]] = None, + apply_noisy_b: Optional[Callable[[nn.Variable], nn.Variable]] = None, + seed: int = -1, +) -> nn.Variable: """Noisy linear layer with factorized gaussian noise proposed by Fortunato et al. in the paper "Noisy networks for exploration". See: https://arxiv.org/abs/1706.10295 for details. @@ -83,7 +85,7 @@ def noisy_net(inp: nn.Variable, inmaps = int(np.prod(inp.shape[base_axis:])) if w_init is None: - w_init = HeUniform(inmaps, n_outmap, factor=1.0/3.0, rng=rng) + w_init = HeUniform(inmaps, n_outmap, factor=1.0 / 3.0, rng=rng) if noisy_w_init is None: noisy_w_init = ConstantInitializer(0.5 / np.sqrt(inmaps)) w = get_parameter_or_create("W", (inmaps, n_outmap), w_init, True, not fix_parameters) @@ -97,8 +99,8 @@ def noisy_net(inp: nn.Variable, b = None if with_bias: if b_init is None: - b_init = HeUniform(inmaps, n_outmap, factor=1.0/3.0, rng=rng) - b = get_parameter_or_create("b", (n_outmap, ), b_init, True, not fix_parameters) + b_init = HeUniform(inmaps, n_outmap, factor=1.0 / 3.0, rng=rng) + b = get_parameter_or_create("b", (n_outmap,), b_init, True, not fix_parameters) if apply_b is not None: b = apply_b(b) @@ -106,7 +108,7 @@ def noisy_net(inp: nn.Variable, if with_noisy_bias: if noisy_b_init is None: noisy_b_init = ConstantInitializer(0.5 / np.sqrt(inmaps)) - noisy_b = get_parameter_or_create("noisy_b", (n_outmap, ), noisy_b_init, True, not fix_parameters) + noisy_b = get_parameter_or_create("noisy_b", (n_outmap,), noisy_b_init, True, not fix_parameters) if apply_noisy_b is not None: noisy_b = apply_noisy_b(noisy_b) @@ -142,7 +144,7 @@ def _f(x): return NF.affine(inp, weight, bias, base_axis) -def spatial_softmax(inp: nn.Variable, alpha_init: float = 1., fix_alpha: bool = False) -> nn.Variable: +def spatial_softmax(inp: nn.Variable, alpha_init: float = 1.0, fix_alpha: bool = False) -> nn.Variable: r"""Spatial softmax layer proposed in https://arxiv.org/abs/1509.06113. Computes. @@ -168,33 +170,36 @@ def spatial_softmax(inp: nn.Variable, alpha_init: float = 1., fix_alpha: bool = """ assert len(inp.shape) == 4 (batch_size, channel, height, width) = inp.shape - alpha = get_parameter_or_create("alpha", shape=(1, 1), initializer=ConstantInitializer(alpha_init), - need_grad=True, as_need_grad=not fix_alpha) + alpha = get_parameter_or_create( + "alpha", shape=(1, 1), initializer=ConstantInitializer(alpha_init), need_grad=True, as_need_grad=not fix_alpha + ) - features = NF.reshape(inp, (-1, height*width)) + features = NF.reshape(inp, (-1, height * width)) softmax_attention = NF.softmax(features / alpha) # Image positions are normalized and defined by -1 to 1. # This normalization is referring to the original Guided Policy Search implementation. # See: https://github.com/cbfinn/gps/blob/master/python/gps/algorithm/policy_opt/tf_model_example.py#L238 - pos_x, pos_y = np.meshgrid(np.linspace(-1., 1., height), np.linspace(-1., 1., width)) - pos_x = nn.Variable.from_numpy_array(pos_x.reshape(-1, (height*width))) - pos_y = nn.Variable.from_numpy_array(pos_y.reshape(-1, (height*width))) + pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, height), np.linspace(-1.0, 1.0, width)) + pos_x = nn.Variable.from_numpy_array(pos_x.reshape(-1, (height * width))) + pos_y = nn.Variable.from_numpy_array(pos_y.reshape(-1, (height * width))) - expected_x = NF.sum(pos_x*softmax_attention, axis=1, keepdims=True) - expected_y = NF.sum(pos_y*softmax_attention, axis=1, keepdims=True) + expected_x = NF.sum(pos_x * softmax_attention, axis=1, keepdims=True) + expected_y = NF.sum(pos_y * softmax_attention, axis=1, keepdims=True) expected_xy = NF.concatenate(expected_x, expected_y, axis=1) - feature_points = NF.reshape(expected_xy, (batch_size, channel*2)) + feature_points = NF.reshape(expected_xy, (batch_size, channel * 2)) return feature_points -@parametric_function_api("lstm", [ - ('affine/W', 'Stacked weight matrixes of LSTM block', - '(inmaps, 4, state_size)', True), - ('affine/b', 'Stacked bias vectors of LSTM block', '(4, state_size,)', True), -]) +@parametric_function_api( + "lstm", + [ + ("affine/W", "Stacked weight matrixes of LSTM block", "(inmaps, 4, state_size)", True), + ("affine/b", "Stacked bias vectors of LSTM block", "(4, state_size,)", True), + ], +) def lstm_cell(x, h, c, state_size, w_init=None, b_init=None, fix_parameters=False, base_axis=1): """Long Short-Term Memory with base_axis. @@ -230,22 +235,27 @@ def lstm_cell(x, h, c, state_size, w_init=None, b_init=None, fix_parameters=Fals :class:`~nnabla.Variable` """ xh = NF.concatenate(*(x, h), axis=base_axis) - iofc = NPF.affine(xh, (4, state_size), base_axis=base_axis, w_init=w_init, - b_init=b_init, fix_parameters=fix_parameters) + iofc = NPF.affine( + xh, (4, state_size), base_axis=base_axis, w_init=w_init, b_init=b_init, fix_parameters=fix_parameters + ) i_t, o_t, f_t, gate = NF.split(iofc, axis=base_axis) c_t = NF.sigmoid(f_t) * c + NF.sigmoid(i_t) * NF.tanh(gate) h_t = NF.sigmoid(o_t) * NF.tanh(c_t) return h_t, c_t -def causal_self_attention(x: nn.Variable, embed_dim: int, num_heads: int, - mask: Optional[nn.Variable] = None, - attention_dropout: Optional[float] = None, - output_dropout: Optional[float] = None, - w_init_key: Optional[Callable[[Any], Any]] = NI.NormalInitializer(0.02), - w_init_query: Optional[Callable[[Any], Any]] = NI.NormalInitializer(0.02), - w_init_value: Optional[Callable[[Any], Any]] = NI.NormalInitializer(0.02), - w_init_proj: Optional[Callable[[Any], Any]] = NI.NormalInitializer(0.02)) -> nn.Variable: +def causal_self_attention( + x: nn.Variable, + embed_dim: int, + num_heads: int, + mask: Optional[nn.Variable] = None, + attention_dropout: Optional[float] = None, + output_dropout: Optional[float] = None, + w_init_key: Optional[Callable[[Any], Any]] = NI.NormalInitializer(0.02), + w_init_query: Optional[Callable[[Any], Any]] = NI.NormalInitializer(0.02), + w_init_value: Optional[Callable[[Any], Any]] = NI.NormalInitializer(0.02), + w_init_proj: Optional[Callable[[Any], Any]] = NI.NormalInitializer(0.02), +) -> nn.Variable: """Causal self attention used in https://arxiv.org/abs/2106.01345. Args: @@ -265,15 +275,15 @@ def causal_self_attention(x: nn.Variable, embed_dim: int, num_heads: int, nn.Variables: Encoded vector """ batch_size, timesteps, _ = x.shape - with nn.parameter_scope('key'): + with nn.parameter_scope("key"): k = NPF.affine(x, n_outmaps=embed_dim, base_axis=2, w_init=w_init_key) k = NF.reshape(k, shape=(batch_size, timesteps, num_heads, embed_dim // num_heads)) k = NF.transpose(k, axes=(0, 2, 1, 3)) - with nn.parameter_scope('query'): + with nn.parameter_scope("query"): q = NPF.affine(x, n_outmaps=embed_dim, base_axis=2, w_init=w_init_query) q = NF.reshape(q, shape=(batch_size, timesteps, num_heads, embed_dim // num_heads)) q = NF.transpose(q, axes=(0, 2, 1, 3)) - with nn.parameter_scope('value'): + with nn.parameter_scope("value"): v = NPF.affine(x, n_outmaps=embed_dim, base_axis=2, w_init=w_init_value) v = NF.reshape(v, shape=(batch_size, timesteps, num_heads, embed_dim // num_heads)) v = NF.transpose(v, axes=(0, 2, 1, 3)) @@ -296,7 +306,7 @@ def causal_self_attention(x: nn.Variable, embed_dim: int, num_heads: int, output = NF.reshape(output, shape=(batch_size, timesteps, -1)) assert output.shape == (batch_size, timesteps, embed_dim) - with nn.parameter_scope('proj'): + with nn.parameter_scope("proj"): output = NPF.affine(output, n_outmaps=embed_dim, base_axis=2, w_init=w_init_proj) if output_dropout is not None: output = NF.dropout(output, p=output_dropout) diff --git a/nnabla_rl/preprocessors/her_preprocessor.py b/nnabla_rl/preprocessors/her_preprocessor.py index 86bd4d82..38d3616f 100644 --- a/nnabla_rl/preprocessors/her_preprocessor.py +++ b/nnabla_rl/preprocessors/her_preprocessor.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,7 +26,7 @@ def __init__(self, scope_name, shape, epsilon=1e-8, value_clip=None): def process(self, x): assert 0 < self._count.d - std = NF.maximum2(self._var ** 0.5, self._fixed_epsilon) + std = NF.maximum2(self._var**0.5, self._fixed_epsilon) normalized = (x - self._mean) / std if self._value_clip is not None: normalized = NF.clip_by_value(normalized, min=self._value_clip[0], max=self._value_clip[1]) @@ -35,7 +35,7 @@ def process(self, x): @property def _fixed_epsilon(self): - if not hasattr(self, '_epsilon_var'): + if not hasattr(self, "_epsilon_var"): self._epsilon_var = create_variable(batch_size=1, shape=self._shape) self._epsilon_var.d = self._epsilon return self._epsilon_var @@ -46,14 +46,12 @@ def __init__(self, scope_name, shape, epsilon=1e-8, value_clip=None): super(HERPreprocessor, self).__init__(scope_name) observation_shape, goal_shape, _ = shape - self._observation_preprocessor = HERMeanNormalizer(scope_name=f'{scope_name}/observation', - shape=observation_shape, - epsilon=epsilon, - value_clip=value_clip) - self._goal_preprocessor = HERMeanNormalizer(scope_name=f'{scope_name}/goal', - shape=goal_shape, - epsilon=epsilon, - value_clip=value_clip) + self._observation_preprocessor = HERMeanNormalizer( + scope_name=f"{scope_name}/observation", shape=observation_shape, epsilon=epsilon, value_clip=value_clip + ) + self._goal_preprocessor = HERMeanNormalizer( + scope_name=f"{scope_name}/goal", shape=goal_shape, epsilon=epsilon, value_clip=value_clip + ) def process(self, x): observation, goal, achived_goal = x diff --git a/nnabla_rl/preprocessors/running_mean_normalizer.py b/nnabla_rl/preprocessors/running_mean_normalizer.py index e1026494..93bace9d 100644 --- a/nnabla_rl/preprocessors/running_mean_normalizer.py +++ b/nnabla_rl/preprocessors/running_mean_normalizer.py @@ -53,19 +53,20 @@ class RunningMeanNormalizer(Preprocessor, Model): The computation of a running variance is started from this value. Defaults to NI.ConstantInitializer(1.0). """ - def __init__(self, - scope_name: str, - shape: Shape, - epsilon: float = 1e-8, - value_clip: Optional[Tuple[float, float]] = None, - mode_for_floating_point_error: str = "add", - mean_initializer: Union[NI.BaseInitializer, np.ndarray] = NI.ConstantInitializer(0.0), - var_initializer: Union[NI.BaseInitializer, np.ndarray] = NI.ConstantInitializer(1.0)): + def __init__( + self, + scope_name: str, + shape: Shape, + epsilon: float = 1e-8, + value_clip: Optional[Tuple[float, float]] = None, + mode_for_floating_point_error: str = "add", + mean_initializer: Union[NI.BaseInitializer, np.ndarray] = NI.ConstantInitializer(0.0), + var_initializer: Union[NI.BaseInitializer, np.ndarray] = NI.ConstantInitializer(1.0), + ): super(RunningMeanNormalizer, self).__init__(scope_name) if value_clip is not None and value_clip[0] > value_clip[1]: - raise ValueError( - f"Unexpected clipping value range: {value_clip[0]} > {value_clip[1]}") + raise ValueError(f"Unexpected clipping value range: {value_clip[0]} > {value_clip[1]}") self._value_clip = value_clip if isinstance(shape, int): @@ -117,17 +118,20 @@ def update(self, data): @property def _mean(self): with nn.parameter_scope(self.scope_name): - return nn.parameter.get_parameter_or_create(name='mean', shape=(1, *self._shape), - initializer=self._mean_initializer) + return nn.parameter.get_parameter_or_create( + name="mean", shape=(1, *self._shape), initializer=self._mean_initializer + ) @property def _var(self): with nn.parameter_scope(self.scope_name): - return nn.parameter.get_parameter_or_create(name='var', shape=(1, *self._shape), - initializer=self._var_initializer) + return nn.parameter.get_parameter_or_create( + name="var", shape=(1, *self._shape), initializer=self._var_initializer + ) @property def _count(self): with nn.parameter_scope(self.scope_name): - return nn.parameter.get_parameter_or_create(name='count', shape=(1, 1), - initializer=NI.ConstantInitializer(1e-4)) + return nn.parameter.get_parameter_or_create( + name="count", shape=(1, 1), initializer=NI.ConstantInitializer(1e-4) + ) diff --git a/nnabla_rl/replay_buffer.py b/nnabla_rl/replay_buffer.py index 14e69c55..f2e28772 100644 --- a/nnabla_rl/replay_buffer.py +++ b/nnabla_rl/replay_buffer.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -73,8 +73,9 @@ def append_all(self, experiences: Sequence[Experience]): for experience in experiences: self.append(experience) - def sample(self, num_samples: int = 1, num_steps: int = 1) \ - -> Tuple[Union[Sequence[Experience], Tuple[Sequence[Experience], ...]], Dict[str, Any]]: + def sample( + self, num_samples: int = 1, num_steps: int = 1 + ) -> Tuple[Union[Sequence[Experience], Tuple[Sequence[Experience], ...]], Dict[str, Any]]: """Randomly sample num_samples experiences from the replay buffer. Args: @@ -96,12 +97,13 @@ def sample(self, num_samples: int = 1, num_steps: int = 1) \ """ max_index = len(self) - num_steps + 1 if num_samples > max_index: - raise ValueError(f'num_samples: {num_samples} is greater than the size of buffer: {max_index}') + raise ValueError(f"num_samples: {num_samples} is greater than the size of buffer: {max_index}") indices = self._random_indices(num_samples=num_samples, max_index=max_index) return self.sample_indices(indices, num_steps=num_steps) - def sample_indices(self, indices: Sequence[int], num_steps: int = 1) \ - -> Tuple[Union[Sequence[Experience], Tuple[Sequence[Experience], ...]], Dict[str, Any]]: + def sample_indices( + self, indices: Sequence[int], num_steps: int = 1 + ) -> Tuple[Union[Sequence[Experience], Tuple[Sequence[Experience], ...]], Dict[str, Any]]: """Sample experiences for given indices from the replay buffer. Args: @@ -118,14 +120,14 @@ def sample_indices(self, indices: Sequence[int], num_steps: int = 1) \ ValueError: If indices are empty or num_steps is 0 or negative. """ if len(indices) == 0: - raise ValueError('Indices are empty') + raise ValueError("Indices are empty") if num_steps < 1: - raise ValueError(f'num_steps: {num_steps} should be greater than 0!') + raise ValueError(f"num_steps: {num_steps} should be greater than 0!") experiences: Union[Sequence[Experience], Tuple[Sequence[Experience], ...]] if num_steps == 1: experiences = [self.__getitem__(index) for index in indices] else: - experiences = tuple([self.__getitem__(index+i) for index in indices] for i in range(num_steps)) + experiences = tuple([self.__getitem__(index + i) for index in indices] for i in range(num_steps)) weights = np.ones([len(indices), 1]) return experiences, dict(weights=weights) diff --git a/nnabla_rl/replay_buffers/__init__.py b/nnabla_rl/replay_buffers/__init__.py index a8a6bc31..76c5d309 100644 --- a/nnabla_rl/replay_buffers/__init__.py +++ b/nnabla_rl/replay_buffers/__init__.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,13 +15,17 @@ from nnabla_rl.replay_buffers.buffer_iterator import BufferIterator # noqa from nnabla_rl.replay_buffers.hindsight_replay_buffer import HindsightReplayBuffer # noqa -from nnabla_rl.replay_buffers.memory_efficient_atari_buffer import (MemoryEfficientAtariBuffer, # noqa - MemoryEfficientAtariTrajectoryBuffer, - ProportionalPrioritizedAtariBuffer, - RankBasedPrioritizedAtariBuffer) +from nnabla_rl.replay_buffers.memory_efficient_atari_buffer import ( # noqa + MemoryEfficientAtariBuffer, + MemoryEfficientAtariTrajectoryBuffer, + ProportionalPrioritizedAtariBuffer, + RankBasedPrioritizedAtariBuffer, +) from nnabla_rl.replay_buffers.decorable_replay_buffer import DecorableReplayBuffer # noqa from nnabla_rl.replay_buffers.replacement_sampling_replay_buffer import ReplacementSamplingReplayBuffer # noqa -from nnabla_rl.replay_buffers.prioritized_replay_buffer import (PrioritizedReplayBuffer, # noqa - ProportionalPrioritizedReplayBuffer, - RankBasedPrioritizedReplayBuffer) +from nnabla_rl.replay_buffers.prioritized_replay_buffer import ( # noqa + PrioritizedReplayBuffer, + ProportionalPrioritizedReplayBuffer, + RankBasedPrioritizedReplayBuffer, +) from nnabla_rl.replay_buffers.trajectory_replay_buffer import TrajectoryReplayBuffer # noqa diff --git a/nnabla_rl/replay_buffers/buffer_iterator.py b/nnabla_rl/replay_buffers/buffer_iterator.py index d9257037..1c669b28 100644 --- a/nnabla_rl/replay_buffers/buffer_iterator.py +++ b/nnabla_rl/replay_buffers/buffer_iterator.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -41,21 +41,19 @@ def next(self): if self.is_new_epoch(): self._new_epoch = False raise StopIteration - indices = \ - self._indices[self._index:self._index + self._batch_size] - if (len(indices) < self._batch_size): + indices = self._indices[self._index : self._index + self._batch_size] + if len(indices) < self._batch_size: if self._repeat: rest = self._batch_size - len(indices) self.reset() - indices = np.append( - indices, self._indices[self._index:self._index + rest]) + indices = np.append(indices, self._indices[self._index : self._index + rest]) self._index += rest else: self._index = len(self._replay_buffer) self._new_epoch = True else: self._index += self._batch_size - self._new_epoch = (len(self._replay_buffer) <= self._index) + self._new_epoch = len(self._replay_buffer) <= self._index return self._sample(indices) __next__ = next diff --git a/nnabla_rl/replay_buffers/hindsight_replay_buffer.py b/nnabla_rl/replay_buffers/hindsight_replay_buffer.py index 7bf50c32..3a2bf978 100644 --- a/nnabla_rl/replay_buffers/hindsight_replay_buffer.py +++ b/nnabla_rl/replay_buffers/hindsight_replay_buffer.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,10 +22,12 @@ class HindsightReplayBuffer(ReplayBuffer): - def __init__(self, - reward_function: Callable[[np.ndarray, np.ndarray, Dict[str, Any]], Any], - hindsight_prob: float = 0.8, - capacity: Optional[int] = None): + def __init__( + self, + reward_function: Callable[[np.ndarray, np.ndarray, Dict[str, Any]], Any], + hindsight_prob: float = 0.8, + capacity: Optional[int] = None, + ): super(HindsightReplayBuffer, self).__init__(capacity=capacity) self._reward_function = reward_function self._hindsight_prob = hindsight_prob @@ -38,17 +40,16 @@ def __init__(self, def append(self, experience: Experience): # experience = (s, a, r, non_terminal, s_next, info) if not isinstance(experience[0], tuple): - raise RuntimeError('Hindsight replay only supports tuple observation environment') + raise RuntimeError("Hindsight replay only supports tuple observation environment") if not len(experience[0]) == 3: - raise RuntimeError('Observation is not a tuple of 3 elements: (observation, desired_goal, achieved_goal)') + raise RuntimeError("Observation is not a tuple of 3 elements: (observation, desired_goal, achieved_goal)") # Here, info will be updated. if not isinstance(experience[5], dict): raise ValueError non_terminal = experience[3] done = non_terminal == 0 self._episode_end_index[0] = self._index_in_episode # end index is shared among episode - update_info = {'index_in_episode': self._index_in_episode, - 'episode_end_index': self._episode_end_index} + update_info = {"index_in_episode": self._index_in_episode, "episode_end_index": self._episode_end_index} experience[5].update(update_info) super().append(experience) @@ -65,7 +66,7 @@ def sample_indices(self, indices: Sequence[int], num_steps: int = 1) -> Tuple[Se raise NotImplementedError if len(indices) == 0: - raise ValueError('Indices are empty') + raise ValueError("Indices are empty") weights = np.ones([len(indices), 1]) return [self._sample_experience(index) for index in indices], dict(weights=weights) @@ -80,8 +81,8 @@ def _make_hindsight_experience(self, index: int) -> Experience: # state = (observation, desired_goal, achieved_goal) experience = self.__getitem__(index) experience_info = experience[5] - index_in_episode = experience_info['index_in_episode'] - episode_end_index = int(experience_info['episode_end_index']) # NOTE: episode_end_index is saved as np.ndarray + index_in_episode = experience_info["index_in_episode"] + episode_end_index = int(experience_info["episode_end_index"]) # NOTE: episode_end_index is saved as np.ndarray distance_to_end = episode_end_index - index_in_episode # sample index for hindsight goal @@ -93,7 +94,7 @@ def _make_hindsight_experience(self, index: int) -> Experience: new_experience = self._replace_goal(experience, future_experience) # save for test - new_experience[-1].update({'future_index': future_index}) + new_experience[-1].update({"future_index": future_index}) return new_experience diff --git a/nnabla_rl/replay_buffers/memory_efficient_atari_buffer.py b/nnabla_rl/replay_buffers/memory_efficient_atari_buffer.py index 6b3bfff4..d6a0ebf2 100644 --- a/nnabla_rl/replay_buffers/memory_efficient_atari_buffer.py +++ b/nnabla_rl/replay_buffers/memory_efficient_atari_buffer.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,8 +19,10 @@ import numpy as np from nnabla_rl.replay_buffer import ReplayBuffer -from nnabla_rl.replay_buffers.prioritized_replay_buffer import (ProportionalPrioritizedReplayBuffer, - RankBasedPrioritizedReplayBuffer) +from nnabla_rl.replay_buffers.prioritized_replay_buffer import ( + ProportionalPrioritizedReplayBuffer, + RankBasedPrioritizedReplayBuffer, +) from nnabla_rl.replay_buffers.trajectory_replay_buffer import TrajectoryReplayBuffer from nnabla_rl.typing import Trajectory from nnabla_rl.utils.data import RingBuffer @@ -42,6 +44,7 @@ class MemoryEfficientAtariBuffer(ReplayBuffer): consists of "stacked_frames" number of concatenated grayscaled frames and its values are normalized between 0 and 1) """ + # type declarations to type check with mypy # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar # See https://mypy.readthedocs.io/en/stable/class_basics.html for details @@ -52,7 +55,7 @@ def __init__(self, capacity: int, stacked_frames: int = 4): super(MemoryEfficientAtariBuffer, self).__init__(capacity=capacity) self._reset = True self._buffer = RingBuffer(maxlen=capacity) - self._sub_buffer = deque(maxlen=stacked_frames-1) + self._sub_buffer = deque(maxlen=stacked_frames - 1) self._stacked_frames = stacked_frames def append(self, experience): @@ -75,7 +78,7 @@ def __getitem__(self, key): elif isinstance(key, int): return self._buffer[key] else: - raise TypeError('Invalid key type') + raise TypeError("Invalid key type") class MemoryEfficientAtariTrajectoryBuffer(TrajectoryReplayBuffer): @@ -125,29 +128,34 @@ class ProportionalPrioritizedAtariBuffer(ProportionalPrioritizedReplayBuffer): concatenated grayscaled frames and its values are normalized between 0 and 1) """ + # type declarations to type check with mypy # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar # See https://mypy.readthedocs.io/en/stable/class_basics.html for details _sub_buffer: deque - def __init__(self, - capacity: int, - alpha: float = 0.6, - beta: float = 0.4, - betasteps: int = 50000000, - error_clip: Optional[Tuple[float, float]] = (-1, 1), - epsilon: float = 1e-8, - normalization_method: str = "buffer_max", - stacked_frames: int = 4): - super(ProportionalPrioritizedAtariBuffer, self).__init__(capacity=capacity, - alpha=alpha, - beta=beta, - betasteps=betasteps, - error_clip=error_clip, - epsilon=epsilon, - normalization_method=normalization_method) + def __init__( + self, + capacity: int, + alpha: float = 0.6, + beta: float = 0.4, + betasteps: int = 50000000, + error_clip: Optional[Tuple[float, float]] = (-1, 1), + epsilon: float = 1e-8, + normalization_method: str = "buffer_max", + stacked_frames: int = 4, + ): + super(ProportionalPrioritizedAtariBuffer, self).__init__( + capacity=capacity, + alpha=alpha, + beta=beta, + betasteps=betasteps, + error_clip=error_clip, + epsilon=epsilon, + normalization_method=normalization_method, + ) self._reset = True - self._sub_buffer = deque(maxlen=stacked_frames-1) + self._sub_buffer = deque(maxlen=stacked_frames - 1) self._stacked_frames = stacked_frames def append(self, experience): @@ -167,29 +175,34 @@ class RankBasedPrioritizedAtariBuffer(RankBasedPrioritizedReplayBuffer): concatenated grayscaled frames and its values are normalized between 0 and 1) """ + # type declarations to type check with mypy # NOTE: declared variables are instance variable and NOT class variable, unless it is marked with ClassVar # See https://mypy.readthedocs.io/en/stable/class_basics.html for details _sub_buffer: deque - def __init__(self, - capacity: int, - alpha: float = 0.7, - beta: float = 0.5, - betasteps: int = 50000000, - error_clip: Optional[Tuple[float, float]] = (-1, 1), - reset_segment_interval: int = 1000, - sort_interval: int = 1000000, - stacked_frames: int = 4): - super(RankBasedPrioritizedAtariBuffer, self).__init__(capacity=capacity, - alpha=alpha, - beta=beta, - betasteps=betasteps, - error_clip=error_clip, - reset_segment_interval=reset_segment_interval, - sort_interval=sort_interval) + def __init__( + self, + capacity: int, + alpha: float = 0.7, + beta: float = 0.5, + betasteps: int = 50000000, + error_clip: Optional[Tuple[float, float]] = (-1, 1), + reset_segment_interval: int = 1000, + sort_interval: int = 1000000, + stacked_frames: int = 4, + ): + super(RankBasedPrioritizedAtariBuffer, self).__init__( + capacity=capacity, + alpha=alpha, + beta=beta, + betasteps=betasteps, + error_clip=error_clip, + reset_segment_interval=reset_segment_interval, + sort_interval=sort_interval, + ) self._reset = True - self._sub_buffer = deque(maxlen=stacked_frames-1) + self._sub_buffer = deque(maxlen=stacked_frames - 1) self._stacked_frames = stacked_frames def append(self, experience): @@ -228,7 +241,7 @@ def _append_to_buffer(experience, buffer, sub_buffer, reset_flag): removed = buffer.append_with_removed_item_check(experience) if removed is not None: sub_buffer.append(removed) - return (0 == non_terminal) + return 0 == non_terminal def _getitem_from_buffer(index, buffer, sub_buffer, stacked_frames): @@ -241,12 +254,12 @@ def _getitem_from_buffer(index, buffer, sub_buffer, stacked_frames): else: (s, _, _, _, _, _, reset) = sub_buffer[buffer_index] assert s.shape == (84, 84) - tail_index = stacked_frames-i + tail_index = stacked_frames - i if reset: states[0:tail_index] = s break else: - states[tail_index-1] = s + states[tail_index - 1] = s s = _normalize_state(states) assert s.shape == (stacked_frames, 84, 84) diff --git a/nnabla_rl/replay_buffers/prioritized_replay_buffer.py b/nnabla_rl/replay_buffers/prioritized_replay_buffer.py index 9d8eed03..341d8189 100644 --- a/nnabla_rl/replay_buffers/prioritized_replay_buffer.py +++ b/nnabla_rl/replay_buffers/prioritized_replay_buffer.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,7 +26,7 @@ from nnabla_rl.typing import Experience from nnabla_rl.utils.data import DataHolder, RingBuffer -T = TypeVar('T') +T = TypeVar("T") # NOTE: index naming convention used in this module @@ -35,6 +35,7 @@ # tree index: 0: root of the tree. 2 * capacity - 1: right most leaf of the tree. # heap index: 0: head of the heap. If max heap, maximum value is saved in this index. capacity - 1: tail of the heap. + @dataclass class Node(Generic[T]): value: T @@ -62,7 +63,7 @@ def __init__(self, capacity: int, init_node_value: T): self._init_node_value = init_node_value self._tail_index = 0 self._length = 0 - self._tree = [self._make_init_node(i) for i in range(2*capacity-1)] + self._tree = [self._make_init_node(i) for i in range(2 * capacity - 1)] def __len__(self): return self._length @@ -343,7 +344,7 @@ def append_with_removed_item_check(self, data): def get_priority(self, relative_index: int): absolute_index = self._relative_to_absolute_index(relative_index) heap_index = self._max_heap.absolute_to_heap_index(absolute_index) - rank = (heap_index + 1) + rank = heap_index + 1 return self._compute_priority(rank) def get_relative_index_from_heap_index(self, heap_index: int): @@ -366,12 +367,9 @@ def _compute_priority(self, rank: int): class _PrioritizedReplayBuffer(ReplayBuffer): - def __init__(self, - capacity: int, - alpha: float, - beta: float, - betasteps: int, - error_clip: Optional[Tuple[float, float]]): + def __init__( + self, capacity: int, alpha: float, beta: float, betasteps: int, error_clip: Optional[Tuple[float, float]] + ): # Do not call super class' constructor self._capacity_check(capacity) self._capacity = capacity @@ -397,15 +395,17 @@ def sample(self, num_samples: int = 1, num_steps: int = 1): def sample_indices(self, indices: Sequence[int], num_steps: int = 1): if len(indices) == 0: - raise ValueError('Indices are empty') + raise ValueError("Indices are empty") if self._last_sampled_indices is not None: - raise RuntimeError('Trying to sample data from buffer without updating priority. ' - 'Check that the algorithm supports prioritized replay buffer.') + raise RuntimeError( + "Trying to sample data from buffer without updating priority. " + "Check that the algorithm supports prioritized replay buffer." + ) experiences: Union[Sequence[Experience], Tuple[Sequence[Experience], ...]] if num_steps == 1: experiences = [self.__getitem__(index) for index in indices] else: - experiences = tuple([self.__getitem__(index+i) for index in indices] for i in range(num_steps)) + experiences = tuple([self.__getitem__(index + i) for index in indices] for i in range(num_steps)) weights = self._get_weights(indices, self._alpha, self._beta) info = dict(weights=weights) @@ -427,7 +427,7 @@ def _get_weights(self, indices: Sequence[int], alpha: float, beta: float): def _capacity_check(self, capacity: int): if capacity is None or capacity <= 0: - error_msg = 'buffer size must be greater than 0' + error_msg = "buffer size must be greater than 0" raise ValueError(error_msg) @@ -438,18 +438,21 @@ class ProportionalPrioritizedReplayBuffer(_PrioritizedReplayBuffer): _buffer: SumTreeDataHolder _epsilon: float - def __init__(self, capacity: int, - alpha: float = 0.6, - beta: float = 0.4, - betasteps: int = 10000, - error_clip: Optional[Tuple[float, float]] = (-1, 1), - epsilon: float = 1e-8, - init_max_error: float = 1.0, - normalization_method: str = "buffer_max"): + def __init__( + self, + capacity: int, + alpha: float = 0.6, + beta: float = 0.4, + betasteps: int = 10000, + error_clip: Optional[Tuple[float, float]] = (-1, 1), + epsilon: float = 1e-8, + init_max_error: float = 1.0, + normalization_method: str = "buffer_max", + ): super(ProportionalPrioritizedReplayBuffer, self).__init__(capacity, alpha, beta, betasteps, error_clip) assert normalization_method in ("batch_max", "buffer_max") self._normalization_method = normalization_method - keep_min = (self._normalization_method == "buffer_max") + keep_min = self._normalization_method == "buffer_max" self._buffer = SumTreeDataHolder(capacity=capacity, initial_max_priority=init_max_error, keep_min=keep_min) self._epsilon = epsilon @@ -459,10 +462,10 @@ def append(self, experience): def sample(self, num_samples: int = 1, num_steps: int = 1): buffer_length = len(self) if num_samples > buffer_length: - error_msg = f'num_samples: {num_samples} is greater than the size of buffer: {buffer_length}' + error_msg = f"num_samples: {num_samples} is greater than the size of buffer: {buffer_length}" raise ValueError(error_msg) if buffer_length - num_steps < 0: - raise RuntimeError(f'Insufficient buffer length. buffer: {buffer_length} < steps: {num_steps}') + raise RuntimeError(f"Insufficient buffer length. buffer: {buffer_length} < steps: {num_steps}") # In paper, # "To sample a minibatch of size k, the range [0, ptotal] is divided equally into k ranges. @@ -510,13 +513,16 @@ class RankBasedPrioritizedReplayBuffer(_PrioritizedReplayBuffer): _prev_num_steps: int _appends_since_prev_start: int - def __init__(self, capacity: int, - alpha: float = 0.7, - beta: float = 0.5, - betasteps: int = 10000, - error_clip: Optional[Tuple[float, float]] = (-1, 1), - reset_segment_interval: int = 1000, - sort_interval: int = 1000000): + def __init__( + self, + capacity: int, + alpha: float = 0.7, + beta: float = 0.5, + betasteps: int = 10000, + error_clip: Optional[Tuple[float, float]] = (-1, 1), + reset_segment_interval: int = 1000, + sort_interval: int = 1000000, + ): super(RankBasedPrioritizedReplayBuffer, self).__init__(capacity, alpha, beta, betasteps, error_clip) self._buffer = MaxHeapDataHolder(capacity, alpha) @@ -540,15 +546,16 @@ def append(self, experience): def sample(self, num_samples: int = 1, num_steps: int = 1): buffer_length = len(self) if num_samples > buffer_length: - error_msg = f'num_samples: {num_samples} is greater than the size of buffer: {buffer_length}' + error_msg = f"num_samples: {num_samples} is greater than the size of buffer: {buffer_length}" raise ValueError(error_msg) if buffer_length - num_steps < 0: - raise RuntimeError( - f'Insufficient buffer length. buffer: {buffer_length} < steps: {num_steps}') - if (num_samples != self._prev_num_samples) or \ - (num_steps != self._prev_num_steps) or \ - (buffer_length % self._reset_segment_interval == 0 and buffer_length != self._capacity) or \ - (len(self._boundaries) == 0): + raise RuntimeError(f"Insufficient buffer length. buffer: {buffer_length} < steps: {num_steps}") + if ( + (num_samples != self._prev_num_samples) + or (num_steps != self._prev_num_steps) + or (buffer_length % self._reset_segment_interval == 0 and buffer_length != self._capacity) + or (len(self._boundaries) == 0) + ): self._boundaries = self._compute_segment_boundaries(N=buffer_length, k=num_samples) self._prev_num_samples = num_samples self._prev_num_steps = num_steps @@ -586,7 +593,7 @@ def _compute_segment_boundaries(self, N: int, k: int): if N < k: raise ValueError(f"Batch size {k} is greater than buffer size {N}") boundaries: List[int] = [] - denominator = self._ps_cumsum[N-1] + denominator = self._ps_cumsum[N - 1] for i in range(N): if (len(boundaries) + 1) / k <= self._ps_cumsum[i] / denominator: boundaries.append(i + 1) @@ -601,36 +608,37 @@ def _get_weights(self, indices: Sequence[int], alpha: float, beta: float): class PrioritizedReplayBuffer(ReplayBuffer): - _variants: ClassVar[Sequence[str]] = ['proportional', 'rank_based'] + _variants: ClassVar[Sequence[str]] = ["proportional", "rank_based"] _buffer_impl: _PrioritizedReplayBuffer - def __init__(self, - capacity: int, - alpha: float = 0.6, - beta: float = 0.4, - betasteps: int = 10000, - error_clip: Optional[Tuple[float, float]] = (-1, 1), - epsilon: float = 1e-8, - reset_segment_interval: int = 1000, - sort_interval: int = 1000000, - variant: str = 'proportional'): + def __init__( + self, + capacity: int, + alpha: float = 0.6, + beta: float = 0.4, + betasteps: int = 10000, + error_clip: Optional[Tuple[float, float]] = (-1, 1), + epsilon: float = 1e-8, + reset_segment_interval: int = 1000, + sort_interval: int = 1000000, + variant: str = "proportional", + ): if variant not in PrioritizedReplayBuffer._variants: - raise ValueError(f'Unknown prioritized replay buffer variant: {variant}') - if variant == 'proportional': - self._buffer_impl = ProportionalPrioritizedReplayBuffer(capacity=capacity, - alpha=alpha, - beta=beta, - betasteps=betasteps, - error_clip=error_clip, - epsilon=epsilon) - elif variant == 'rank_based': - self._buffer_impl = RankBasedPrioritizedReplayBuffer(capacity=capacity, - alpha=alpha, - beta=beta, - betasteps=betasteps, - error_clip=error_clip, - reset_segment_interval=reset_segment_interval, - sort_interval=sort_interval) + raise ValueError(f"Unknown prioritized replay buffer variant: {variant}") + if variant == "proportional": + self._buffer_impl = ProportionalPrioritizedReplayBuffer( + capacity=capacity, alpha=alpha, beta=beta, betasteps=betasteps, error_clip=error_clip, epsilon=epsilon + ) + elif variant == "rank_based": + self._buffer_impl = RankBasedPrioritizedReplayBuffer( + capacity=capacity, + alpha=alpha, + beta=beta, + betasteps=betasteps, + error_clip=error_clip, + reset_segment_interval=reset_segment_interval, + sort_interval=sort_interval, + ) else: raise NotImplementedError diff --git a/nnabla_rl/replay_buffers/trajectory_replay_buffer.py b/nnabla_rl/replay_buffers/trajectory_replay_buffer.py index f7b8fe07..046d5d2d 100644 --- a/nnabla_rl/replay_buffers/trajectory_replay_buffer.py +++ b/nnabla_rl/replay_buffers/trajectory_replay_buffer.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -66,22 +66,24 @@ def append_trajectory(self, trajectory: Trajectory): self._num_experiences = num_experiences self._cumsum_experiences = cumsum_experiences - def sample_indices(self, indices: Sequence[int], num_steps: int = 1) \ - -> Tuple[Union[Sequence[Experience], Tuple[Sequence[Experience], ...]], Dict[str, Any]]: + def sample_indices( + self, indices: Sequence[int], num_steps: int = 1 + ) -> Tuple[Union[Sequence[Experience], Tuple[Sequence[Experience], ...]], Dict[str, Any]]: if len(indices) == 0: - raise ValueError('Indices are empty') + raise ValueError("Indices are empty") if num_steps < 1: - raise ValueError(f'num_steps: {num_steps} should be greater than 0!') + raise ValueError(f"num_steps: {num_steps} should be greater than 0!") experiences: Union[Sequence[Experience], Tuple[Sequence[Experience], ...]] if num_steps == 1: experiences = [self._get_experience(index) for index in indices] else: - experiences = tuple([self._get_experience(index+i) for index in indices] for i in range(num_steps)) + experiences = tuple([self._get_experience(index + i) for index in indices] for i in range(num_steps)) weights = np.ones([len(indices), 1]) return experiences, dict(weights=weights) - def sample_trajectories(self, num_samples: int = 1) -> Tuple[Union[Trajectory, Tuple[Trajectory, ...]], - Dict[str, Any]]: + def sample_trajectories( + self, num_samples: int = 1 + ) -> Tuple[Union[Trajectory, Tuple[Trajectory, ...]], Dict[str, Any]]: """Randomly sample num_samples trajectories from the replay buffer. Args: @@ -95,14 +97,16 @@ def sample_trajectories(self, num_samples: int = 1) -> Tuple[Union[Trajectory, T max_index = self.trajectory_num if num_samples > max_index: raise ValueError( - f'num_samples: {num_samples} is greater than the number of trajectories saved in buffer: {max_index}') + f"num_samples: {num_samples} is greater than the number of trajectories saved in buffer: {max_index}" + ) indices = self._random_trajectory_indices(num_samples=num_samples, max_index=max_index) return self.sample_indices_trajectory(indices) - def sample_indices_trajectory(self, indices: Sequence[int]) \ - -> Tuple[Union[Trajectory, Tuple[Trajectory, ...]], Dict[str, Any]]: + def sample_indices_trajectory( + self, indices: Sequence[int] + ) -> Tuple[Union[Trajectory, Tuple[Trajectory, ...]], Dict[str, Any]]: if len(indices) == 0: - raise ValueError('Indices are empty') + raise ValueError("Indices are empty") trajectories: Union[Trajectory, Tuple[Trajectory, ...]] if len(indices) == 1: trajectories = self.get_trajectory(indices[0]) @@ -111,9 +115,9 @@ def sample_indices_trajectory(self, indices: Sequence[int]) \ weights = np.ones([len(indices), 1]) return trajectories, dict(weights=weights) - def sample_trajectories_portion(self, - num_samples: int = 1, - portion_length: int = 1) -> Tuple[Tuple[Trajectory, ...], Dict[str, Any]]: + def sample_trajectories_portion( + self, num_samples: int = 1, portion_length: int = 1 + ) -> Tuple[Tuple[Trajectory, ...], Dict[str, Any]]: """Randomly sample num_samples trajectories with length portion_length from the replay buffer. (i.e. Each trajectory length will be portion_length) Trajectory will be sampled as follows. First, a @@ -142,16 +146,17 @@ def sample_trajectories_portion(self, sliced_trajectories: MutableSequence[Trajectory] = [] for trajectory in trajectories: - max_index = len(trajectory)-portion_length + max_index = len(trajectory) - portion_length if max_index < 0: - raise RuntimeError(f'Trajectory length is shorter than portion length: {portion_length}') - initial_index = rl.random.drng.choice(max_index+1, replace=False) - sliced_trajectories.append(trajectory[initial_index:initial_index+portion_length]) + raise RuntimeError(f"Trajectory length is shorter than portion length: {portion_length}") + initial_index = rl.random.drng.choice(max_index + 1, replace=False) + sliced_trajectories.append(trajectory[initial_index : initial_index + portion_length]) weights = np.ones([len(trajectories), 1]) return tuple(sliced_trajectories), dict(weights=weights) - def sample_indices_portion(self, indices: Sequence[int], portion_length: int = 1) -> \ - Tuple[Tuple[Trajectory, ...], Dict[str, Any]]: + def sample_indices_portion( + self, indices: Sequence[int], portion_length: int = 1 + ) -> Tuple[Tuple[Trajectory, ...], Dict[str, Any]]: """Sample trajectory portions from the buffer. (i.e. Each trajectory length will be portion_length) Trajectory from given index to index+portion_length-1 will be sampled. Index should be the index of a @@ -177,23 +182,23 @@ def sample_indices_portion(self, indices: Sequence[int], portion_length: int = 1 RuntimeError: Trajectory's length is below portion_length. """ if len(indices) == 0: - raise ValueError('Indices are empty') + raise ValueError("Indices are empty") if portion_length < 1: - raise ValueError(f'portion_length: {portion_length} should be greater than 0!') + raise ValueError(f"portion_length: {portion_length} should be greater than 0!") sliced_trajectories: MutableSequence[Trajectory] = [] for index in indices: trajectory_index = np.argwhere(np.asarray(self._cumsum_experiences) > index)[0][0] trajectory = self.get_trajectory(trajectory_index) if len(trajectory) < portion_length: - raise RuntimeError(f'Trajectory length is shorter than portion length: {portion_length}') + raise RuntimeError(f"Trajectory length is shorter than portion length: {portion_length}") if 0 < trajectory_index: experience_index = index - self._cumsum_experiences[trajectory_index - 1] else: experience_index = index experience_index = min(experience_index, len(trajectory) - portion_length) - sliced_trajectories.append(trajectory[experience_index:experience_index+portion_length]) + sliced_trajectories.append(trajectory[experience_index : experience_index + portion_length]) weights = np.ones([len(indices), 1]) return tuple(sliced_trajectories), dict(weights=weights) @@ -222,4 +227,4 @@ def _get_experience(self, experience_index) -> Experience: experience: Experience = trajectory[experience_index - prev_cumsum] return experience prev_cumsum = cumsum - raise ValueError(f'index {experience_index} is out of range') + raise ValueError(f"index {experience_index} is out of range") diff --git a/nnabla_rl/typing.py b/nnabla_rl/typing.py index 743ff63a..f28e27a6 100644 --- a/nnabla_rl/typing.py +++ b/nnabla_rl/typing.py @@ -1,5 +1,5 @@ # Copyright 2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ import numpy as np -F = TypeVar('F', bound=Callable[..., Any]) +F = TypeVar("F", bound=Callable[..., Any]) State = Union[np.ndarray, Tuple[np.ndarray, ...]] @@ -76,6 +76,7 @@ def dummy_function(x, y, z, non_shape_args=False): def dummy_function(x, y, z, non_shape_args=False): pass """ + def accepted_shapes_wrapper(f: F) -> F: signature_f = signature(f) @@ -85,14 +86,14 @@ def wrapped_with_accepted_shapes(*args, **kwargs): return f(*args, **kwargs) return cast(F, wrapped_with_accepted_shapes) + return accepted_shapes_wrapper def _is_same_shape(actual_shape: Tuple[int], expected_shape: Tuple[int]) -> bool: if len(actual_shape) != len(expected_shape): return False - return all([actual == expected or expected is None - for actual, expected in zip(actual_shape, expected_shape)]) + return all([actual == expected or expected is None for actual, expected in zip(actual_shape, expected_shape)]) def _check_kwargs_shape(kwargs, expected_kwargs_shapes): diff --git a/nnabla_rl/utils/context.py b/nnabla_rl/utils/context.py index 55ef0301..64a1fd35 100644 --- a/nnabla_rl/utils/context.py +++ b/nnabla_rl/utils/context.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,13 +32,14 @@ def get_nnabla_context(gpu_id): if gpu_id in contexts: return contexts[gpu_id] if gpu_id < 0: - ctx = get_extension_context('cpu') + ctx = get_extension_context("cpu") else: try: - ctx = get_extension_context('cudnn', device_id=gpu_id) + ctx = get_extension_context("cudnn", device_id=gpu_id) except ModuleNotFoundError: - warnings.warn('Could not get CUDA context and cuDNN context. Fallback to CPU context instead', - RuntimeWarning) - ctx = get_extension_context('cpu') + warnings.warn( + "Could not get CUDA context and cuDNN context. Fallback to CPU context instead", RuntimeWarning + ) + ctx = get_extension_context("cpu") contexts[gpu_id] = ctx return ctx diff --git a/nnabla_rl/utils/data.py b/nnabla_rl/utils/data.py index 4c8487ff..7a76c0d9 100644 --- a/nnabla_rl/utils/data.py +++ b/nnabla_rl/utils/data.py @@ -21,7 +21,7 @@ from nnabla_rl.logger import logger from nnabla_rl.typing import TupledData -T = TypeVar('T') +T = TypeVar("T") def add_axis_if_single_dim(data): @@ -85,7 +85,7 @@ def marshal_dict_experiences(dict_experiences: Sequence[Dict[str, Any]]) -> Dict marshaled_experiences.update({key: add_axis_if_single_dim(np.asarray(data))}) except ValueError as e: # do nothing - logger.warn(f'key: {key} contains inconsistent elements!. Details: {e}') + logger.warn(f"key: {key} contains inconsistent elements!. Details: {e}") return marshaled_experiences @@ -108,8 +108,9 @@ def convert_to_list_if_not_list(value: Union[Iterable[T], T]) -> List[T]: return [value] -def set_data_to_variable(variable: Union[nn.Variable, Tuple[nn.Variable, ...]], - data: Union[float, np.ndarray, Tuple[np.ndarray, ...]]) -> None: +def set_data_to_variable( + variable: Union[nn.Variable, Tuple[nn.Variable, ...]], data: Union[float, np.ndarray, Tuple[np.ndarray, ...]] +) -> None: """Set data to variable. Args: @@ -213,10 +214,9 @@ def append_with_removed_item_check(self, data): return removed -def normalize_ndarray(ndarray: np.ndarray, - mean: np.ndarray, - std: np.ndarray, - value_clip: Optional[Tuple[float, float]] = None) -> np.ndarray: +def normalize_ndarray( + ndarray: np.ndarray, mean: np.ndarray, std: np.ndarray, value_clip: Optional[Tuple[float, float]] = None +) -> np.ndarray: """Normalize the given ndarray. Args: @@ -234,10 +234,9 @@ def normalize_ndarray(ndarray: np.ndarray, return normalized -def unnormalize_ndarray(ndarray: np.ndarray, - mean: np.ndarray, - std: np.ndarray, - value_clip: Optional[Tuple[float, float]] = None) -> np.ndarray: +def unnormalize_ndarray( + ndarray: np.ndarray, mean: np.ndarray, std: np.ndarray, value_clip: Optional[Tuple[float, float]] = None +) -> np.ndarray: """Unnormalize the given ndarray. Args: diff --git a/nnabla_rl/utils/debugging.py b/nnabla_rl/utils/debugging.py index c7c0ee6a..b584aa58 100644 --- a/nnabla_rl/utils/debugging.py +++ b/nnabla_rl/utils/debugging.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,6 +31,7 @@ def accept_nnabla_func(nnabla_func): print(nnabla_func.inputs) print(nnabla_func.outputs) print(nnabla_func.info.args) + x.visit(accept_nnabla_func) @@ -45,12 +46,12 @@ def save_graph(x, file_path, verbose=False): def count_parameter_number(parameters): - ''' + """ Args: parameters (dict): parameters in dictionary form Returns: parameter_number (int): parameter number - ''' + """ parameter_number = 0 for parameter in parameters.values(): parameter_number += parameter.size @@ -61,7 +62,7 @@ def profile_graph( output_variable: nn.Variable, csv_file_path: Union[str, pathlib.Path], solver: Optional[S.Solver] = None, - ext_name: str = 'cudnn', + ext_name: str = "cudnn", device_id: int = 0, n_run: int = 1000, ) -> None: @@ -107,8 +108,9 @@ def on_hook_called(self, _): summarized = summary.summarize(all_objects) self.print_summary(summarized) - def print_summary(self, rows, limit=30, sort='size', order='descending'): + def print_summary(self, rows, limit=30, sort="size", order="descending"): for line in summary.format_(rows, limit=limit, sort=sort, order=order): logger.debug(line) + except ModuleNotFoundError: pass diff --git a/nnabla_rl/utils/evaluator.py b/nnabla_rl/utils/evaluator.py index 12d4676c..82bd8e2e 100644 --- a/nnabla_rl/utils/evaluator.py +++ b/nnabla_rl/utils/evaluator.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ from nnabla_rl.typing import Experience -class EpisodicEvaluator(): +class EpisodicEvaluator: def __init__(self, run_per_evaluation=10): self._num_episodes = run_per_evaluation @@ -31,12 +31,12 @@ def __call__(self, algorithm, env): reward_sum, *_ = run_one_episode(algorithm, env) returns.append(reward_sum) logger.info( - 'Finished evaluation run: #{} out of {}. Total reward: {}' - .format(num, self._num_episodes, reward_sum)) + "Finished evaluation run: #{} out of {}. Total reward: {}".format(num, self._num_episodes, reward_sum) + ) return returns -class TimestepEvaluator(): +class TimestepEvaluator: def __init__(self, num_timesteps): self._num_timesteps = num_timesteps @@ -56,17 +56,18 @@ def limit_checker(t): returns.append(reward_sum) logger.info( - 'Finished evaluation run: Time step #{} out of {}, Episode #{}. Total reward: {}' - .format(timesteps, self._num_timesteps, len(returns), reward_sum)) + "Finished evaluation run: Time step #{} out of {}, Episode #{}. Total reward: {}".format( + timesteps, self._num_timesteps, len(returns), reward_sum + ) + ) if len(returns) == 0: # In case the time limit reaches on first episode, save the return received up to that time returns.append(reward_sum) return returns -class EpisodicSuccessEvaluator(): - def __init__(self, check_success: Callable[[List[Experience]], Union[bool, float]], - run_per_evaluation=10): +class EpisodicSuccessEvaluator: + def __init__(self, check_success: Callable[[List[Experience]], Union[bool, float]], run_per_evaluation=10): self._num_episodes = run_per_evaluation self._compute_success_func = check_success @@ -76,11 +77,8 @@ def __call__(self, algorithm, env): _, _, experiences = run_one_episode(algorithm, env) success = self._compute_success_func(experiences) results.append(success) - success_tag = 'Success' if success else 'Failed' - logger.info( - 'Finished evaluation run: #{} out of {}. {}' - .format(num, self._num_episodes, success_tag) - ) + success_tag = "Success" if success else "Failed" + logger.info("Finished evaluation run: #{} out of {}. {}".format(num, self._num_episodes, success_tag)) return results @@ -90,7 +88,7 @@ def run_one_episode(algorithm, env, timestep_limit=lambda t: False): rewards = [] timesteps = 0 state = env.reset() - extra_info = {'reward': 0} + extra_info = {"reward": 0} action = algorithm.compute_eval_action(state, begin_of_episode=True, extra_info=extra_info) while True: next_state, reward, done, info = env.step(action) @@ -104,6 +102,6 @@ def run_one_episode(algorithm, env, timestep_limit=lambda t: False): break else: state = next_state - extra_info['reward'] = reward + extra_info["reward"] = reward action = algorithm.compute_eval_action(state, begin_of_episode=False, extra_info=extra_info) return np.sum(rewards), timesteps, experiences diff --git a/nnabla_rl/utils/files.py b/nnabla_rl/utils/files.py index e5cd8223..2338522a 100644 --- a/nnabla_rl/utils/files.py +++ b/nnabla_rl/utils/files.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ def create_dir_if_not_exist(outdir): """ if file_exists(outdir): if not os.path.isdir(outdir): - raise RuntimeError('{} is not a directory'.format(outdir)) + raise RuntimeError("{} is not a directory".format(outdir)) else: return os.makedirs(outdir) @@ -51,7 +51,7 @@ def read_text_from_file(file_path): Returns: data (str): Text read from the file """ - with open(file_path, 'r') as f: + with open(file_path, "r") as f: return f.read() @@ -62,5 +62,5 @@ def write_text_to_file(file_path, data): file_path (str or pathlib.Path): Path of the file to write data data (str): Text to write to the file """ - with open(file_path, 'w') as f: + with open(file_path, "w") as f: f.write(data) diff --git a/nnabla_rl/utils/matrices.py b/nnabla_rl/utils/matrices.py index 1077eb05..f8625d64 100644 --- a/nnabla_rl/utils/matrices.py +++ b/nnabla_rl/utils/matrices.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,14 +32,12 @@ def compute_hessian(y, x): param.grad.zero() grads = nn.grad([y], x) if len(grads) > 1: - flat_grads = NF.concatenate( - *[NF.reshape(grad, (-1,), inplace=False) for grad in grads]) + flat_grads = NF.concatenate(*[NF.reshape(grad, (-1,), inplace=False) for grad in grads]) else: flat_grads = NF.reshape(grads[0], (-1,), inplace=False) flat_grads.need_grad = True - hessian = np.zeros( - (flat_grads.shape[0], flat_grads.shape[0]), dtype=np.float32) + hessian = np.zeros((flat_grads.shape[0], flat_grads.shape[0]), dtype=np.float32) for i in range(flat_grads.shape[0]): flat_grads[i].forward() @@ -50,7 +48,7 @@ def compute_hessian(y, x): num_index = 0 for param in x: grad = param.g.flatten() # grad of grad so this is hessian - hessian[i, num_index:num_index+len(grad)] = grad + hessian[i, num_index : num_index + len(grad)] = grad num_index += len(grad) return hessian diff --git a/nnabla_rl/utils/misc.py b/nnabla_rl/utils/misc.py index 901c69b9..354bf76a 100644 --- a/nnabla_rl/utils/misc.py +++ b/nnabla_rl/utils/misc.py @@ -1,5 +1,5 @@ # Copyright 2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ def sync_model(src: Model, dst: Model, tau: float = 1.0): def copy_network_parameters(origin_params, target_params, tau=1.0): if not ((0.0 <= tau) & (tau <= 1.0)): - raise ValueError('tau must lie between [0.0, 1.0]') + raise ValueError("tau must lie between [0.0, 1.0]") for key in target_params.keys(): target_params[key].data.copy_from(origin_params[key].data * tau + target_params[key].data * (1 - tau)) @@ -66,11 +66,13 @@ def create_variables(batch_size: int, shapes: Dict[str, Tuple[int, ...]]) -> Dic return variables -def retrieve_internal_states(scope_name: str, - prev_rnn_states: Dict[str, Dict[str, nn.Variable]], - train_rnn_states: Dict[str, Dict[str, nn.Variable]], - training_variables: 'TrainingVariables', - reset_on_terminal: bool) -> Dict[str, nn.Variable]: +def retrieve_internal_states( + scope_name: str, + prev_rnn_states: Dict[str, Dict[str, nn.Variable]], + train_rnn_states: Dict[str, Dict[str, nn.Variable]], + training_variables: "TrainingVariables", + reset_on_terminal: bool, +) -> Dict[str, nn.Variable]: internal_states: Dict[str, nn.Variable] = {} if training_variables.is_initial_step(): internal_states = train_rnn_states[scope_name] diff --git a/nnabla_rl/utils/multiprocess.py b/nnabla_rl/utils/multiprocess.py index 2a739d3c..135169c5 100644 --- a/nnabla_rl/utils/multiprocess.py +++ b/nnabla_rl/utils/multiprocess.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -85,6 +85,5 @@ def copy_mp_arrays_to_params(mp_arrays, params): param_shape = params[key].shape # FIXME: force using float32. # This is a workaround for compensating nnabla's parameter initialization with float64 - np_array = mp_to_np_array( - mp_array, np_shape=param_shape, dtype=np.float32) + np_array = mp_to_np_array(mp_array, np_shape=param_shape, dtype=np.float32) params[key].d = np_array diff --git a/nnabla_rl/utils/optimization.py b/nnabla_rl/utils/optimization.py index a83b877d..13e67be1 100644 --- a/nnabla_rl/utils/optimization.py +++ b/nnabla_rl/utils/optimization.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -48,7 +48,7 @@ def conjugate_gradient(compute_Ax, b, max_iterations=10, residual_tol=1e-10): break if max_iterations is not None: - if iteration_number >= max_iterations-1: + if iteration_number >= max_iterations - 1: break beta = new_square_r / square_r diff --git a/nnabla_rl/utils/reproductions.py b/nnabla_rl/utils/reproductions.py index f18ae72d..98e2d514 100644 --- a/nnabla_rl/utils/reproductions.py +++ b/nnabla_rl/utils/reproductions.py @@ -22,8 +22,13 @@ import nnabla as nn import nnabla_rl as rl from nnabla_rl.environments.environment_info import EnvironmentInfo -from nnabla_rl.environments.wrappers import (Gymnasium2GymWrapper, NumpyFloat32Env, ScreenRenderEnv, make_atari, - wrap_deepmind) +from nnabla_rl.environments.wrappers import ( + Gymnasium2GymWrapper, + NumpyFloat32Env, + ScreenRenderEnv, + make_atari, + wrap_deepmind, +) from nnabla_rl.logger import logger @@ -48,15 +53,17 @@ def build_classic_control_env(id_or_env, seed=None, render=False): return env -def build_atari_env(id_or_env, - test=False, - seed=None, - render=False, - print_info=True, - max_frames_per_episode=None, - frame_stack=True, - flicker_probability=0.0, - use_gymnasium=False): +def build_atari_env( + id_or_env, + test=False, + seed=None, + render=False, + print_info=True, + max_frames_per_episode=None, + frame_stack=True, + flicker_probability=0.0, + use_gymnasium=False, +): if isinstance(id_or_env, gym.Env): env = id_or_env elif isinstance(id_or_env, gymnasium.Env): @@ -66,11 +73,13 @@ def build_atari_env(id_or_env, if print_info: print_env_info(env) - env = wrap_deepmind(env, - episode_life=not test, - clip_rewards=not test, - frame_stack=frame_stack, - flicker_probability=flicker_probability) + env = wrap_deepmind( + env, + episode_life=not test, + clip_rewards=not test, + frame_stack=frame_stack, + flicker_probability=flicker_probability, + ) env = NumpyFloat32Env(env) if render: @@ -125,12 +134,10 @@ def build_dmc_env(id_or_env, test=False, seed=None, render=False, print_info=Tru elif id_or_env.startswith("FakeDMControl"): env = gym.make(id_or_env) else: - domain_name, task_name = id_or_env.split('-') - env = DMCEnv(domain_name, - task_name=task_name, - task_kwargs={'random': seed}) + domain_name, task_name = id_or_env.split("-") + env = DMCEnv(domain_name, task_name=task_name, task_kwargs={"random": seed}) env = gym.wrappers.FlattenObservation(env) - env = gym.wrappers.RescaleAction(env, min_action=-1., max_action=1.) + env = gym.wrappers.RescaleAction(env, min_action=-1.0, max_action=1.0) if print_info: print_env_info(env) @@ -144,11 +151,11 @@ def build_dmc_env(id_or_env, test=False, seed=None, render=False, print_info=Tru def d4rl_dataset_to_experiences(dataset, size=1000000): - size = min(dataset['observations'].shape[0], size) - states = dataset['observations'][:size] - actions = dataset['actions'][:size] - rewards = dataset['rewards'][:size].reshape(size, 1) - non_terminals = 1.0 - dataset['terminals'][:size].reshape(size, 1) + size = min(dataset["observations"].shape[0], size) + states = dataset["observations"][:size] + actions = dataset["actions"][:size] + rewards = dataset["rewards"][:size].reshape(size, 1) + non_terminals = 1.0 - dataset["terminals"][:size].reshape(size, 1) next_states = np.concatenate([states[1:size, :], np.zeros(shape=states[0].shape)[np.newaxis, :]], axis=0) infos = [{} for _ in range(size)] assert len(states) == size @@ -164,11 +171,11 @@ def print_env_info(env): if env.unwrapped.spec is not None: env_name = env.unwrapped.spec.id else: - env_name = 'Unknown' + env_name = "Unknown" env_info = EnvironmentInfo.from_env(env) - info = f'''env: {env_name}, + info = f"""env: {env_name}, state_dim: {env_info.state_dim}, state_shape: {env_info.state_shape}, state_high: {env_info.state_high}, @@ -177,5 +184,5 @@ def print_env_info(env): action_shape: {env_info.action_shape}, action_high: {env_info.action_high}, action_low: {env_info.action_low}, - max_episode_steps: {env.spec.max_episode_steps}''' + max_episode_steps: {env.spec.max_episode_steps}""" logger.info(info) diff --git a/nnabla_rl/utils/serializers.py b/nnabla_rl/utils/serializers.py index 159ba933..7eed529c 100644 --- a/nnabla_rl/utils/serializers.py +++ b/nnabla_rl/utils/serializers.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,14 +25,14 @@ from nnabla_rl.algorithm import Algorithm from nnabla_rl.environments.environment_info import EnvironmentInfo -_TRAINING_INFO_FILENAME = 'training_info.json' +_TRAINING_INFO_FILENAME = "training_info.json" -_KEY_ALGORITHM_NAME = 'algorithm_name' -_KEY_ALGORITHM_CLASS_NAME = 'algorithm_class_name' -_KEY_ALGORITHM_CONFIG = 'algorithm_config' -_KEY_ITERATION_NUM = 'iteration_num' -_KEY_MODELS = 'models' -_KEY_SOLVERS = 'solvers' +_KEY_ALGORITHM_NAME = "algorithm_name" +_KEY_ALGORITHM_CLASS_NAME = "algorithm_class_name" +_KEY_ALGORITHM_CONFIG = "algorithm_config" +_KEY_ITERATION_NUM = "iteration_num" +_KEY_MODELS = "models" +_KEY_SOLVERS = "solvers" def save_snapshot(path, algorithm): @@ -48,7 +48,7 @@ def save_snapshot(path, algorithm): assert isinstance(algorithm, Algorithm) if isinstance(path, str): path = pathlib.Path(path) - dirname = 'iteration-' + str(algorithm.iteration_num) + dirname = "iteration-" + str(algorithm.iteration_num) outdir = path / dirname files.create_dir_if_not_exist(outdir=outdir) @@ -60,9 +60,7 @@ def save_snapshot(path, algorithm): return outdir -def load_snapshot(path, - env_or_env_info, - algorithm_kwargs={}): +def load_snapshot(path, env_or_env_info, algorithm_kwargs={}): """Load training snapshot from file. Args: @@ -76,11 +74,11 @@ def load_snapshot(path, path = pathlib.Path(path) if not isinstance(env_or_env_info, (gym.Env, EnvironmentInfo)): raise RuntimeError( - 'load_snapshot requires training gym.Env or EnvironmentInfo. ' - 'Automatic loading of env_info is no longer supported since v0.10.0') + "load_snapshot requires training gym.Env or EnvironmentInfo. " + "Automatic loading of env_info is no longer supported since v0.10.0" + ) training_info = _load_training_info(path) - algorithm = _instantiate_algorithm_from_training_info( - training_info, env_or_env_info, **algorithm_kwargs) + algorithm = _instantiate_algorithm_from_training_info(training_info, env_or_env_info, **algorithm_kwargs) _load_network_parameters(path, algorithm) _load_solver_states(path, algorithm) return algorithm @@ -90,12 +88,12 @@ def _instantiate_algorithm_from_training_info(training_info, env_info, **kwargs) algorithm_name = training_info[_KEY_ALGORITHM_CLASS_NAME] (algorithm_klass, config_klass) = A.get_class_of(algorithm_name) - config = kwargs.get('config', None) + config = kwargs.get("config", None) if not isinstance(config, config_klass): saved_config = training_info[_KEY_ALGORITHM_CONFIG] saved_config = config_klass(**saved_config) config = dataclasses.replace(saved_config, **config) if isinstance(config, dict) else saved_config - kwargs['config'] = config + kwargs["config"] = config algorithm = algorithm_klass(env_info, **kwargs) algorithm._iteration_num = training_info[_KEY_ITERATION_NUM] return algorithm @@ -115,34 +113,34 @@ def _create_training_info(algorithm): def _save_training_info(path, training_info): filepath = path / _TRAINING_INFO_FILENAME - with open(filepath, 'w+') as outfile: + with open(filepath, "w+") as outfile: json.dump(training_info, outfile) def _load_training_info(path): filepath = path / _TRAINING_INFO_FILENAME - with open(filepath, 'r') as infile: + with open(filepath, "r") as infile: training_info = json.load(infile) return training_info def _save_network_parameters(path, algorithm): for scope_name, model in algorithm._models().items(): - filename = scope_name + '.h5' + filename = scope_name + ".h5" filepath = path / filename model.save_parameters(filepath) def _load_network_parameters(path, algorithm): for scope_name, model in algorithm._models().items(): - filename = scope_name + '.h5' + filename = scope_name + ".h5" filepath = path / filename model.load_parameters(filepath) def _save_solver_states(path, algorithm): for scope_name, solver in algorithm._solvers().items(): - filename = scope_name + '_solver' + '.h5' + filename = scope_name + "_solver" + ".h5" filepath = path / filename solver.save_states(filepath) @@ -150,7 +148,7 @@ def _save_solver_states(path, algorithm): def _load_solver_states(path, algorithm): models = algorithm._models() for scope_name, solver in algorithm._solvers().items(): - filename = scope_name + '_solver' + '.h5' + filename = scope_name + "_solver" + ".h5" filepath = path / filename if not filepath.exists(): warnings.warn(f"No solver file found in: {filepath}. Ommitting...") diff --git a/nnabla_rl/writer.py b/nnabla_rl/writer.py index 0535b8b8..30f8d830 100644 --- a/nnabla_rl/writer.py +++ b/nnabla_rl/writer.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + class Writer(object): def __init__(self): pass diff --git a/nnabla_rl/writers/file_writer.py b/nnabla_rl/writers/file_writer.py index ef47c7ba..676e8b9a 100644 --- a/nnabla_rl/writers/file_writer.py +++ b/nnabla_rl/writers/file_writer.py @@ -32,30 +32,28 @@ def __init__(self, outdir, file_prefix, fmt="%.3f"): self._fmt = fmt def write_scalar(self, iteration_num, scalar): - outfile = self._outdir / (self._file_prefix + '_scalar.tsv') + outfile = self._outdir / (self._file_prefix + "_scalar.tsv") len_scalar = len(scalar.values()) out_scalar = {} - out_scalar['iteration'] = iteration_num + out_scalar["iteration"] = iteration_num out_scalar.update(scalar) self._create_file_if_not_exists(outfile, out_scalar.keys()) - with open(outfile, 'a') as f: - np.savetxt(f, [list(out_scalar.values())], - fmt=['%i'] + [self._fmt] * len_scalar, - delimiter='\t') + with open(outfile, "a") as f: + np.savetxt(f, [list(out_scalar.values())], fmt=["%i"] + [self._fmt] * len_scalar, delimiter="\t") def write_histogram(self, iteration_num, histogram): - outfile = self._outdir / (self._file_prefix + '_histogram.tsv') + outfile = self._outdir / (self._file_prefix + "_histogram.tsv") - self._create_file_if_not_exists(outfile, ['iteration(key)', 'values']) + self._create_file_if_not_exists(outfile, ["iteration(key)", "values"]) - with open(outfile, 'a') as f: + with open(outfile, "a") as f: for key, values in histogram.items(): - np.savetxt(f, [[iteration_num] + [*values]], - fmt=[f'%i ({key})'] + [self._fmt] * len(values), - delimiter='\t') + np.savetxt( + f, [[iteration_num] + [*values]], fmt=[f"%i ({key})"] + [self._fmt] * len(values), delimiter="\t" + ) def write_image(self, iteration_num, image): pass @@ -66,5 +64,5 @@ def _create_file_if_not_exists(self, outfile, header_keys): self._write_file_header(outfile, header_keys) def _write_file_header(self, filepath, keys): - with open(filepath, 'w+') as f: - np.savetxt(f, [list(keys)], fmt='%s', delimiter='\t') + with open(filepath, "w+") as f: + np.savetxt(f, [list(keys)], fmt="%s", delimiter="\t") diff --git a/nnabla_rl/writers/monitor_writer.py b/nnabla_rl/writers/monitor_writer.py index dfce274f..7dff98e4 100644 --- a/nnabla_rl/writers/monitor_writer.py +++ b/nnabla_rl/writers/monitor_writer.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ def __init__(self, outdir, file_prefix): self._monitors = {} def write_scalar(self, iteration_num, scalar): - prefix = self._file_prefix + '_scalar_' + prefix = self._file_prefix + "_scalar_" for name, value in scalar.items(): monitor = self._create_or_get_monitor_series(prefix + name) monitor.add(iteration_num, value) diff --git a/nnabla_rl/writers/writing_distributor.py b/nnabla_rl/writers/writing_distributor.py index 4b02c87e..460a3a31 100644 --- a/nnabla_rl/writers/writing_distributor.py +++ b/nnabla_rl/writers/writing_distributor.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,5 +36,5 @@ def write_image(self, iteration_num, image): writer.write_image(iteration_num, image) def _write_file_header(self, filepath, keys): - with open(filepath, 'w+') as f: - np.savetxt(f, [list(keys)], fmt='%s', delimiter='\t') + with open(filepath, "w+") as f: + np.savetxt(f, [list(keys)], fmt="%s", delimiter="\t") diff --git a/pyproject.toml b/pyproject.toml index cc3fc41d..0c09d377 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,8 +79,8 @@ dev = [ "pytest-cov", "mypy != 1.11.0", "typing-extensions", - "isort", - "autopep8", + "isort > 5.0.0", + "black", "docformatter" ] deploy = [ @@ -114,11 +114,18 @@ exclude = [ testpaths = ["tests"] addopts = "-s" -[tool.autopep8] -max_line_length = 120 -recursive = true +[tool.black] +line-length = 120 +target-version = ['py38'] +include = '\.pyi?$' +extend-exclude = ''' +/( + nnabla_rl/external +) +''' [tool.isort] +profile = "black" line_length = 120 honor_noqa = true known_first_party = ["nnabla"] diff --git a/reproductions/algorithms/atari/a2c/a2c_reproduction.py b/reproductions/algorithms/atari/a2c/a2c_reproduction.py index b0aea81d..f063bdae 100644 --- a/reproductions/algorithms/atari/a2c/a2c_reproduction.py +++ b/reproductions/algorithms/atari/a2c/a2c_reproduction.py @@ -30,11 +30,12 @@ def run_training(args): set_global_seed(args.seed) train_env = build_atari_env(args.env, seed=args.seed, use_gymnasium=args.use_gymnasium) eval_env = build_atari_env( - args.env, test=True, seed=args.seed + 100, render=args.render, use_gymnasium=args.use_gymnasium) + args.env, test=True, seed=args.seed + 100, render=args.render, use_gymnasium=args.use_gymnasium + ) iteration_num_hook = H.IterationNumHook(timing=100) - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) writer = FileWriter(outdir, "evaluation_result") @@ -55,36 +56,37 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError('Please specify the snapshot dir for showcasing') - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=False, - use_gymnasium=args.use_gymnasium) + raise ValueError("Please specify the snapshot dir for showcasing") + eval_env = build_atari_env( + args.env, test=True, seed=args.seed + 200, render=False, use_gymnasium=args.use_gymnasium + ) config = A.A2CConfig(gpu_id=args.gpu) a2c = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(a2c, A.A2C): - raise ValueError('Loaded snapshot is not trained with A2C!') + raise ValueError("Loaded snapshot is not trained with A2C!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) returns = evaluator(a2c, eval_env) mean = np.mean(returns) std_dev = np.std(returns) median = np.median(returns) - logger.info('Evaluation results. mean: {} +/- std: {}, median: {}'.format(mean, std_dev, median)) + logger.info("Evaluation results. mean: {} +/- std: {}, median: {}".format(mean, std_dev, median)) def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='BreakoutNoFrameskip-v4') - parser.add_argument('--save-dir', type=str, default="") - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=50000000) - parser.add_argument('--save_timing', type=int, default=250000) - parser.add_argument('--eval_timing', type=int, default=250000) - parser.add_argument('--showcase_runs', type=int, default=10) - parser.add_argument('--use-gymnasium', action='store_true') + parser.add_argument("--env", type=str, default="BreakoutNoFrameskip-v4") + parser.add_argument("--save-dir", type=str, default="") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=50000000) + parser.add_argument("--save_timing", type=int, default=250000) + parser.add_argument("--eval_timing", type=int, default=250000) + parser.add_argument("--showcase_runs", type=int, default=10) + parser.add_argument("--use-gymnasium", action="store_true") args = parser.parse_args() @@ -94,5 +96,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/atari/c51/c51_reproduction.py b/reproductions/algorithms/atari/c51/c51_reproduction.py index 4b95c82e..4ca701db 100644 --- a/reproductions/algorithms/atari/c51/c51_reproduction.py +++ b/reproductions/algorithms/atari/c51/c51_reproduction.py @@ -31,26 +31,28 @@ def __call__(self, env_info, algorithm_config, **kwargs): def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 100, render=args.render, - use_gymnasium=args.use_gymnasium) + eval_env = build_atari_env( + args.env, test=True, seed=args.seed + 100, render=args.render, use_gymnasium=args.use_gymnasium + ) evaluator = TimestepEvaluator(num_timesteps=125000) evaluation_hook = H.EvaluationHook( - eval_env, evaluator, timing=args.eval_timing, writer=W.FileWriter(outdir=outdir, - file_prefix='evaluation_result')) + eval_env, + evaluator, + timing=args.eval_timing, + writer=W.FileWriter(outdir=outdir, file_prefix="evaluation_result"), + ) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) iteration_num_hook = H.IterationNumHook(timing=100) train_env = build_atari_env(args.env, seed=args.seed, render=args.render, use_gymnasium=args.use_gymnasium) config = A.CategoricalDQNConfig(gpu_id=args.gpu) - categorical_dqn = A.CategoricalDQN(train_env, - config=config, - replay_buffer_builder=MemoryEfficientBufferBuilder()) + categorical_dqn = A.CategoricalDQN(train_env, config=config, replay_buffer_builder=MemoryEfficientBufferBuilder()) hooks = [iteration_num_hook, save_snapshot_hook, evaluation_hook] categorical_dqn.set_hooks(hooks) @@ -62,14 +64,14 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError( - 'Please specify the snapshot dir for showcasing') - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=args.render, - use_gymnasium=args.use_gymnasium) + raise ValueError("Please specify the snapshot dir for showcasing") + eval_env = build_atari_env( + args.env, test=True, seed=args.seed + 200, render=args.render, use_gymnasium=args.use_gymnasium + ) config = A.CategoricalDQNConfig(gpu_id=args.gpu) categorical_dqn = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(categorical_dqn, A.CategoricalDQN): - raise ValueError('Loaded snapshot is not trained with CategoricalDQN!') + raise ValueError("Loaded snapshot is not trained with CategoricalDQN!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) evaluator(categorical_dqn, eval_env) @@ -77,18 +79,18 @@ def run_showcase(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='BreakoutNoFrameskip-v4') - parser.add_argument('--save-dir', type=str, default="") - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=50000000) - parser.add_argument('--save_timing', type=int, default=250000) - parser.add_argument('--eval_timing', type=int, default=250000) - parser.add_argument('--showcase_runs', type=int, default=10) - parser.add_argument('--use-gymnasium', action='store_true') + parser.add_argument("--env", type=str, default="BreakoutNoFrameskip-v4") + parser.add_argument("--save-dir", type=str, default="") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=50000000) + parser.add_argument("--save_timing", type=int, default=250000) + parser.add_argument("--eval_timing", type=int, default=250000) + parser.add_argument("--showcase_runs", type=int, default=10) + parser.add_argument("--use-gymnasium", action="store_true") args = parser.parse_args() @@ -98,5 +100,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/atari/ddqn/ddqn_reproduction.py b/reproductions/algorithms/atari/ddqn/ddqn_reproduction.py index 7d30496f..7bd92b2a 100644 --- a/reproductions/algorithms/atari/ddqn/ddqn_reproduction.py +++ b/reproductions/algorithms/atari/ddqn/ddqn_reproduction.py @@ -34,14 +34,15 @@ def __call__(self, env_info, algorithm_config, **kwargs): def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) writer = FileWriter(outdir, "evaluation_result") - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 100, render=args.render, - use_gymnasium=args.use_gymnasium) + eval_env = build_atari_env( + args.env, test=True, seed=args.seed + 100, render=args.render, use_gymnasium=args.use_gymnasium + ) evaluator = TimestepEvaluator(num_timesteps=125000) evaluation_hook = H.EvaluationHook(eval_env, evaluator, timing=args.eval_timing, writer=writer) @@ -51,9 +52,7 @@ def run_training(args): train_env = build_atari_env(args.env, seed=args.seed, use_gymnasium=args.use_gymnasium) config = A.DDQNConfig(gpu_id=args.gpu) - ddqn = A.DDQN(train_env, - config=config, - replay_buffer_builder=MemoryEfficientBufferBuilder()) + ddqn = A.DDQN(train_env, config=config, replay_buffer_builder=MemoryEfficientBufferBuilder()) ddqn.set_hooks(hooks=[iteration_num_hook, save_snapshot_hook, evaluation_hook]) ddqn.train(train_env, total_iterations=args.total_iterations) @@ -64,38 +63,37 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError( - 'Please specify the snapshot dir for showcasing') - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=args.render, - use_gymnasium=args.use_gymnasium) + raise ValueError("Please specify the snapshot dir for showcasing") + eval_env = build_atari_env( + args.env, test=True, seed=args.seed + 200, render=args.render, use_gymnasium=args.use_gymnasium + ) config = A.DDQNConfig(gpu_id=args.gpu) ddqn = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(ddqn, A.DDQN): - raise ValueError('Loaded snapshot is not trained with DDQN!') + raise ValueError("Loaded snapshot is not trained with DDQN!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) returns = evaluator(ddqn, eval_env) mean = np.mean(returns) std_dev = np.std(returns) median = np.median(returns) - logger.info('Evaluation results. mean: {} +/- std: {}, median: {}'.format(mean, std_dev, median)) + logger.info("Evaluation results. mean: {} +/- std: {}, median: {}".format(mean, std_dev, median)) def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, - default='BreakoutNoFrameskip-v4') - parser.add_argument('--save-dir', type=str, default="") - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=50000000) - parser.add_argument('--save_timing', type=int, default=250000) - parser.add_argument('--eval_timing', type=int, default=250000) - parser.add_argument('--showcase_runs', type=int, default=10) - parser.add_argument('--use-gymnasium', action='store_true') + parser.add_argument("--env", type=str, default="BreakoutNoFrameskip-v4") + parser.add_argument("--save-dir", type=str, default="") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=50000000) + parser.add_argument("--save_timing", type=int, default=250000) + parser.add_argument("--eval_timing", type=int, default=250000) + parser.add_argument("--showcase_runs", type=int, default=10) + parser.add_argument("--use-gymnasium", action="store_true") args = parser.parse_args() @@ -105,5 +103,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/atari/decision_transformer/atari_dataset_loader.py b/reproductions/algorithms/atari/decision_transformer/atari_dataset_loader.py index cf797b32..215170db 100755 --- a/reproductions/algorithms/atari/decision_transformer/atari_dataset_loader.py +++ b/reproductions/algorithms/atari/decision_transformer/atari_dataset_loader.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ def load_data_from_gz(gzfile): - with gzip.open(gzfile, mode='rb') as f: + with gzip.open(gzfile, mode="rb") as f: data = np.load(f, allow_pickle=False) return data @@ -49,8 +49,7 @@ def load_experiences(gzfiles, num_experiences_to_load): experiences = experience else: experiences.extend(experience) - print('loaded experiences: {} / {}'.format(len(experiences), - num_experiences_to_load)) + print("loaded experiences: {} / {}".format(len(experiences), num_experiences_to_load)) if num_experiences_to_load <= len(experiences): break return experiences[:num_experiences_to_load] @@ -58,16 +57,16 @@ def load_experiences(gzfiles, num_experiences_to_load): def load_dataset(dataset_dir, percentage): dataset_dir = pathlib.Path(dataset_dir) - observation_files = find_all_file_with_name(dataset_dir, 'observation') + observation_files = find_all_file_with_name(dataset_dir, "observation") observation_files.sort() - action_files = find_all_file_with_name(dataset_dir, 'action') + action_files = find_all_file_with_name(dataset_dir, "action") action_files.sort() - reward_files = find_all_file_with_name(dataset_dir, 'reward') + reward_files = find_all_file_with_name(dataset_dir, "reward") reward_files.sort() - terminal_files = find_all_file_with_name(dataset_dir, 'terminal') + terminal_files = find_all_file_with_name(dataset_dir, "terminal") terminal_files.sort() file_num = len(observation_files) @@ -84,14 +83,16 @@ def load_dataset(dataset_dir, percentage): def load_dataset_by_dataset_num(dataset_dir, dataset_num): dataset_dir = pathlib.Path(dataset_dir) - observation_file = dataset_dir / f'$store$_observation_ckpt.{dataset_num}.gz' - action_file = dataset_dir / f'$store$_action_ckpt.{dataset_num}.gz' - reward_file = dataset_dir / f'$store$_reward_ckpt.{dataset_num}.gz' - terminal_file = dataset_dir / f'$store$_terminal_ckpt.{dataset_num}.gz' + observation_file = dataset_dir / f"$store$_observation_ckpt.{dataset_num}.gz" + action_file = dataset_dir / f"$store$_action_ckpt.{dataset_num}.gz" + reward_file = dataset_dir / f"$store$_reward_ckpt.{dataset_num}.gz" + terminal_file = dataset_dir / f"$store$_terminal_ckpt.{dataset_num}.gz" with futures.ThreadPoolExecutor(max_workers=4) as executor: - data_futures = [executor.submit(load_experience, file_name) - for file_name in (observation_file, action_file, reward_file, terminal_file)] + data_futures = [ + executor.submit(load_experience, file_name) + for file_name in (observation_file, action_file, reward_file, terminal_file) + ] observations = data_futures[0].result() actions = data_futures[1].result() rewards = data_futures[2].result() diff --git a/reproductions/algorithms/atari/decision_transformer/compute_gamer_normalized_score.py b/reproductions/algorithms/atari/decision_transformer/compute_gamer_normalized_score.py index dff22548..e09bcf17 100644 --- a/reproductions/algorithms/atari/decision_transformer/compute_gamer_normalized_score.py +++ b/reproductions/algorithms/atari/decision_transformer/compute_gamer_normalized_score.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,7 +25,7 @@ def load_histogram_data(path, dtype=float): histogram = [] with open(path) as f: - tsv_reader = reader(f, delimiter='\t') + tsv_reader = reader(f, delimiter="\t") for i, row in enumerate(tsv_reader): if i == 0: continue @@ -54,8 +54,8 @@ def extract_iteration_num_and_returns(histogram_data): returns = [] for i in range(len(histogram_data)): data_row = histogram_data[i] - if 'returns' in data_row[0]: - iteration_nums.append(int(data_row[0].split(' ')[0])) + if "returns" in data_row[0]: + iteration_nums.append(int(data_row[0].split(" ")[0])) scores = data_row[1][0:].astype(float) returns.append(scores) @@ -82,8 +82,8 @@ def create_gamer_normalized_score_file(histograms, file_outdir, gamer_score): std_dev = np.std(normalized_r) * 100 scalar_results = {} - scalar_results['mean'] = mean - scalar_results['std_dev'] = std_dev + scalar_results["mean"] = mean + scalar_results["std_dev"] = std_dev writer.write_scalar(i, scalar_results) @@ -93,12 +93,12 @@ def compile_results(args): histograms = {} histogram_directories = list_all_directory_with(rootdir, args.eval_histogram_filename) - print(f'files: {histogram_directories}') + print(f"files: {histogram_directories}") for directory in histogram_directories: if args.resultdir not in str(directory): continue relative_dir = directory.relative_to(rootdir) - env_name = str(relative_dir).split('/')[0] + env_name = str(relative_dir).split("/")[0] histogram_file = directory / args.eval_histogram_filename print(f"found histogram file of env: {env_name} at: {histogram_file}") if histogram_file.exists(): @@ -112,16 +112,18 @@ def compile_results(args): create_gamer_normalized_score_file(histograms, file_outdir, args.gamer_score) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--outdir', type=str, required=True, help='output directory') - parser.add_argument('--resultdir', type=str, required=True, help='result directory') - parser.add_argument('--gamer-score', type=float, required=True, help='gamer score') - parser.add_argument('--eval-histogram-filename', - type=str, - default="evaluation_result_histogram.tsv", - help='eval result(histogram) filename') + parser.add_argument("--outdir", type=str, required=True, help="output directory") + parser.add_argument("--resultdir", type=str, required=True, help="result directory") + parser.add_argument("--gamer-score", type=float, required=True, help="gamer score") + parser.add_argument( + "--eval-histogram-filename", + type=str, + default="evaluation_result_histogram.tsv", + help="eval result(histogram) filename", + ) args = parser.parse_args() diff --git a/reproductions/algorithms/atari/decision_transformer/dataset_viewer.py b/reproductions/algorithms/atari/decision_transformer/dataset_viewer.py index e5e32c73..c5957b43 100755 --- a/reproductions/algorithms/atari/decision_transformer/dataset_viewer.py +++ b/reproductions/algorithms/atari/decision_transformer/dataset_viewer.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,30 +22,30 @@ def view_dataset(args): dataset_dir = pathlib.Path(args.dataset_dir) (o, a, r, t) = load_expert_dataset(dataset_dir) - print('observation data shape: ', o.shape) - print('action data shape: ', a.shape) - print('reward data shape: ', r.shape) - print('terminal data shape: ', t.shape) + print("observation data shape: ", o.shape) + print("action data shape: ", a.shape) + print("reward data shape: ", r.shape) + print("terminal data shape: ", t.shape) show_observations(o) def show_observations(observations): - print('press q to quit displaying observation') + print("press q to quit displaying observation") for observation in observations: - cv2.imshow('obs0', observation) - if cv2.waitKey(5) & 0xFF == ord('q'): + cv2.imshow("obs0", observation) + if cv2.waitKey(5) & 0xFF == ord("q"): break cv2.destroyAllWindows() def main(): parser = argparse.ArgumentParser() - parser.add_argument('--dataset-dir', type=str, default='datasets/Breakout/1/replay_logs') + parser.add_argument("--dataset-dir", type=str, default="datasets/Breakout/1/replay_logs") args = parser.parse_args() view_dataset(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/atari/decision_transformer/decision_transformer_reproduction.py b/reproductions/algorithms/atari/decision_transformer/decision_transformer_reproduction.py index 8fd80c20..6a62ce99 100755 --- a/reproductions/algorithms/atari/decision_transformer/decision_transformer_reproduction.py +++ b/reproductions/algorithms/atari/decision_transformer/decision_transformer_reproduction.py @@ -61,8 +61,9 @@ def get_learning_rate(self, iter): if self._processed_tokens < self._warmup_tokens: new_learning_rate *= float(self._processed_tokens) / max(1, self._warmup_tokens) else: - progress = float(self._processed_tokens - self._warmup_tokens) / \ - max(1, self._final_tokens - self._warmup_tokens) + progress = float(self._processed_tokens - self._warmup_tokens) / max( + 1, self._final_tokens - self._warmup_tokens + ) new_learning_rate *= max(0.1, 0.5 * (1.0 + np.cos(np.pi * progress))) return new_learning_rate @@ -75,15 +76,17 @@ def __init__(self, warmup_tokens, final_tokens) -> None: self._final_tokens = final_tokens def build_scheduler(self, env_info, algorithm_config, **kwargs) -> BaseLearningRateScheduler: - return AtariLearningRateScheduler(initial_learning_rate=algorithm_config.learning_rate, - context_length=algorithm_config.context_length, - batch_size=algorithm_config.batch_size, - warmup_tokens=self._warmup_tokens, - final_tokens=self._final_tokens) + return AtariLearningRateScheduler( + initial_learning_rate=algorithm_config.learning_rate, + context_length=algorithm_config.context_length, + batch_size=algorithm_config.batch_size, + warmup_tokens=self._warmup_tokens, + final_tokens=self._final_tokens, + ) def num_datasets(dataset_path): - return len(find_all_file_with_name(dataset_path, 'observation')) + return len(find_all_file_with_name(dataset_path, "observation")) def get_next_trajectory(dataset, trajectory_length): @@ -93,23 +96,23 @@ def get_next_trajectory(dataset, trajectory_length): actions = np.copy(a[:end]) rewards = np.copy(r[:end]) non_terminals = np.copy(1 - t[:end]) - if end+1 < len(s): - next_states = np.copy(s[1:end+1]) + if end + 1 < len(s): + next_states = np.copy(s[1 : end + 1]) else: state_shape = s[0].shape next_states = np.concatenate((s[1:end], np.zeros(shape=(1, *state_shape), dtype=np.uint8)), axis=0) info = [{} for _ in range(trajectory_length)] for timestep, d in enumerate(info): - d['rtg'] = np.sum(rewards[timestep:]) - d['timesteps'] = timestep + d["rtg"] = np.sum(rewards[timestep:]) + d["timesteps"] = timestep assert all([len(data) == len(states) for data in (actions, rewards, non_terminals, next_states, info)]) timesteps = len(info) - 1 return list(zip(states, actions, rewards, non_terminals, next_states, info)), timesteps def load_dataset(dataset_dir, buffer_size, context_length, trajectories_per_buffer): - print(f'start loading dataset from: {dataset_dir}') + print(f"start loading dataset from: {dataset_dir}") # NOTE: actual number of loaded trajectories could be less than maximum possible trajectories max_possible_trajectories = buffer_size // context_length buffer = MemoryEfficientAtariTrajectoryBuffer(num_trajectories=max_possible_trajectories) @@ -119,7 +122,7 @@ def load_dataset(dataset_dir, buffer_size, context_length, trajectories_per_buff dataset_seek_index = np.zeros(max_datasets, dtype=int) while len(buffer) < buffer_size: dataset_num = rl.random.drng.integers(low=0, high=max_datasets) - print(f'loading dataset: #{dataset_num}') + print(f"loading dataset: #{dataset_num}") appended_trajectories = 0 seek_index = dataset_seek_index[dataset_num] s, a, r, t = load_dataset_by_dataset_num(dataset_dir, dataset_num) @@ -129,7 +132,7 @@ def load_dataset(dataset_dir, buffer_size, context_length, trajectories_per_buff r = r[seek_index:] t = t[seek_index:] if len(s) < context_length: - print(f'all available trajectories in dataset #{dataset_num} has been loaded') + print(f"all available trajectories in dataset #{dataset_num} has been loaded") break done_indices, *_ = np.where(t == 1) @@ -140,58 +143,59 @@ def load_dataset(dataset_dir, buffer_size, context_length, trajectories_per_buff max_timesteps = max(max_timesteps, timesteps) buffer.append_trajectory(trajectory) appended_trajectories += 1 - print(f'loaded trajectories: {appended_trajectories}') + print(f"loaded trajectories: {appended_trajectories}") # Set next index seek_index = trajectory_length dataset_seek_index[dataset_num] += trajectory_length - print(f'loaded buffer size: {len(buffer)}') - print(f'buffer size: {len(buffer)}, max timestep: {max_timesteps}') + print(f"loaded buffer size: {len(buffer)}") + print(f"buffer size: {len(buffer)}, max timestep: {max_timesteps}") return buffer, max_timesteps def get_target_return(env_name): - if 'Breakout' in env_name: + if "Breakout" in env_name: return 90 - if 'Seaquest' in env_name: + if "Seaquest" in env_name: return 1150 - if 'Qbert' in env_name: + if "Qbert" in env_name: return 14000 - if 'Pong' in env_name: + if "Pong" in env_name: return 20 - raise NotImplementedError(f'No return is defined for: {env_name}') + raise NotImplementedError(f"No return is defined for: {env_name}") def get_batch_size(env_name): - if 'Breakout' in env_name or 'Seaquest' in env_name or 'Qbert' in env_name: + if "Breakout" in env_name or "Seaquest" in env_name or "Qbert" in env_name: return 128 - if 'Pong' in env_name: + if "Pong" in env_name: return 512 - raise NotImplementedError(f'No batch_size is defined for: {env_name}') + raise NotImplementedError(f"No batch_size is defined for: {env_name}") def get_context_length(env_name): - if 'Breakout' in env_name or 'Seaquest' in env_name or 'Qbert' in env_name: + if "Breakout" in env_name or "Seaquest" in env_name or "Qbert" in env_name: return 30 - if 'Pong' in env_name: + if "Pong" in env_name: return 50 - raise NotImplementedError(f'No context_length is defined for: {env_name}') + raise NotImplementedError(f"No context_length is defined for: {env_name}") def guess_dataset_path(env_name): - game = re.findall(r'[A-Z](?:[a-z]+|[A-Z]*(?=[A-Z]|$))', env_name)[0] - return f'datasets/{game}/1/replay_logs' + game = re.findall(r"[A-Z](?:[a-z]+|[A-Z]*(?=[A-Z]|$))", env_name)[0] + return f"datasets/{game}/1/replay_logs" def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) writer = FileWriter(outdir, "evaluation_result") - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 100, render=args.render, - use_gymnasium=args.use_gymnasium) + eval_env = build_atari_env( + args.env, test=True, seed=args.seed + 100, render=args.render, use_gymnasium=args.use_gymnasium + ) evaluator = EpisodicEvaluator(run_per_evaluation=10) evaluation_hook = H.EvaluationHook(eval_env, evaluator, timing=args.eval_timing, writer=writer) @@ -200,28 +204,28 @@ def run_training(args): dataset_path = args.dataset_path if args.dataset_path is not None else guess_dataset_path(args.env) context_length = args.context_length if args.context_length is not None else get_context_length(args.env) - dataset, max_timesteps = load_dataset(dataset_path, - args.buffer_size, - context_length, - args.trajectories_per_buffer) + dataset, max_timesteps = load_dataset(dataset_path, args.buffer_size, context_length, args.trajectories_per_buffer) final_tokens = 2 * len(dataset) * context_length * 3 target_return = args.target_return if args.target_return is not None else get_target_return(args.env) batch_size = args.batch_size if args.batch_size is not None else get_batch_size(args.env) - config = A.DecisionTransformerConfig(gpu_id=args.gpu, - context_length=context_length, - max_timesteps=max_timesteps, - batch_size=batch_size, - target_return=target_return) + config = A.DecisionTransformerConfig( + gpu_id=args.gpu, + context_length=context_length, + max_timesteps=max_timesteps, + batch_size=batch_size, + target_return=target_return, + ) env_info = EnvironmentInfo.from_env(eval_env) decision_transformer = A.DecisionTransformer( env_info, config=config, transformer_wd_solver_builder=AtariDecaySolverBuilder(), - lr_scheduler_builder=AtariLearningRateSchedulerBuilder(args.warmup_tokens, final_tokens)) + lr_scheduler_builder=AtariLearningRateSchedulerBuilder(args.warmup_tokens, final_tokens), + ) decision_transformer.set_hooks(hooks=[epoch_num_hook, save_snapshot_hook, evaluation_hook]) - print(f'total epochs: {args.total_epochs}') + print(f"total epochs: {args.total_epochs}") # decision transformer runs 1 epoch per iteration decision_transformer.train(dataset, total_iterations=args.total_epochs) @@ -230,43 +234,44 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError('Please specify the snapshot dir for showcasing') - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=args.render, - use_gymnasium=args.use_gymnasium) - config = {'gpu_id': args.gpu} + raise ValueError("Please specify the snapshot dir for showcasing") + eval_env = build_atari_env( + args.env, test=True, seed=args.seed + 200, render=args.render, use_gymnasium=args.use_gymnasium + ) + config = {"gpu_id": args.gpu} decision_transformer = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(decision_transformer, A.DecisionTransformer): - raise ValueError('Loaded snapshot is not trained with DecisionTransformer!') + raise ValueError("Loaded snapshot is not trained with DecisionTransformer!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) returns = evaluator(decision_transformer, eval_env) mean = np.mean(returns) std_dev = np.std(returns) median = np.median(returns) - logger.info('Evaluation results. mean: {} +/- std: {}, median: {}'.format(mean, std_dev, median)) + logger.info("Evaluation results. mean: {} +/- std: {}, median: {}".format(mean, std_dev, median)) def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='BreakoutNoFrameskip-v4') - parser.add_argument('--dataset-path', type=str, default=None) - parser.add_argument('--save-dir', type=str, default="") - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--total-epochs', type=int, default=5) - parser.add_argument('--trajectories-per-buffer', type=int, default=10) - parser.add_argument('--buffer-size', type=int, default=500000) - parser.add_argument('--batch-size', type=int, default=None) - parser.add_argument('--context-length', type=int, default=None) - parser.add_argument('--warmup-tokens', type=int, default=512*20) - parser.add_argument('--save_timing', type=int, default=1) - parser.add_argument('--eval_timing', type=int, default=1) - parser.add_argument('--showcase_runs', type=int, default=10) - parser.add_argument('--target-return', type=int, default=None) - parser.add_argument('--use-gymnasium', action='store_true') + parser.add_argument("--env", type=str, default="BreakoutNoFrameskip-v4") + parser.add_argument("--dataset-path", type=str, default=None) + parser.add_argument("--save-dir", type=str, default="") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--total-epochs", type=int, default=5) + parser.add_argument("--trajectories-per-buffer", type=int, default=10) + parser.add_argument("--buffer-size", type=int, default=500000) + parser.add_argument("--batch-size", type=int, default=None) + parser.add_argument("--context-length", type=int, default=None) + parser.add_argument("--warmup-tokens", type=int, default=512 * 20) + parser.add_argument("--save_timing", type=int, default=1) + parser.add_argument("--eval_timing", type=int, default=1) + parser.add_argument("--showcase_runs", type=int, default=10) + parser.add_argument("--target-return", type=int, default=None) + parser.add_argument("--use-gymnasium", action="store_true") args = parser.parse_args() @@ -276,5 +281,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/atari/dqn/dqn_reproduction.py b/reproductions/algorithms/atari/dqn/dqn_reproduction.py index 973ecfd6..3a3c07d0 100644 --- a/reproductions/algorithms/atari/dqn/dqn_reproduction.py +++ b/reproductions/algorithms/atari/dqn/dqn_reproduction.py @@ -34,14 +34,15 @@ def __call__(self, env_info, algorithm_config, **kwargs): def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) writer = FileWriter(outdir, "evaluation_result") - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 100, render=args.render, - use_gymnasium=args.use_gymnasium) + eval_env = build_atari_env( + args.env, test=True, seed=args.seed + 100, render=args.render, use_gymnasium=args.use_gymnasium + ) evaluator = TimestepEvaluator(num_timesteps=125000) evaluation_hook = H.EvaluationHook(eval_env, evaluator, timing=args.eval_timing, writer=writer) @@ -51,9 +52,7 @@ def run_training(args): train_env = build_atari_env(args.env, seed=args.seed, use_gymnasium=args.use_gymnasium) config = A.DQNConfig(gpu_id=args.gpu) - dqn = A.DQN(train_env, - config=config, - replay_buffer_builder=MemoryEfficientBufferBuilder()) + dqn = A.DQN(train_env, config=config, replay_buffer_builder=MemoryEfficientBufferBuilder()) dqn.set_hooks(hooks=[iteration_num_hook, save_snapshot_hook, evaluation_hook]) dqn.train(train_env, total_iterations=args.total_iterations) @@ -64,38 +63,37 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError( - 'Please specify the snapshot dir for showcasing') - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=args.render, - use_gymnasium=args.use_gymnasium) + raise ValueError("Please specify the snapshot dir for showcasing") + eval_env = build_atari_env( + args.env, test=True, seed=args.seed + 200, render=args.render, use_gymnasium=args.use_gymnasium + ) config = A.DQNConfig(gpu_id=args.gpu) dqn = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(dqn, A.DQN): - raise ValueError('Loaded snapshot is not trained with DQN!') + raise ValueError("Loaded snapshot is not trained with DQN!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) returns = evaluator(dqn, eval_env) mean = np.mean(returns) std_dev = np.std(returns) median = np.median(returns) - logger.info('Evaluation results. mean: {} +/- std: {}, median: {}'.format(mean, std_dev, median)) + logger.info("Evaluation results. mean: {} +/- std: {}, median: {}".format(mean, std_dev, median)) def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, - default='BreakoutNoFrameskip-v4') - parser.add_argument('--save-dir', type=str, default="") - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=50000000) - parser.add_argument('--save_timing', type=int, default=250000) - parser.add_argument('--eval_timing', type=int, default=250000) - parser.add_argument('--showcase_runs', type=int, default=10) - parser.add_argument('--use-gymnasium', action='store_true') + parser.add_argument("--env", type=str, default="BreakoutNoFrameskip-v4") + parser.add_argument("--save-dir", type=str, default="") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=50000000) + parser.add_argument("--save_timing", type=int, default=250000) + parser.add_argument("--eval_timing", type=int, default=250000) + parser.add_argument("--showcase_runs", type=int, default=10) + parser.add_argument("--use-gymnasium", action="store_true") args = parser.parse_args() @@ -105,5 +103,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/atari/drqn/drqn_reproduction.py b/reproductions/algorithms/atari/drqn/drqn_reproduction.py index 4dd9a3c9..e04f1e79 100644 --- a/reproductions/algorithms/atari/drqn/drqn_reproduction.py +++ b/reproductions/algorithms/atari/drqn/drqn_reproduction.py @@ -34,33 +34,37 @@ def __call__(self, env_info, algorithm_config, **kwargs): def run_training(args): - suffix = '-with-flicker' if args.flicker else '' - outdir = f'{args.env}{suffix}_results/seed-{args.seed}' + suffix = "-with-flicker" if args.flicker else "" + outdir = f"{args.env}{suffix}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) writer = FileWriter(outdir, "evaluation_result") flicker_probability = 0.5 if args.flicker else 0.0 - eval_env = build_atari_env(args.env, - test=True, - seed=args.seed + 100, - render=args.render, - frame_stack=False, - flicker_probability=flicker_probability, - use_gymnasium=args.use_gymnasium) + eval_env = build_atari_env( + args.env, + test=True, + seed=args.seed + 100, + render=args.render, + frame_stack=False, + flicker_probability=flicker_probability, + use_gymnasium=args.use_gymnasium, + ) evaluator = EpisodicEvaluator(run_per_evaluation=10) evaluation_hook = H.EvaluationHook(eval_env, evaluator, timing=args.eval_timing, writer=writer) iteration_num_hook = H.IterationNumHook(timing=100) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) - train_env = build_atari_env(args.env, - seed=args.seed, - render=args.render, - frame_stack=False, - flicker_probability=flicker_probability, - use_gymnasium=args.use_gymnasium) + train_env = build_atari_env( + args.env, + seed=args.seed, + render=args.render, + frame_stack=False, + flicker_probability=flicker_probability, + use_gymnasium=args.use_gymnasium, + ) config = A.DRQNConfig(gpu_id=args.gpu) drqn = A.DRQN(train_env, config=config, replay_buffer_builder=MemoryEfficientBufferBuilder()) @@ -74,43 +78,45 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError('Please specify the snapshot dir for showcasing') + raise ValueError("Please specify the snapshot dir for showcasing") flicker_probability = 0.5 if args.flicker else 0.0 - eval_env = build_atari_env(args.env, - test=True, - seed=args.seed + 200, - render=args.render, - frame_stack=False, - flicker_probability=flicker_probability, - use_gymnasium=args.use_gymnasium) + eval_env = build_atari_env( + args.env, + test=True, + seed=args.seed + 200, + render=args.render, + frame_stack=False, + flicker_probability=flicker_probability, + use_gymnasium=args.use_gymnasium, + ) config = A.DRQNConfig(gpu_id=args.gpu) drqn = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(drqn, A.DRQN): - raise ValueError('Loaded snapshot is not trained with DRQN!') + raise ValueError("Loaded snapshot is not trained with DRQN!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) returns = evaluator(drqn, eval_env) mean = np.mean(returns) std_dev = np.std(returns) median = np.median(returns) - logger.info('Evaluation results. mean: {} +/- std: {}, median: {}'.format(mean, std_dev, median)) + logger.info("Evaluation results. mean: {} +/- std: {}, median: {}".format(mean, std_dev, median)) def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='BreakoutNoFrameskip-v4') - parser.add_argument('--save-dir', type=str, default="") - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=10000000) - parser.add_argument('--save_timing', type=int, default=50000) - parser.add_argument('--eval_timing', type=int, default=50000) - parser.add_argument('--showcase_runs', type=int, default=10) - parser.add_argument('--flicker', action='store_true') - parser.add_argument('--use-gymnasium', action='store_true') + parser.add_argument("--env", type=str, default="BreakoutNoFrameskip-v4") + parser.add_argument("--save-dir", type=str, default="") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=10000000) + parser.add_argument("--save_timing", type=int, default=50000) + parser.add_argument("--eval_timing", type=int, default=50000) + parser.add_argument("--showcase_runs", type=int, default=10) + parser.add_argument("--flicker", action="store_true") + parser.add_argument("--use-gymnasium", action="store_true") args = parser.parse_args() @@ -120,5 +126,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/atari/icml2015trpo/icml2015trpo_reproduction.py b/reproductions/algorithms/atari/icml2015trpo/icml2015trpo_reproduction.py index d3f55da9..b1db356e 100644 --- a/reproductions/algorithms/atari/icml2015trpo/icml2015trpo_reproduction.py +++ b/reproductions/algorithms/atari/icml2015trpo/icml2015trpo_reproduction.py @@ -24,22 +24,22 @@ def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) writer = FileWriter(outdir, "evaluation_result") - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 100, render=args.render, - use_gymnasium=args.use_gymnasium) + eval_env = build_atari_env( + args.env, test=True, seed=args.seed + 100, render=args.render, use_gymnasium=args.use_gymnasium + ) evaluator = EpisodicEvaluator() evaluation_hook = H.EvaluationHook(eval_env, evaluator, timing=args.eval_timing, writer=writer) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) iteration_num_hook = H.IterationNumHook(timing=int(1e5)) - train_env = build_atari_env(args.env, seed=args.seed, render=args.render, - use_gymnasium=args.use_gymnasium) + train_env = build_atari_env(args.env, seed=args.seed, render=args.render, use_gymnasium=args.use_gymnasium) config = A.ICML2015TRPOConfig(gpu_id=args.gpu, gpu_batch_size=args.gpu_batch_size) trpo = A.ICML2015TRPO(train_env, config=config) @@ -54,13 +54,14 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError('Please specify the snapshot dir for showcasing') - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=args.render, - use_gymnasium=args.use_gymnasium) + raise ValueError("Please specify the snapshot dir for showcasing") + eval_env = build_atari_env( + args.env, test=True, seed=args.seed + 200, render=args.render, use_gymnasium=args.use_gymnasium + ) config = A.ICML2015TRPOConfig(gpu_id=args.gpu) trpo = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(trpo, A.ICML2015TRPO): - raise ValueError('Loaded snapshot is not trained with ICML2015TRPO') + raise ValueError("Loaded snapshot is not trained with ICML2015TRPO") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) evaluator(trpo, eval_env) @@ -68,19 +69,19 @@ def run_showcase(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='PongNoFrameskip-v4') - parser.add_argument('--save-dir', type=str, default="") - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--gpu_batch_size', type=int, default=2500) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=50000000) - parser.add_argument('--save_timing', type=int, default=100000) - parser.add_argument('--eval_timing', type=int, default=100000) - parser.add_argument('--showcase_runs', type=int, default=10) - parser.add_argument('--use-gymnasium', action='store_true') + parser.add_argument("--env", type=str, default="PongNoFrameskip-v4") + parser.add_argument("--save-dir", type=str, default="") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--gpu_batch_size", type=int, default=2500) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=50000000) + parser.add_argument("--save_timing", type=int, default=100000) + parser.add_argument("--eval_timing", type=int, default=100000) + parser.add_argument("--showcase_runs", type=int, default=10) + parser.add_argument("--use-gymnasium", action="store_true") args = parser.parse_args() @@ -90,5 +91,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/atari/iqn/iqn_reproduction.py b/reproductions/algorithms/atari/iqn/iqn_reproduction.py index 0afbbbf3..8c385d7b 100644 --- a/reproductions/algorithms/atari/iqn/iqn_reproduction.py +++ b/reproductions/algorithms/atari/iqn/iqn_reproduction.py @@ -31,16 +31,21 @@ def build_replay_buffer(self, env_info, algorithm_config, **kwargs): def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 100, render=args.render, - use_gymnasium=args.use_gymnasium) + eval_env = build_atari_env( + args.env, test=True, seed=args.seed + 100, render=args.render, use_gymnasium=args.use_gymnasium + ) evaluator = TimestepEvaluator(num_timesteps=125000) - evaluation_hook = H.EvaluationHook(eval_env, evaluator, timing=args.eval_timing, - writer=W.FileWriter(outdir=outdir, file_prefix='evaluation_result')) + evaluation_hook = H.EvaluationHook( + eval_env, + evaluator, + timing=args.eval_timing, + writer=W.FileWriter(outdir=outdir, file_prefix="evaluation_result"), + ) iteration_num_hook = H.IterationNumHook(timing=100) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) @@ -60,14 +65,14 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError( - 'Please specify the snapshot dir for showcasing') - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=args.render, - use_gymnasium=args.use_gymnasium) + raise ValueError("Please specify the snapshot dir for showcasing") + eval_env = build_atari_env( + args.env, test=True, seed=args.seed + 200, render=args.render, use_gymnasium=args.use_gymnasium + ) config = A.IQNConfig(gpu_id=args.gpu) iqn = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(iqn, A.IQN): - raise ValueError('Loaded snapshot is not trained with IQN!') + raise ValueError("Loaded snapshot is not trained with IQN!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) evaluator(iqn, eval_env) @@ -75,18 +80,18 @@ def run_showcase(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='BreakoutNoFrameskip-v4') - parser.add_argument('--save-dir', type=str, default="") - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=50000000) - parser.add_argument('--save_timing', type=int, default=250000) - parser.add_argument('--eval_timing', type=int, default=250000) - parser.add_argument('--showcase_runs', type=int, default=10) - parser.add_argument('--use-gymnasium', action='store_true') + parser.add_argument("--env", type=str, default="BreakoutNoFrameskip-v4") + parser.add_argument("--save-dir", type=str, default="") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=50000000) + parser.add_argument("--save_timing", type=int, default=250000) + parser.add_argument("--eval_timing", type=int, default=250000) + parser.add_argument("--showcase_runs", type=int, default=10) + parser.add_argument("--use-gymnasium", action="store_true") args = parser.parse_args() @@ -96,5 +101,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/atari/munchausen_dqn/munchausen_dqn_reproduction.py b/reproductions/algorithms/atari/munchausen_dqn/munchausen_dqn_reproduction.py index 8791daee..7c6beaea 100644 --- a/reproductions/algorithms/atari/munchausen_dqn/munchausen_dqn_reproduction.py +++ b/reproductions/algorithms/atari/munchausen_dqn/munchausen_dqn_reproduction.py @@ -34,7 +34,7 @@ def build_replay_buffer(self, env_info, algorithm_config, **kwargs): def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) @@ -59,38 +59,37 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError('Please specify the snapshot dir for showcasing') - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=args.render, - use_gymnasium=args.use_gymnasium) + raise ValueError("Please specify the snapshot dir for showcasing") + eval_env = build_atari_env( + args.env, test=True, seed=args.seed + 200, render=args.render, use_gymnasium=args.use_gymnasium + ) config = A.MunchausenDQNConfig(gpu_id=args.gpu) m_dqn = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(m_dqn, A.MunchausenDQN): - raise ValueError('Loaded snapshot is not trained with DQN!') + raise ValueError("Loaded snapshot is not trained with DQN!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) returns = evaluator(m_dqn, eval_env) mean = np.mean(returns) std_dev = np.std(returns) median = np.median(returns) - logger.info('Evaluation results. mean: {} +/- std: {}, median: {}'.format( - mean, std_dev, median)) + logger.info("Evaluation results. mean: {} +/- std: {}, median: {}".format(mean, std_dev, median)) def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, - default='BreakoutNoFrameskip-v4') - parser.add_argument('--save-dir', type=str, default="") - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=50000000) - parser.add_argument('--save_timing', type=int, default=250000) - parser.add_argument('--eval_timing', type=int, default=250000) - parser.add_argument('--showcase_runs', type=int, default=10) - parser.add_argument('--use-gymnasium', action='store_true') + parser.add_argument("--env", type=str, default="BreakoutNoFrameskip-v4") + parser.add_argument("--save-dir", type=str, default="") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=50000000) + parser.add_argument("--save_timing", type=int, default=250000) + parser.add_argument("--eval_timing", type=int, default=250000) + parser.add_argument("--showcase_runs", type=int, default=10) + parser.add_argument("--use-gymnasium", action="store_true") args = parser.parse_args() @@ -100,5 +99,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/atari/munchausen_iqn/munchausen_iqn_reproduction.py b/reproductions/algorithms/atari/munchausen_iqn/munchausen_iqn_reproduction.py index 3e998f9f..0216200c 100644 --- a/reproductions/algorithms/atari/munchausen_iqn/munchausen_iqn_reproduction.py +++ b/reproductions/algorithms/atari/munchausen_iqn/munchausen_iqn_reproduction.py @@ -34,7 +34,7 @@ def build_replay_buffer(self, env_info, algorithm_config, **kwargs): def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) @@ -59,37 +59,37 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError('Please specify the snapshot dir for showcasing') - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=args.render, - use_gymnasium=args.use_gymnasium) + raise ValueError("Please specify the snapshot dir for showcasing") + eval_env = build_atari_env( + args.env, test=True, seed=args.seed + 200, render=args.render, use_gymnasium=args.use_gymnasium + ) config = A.MunchausenIQNConfig(gpu_id=args.gpu) m_iqn = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(m_iqn, A.MunchausenIQN): - raise ValueError('Loaded snapshot is not trained with IQN!') + raise ValueError("Loaded snapshot is not trained with IQN!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) returns = evaluator(m_iqn, eval_env) mean = np.mean(returns) std_dev = np.std(returns) median = np.median(returns) - logger.info('Evaluation results. mean: {} +/- std: {}, median: {}'.format(mean, std_dev, median)) + logger.info("Evaluation results. mean: {} +/- std: {}, median: {}".format(mean, std_dev, median)) def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, - default='BreakoutNoFrameskip-v4') - parser.add_argument('--save-dir', type=str, default="") - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=50000000) - parser.add_argument('--save_timing', type=int, default=250000) - parser.add_argument('--eval_timing', type=int, default=250000) - parser.add_argument('--showcase_runs', type=int, default=10) - parser.add_argument('--use-gymnasium', action='store_true') + parser.add_argument("--env", type=str, default="BreakoutNoFrameskip-v4") + parser.add_argument("--save-dir", type=str, default="") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=50000000) + parser.add_argument("--save_timing", type=int, default=250000) + parser.add_argument("--eval_timing", type=int, default=250000) + parser.add_argument("--showcase_runs", type=int, default=10) + parser.add_argument("--use-gymnasium", action="store_true") args = parser.parse_args() @@ -99,5 +99,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/atari/ppo/ppo_reproduction.py b/reproductions/algorithms/atari/ppo/ppo_reproduction.py index 9a043834..13f90564 100644 --- a/reproductions/algorithms/atari/ppo/ppo_reproduction.py +++ b/reproductions/algorithms/atari/ppo/ppo_reproduction.py @@ -24,29 +24,35 @@ def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 100, render=args.render, - use_gymnasium=args.use_gymnasium) + eval_env = build_atari_env( + args.env, test=True, seed=args.seed + 100, render=args.render, use_gymnasium=args.use_gymnasium + ) evaluator = TimestepEvaluator(num_timesteps=125000) evaluation_hook = H.EvaluationHook( - eval_env, evaluator, timing=args.evaluate_timing, writer=W.FileWriter(outdir=outdir, - file_prefix='evaluation_result')) + eval_env, + evaluator, + timing=args.evaluate_timing, + writer=W.FileWriter(outdir=outdir, file_prefix="evaluation_result"), + ) iteration_num_hook = H.IterationNumHook(timing=100) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) actor_num = 8 train_env = build_atari_env(args.env, seed=args.seed, render=args.render, use_gymnasium=args.use_gymnasium) - config = A.PPOConfig(gpu_id=args.gpu, - actor_num=actor_num, - total_timesteps=args.total_iterations, - timelimit_as_terminal=True, - seed=args.seed, - preprocess_state=False) + config = A.PPOConfig( + gpu_id=args.gpu, + actor_num=actor_num, + total_timesteps=args.total_iterations, + timelimit_as_terminal=True, + seed=args.seed, + preprocess_state=False, + ) ppo = A.PPO(train_env, config=config) hooks = [iteration_num_hook, save_snapshot_hook, evaluation_hook] @@ -60,17 +66,14 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError( - 'Please specify the snapshot dir for showcasing') - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=args.render, - use_gymnasium=args.use_gymnasium) - config = A.PPOConfig(gpu_id=args.gpu, - timelimit_as_terminal=True, - seed=args.seed, - preprocess_state=False) + raise ValueError("Please specify the snapshot dir for showcasing") + eval_env = build_atari_env( + args.env, test=True, seed=args.seed + 200, render=args.render, use_gymnasium=args.use_gymnasium + ) + config = A.PPOConfig(gpu_id=args.gpu, timelimit_as_terminal=True, seed=args.seed, preprocess_state=False) ppo = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(ppo, A.PPO): - raise ValueError('Loaded snapshot is not trained with PPO!') + raise ValueError("Loaded snapshot is not trained with PPO!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) evaluator(ppo, eval_env) @@ -78,18 +81,18 @@ def run_showcase(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='BreakoutNoFrameskip-v4') - parser.add_argument('--save-dir', type=str, default="") - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=10000000) - parser.add_argument('--save_timing', type=int, default=250000) - parser.add_argument('--evaluate_timing', type=int, default=250000) - parser.add_argument('--showcase_runs', type=int, default=10) - parser.add_argument('--use-gymnasium', action='store_true') + parser.add_argument("--env", type=str, default="BreakoutNoFrameskip-v4") + parser.add_argument("--save-dir", type=str, default="") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=10000000) + parser.add_argument("--save_timing", type=int, default=250000) + parser.add_argument("--evaluate_timing", type=int, default=250000) + parser.add_argument("--showcase_runs", type=int, default=10) + parser.add_argument("--use-gymnasium", action="store_true") args = parser.parse_args() @@ -99,5 +102,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/atari/qrdqn/qrdqn_reproduction.py b/reproductions/algorithms/atari/qrdqn/qrdqn_reproduction.py index 0242f4fa..f1ec91be 100644 --- a/reproductions/algorithms/atari/qrdqn/qrdqn_reproduction.py +++ b/reproductions/algorithms/atari/qrdqn/qrdqn_reproduction.py @@ -31,16 +31,21 @@ def __call__(self, env_info, algorithm_config, **kwargs): def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 100, render=args.render, - use_gymnasium=args.use_gymnasium) + eval_env = build_atari_env( + args.env, test=True, seed=args.seed + 100, render=args.render, use_gymnasium=args.use_gymnasium + ) evaluator = TimestepEvaluator(num_timesteps=125000) - evaluation_hook = H.EvaluationHook(eval_env, evaluator, timing=args.eval_timing, - writer=W.FileWriter(outdir=outdir, file_prefix='evaluation_result')) + evaluation_hook = H.EvaluationHook( + eval_env, + evaluator, + timing=args.eval_timing, + writer=W.FileWriter(outdir=outdir, file_prefix="evaluation_result"), + ) iteration_num_hook = H.IterationNumHook(timing=100) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) @@ -60,14 +65,14 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError( - 'Please specify the snapshot dir for showcasing') - eval_env = build_atari_env(args.env, test=True, seed=args.seed + 200, render=args.render, - use_gymnasium=args.use_gymnasium) + raise ValueError("Please specify the snapshot dir for showcasing") + eval_env = build_atari_env( + args.env, test=True, seed=args.seed + 200, render=args.render, use_gymnasium=args.use_gymnasium + ) config = A.QRDQNConfig(gpu_id=args.gpu) qrdqn = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(qrdqn, A.QRDQN): - raise ValueError('Loaded snapshot is not trained with QRDQN!') + raise ValueError("Loaded snapshot is not trained with QRDQN!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) evaluator(qrdqn, eval_env) @@ -75,18 +80,18 @@ def run_showcase(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='BreakoutNoFrameskip-v4') - parser.add_argument('--save-dir', type=str, default="") - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=50000000) - parser.add_argument('--save_timing', type=int, default=250000) - parser.add_argument('--eval_timing', type=int, default=250000) - parser.add_argument('--showcase_runs', type=int, default=10) - parser.add_argument('--use-gymnasium', action='store_true') + parser.add_argument("--env", type=str, default="BreakoutNoFrameskip-v4") + parser.add_argument("--save-dir", type=str, default="") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=50000000) + parser.add_argument("--save_timing", type=int, default=250000) + parser.add_argument("--eval_timing", type=int, default=250000) + parser.add_argument("--showcase_runs", type=int, default=10) + parser.add_argument("--use-gymnasium", action="store_true") args = parser.parse_args() @@ -96,5 +101,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/atari/rainbow/rainbow_reproduction.py b/reproductions/algorithms/atari/rainbow/rainbow_reproduction.py index f6a1c7f2..7ccf4502 100644 --- a/reproductions/algorithms/atari/rainbow/rainbow_reproduction.py +++ b/reproductions/algorithms/atari/rainbow/rainbow_reproduction.py @@ -23,8 +23,11 @@ from nnabla_rl.algorithms import RainbowConfig from nnabla_rl.builders import ExplorerBuilder, ModelBuilder, ReplayBufferBuilder from nnabla_rl.environments.environment_info import EnvironmentInfo -from nnabla_rl.models import (RainbowNoDuelValueDistributionFunction, RainbowNoNoisyValueDistributionFunction, - ValueDistributionFunction) +from nnabla_rl.models import ( + RainbowNoDuelValueDistributionFunction, + RainbowNoNoisyValueDistributionFunction, + ValueDistributionFunction, +) from nnabla_rl.utils import serializers from nnabla_rl.utils.evaluator import EpisodicEvaluator, TimestepEvaluator from nnabla_rl.utils.reproductions import build_atari_env, set_global_seed @@ -33,12 +36,14 @@ class MemoryEfficientPrioritizedBufferBuilder(ReplayBufferBuilder): def build_replay_buffer(self, env_info, algorithm_config, **kwargs): # Some of hyper-parameters was taken from: https://github.com/deepmind/dqn_zoo - return RB.ProportionalPrioritizedAtariBuffer(capacity=algorithm_config.replay_buffer_size, - alpha=algorithm_config.alpha, - beta=algorithm_config.beta, - betasteps=algorithm_config.betasteps, - error_clip=(-100, 100), - normalization_method="batch_max") + return RB.ProportionalPrioritizedAtariBuffer( + capacity=algorithm_config.replay_buffer_size, + alpha=algorithm_config.alpha, + beta=algorithm_config.beta, + betasteps=algorithm_config.betasteps, + error_clip=(-100, 100), + normalization_method="batch_max", + ) class MemoryEfficientNonPrioritizedBufferBuilder(ReplayBufferBuilder): @@ -47,29 +52,29 @@ def build_replay_buffer(self, env_info, algorithm_config, **kwargs): class NoDuelValueDistributionFunctionBuilder(ModelBuilder[ValueDistributionFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: RainbowConfig, - **kwargs) -> ValueDistributionFunction: - return RainbowNoDuelValueDistributionFunction(scope_name, - env_info.action_dim, - algorithm_config.num_atoms, - algorithm_config.v_min, - algorithm_config.v_max) + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: RainbowConfig, + **kwargs, + ) -> ValueDistributionFunction: + return RainbowNoDuelValueDistributionFunction( + scope_name, env_info.action_dim, algorithm_config.num_atoms, algorithm_config.v_min, algorithm_config.v_max + ) class NoNoisyValueDistributionFunctionBuilder(ModelBuilder[ValueDistributionFunction]): - def build_model(self, # type: ignore[override] - scope_name: str, - env_info: EnvironmentInfo, - algorithm_config: RainbowConfig, - **kwargs) -> ValueDistributionFunction: - return RainbowNoNoisyValueDistributionFunction(scope_name, - env_info.action_dim, - algorithm_config.num_atoms, - algorithm_config.v_min, - algorithm_config.v_max) + def build_model( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: RainbowConfig, + **kwargs, + ) -> ValueDistributionFunction: + return RainbowNoNoisyValueDistributionFunction( + scope_name, env_info.action_dim, algorithm_config.num_atoms, algorithm_config.v_min, algorithm_config.v_max + ) class EpsilonGreedyExplorerBuilder(ExplorerBuilder): @@ -79,13 +84,14 @@ def build_explorer(self, env_info, algorithm_config, algorithm, **kwargs): initial_step_num=algorithm.iteration_num, initial_epsilon=algorithm_config.initial_epsilon, final_epsilon=algorithm_config.final_epsilon, - max_explore_steps=algorithm_config.max_explore_steps + max_explore_steps=algorithm_config.max_explore_steps, ) explorer = EE.LinearDecayEpsilonGreedyExplorer( greedy_action_selector=algorithm._exploration_action_selector, random_action_selector=algorithm._random_action_selector, env_info=env_info, - config=explorer_config) + config=explorer_config, + ) return explorer @@ -96,37 +102,35 @@ def setup_no_double_rainbow(train_env, args): def setup_no_prior_rainbow(train_env, args): config = A.RainbowConfig(gpu_id=args.gpu) - return A.Rainbow(train_env, - config=config, - replay_buffer_builder=MemoryEfficientNonPrioritizedBufferBuilder()) + return A.Rainbow(train_env, config=config, replay_buffer_builder=MemoryEfficientNonPrioritizedBufferBuilder()) def setup_no_duel_rainbow(train_env, args): config = A.RainbowConfig(gpu_id=args.gpu) - return A.Rainbow(train_env, - config=config, - value_distribution_builder=NoDuelValueDistributionFunctionBuilder(), - replay_buffer_builder=MemoryEfficientPrioritizedBufferBuilder()) + return A.Rainbow( + train_env, + config=config, + value_distribution_builder=NoDuelValueDistributionFunctionBuilder(), + replay_buffer_builder=MemoryEfficientPrioritizedBufferBuilder(), + ) def setup_no_n_steps_rainbow(train_env, args): config = A.RainbowConfig(gpu_id=args.gpu, num_steps=1) - return A.Rainbow(train_env, - config=config, - replay_buffer_builder=MemoryEfficientPrioritizedBufferBuilder()) + return A.Rainbow(train_env, config=config, replay_buffer_builder=MemoryEfficientPrioritizedBufferBuilder()) def setup_no_noisy_rainbow(train_env, args): - config = A.RainbowConfig(gpu_id=args.gpu, - initial_epsilon=1.0, - final_epsilon=0.01, - test_epsilon=0.001, - max_explore_steps=250000 // 4) - return A.Rainbow(train_env, - config=config, - value_distribution_builder=NoNoisyValueDistributionFunctionBuilder(), - replay_buffer_builder=MemoryEfficientPrioritizedBufferBuilder(), - explorer_builder=EpsilonGreedyExplorerBuilder()) + config = A.RainbowConfig( + gpu_id=args.gpu, initial_epsilon=1.0, final_epsilon=0.01, test_epsilon=0.001, max_explore_steps=250000 // 4 + ) + return A.Rainbow( + train_env, + config=config, + value_distribution_builder=NoNoisyValueDistributionFunctionBuilder(), + replay_buffer_builder=MemoryEfficientPrioritizedBufferBuilder(), + explorer_builder=EpsilonGreedyExplorerBuilder(), + ) def setup_full_rainbow(train_env, args): @@ -182,34 +186,43 @@ def load_rainbow(env, args): def run_training(args): suffix = suffix_from_algorithm_options(args) - outdir = f'{args.env}{suffix}_results/seed-{args.seed}' + outdir = f"{args.env}{suffix}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) max_frames_per_episode = 30 * 60 * 60 # 30 min * 60 seconds * 60 fps - eval_env = build_atari_env(args.env, - test=True, seed=args.seed + 100, - render=args.render, - max_frames_per_episode=max_frames_per_episode, - use_gymnasium=args.use_gymnasium) + eval_env = build_atari_env( + args.env, + test=True, + seed=args.seed + 100, + render=args.render, + max_frames_per_episode=max_frames_per_episode, + use_gymnasium=args.use_gymnasium, + ) evaluator = TimestepEvaluator(num_timesteps=125000) - evaluation_hook = H.EvaluationHook(eval_env, - evaluator, - timing=args.eval_timing, - writer=W.FileWriter(outdir=outdir, file_prefix='evaluation_result')) + evaluation_hook = H.EvaluationHook( + eval_env, + evaluator, + timing=args.eval_timing, + writer=W.FileWriter(outdir=outdir, file_prefix="evaluation_result"), + ) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) iteration_num_hook = H.IterationNumHook(timing=100) - train_env = build_atari_env(args.env, seed=args.seed, render=args.render, - max_frames_per_episode=max_frames_per_episode, - use_gymnasium=args.use_gymnasium) + train_env = build_atari_env( + args.env, + seed=args.seed, + render=args.render, + max_frames_per_episode=max_frames_per_episode, + use_gymnasium=args.use_gymnasium, + ) rainbow = setup_rainbow(train_env, args) hooks = [iteration_num_hook, save_snapshot_hook, evaluation_hook] rainbow.set_hooks(hooks) - print(f'current Rainbow config: {rainbow._config}') + print(f"current Rainbow config: {rainbow._config}") rainbow.train_online(train_env, total_iterations=args.total_iterations) eval_env.close() @@ -218,18 +231,19 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError( - 'Please specify the snapshot dir for showcasing') + raise ValueError("Please specify the snapshot dir for showcasing") max_frames_per_episode = 30 * 60 * 60 # 30 min * 60 seconds * 60 fps - eval_env = build_atari_env(args.env, - test=True, - seed=args.seed + 200, - render=args.render, - max_frames_per_episode=max_frames_per_episode, - use_gymnasium=args.use_gymnasium) + eval_env = build_atari_env( + args.env, + test=True, + seed=args.seed + 200, + render=args.render, + max_frames_per_episode=max_frames_per_episode, + use_gymnasium=args.use_gymnasium, + ) rainbow = load_rainbow(eval_env, args) if not isinstance(rainbow, A.Rainbow): - raise ValueError('Loaded snapshot is not trained with Rainbow!') + raise ValueError("Loaded snapshot is not trained with Rainbow!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) evaluator(rainbow, eval_env) @@ -237,27 +251,27 @@ def run_showcase(args): def add_algorithm_options(parser): group = parser.add_mutually_exclusive_group() - group.add_argument('--no-double', action='store_true') - group.add_argument('--no-prior', action='store_true') - group.add_argument('--no-n-steps', action='store_true') - group.add_argument('--no-noisy', action='store_true') - group.add_argument('--no-duel', action='store_true') + group.add_argument("--no-double", action="store_true") + group.add_argument("--no-prior", action="store_true") + group.add_argument("--no-n-steps", action="store_true") + group.add_argument("--no-noisy", action="store_true") + group.add_argument("--no-duel", action="store_true") def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='BreakoutNoFrameskip-v4') - parser.add_argument('--save-dir', type=str, default="") - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=50000000) - parser.add_argument('--save_timing', type=int, default=250000) - parser.add_argument('--eval_timing', type=int, default=250000) - parser.add_argument('--showcase_runs', type=int, default=10) - parser.add_argument('--use-gymnasium', action='store_true') + parser.add_argument("--env", type=str, default="BreakoutNoFrameskip-v4") + parser.add_argument("--save-dir", type=str, default="") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=50000000) + parser.add_argument("--save_timing", type=int, default=250000) + parser.add_argument("--eval_timing", type=int, default=250000) + parser.add_argument("--showcase_runs", type=int, default=10) + parser.add_argument("--use-gymnasium", action="store_true") add_algorithm_options(parser) args = parser.parse_args() @@ -268,5 +282,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/atari/rainbow/summarize_train_all_results.py b/reproductions/algorithms/atari/rainbow/summarize_train_all_results.py index 4e980706..6cbba025 100644 --- a/reproductions/algorithms/atari/rainbow/summarize_train_all_results.py +++ b/reproductions/algorithms/atari/rainbow/summarize_train_all_results.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,25 +18,25 @@ def find_tsvs(result_dir: pathlib.Path): - print(f'result dir {result_dir}') + print(f"result dir {result_dir}") tsvs = [] for f in result_dir.iterdir(): if f.is_dir(): tsvs.extend(find_tsvs(f)) - elif f.name == 'evaluation_result_average_scalar.tsv': + elif f.name == "evaluation_result_average_scalar.tsv": tsvs.append(f) tsvs.sort() return tsvs def extract_label(tsv_path: pathlib.Path): - pattern = r'.*-v4[- ](.*)_results' + pattern = r".*-v4[- ](.*)_results" regex = re.compile(pattern) - env_method = str(tsv_path).split('/')[2] + env_method = str(tsv_path).split("/")[2] print(env_method) result = regex.findall(env_method) if len(result) == 0: - return 'rainbow' + return "rainbow" else: return result[0] @@ -61,16 +61,22 @@ def main(): tsvroot = tsvs[0].parent.parent tsvpaths = [str(tsv) for tsv in tsvs] tsvlabels = labels - print(f'tsvroot: {tsvroot}') - print(f'tsvs: {tsvs}') - print(f'labels: {labels}') + print(f"tsvroot: {tsvroot}") + print(f"tsvs: {tsvs}") + print(f"labels: {labels}") - command = ['plot_result', '--tsvpaths'] + tsvpaths + \ - ['--tsvlabels'] + tsvlabels + \ - ['--no-stddev'] + \ - ['--smooth-k'] + ['10'] + \ - ['--outdir'] + [f'{str(tsvroot)}'] - print(f'command: {command}') + command = ( + ["plot_result", "--tsvpaths"] + + tsvpaths + + ["--tsvlabels"] + + tsvlabels + + ["--no-stddev"] + + ["--smooth-k"] + + ["10"] + + ["--outdir"] + + [f"{str(tsvroot)}"] + ) + print(f"command: {command}") subprocess.run(command) diff --git a/reproductions/algorithms/classic_control/ddp/ddp_pendulum.py b/reproductions/algorithms/classic_control/ddp/ddp_pendulum.py index 091844bd..a72a3477 100644 --- a/reproductions/algorithms/classic_control/ddp/ddp_pendulum.py +++ b/reproductions/algorithms/classic_control/ddp/ddp_pendulum.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -44,8 +44,9 @@ def __init__(self): self.m = 1.0 self.length = 1.0 - def next_state(self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False) \ - -> Tuple[np.ndarray, Dict[str, Any]]: + def next_state( + self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False + ) -> Tuple[np.ndarray, Dict[str, Any]]: if batched: raise NotImplementedError # x.shape = (state_dim, 1) and u.shape = (input_dim, 1) @@ -63,18 +64,19 @@ def gradient(self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False) th = x.flatten()[0] # th := theta Fx = np.zeros((self.state_dim(), self.state_dim())) - Fx[0, 0] = 1. - Fx[0, 1] = 1. * self.dt + Fx[0, 0] = 1.0 + Fx[0, 1] = 1.0 * self.dt Fx[1, 0] = 3 * self.g / (2 * self.length) * np.cos(th) * self.dt - Fx[1, 1] = 1. + Fx[1, 1] = 1.0 Fu = np.zeros((self.state_dim(), self.action_dim())) Fu[1, 0] = (3.0 / (self.m * self.length**2)) * self.dt return Fx, Fu - def hessian(self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False) \ - -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + def hessian( + self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: state_dim = self.state_dim() action_dim = self.action_dim() Fxx = np.zeros((state_dim, state_dim, state_dim)) @@ -97,7 +99,7 @@ def action_dim(self) -> int: class PendulumCostFunction(CostFunction): def __init__(self): super().__init__() - self._weight_u = 1. + self._weight_u = 1.0 self._weight_th = 2.5 self._weight_thdot = 0.5 @@ -114,18 +116,23 @@ def evaluate( if final_state: return state_cost else: - act_cost = self._weight_u * (u.flatten()**2) + act_cost = self._weight_u * (u.flatten() ** 2) return state_cost + act_cost def gradient( - self, x: np.ndarray, u: Optional[np.ndarray], t: int, final_state: bool = False, batched: bool = False, + self, + x: np.ndarray, + u: Optional[np.ndarray], + t: int, + final_state: bool = False, + batched: bool = False, ) -> Tuple[np.ndarray, Optional[np.ndarray]]: if batched: raise NotImplementedError # x.shape = (state_dim, 1) and u.shape = (input_dim, 1) th = x.flatten()[0] # th := theta thdot = x.flatten()[1] - Cx = np.array([[2 * self._weight_th * self._angle_normalize(th)], [2. * self._weight_thdot * thdot]]) + Cx = np.array([[2 * self._weight_th * self._angle_normalize(th)], [2.0 * self._weight_thdot * thdot]]) if final_state: return (Cx, None) else: @@ -181,7 +188,7 @@ def run_control(args): start = time.time() improved_trajectory, trajectory_info = ilqr.compute_trajectory(initial_trajectory) end = time.time() - print(f'optimization time: {end - start} [s]') + print(f"optimization time: {end - start} [s]") next_state, reward, done, info = env.step(improved_trajectory[0][1]) total_reward += reward @@ -198,13 +205,13 @@ def run_control(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--T', type=int, default=25) - parser.add_argument('--num_episodes', type=int, default=10) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--T", type=int, default=25) + parser.add_argument("--num_episodes", type=int, default=10) args = parser.parse_args() run_control(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/classic_control/ilqr/ilqr_acrobot.py b/reproductions/algorithms/classic_control/ilqr/ilqr_acrobot.py index e9b4f99f..93560cb9 100644 --- a/reproductions/algorithms/classic_control/ilqr/ilqr_acrobot.py +++ b/reproductions/algorithms/classic_control/ilqr/ilqr_acrobot.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,9 +33,7 @@ class ContinuousAcrobot(AcrobotEnv): def __init__(self): super().__init__() self.dt = 0.02 - high = np.array( - [np.pi, np.pi, self.MAX_VEL_1, self.MAX_VEL_2], dtype=np.float32 - ) + high = np.array([np.pi, np.pi, self.MAX_VEL_1, self.MAX_VEL_2], dtype=np.float32) low = -high self.observation_space = spaces.Box(low=low, high=high, dtype=np.float32) self.action_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,)) @@ -86,8 +84,9 @@ def __init__(self): self.max_link1_dtheta = 4 * np.pi self.max_link2_dtheta = 9 * np.pi - def next_state(self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False) \ - -> Tuple[np.ndarray, Dict[str, Any]]: + def next_state( + self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False + ) -> Tuple[np.ndarray, Dict[str, Any]]: if batched: raise NotImplementedError x = x.flatten() @@ -145,19 +144,23 @@ def non_wrapped_next_state(x, u, t): fx_du = non_wrapped_next_state(x, du, t).flatten() (fx1_du, fx2_du, fx3_du, fx4_du) = fx_du[0], fx_du[1], fx_du[2], fx_du[3] - Fx = np.asarray([[(fx1_dx1 - fx1) / eps, (fx1_dx2 - fx1) / eps, (fx1_dx3 - fx1) / eps, (fx1_dx4 - fx1) / eps], - [(fx2_dx1 - fx2) / eps, (fx2_dx2 - fx2) / eps, (fx2_dx3 - fx2) / eps, (fx2_dx4 - fx2) / eps], - [(fx3_dx1 - fx3) / eps, (fx3_dx2 - fx3) / eps, (fx3_dx3 - fx3) / eps, (fx3_dx4 - fx3) / eps], - [(fx4_dx1 - fx4) / eps, (fx4_dx2 - fx4) / eps, (fx4_dx3 - fx4) / eps, (fx4_dx4 - fx4) / eps]]) - Fu = np.asarray([[(fx1_du - fx1) / eps], - [(fx2_du - fx2) / eps], - [(fx3_du - fx3) / eps], - [(fx4_du - fx4) / eps]]) + Fx = np.asarray( + [ + [(fx1_dx1 - fx1) / eps, (fx1_dx2 - fx1) / eps, (fx1_dx3 - fx1) / eps, (fx1_dx4 - fx1) / eps], + [(fx2_dx1 - fx2) / eps, (fx2_dx2 - fx2) / eps, (fx2_dx3 - fx2) / eps, (fx2_dx4 - fx2) / eps], + [(fx3_dx1 - fx3) / eps, (fx3_dx2 - fx3) / eps, (fx3_dx3 - fx3) / eps, (fx3_dx4 - fx3) / eps], + [(fx4_dx1 - fx4) / eps, (fx4_dx2 - fx4) / eps, (fx4_dx3 - fx4) / eps, (fx4_dx4 - fx4) / eps], + ] + ) + Fu = np.asarray( + [[(fx1_du - fx1) / eps], [(fx2_du - fx2) / eps], [(fx3_du - fx3) / eps], [(fx4_du - fx4) / eps]] + ) return Fx, Fu - def hessian(self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False) \ - -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + def hessian( + self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: raise NotImplementedError def state_dim(self) -> int: @@ -187,8 +190,9 @@ def _dsdt(self, x_and_u): + (self.m1 * self.lc1 + self.m2 * self.l1) * self.g * np.cos(theta1 - np.pi / 2) + phi2 ) - ddtheta2 = (u + d2 / d1 * phi1 - self.m2 * self.l1 * self.lc2 * dtheta1**2 * - np.sin(theta2) - phi2) / (self.m2 * self.lc2**2 + self.I2 - d2**2 / d1) + ddtheta2 = (u + d2 / d1 * phi1 - self.m2 * self.l1 * self.lc2 * dtheta1**2 * np.sin(theta2) - phi2) / ( + self.m2 * self.lc2**2 + self.I2 - d2**2 / d1 + ) ddtheta1 = -(d2 * ddtheta2 + phi1) / d1 return (dtheta1, dtheta2, ddtheta1, ddtheta2, 0.0) @@ -204,10 +208,9 @@ def __init__(self, T): self._x_target = np.asarray([[np.pi], [0.0], [0.0], [0.0]], dtype=np.float32) self.T = T self.Q = np.zeros(shape=(4, 4)) - self.Q_final = np.asarray([[weight_th1, 0, 0, 0], - [0, weight_th2, 0, 0], - [0, 0, weight_th1dot, 0], - [0, 0, 0, weight_th2dot]]) + self.Q_final = np.asarray( + [[weight_th1, 0, 0, 0], [0, weight_th2, 0, 0], [0, 0, weight_th1dot, 0], [0, 0, 0, weight_th2dot]] + ) def evaluate( self, x: np.ndarray, u: Optional[np.ndarray], t: int, final_state: bool = False, batched: bool = False @@ -229,7 +232,7 @@ def evaluate( # x.shape = (state_dim, 1) and u.shape = (input_dim, 1) # error vector state_cost = (x).T.dot(self.Q).dot(x) - act_cost = self._weight_u * (u.flatten()**2) + act_cost = self._weight_u * (u.flatten() ** 2) return state_cost + act_cost def gradient( @@ -305,7 +308,7 @@ def run_control(args): start = time.time() improved_trajectory, trajectory_info = ilqr.compute_trajectory(initial_trajectory) end = time.time() - print(f'optimization time: {end - start} [s]') + print(f"optimization time: {end - start} [s]") u = improved_trajectory[0][1] next_state, reward, done, info = env.step(u) @@ -316,11 +319,11 @@ def run_control(args): theta2 = state[1] / np.pi * 180.0 dtheta1 = state[2] / np.pi * 180.0 dtheta2 = state[3] / np.pi * 180.0 - print(f'control input torque: {u}') - print(f'theta1: {theta1} [deg]') - print(f'theta2: {theta2} [deg]') - print(f'dtheta1: {dtheta1} [deg/s]') - print(f'dtheta2: {dtheta2} [deg/s]') + print(f"control input torque: {u}") + print(f"theta1: {theta1} [deg]") + print(f"theta2: {theta2} [deg]") + print(f"dtheta1: {dtheta1} [deg/s]") + print(f"dtheta2: {dtheta2} [deg/s]") state = next_state state = np.reshape(state, (4, 1)) @@ -332,13 +335,13 @@ def run_control(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--T', type=int, default=100) - parser.add_argument('--num_episodes', type=int, default=25) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--T", type=int, default=100) + parser.add_argument("--num_episodes", type=int, default=25) args = parser.parse_args() run_control(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/classic_control/ilqr/ilqr_pendulum.py b/reproductions/algorithms/classic_control/ilqr/ilqr_pendulum.py index a51e8c44..e454ae6d 100644 --- a/reproductions/algorithms/classic_control/ilqr/ilqr_pendulum.py +++ b/reproductions/algorithms/classic_control/ilqr/ilqr_pendulum.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -44,8 +44,9 @@ def __init__(self): self.m = 1.0 self.length = 1.0 - def next_state(self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False) \ - -> Tuple[np.ndarray, Dict[str, Any]]: + def next_state( + self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False + ) -> Tuple[np.ndarray, Dict[str, Any]]: if batched: raise NotImplementedError # x.shape = (state_dim, 1) and u.shape = (input_dim, 1) @@ -63,18 +64,19 @@ def gradient(self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False) th = x.flatten()[0] # th := theta Fx = np.zeros((self.state_dim(), self.state_dim())) - Fx[0, 0] = 1. - Fx[0, 1] = 1. * self.dt + Fx[0, 0] = 1.0 + Fx[0, 1] = 1.0 * self.dt Fx[1, 0] = 3 * self.g / (2 * self.length) * np.cos(th) * self.dt - Fx[1, 1] = 1. + Fx[1, 1] = 1.0 Fu = np.zeros((self.state_dim(), self.action_dim())) Fu[1, 0] = (3.0 / (self.m * self.length**2)) * self.dt return Fx, Fu - def hessian(self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False) \ - -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + def hessian( + self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: raise NotImplementedError def state_dim(self) -> int: @@ -87,7 +89,7 @@ def action_dim(self) -> int: class PendulumCostFunction(CostFunction): def __init__(self): super().__init__() - self._weight_u = 1. + self._weight_u = 1.0 self._weight_th = 2.5 self._weight_thdot = 0.5 @@ -104,18 +106,23 @@ def evaluate( if final_state: return state_cost else: - act_cost = self._weight_u * (u.flatten()**2) + act_cost = self._weight_u * (u.flatten() ** 2) return state_cost + act_cost def gradient( - self, x: np.ndarray, u: Optional[np.ndarray], t: int, final_state: bool = False, batched: bool = False, + self, + x: np.ndarray, + u: Optional[np.ndarray], + t: int, + final_state: bool = False, + batched: bool = False, ) -> Tuple[np.ndarray, Optional[np.ndarray]]: if batched: raise NotImplementedError # x.shape = (state_dim, 1) and u.shape = (input_dim, 1) th = x.flatten()[0] # th := theta thdot = x.flatten()[1] - Cx = np.array([[2 * self._weight_th * self._angle_normalize(th)], [2. * self._weight_thdot * thdot]]) + Cx = np.array([[2 * self._weight_th * self._angle_normalize(th)], [2.0 * self._weight_thdot * thdot]]) if final_state: return (Cx, None) else: @@ -171,9 +178,9 @@ def run_control(args): start = time.time() improved_trajectory, trajectory_info = ilqr.compute_trajectory(initial_trajectory) end = time.time() - print(f'optimization time: {end - start} [s]') + print(f"optimization time: {end - start} [s]") - u = improved_trajectory[0][1].reshape((1, )) + u = improved_trajectory[0][1].reshape((1,)) next_state, reward, done, *_ = env.step(u) total_reward += reward @@ -189,13 +196,13 @@ def run_control(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--T', type=int, default=25) - parser.add_argument('--num_episodes', type=int, default=10) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--T", type=int, default=25) + parser.add_argument("--num_episodes", type=int, default=10) args = parser.parse_args() run_control(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/classic_control/lqr/lqr_cartpole.py b/reproductions/algorithms/classic_control/lqr/lqr_cartpole.py index b632fff7..ac6456c5 100644 --- a/reproductions/algorithms/classic_control/lqr/lqr_cartpole.py +++ b/reproductions/algorithms/classic_control/lqr/lqr_cartpole.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -44,11 +44,9 @@ def step(self, action): # For the interested reader: # https://coneural.org/florian/papers/05_cart_pole.pdf - temp = ( - force + self.polemass_length * theta_dot ** 2 * sintheta - ) / self.total_mass + temp = (force + self.polemass_length * theta_dot**2 * sintheta) / self.total_mass thetaacc = (self.gravity * sintheta - costheta * temp) / ( - self.length * (4.0 / 3.0 - self.masspole * costheta ** 2 / self.total_mass) + self.length * (4.0 / 3.0 - self.masspole * costheta**2 / self.total_mass) ) xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass @@ -104,33 +102,34 @@ def __init__(self): self.force_mag = 10.0 self.tau = 0.02 # seconds between state updates self._A = np.zeros((4, 4)) - self._A[0, 1] = 1. - self._A[1, 2] = - self.polemass_length / self.total_mass - self._A[2, 3] = 1. - denom = self.length * (4. / 3. - (self.masspole / self.total_mass)) + self._A[0, 1] = 1.0 + self._A[1, 2] = -self.polemass_length / self.total_mass + self._A[2, 3] = 1.0 + denom = self.length * (4.0 / 3.0 - (self.masspole / self.total_mass)) self._A[3, 2] = self.gravity / denom self._A = self._A * self.tau + np.eye(4) self._B = np.zeros((4, 1)) - self._B[1, 0] = 1. / self.total_mass - self._B[3, 0] = (- 1. / self.total_mass) / denom + self._B[1, 0] = 1.0 / self.total_mass + self._B[3, 0] = (-1.0 / self.total_mass) / denom self._B = self._B * self.tau - def next_state(self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False) \ - -> Tuple[np.ndarray, Dict[str, Any]]: + def next_state( + self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False + ) -> Tuple[np.ndarray, Dict[str, Any]]: if batched: raise NotImplementedError # x.shape = (state_dim, 1) and u.shape = (input_dim, 1) return self._A.dot(x) + self._B.dot(u), {} - def gradient(self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False) \ - -> Tuple[np.ndarray, np.ndarray]: + def gradient(self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False) -> Tuple[np.ndarray, np.ndarray]: if batched: raise NotImplementedError return self._A, self._B - def hessian(self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False) \ - -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + def hessian( + self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: raise NotImplementedError def state_dim(self) -> int: @@ -156,7 +155,7 @@ def evaluate( return x.T.dot(self._Q).dot(x) else: # Assuming that target state is zero - return x.T.dot(self._Q).dot(x) + 2.0*x.T.dot(self._F).dot(u) + u.T.dot(self._R).dot(u) + return x.T.dot(self._Q).dot(x) + 2.0 * x.T.dot(self._F).dot(u) + u.T.dot(self._R).dot(u) def gradient( self, x: np.ndarray, u: Optional[np.ndarray], t: int, final_state: bool = False, batched: bool = False @@ -208,7 +207,7 @@ def run_control(args): start = time.time() improved_trajectory, trajectory_info = lqr.compute_trajectory(initial_trajectory) end = time.time() - print(f'optimization time: {end - start} [s]') + print(f"optimization time: {end - start} [s]") next_state, reward, done, info = env.step(improved_trajectory[0][1].flatten().astype(np.float32)) total_reward += reward @@ -225,13 +224,13 @@ def run_control(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--T', type=int, default=50) - parser.add_argument('--num_episodes', type=int, default=10) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--T", type=int, default=50) + parser.add_argument("--num_episodes", type=int, default=10) args = parser.parse_args() run_control(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/classic_control/mppi/mppi_pendulum.py b/reproductions/algorithms/classic_control/mppi/mppi_pendulum.py index c059acca..ed3af3df 100644 --- a/reproductions/algorithms/classic_control/mppi/mppi_pendulum.py +++ b/reproductions/algorithms/classic_control/mppi/mppi_pendulum.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -65,8 +65,9 @@ def __init__(self): self.m = 1.0 self.length = 1.0 - def next_state(self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False) \ - -> Tuple[np.ndarray, Dict[str, Any]]: + def next_state( + self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False + ) -> Tuple[np.ndarray, Dict[str, Any]]: # x.shape = (state_dim, 1) and u.shape = (input_dim, 1) if batched: th, thdot = np.split(x, x.shape[-1], axis=-1) @@ -88,8 +89,9 @@ def next_state(self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False def gradient(self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False) -> Tuple[np.ndarray, np.ndarray]: raise NotImplementedError - def hessian(self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False) \ - -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + def hessian( + self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: raise NotImplementedError def state_dim(self) -> int: @@ -108,11 +110,9 @@ def __init__(self): self._weight_th = 2.0 self._weight_thdot = 0.5 - def evaluate(self, x: np.ndarray, - u: Optional[np.ndarray], - t: int, - final_state: bool = False, - batched: bool = False) -> np.ndarray: + def evaluate( + self, x: np.ndarray, u: Optional[np.ndarray], t: int, final_state: bool = False, batched: bool = False + ) -> np.ndarray: # x.shape = (state_dim, 1) and u.shape = (input_dim, 1) if batched: th, thdot = np.split(x, x.shape[-1], axis=-1) @@ -122,17 +122,14 @@ def evaluate(self, x: np.ndarray, state_cost = self._weight_th * self._angle_normalize(th) ** 2 + self._weight_thdot * thdot**2 return state_cost - def gradient(self, x: np.ndarray, u: Optional[np.ndarray], t: int, final_state: bool = False, - batched: bool = False) -> Tuple[np.ndarray, Optional[np.ndarray]]: + def gradient( + self, x: np.ndarray, u: Optional[np.ndarray], t: int, final_state: bool = False, batched: bool = False + ) -> Tuple[np.ndarray, Optional[np.ndarray]]: raise NotImplementedError - def hessian(self, - x: np.ndarray, - u: Optional[np.ndarray], - t: int, - final_state: bool = False, - batched: bool = False) \ - -> Tuple[np.ndarray, Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: + def hessian( + self, x: np.ndarray, u: Optional[np.ndarray], t: int, final_state: bool = False, batched: bool = False + ) -> Tuple[np.ndarray, Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: raise NotImplementedError def support_batch(self) -> bool: @@ -168,7 +165,7 @@ def on_hook_called(self, algorithm): self._dummy_states[0] = x_next self._control_inputs[0:-1] = self._control_inputs[1:] total_reward += reward - print(f'total reward: {total_reward}') + print(f"total reward: {total_reward}") def compute_initial_trajectory(x0, dynamics, T, u): @@ -199,13 +196,15 @@ def run_control(args): covariance = np.eye(N=env_info.action_dim) * 0.5 known_dynamics = PendulumKnownDynamics() - config = MPPIConfig(gpu_id=args.gpu, - T=args.T, - covariance=covariance, - K=1000, - use_known_dynamics=args.use_known_dynamics, - training_iterations=1000, - dt=known_dynamics.dt) + config = MPPIConfig( + gpu_id=args.gpu, + T=args.T, + covariance=covariance, + K=1000, + use_known_dynamics=args.use_known_dynamics, + training_iterations=1000, + dt=known_dynamics.dt, + ) def normalize_state(x): th, thdot = np.split(x, x.shape[-1], axis=-1) @@ -223,15 +222,15 @@ def normalize_state(x): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--T', type=int, default=20) - parser.add_argument('--num_episodes', type=int, default=10) - parser.add_argument('--use-known-dynamics', action='store_true') + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--T", type=int, default=20) + parser.add_argument("--num_episodes", type=int, default=10) + parser.add_argument("--use-known-dynamics", action="store_true") args = parser.parse_args() run_control(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/delayed_mujoco/demme_sac/demme_sac_reproduction.py b/reproductions/algorithms/delayed_mujoco/demme_sac/demme_sac_reproduction.py index 195ca4ae..2087cc7f 100644 --- a/reproductions/algorithms/delayed_mujoco/demme_sac/demme_sac_reproduction.py +++ b/reproductions/algorithms/delayed_mujoco/demme_sac/demme_sac_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,57 +26,58 @@ def select_start_timesteps(env_name): - if env_name in ['DelayedAnt-v1', 'DelayedHalfCheetah-v1']: + if env_name in ["DelayedAnt-v1", "DelayedHalfCheetah-v1"]: timesteps = 10000 else: timesteps = 1000 - print(f'Selected start timesteps: {timesteps}') + print(f"Selected start timesteps: {timesteps}") return timesteps def select_total_iterations(env_name): - if env_name in ['DelayedHpper-v1']: + if env_name in ["DelayedHpper-v1"]: total_iterations = 3000000 else: total_iterations = 5000000 - print(f'Selected total iterations: {total_iterations}') + print(f"Selected total iterations: {total_iterations}") return total_iterations def select_alpha_pi(env_name): alpha_pi = 0.2 - print(f'Selected alpha_pi: {alpha_pi}') + print(f"Selected alpha_pi: {alpha_pi}") return alpha_pi def select_alpha_q(env_name): - if env_name in ['DelayedHopper-v1']: + if env_name in ["DelayedHopper-v1"]: alpha_q = 2.0 - elif env_name in ['DelayedHalfCheetah-v1']: + elif env_name in ["DelayedHalfCheetah-v1"]: alpha_q = 2.0 - elif env_name in ['DelayedWalker2d-v1']: + elif env_name in ["DelayedWalker2d-v1"]: alpha_q = 2.0 - elif env_name in ['DelayedAnt-v1']: + elif env_name in ["DelayedAnt-v1"]: alpha_q = 0.1 else: alpha_q = 1.0 - print(f'Selected alpha_q: {alpha_q}') + print(f"Selected alpha_q: {alpha_q}") return alpha_q def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100) evaluator = EpisodicEvaluator(run_per_evaluation=10) - evaluation_hook = H.EvaluationHook(eval_env, - evaluator, - timing=args.eval_timing, - writer=W.FileWriter(outdir=outdir, - file_prefix='evaluation_result')) + evaluation_hook = H.EvaluationHook( + eval_env, + evaluator, + timing=args.eval_timing, + writer=W.FileWriter(outdir=outdir, file_prefix="evaluation_result"), + ) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) iteration_num_hook = H.IterationNumHook(timing=100) @@ -85,10 +86,7 @@ def run_training(args): timesteps = select_start_timesteps(args.env) alpha_pi = select_alpha_pi(args.env) alpha_q = select_alpha_q(args.env) - config = A.DEMMESACConfig(gpu_id=args.gpu, - start_timesteps=timesteps, - alpha_pi=alpha_pi, - alpha_q=alpha_q) + config = A.DEMMESACConfig(gpu_id=args.gpu, start_timesteps=timesteps, alpha_pi=alpha_pi, alpha_q=alpha_q) demme_sac = A.DEMMESAC(train_env, config=config) hooks = [iteration_num_hook, save_snapshot_hook, evaluation_hook] @@ -103,13 +101,12 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError( - 'Please specify the snapshot dir for showcasing') + raise ValueError("Please specify the snapshot dir for showcasing") eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render) config = A.DEMMESACConfig(gpu_id=args.gpu) demme_sac = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(demme_sac, A.DEMMESAC): - raise ValueError('Loaded snapshot is not trained with DEMMESAC!') + raise ValueError("Loaded snapshot is not trained with DEMMESAC!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) evaluator(demme_sac, eval_env) @@ -117,17 +114,17 @@ def run_showcase(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='DelayedAnt-v1') - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--save-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=None) - parser.add_argument('--save_timing', type=int, default=5000) - parser.add_argument('--eval_timing', type=int, default=5000) - parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument("--env", type=str, default="DelayedAnt-v1") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--save-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=None) + parser.add_argument("--save_timing", type=int, default=5000) + parser.add_argument("--eval_timing", type=int, default=5000) + parser.add_argument("--showcase_runs", type=int, default=10) args = parser.parse_args() @@ -137,5 +134,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/delayed_mujoco/environment/delayed_mujoco/__init__.py b/reproductions/algorithms/delayed_mujoco/environment/delayed_mujoco/__init__.py index 2eb1436b..9b72a06d 100644 --- a/reproductions/algorithms/delayed_mujoco/environment/delayed_mujoco/__init__.py +++ b/reproductions/algorithms/delayed_mujoco/environment/delayed_mujoco/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,28 +15,28 @@ from gym.envs.registration import register register( - id='DelayedHalfCheetah-v1', - entry_point='delayed_mujoco.delayed_mujoco:DelayedHalfCheetahEnv', + id="DelayedHalfCheetah-v1", + entry_point="delayed_mujoco.delayed_mujoco:DelayedHalfCheetahEnv", max_episode_steps=1000, reward_threshold=4800.0, ) register( - id='DelayedHopper-v1', - entry_point='delayed_mujoco.delayed_mujoco:DelayedHopperEnv', + id="DelayedHopper-v1", + entry_point="delayed_mujoco.delayed_mujoco:DelayedHopperEnv", max_episode_steps=1000, reward_threshold=3800.0, ) register( - id='DelayedWalker2d-v1', + id="DelayedWalker2d-v1", max_episode_steps=1000, - entry_point='delayed_mujoco.delayed_mujoco:DelayedWalker2dEnv', + entry_point="delayed_mujoco.delayed_mujoco:DelayedWalker2dEnv", ) register( - id='DelayedAnt-v1', - entry_point='delayed_mujoco.delayed_mujoco:DelayedAntEnv', + id="DelayedAnt-v1", + entry_point="delayed_mujoco.delayed_mujoco:DelayedAntEnv", max_episode_steps=1000, reward_threshold=6000.0, ) diff --git a/reproductions/algorithms/delayed_mujoco/environment/setup.py b/reproductions/algorithms/delayed_mujoco/environment/setup.py index 9deb05f5..b4d159ff 100644 --- a/reproductions/algorithms/delayed_mujoco/environment/setup.py +++ b/reproductions/algorithms/delayed_mujoco/environment/setup.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,4 +14,4 @@ from setuptools import setup -setup(name='delayed_mujoco', version='0.0.1', install_requires=['gym', 'mujoco-py']) +setup(name="delayed_mujoco", version="0.0.1", install_requires=["gym", "mujoco-py"]) diff --git a/reproductions/algorithms/delayed_mujoco/icml2018sac/icml2018sac_reproduction.py b/reproductions/algorithms/delayed_mujoco/icml2018sac/icml2018sac_reproduction.py index b82f31f2..a3e0ef49 100644 --- a/reproductions/algorithms/delayed_mujoco/icml2018sac/icml2018sac_reproduction.py +++ b/reproductions/algorithms/delayed_mujoco/icml2018sac/icml2018sac_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,42 +26,43 @@ def select_start_timesteps(env_name): - if env_name in ['DelayedAnt-v1', 'DelayedHalfCheetah-v1']: + if env_name in ["DelayedAnt-v1", "DelayedHalfCheetah-v1"]: timesteps = 10000 else: timesteps = 1000 - print(f'Selected start timesteps: {timesteps}') + print(f"Selected start timesteps: {timesteps}") return timesteps def select_total_iterations(env_name): - if env_name in ['DelayedHopper-v1']: + if env_name in ["DelayedHopper-v1"]: total_iterations = 3000000 else: total_iterations = 5000000 - print(f'Selected total iterations: {total_iterations}') + print(f"Selected total iterations: {total_iterations}") return total_iterations def select_reward_scalar(env_name): scalar = 5.0 - print(f'Selected reward scalar: {scalar}') + print(f"Selected reward scalar: {scalar}") return scalar def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100) evaluator = EpisodicEvaluator(run_per_evaluation=10) - evaluation_hook = H.EvaluationHook(eval_env, - evaluator, - timing=args.eval_timing, - writer=W.FileWriter(outdir=outdir, - file_prefix='evaluation_result')) + evaluation_hook = H.EvaluationHook( + eval_env, + evaluator, + timing=args.eval_timing, + writer=W.FileWriter(outdir=outdir, file_prefix="evaluation_result"), + ) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) iteration_num_hook = H.IterationNumHook(timing=100) @@ -69,9 +70,7 @@ def run_training(args): train_env = build_mujoco_env(args.env, seed=args.seed, render=args.render) timesteps = select_start_timesteps(args.env) reward_scalar = select_reward_scalar(args.env) - config = A.ICML2018SACConfig(gpu_id=args.gpu, - start_timesteps=timesteps, - reward_scalar=reward_scalar) + config = A.ICML2018SACConfig(gpu_id=args.gpu, start_timesteps=timesteps, reward_scalar=reward_scalar) icml2018sac = A.ICML2018SAC(train_env, config=config) hooks = [iteration_num_hook, save_snapshot_hook, evaluation_hook] @@ -86,13 +85,12 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError( - 'Please specify the snapshot dir for showcasing') + raise ValueError("Please specify the snapshot dir for showcasing") eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render) config = A.ICML2018SACConfig(gpu_id=args.gpu) icml2018sac = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(icml2018sac, A.ICML2018SAC): - raise ValueError('Loaded snapshot is not trained with ICML2018SAC!') + raise ValueError("Loaded snapshot is not trained with ICML2018SAC!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) evaluator(icml2018sac, eval_env) @@ -100,17 +98,17 @@ def run_showcase(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='DelayedAnt-v1') - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--save-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=None) - parser.add_argument('--save_timing', type=int, default=5000) - parser.add_argument('--eval_timing', type=int, default=5000) - parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument("--env", type=str, default="DelayedAnt-v1") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--save-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=None) + parser.add_argument("--save_timing", type=int, default=5000) + parser.add_argument("--eval_timing", type=int, default=5000) + parser.add_argument("--showcase_runs", type=int, default=10) args = parser.parse_args() @@ -120,5 +118,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/delayed_mujoco/mme_sac/mme_sac_reproduction.py b/reproductions/algorithms/delayed_mujoco/mme_sac/mme_sac_reproduction.py index 87c0484a..235dc89e 100644 --- a/reproductions/algorithms/delayed_mujoco/mme_sac/mme_sac_reproduction.py +++ b/reproductions/algorithms/delayed_mujoco/mme_sac/mme_sac_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,57 +26,58 @@ def select_start_timesteps(env_name): - if env_name in ['DelayedAnt-v1', 'DelayedHalfCheetah-v1']: + if env_name in ["DelayedAnt-v1", "DelayedHalfCheetah-v1"]: timesteps = 10000 else: timesteps = 1000 - print(f'Selected start timesteps: {timesteps}') + print(f"Selected start timesteps: {timesteps}") return timesteps def select_total_iterations(env_name): - if env_name in ['DelayedHpper-v1']: + if env_name in ["DelayedHpper-v1"]: total_iterations = 3000000 else: total_iterations = 5000000 - print(f'Selected total iterations: {total_iterations}') + print(f"Selected total iterations: {total_iterations}") return total_iterations def select_alpha_pi(env_name): alpha_pi = 0.2 - print(f'Selected alpha_pi: {alpha_pi}') + print(f"Selected alpha_pi: {alpha_pi}") return alpha_pi def select_alpha_q(env_name): - if env_name in ['DelayedHopper-v1']: + if env_name in ["DelayedHopper-v1"]: alpha_q = 1.0 - elif env_name in ['DelayedHalfCheetah-v1']: + elif env_name in ["DelayedHalfCheetah-v1"]: alpha_q = 2.0 - elif env_name in ['DelayedWalker2d-v1']: + elif env_name in ["DelayedWalker2d-v1"]: alpha_q = 0.5 - elif env_name in ['DelayedAnt-v1']: + elif env_name in ["DelayedAnt-v1"]: alpha_q = 0.2 else: alpha_q = 1.0 - print(f'Selected alpha_q: {alpha_q}') + print(f"Selected alpha_q: {alpha_q}") return alpha_q def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100) evaluator = EpisodicEvaluator(run_per_evaluation=10) - evaluation_hook = H.EvaluationHook(eval_env, - evaluator, - timing=args.eval_timing, - writer=W.FileWriter(outdir=outdir, - file_prefix='evaluation_result')) + evaluation_hook = H.EvaluationHook( + eval_env, + evaluator, + timing=args.eval_timing, + writer=W.FileWriter(outdir=outdir, file_prefix="evaluation_result"), + ) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) iteration_num_hook = H.IterationNumHook(timing=100) @@ -85,10 +86,7 @@ def run_training(args): timesteps = select_start_timesteps(args.env) alpha_pi = select_alpha_pi(args.env) alpha_q = select_alpha_q(args.env) - config = A.MMESACConfig(gpu_id=args.gpu, - start_timesteps=timesteps, - alpha_pi=alpha_pi, - alpha_q=alpha_q) + config = A.MMESACConfig(gpu_id=args.gpu, start_timesteps=timesteps, alpha_pi=alpha_pi, alpha_q=alpha_q) mme_sac = A.MMESAC(train_env, config=config) hooks = [iteration_num_hook, save_snapshot_hook, evaluation_hook] @@ -103,13 +101,12 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError( - 'Please specify the snapshot dir for showcasing') + raise ValueError("Please specify the snapshot dir for showcasing") eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render) config = A.MMESACConfig(gpu_id=args.gpu) mme_sac = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(mme_sac, A.MMESAC): - raise ValueError('Loaded snapshot is not trained with MMESAC!') + raise ValueError("Loaded snapshot is not trained with MMESAC!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) evaluator(mme_sac, eval_env) @@ -117,17 +114,17 @@ def run_showcase(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='DelayedAnt-v1') - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--save-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=None) - parser.add_argument('--save_timing', type=int, default=5000) - parser.add_argument('--eval_timing', type=int, default=5000) - parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument("--env", type=str, default="DelayedAnt-v1") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--save-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=None) + parser.add_argument("--save_timing", type=int, default=5000) + parser.add_argument("--eval_timing", type=int, default=5000) + parser.add_argument("--showcase_runs", type=int, default=10) args = parser.parse_args() @@ -137,5 +134,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/dm_control/srsac/performance_evaluation.py b/reproductions/algorithms/dm_control/srsac/performance_evaluation.py index 7eb90ce2..d078997c 100644 --- a/reproductions/algorithms/dm_control/srsac/performance_evaluation.py +++ b/reproductions/algorithms/dm_control/srsac/performance_evaluation.py @@ -19,21 +19,21 @@ from rliable import metrics DMC15_TASKS = [ - 'acrobot-swingup', - 'cheetah-run', - 'finger-turn_hard', - 'fish-swim', - 'hopper-hop', - 'hopper-stand', - 'humanoid-run', - 'humanoid-stand', - 'humanoid-walk', - 'pendulum-swingup', - 'quadruped-run', - 'quadruped-walk', - 'reacher-hard', - 'swimmer-swimmer6', - 'walker-run' + "acrobot-swingup", + "cheetah-run", + "finger-turn_hard", + "fish-swim", + "hopper-hop", + "hopper-stand", + "humanoid-run", + "humanoid-stand", + "humanoid-walk", + "pendulum-swingup", + "quadruped-run", + "quadruped-walk", + "reacher-hard", + "swimmer-swimmer6", + "walker-run", ] @@ -88,33 +88,30 @@ def data_to_matrix(data): def evaluate(args): training_data = load_training_data(args.rootdir) - print(f'training_data: {training_data}') + print(f"training_data: {training_data}") data_matrix = data_to_matrix(training_data) - print(f'data_matrix: {data_matrix}. shape: {data_matrix.shape}') + print(f"data_matrix: {data_matrix}. shape: {data_matrix.shape}") # data should be num_runs x num_tasks - data = {'score': data_matrix} + data = {"score": data_matrix} def aggregate_func(x): - return np.array([metrics.aggregate_iqm(x), - metrics.aggregate_median(x), - metrics.aggregate_mean(x)]) - aggregate_scores, aggregate_cis = rly.get_interval_estimates( - data, aggregate_func, reps=50000 - ) - print(f'aggregate scores: {aggregate_scores}') - print(f'aggregate cis: {aggregate_cis}') + return np.array([metrics.aggregate_iqm(x), metrics.aggregate_median(x), metrics.aggregate_mean(x)]) + + aggregate_scores, aggregate_cis = rly.get_interval_estimates(data, aggregate_func, reps=50000) + print(f"aggregate scores: {aggregate_scores}") + print(f"aggregate cis: {aggregate_cis}") def main(): parser = argparse.ArgumentParser() - parser.add_argument('--rootdir', type=str, required=True) - parser.add_argument('--target-file-name', type=str, default=None) + parser.add_argument("--rootdir", type=str, required=True) + parser.add_argument("--target-file-name", type=str, default=None) args = parser.parse_args() evaluate(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/dm_control/srsac/srsac_reproduction.py b/reproductions/algorithms/dm_control/srsac/srsac_reproduction.py index 014bbd79..95f4cde7 100644 --- a/reproductions/algorithms/dm_control/srsac/srsac_reproduction.py +++ b/reproductions/algorithms/dm_control/srsac/srsac_reproduction.py @@ -24,34 +24,34 @@ def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) eval_env = build_dmc_env(args.env, test=True, seed=args.seed + 100) evaluator = EpisodicEvaluator(run_per_evaluation=10) - evaluation_hook = H.EvaluationHook(eval_env, - evaluator, - timing=args.eval_timing, - writer=W.FileWriter(outdir=outdir, file_prefix='evaluation_result')) + evaluation_hook = H.EvaluationHook( + eval_env, + evaluator, + timing=args.eval_timing, + writer=W.FileWriter(outdir=outdir, file_prefix="evaluation_result"), + ) iteration_num_hook = H.IterationNumHook(timing=100) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) train_env = build_dmc_env(args.env, seed=args.seed, render=args.render) if args.enable_rnn: - config = A.SRSACConfig(gpu_id=args.gpu, - replay_ratio=args.replay_ratio, - reset_interval=args.reset_interval, - initial_temperature=1.0) + config = A.SRSACConfig( + gpu_id=args.gpu, replay_ratio=args.replay_ratio, reset_interval=args.reset_interval, initial_temperature=1.0 + ) srsac = A.SRSAC(train_env, config=config) else: - config = A.EfficientSRSACConfig(gpu_id=args.gpu, - replay_ratio=args.replay_ratio, - reset_interval=args.reset_interval, - initial_temperature=1.0) + config = A.EfficientSRSACConfig( + gpu_id=args.gpu, replay_ratio=args.replay_ratio, reset_interval=args.reset_interval, initial_temperature=1.0 + ) srsac = A.EfficientSRSAC(train_env, config=config) hooks = [iteration_num_hook, save_snapshot_hook, evaluation_hook] @@ -65,19 +65,18 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError( - 'Please specify the snapshot dir for showcasing') + raise ValueError("Please specify the snapshot dir for showcasing") eval_env = build_dmc_env(args.env, test=True, seed=args.seed + 200, render=args.render) if args.enable_rnn: config = A.SRSACConfig(gpu_id=args.gpu) srsac = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(srsac, A.SRSAC): - raise ValueError('Loaded snapshot is not trained with SRSAC!') + raise ValueError("Loaded snapshot is not trained with SRSAC!") else: config = A.EfficientSRSACConfig(gpu_id=args.gpu) srsac = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(srsac, A.EfficientSRSAC): - raise ValueError('Loaded snapshot is not trained with EfficientSRSAC!') + raise ValueError("Loaded snapshot is not trained with EfficientSRSAC!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) evaluator(srsac, eval_env) @@ -85,25 +84,39 @@ def run_showcase(args): def main(): parser = argparse.ArgumentParser() - choices = ['acrobot-swingup', 'cheetah-run', 'finger-turn_hard', 'fish-swim', 'hopper-hop', - 'hopper-stand', 'humanoid-run', 'humanoid-stand', 'humanoid-walk', 'pendulum-swingup', - 'quadruped-run', 'quadruped-walk', 'reacher-hard', 'swimmer-swimmer6', 'walker-run'] - parser.add_argument('--env', type=str, default='acrobot-swingup', help=f'DMC15 benchmark: {choices}') - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--save-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=500000) - parser.add_argument('--save_timing', type=int, default=5000) - parser.add_argument('--eval_timing', type=int, default=5000) - parser.add_argument('--showcase_runs', type=int, default=10) + choices = [ + "acrobot-swingup", + "cheetah-run", + "finger-turn_hard", + "fish-swim", + "hopper-hop", + "hopper-stand", + "humanoid-run", + "humanoid-stand", + "humanoid-walk", + "pendulum-swingup", + "quadruped-run", + "quadruped-walk", + "reacher-hard", + "swimmer-swimmer6", + "walker-run", + ] + parser.add_argument("--env", type=str, default="acrobot-swingup", help=f"DMC15 benchmark: {choices}") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--save-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=500000) + parser.add_argument("--save_timing", type=int, default=5000) + parser.add_argument("--eval_timing", type=int, default=5000) + parser.add_argument("--showcase_runs", type=int, default=10) # SRSAC algorithm config - parser.add_argument('--replay-ratio', type=int, default=32) - parser.add_argument('--reset-interval', type=int, default=2560000) - parser.add_argument('--enable-rnn', action='store_true', help='Turn this option true if you want to use rnn models') + parser.add_argument("--replay-ratio", type=int, default=32) + parser.add_argument("--reset-interval", type=int, default=2560000) + parser.add_argument("--enable-rnn", action="store_true", help="Turn this option true if you want to use rnn models") args = parser.parse_args() @@ -113,5 +126,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/factored_mujoco/sacd/sacd_reproduction.py b/reproductions/algorithms/factored_mujoco/sacd/sacd_reproduction.py index f34d7aee..1d5223cf 100644 --- a/reproductions/algorithms/factored_mujoco/sacd/sacd_reproduction.py +++ b/reproductions/algorithms/factored_mujoco/sacd/sacd_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,13 +24,13 @@ def select_total_iterations(env_name): - if env_name in ['FactoredAntV4NNablaRL-v1', 'FactoredHalfCheetahV4NNablaRL-v1', 'FactoredWalker2dV4NNablaRL-v1']: + if env_name in ["FactoredAntV4NNablaRL-v1", "FactoredHalfCheetahV4NNablaRL-v1", "FactoredWalker2dV4NNablaRL-v1"]: total_iterations = 3000000 - elif env_name in ['FactoredHumanoidV4NNablaRL-v1']: + elif env_name in ["FactoredHumanoidV4NNablaRL-v1"]: total_iterations = 10000000 else: total_iterations = 1000000 - print(f'Selected total iterations: {total_iterations}') + print(f"Selected total iterations: {total_iterations}") return total_iterations @@ -39,24 +39,27 @@ def select_reward_dimension(env_name): def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100) evaluator = EpisodicEvaluator(run_per_evaluation=10) - evaluation_hook = H.EvaluationHook(eval_env, - evaluator, - timing=args.eval_timing, - writer=W.FileWriter(outdir=outdir, file_prefix='evaluation_result')) + evaluation_hook = H.EvaluationHook( + eval_env, + evaluator, + timing=args.eval_timing, + writer=W.FileWriter(outdir=outdir, file_prefix="evaluation_result"), + ) iteration_num_hook = H.IterationNumHook(timing=100) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) train_env = build_mujoco_env(args.env, seed=args.seed, render=args.render) - config = A.SACDConfig(gpu_id=args.gpu, fix_temperature=args.fix_temperature, - reward_dimension=select_reward_dimension(args.env)) + config = A.SACDConfig( + gpu_id=args.gpu, fix_temperature=args.fix_temperature, reward_dimension=select_reward_dimension(args.env) + ) sac = A.SACD(train_env, config=config) hooks = [iteration_num_hook, save_snapshot_hook, evaluation_hook] @@ -71,13 +74,12 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError( - 'Please specify the snapshot dir for showcasing') + raise ValueError("Please specify the snapshot dir for showcasing") eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render) config = A.SACDConfig(gpu_id=args.gpu, reward_dimension=select_reward_dimension(args.env)) sacd = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(sacd, A.SACD): - raise ValueError('Loaded snapshot is not trained with SACD!') + raise ValueError("Loaded snapshot is not trained with SACD!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) evaluator(sacd, eval_env) @@ -85,20 +87,20 @@ def run_showcase(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='FactoredAntV4NNablaRL-v1') - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--save-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=None) - parser.add_argument('--save_timing', type=int, default=5000) - parser.add_argument('--eval_timing', type=int, default=5000) - parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument("--env", type=str, default="FactoredAntV4NNablaRL-v1") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--save-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=None) + parser.add_argument("--save_timing", type=int, default=5000) + parser.add_argument("--eval_timing", type=int, default=5000) + parser.add_argument("--showcase_runs", type=int, default=10) # SAC algorithm config - parser.add_argument('--fix-temperature', action='store_true') + parser.add_argument("--fix-temperature", action="store_true") args = parser.parse_args() @@ -108,5 +110,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/hybrid_env/hyar/goal_env_wrapper.py b/reproductions/algorithms/hybrid_env/hyar/goal_env_wrapper.py index 81a2f8f5..fb0a72e4 100644 --- a/reproductions/algorithms/hybrid_env/hyar/goal_env_wrapper.py +++ b/reproductions/algorithms/hybrid_env/hyar/goal_env_wrapper.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,19 +29,20 @@ def __init__(self, env): low = np.zeros(extended_observation_shape) low[:14] = original_observation_space.low - low[14] = -1. - low[15] = -1. + low[14] = -1.0 + low[15] = -1.0 low[16] = -GOAL_WIDTH / 2 high = np.ones(extended_observation_shape) high[:14] = original_observation_space.high - high[14] = 1. - high[15] = 1. + high[14] = 1.0 + high[15] = 1.0 high[16] = GOAL_WIDTH max_steps = 200 - self.observation_space = gym.spaces.Tuple((gym.spaces.Box(low=low, high=high, dtype=np.float32), - gym.spaces.Discrete(max_steps))) + self.observation_space = gym.spaces.Tuple( + (gym.spaces.Box(low=low, high=high, dtype=np.float32), gym.spaces.Discrete(max_steps)) + ) def observation(self, obs): state, steps = obs diff --git a/reproductions/algorithms/hybrid_env/hyar/hyar_reproduction.py b/reproductions/algorithms/hybrid_env/hyar/hyar_reproduction.py index 64916cdc..c73814d1 100644 --- a/reproductions/algorithms/hybrid_env/hyar/hyar_reproduction.py +++ b/reproductions/algorithms/hybrid_env/hyar/hyar_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,8 +22,13 @@ from nnabla_rl.algorithms import HyAR, HyARConfig from nnabla_rl.environments.wrappers import NumpyFloat32Env, ScreenRenderEnv from nnabla_rl.environments.wrappers.common import PrintEpisodeResultEnv -from nnabla_rl.environments.wrappers.hybrid_env import (FlattenActionWrapper, MergeBoxActionWrapper, RemoveStepWrapper, - ScaleActionWrapper, ScaleStateWrapper) +from nnabla_rl.environments.wrappers.hybrid_env import ( + FlattenActionWrapper, + MergeBoxActionWrapper, + RemoveStepWrapper, + ScaleActionWrapper, + ScaleStateWrapper, +) from nnabla_rl.utils import serializers from nnabla_rl.utils.evaluator import EpisodicEvaluator from nnabla_rl.utils.reproductions import print_env_info, set_global_seed @@ -61,7 +66,7 @@ def setup_goal_env(env): def build_env(env_name, test=False, seed=None, render=False, print_episode_result=False): env = gym.make(env_name) - if env_name == 'Goal-v0': + if env_name == "Goal-v0": env = setup_goal_env(env) elif env_name == "Platform-v0": env = setup_platform_env(env) @@ -80,16 +85,18 @@ def build_env(env_name, test=False, seed=None, render=False, print_episode_resul def setup_hyar(env, args): - config = HyARConfig(gpu_id=args.gpu, - learning_rate=3e-4, - batch_size=128, - start_timesteps=128, - train_action_noise_abs=1.0, - train_action_noise_sigma=0.1, - replay_buffer_size=int(1e5), - vae_learning_rate=1e-4, - vae_pretrain_episodes=args.vae_pretrain_episodes, - vae_pretrain_times=args.vae_pretrain_times) + config = HyARConfig( + gpu_id=args.gpu, + learning_rate=3e-4, + batch_size=128, + start_timesteps=128, + train_action_noise_abs=1.0, + train_action_noise_sigma=0.1, + replay_buffer_size=int(1e5), + vae_learning_rate=1e-4, + vae_pretrain_episodes=args.vae_pretrain_episodes, + vae_pretrain_times=args.vae_pretrain_times, + ) return HyAR(env, config=config) @@ -98,17 +105,19 @@ def setup_algorithm(env, args): def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) eval_env = build_env(args.env, test=True, seed=args.seed + 100) evaluator = EpisodicEvaluator(run_per_evaluation=100) - evaluation_hook = H.EvaluationHook(eval_env, - evaluator, - timing=args.eval_timing, - writer=W.FileWriter(outdir=outdir, file_prefix='evaluation_result')) + evaluation_hook = H.EvaluationHook( + eval_env, + evaluator, + timing=args.eval_timing, + writer=W.FileWriter(outdir=outdir, file_prefix="evaluation_result"), + ) iteration_num_hook = H.IterationNumHook(timing=100) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) @@ -126,12 +135,12 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError('Please specify the snapshot dir for showcasing') + raise ValueError("Please specify the snapshot dir for showcasing") eval_env = build_env(args.env, seed=args.seed + 200, render=args.render) config = HyARConfig(gpu_id=args.gpu) hyar = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(hyar, HyAR): - raise ValueError('Loaded snapshot is not trained with PPO!') + raise ValueError("Loaded snapshot is not trained with PPO!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) evaluator(hyar, eval_env) @@ -139,19 +148,19 @@ def run_showcase(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='Goal-v0') - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--save-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=300000) - parser.add_argument('--save_timing', type=int, default=5000) - parser.add_argument('--eval_timing', type=int, default=5000) - parser.add_argument('--showcase_runs', type=int, default=10) - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--vae-pretrain-episodes', type=int, default=20000) - parser.add_argument('--vae-pretrain-times', type=int, default=5000) + parser.add_argument("--env", type=str, default="Goal-v0") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--save-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=300000) + parser.add_argument("--save_timing", type=int, default=5000) + parser.add_argument("--eval_timing", type=int, default=5000) + parser.add_argument("--showcase_runs", type=int, default=10) + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--vae-pretrain-episodes", type=int, default=20000) + parser.add_argument("--vae-pretrain-times", type=int, default=5000) args = parser.parse_args() @@ -161,5 +170,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/mujoco/atrpo/atrpo_reproduction.py b/reproductions/algorithms/mujoco/atrpo/atrpo_reproduction.py index 99ea954f..a2ce979d 100644 --- a/reproductions/algorithms/mujoco/atrpo/atrpo_reproduction.py +++ b/reproductions/algorithms/mujoco/atrpo/atrpo_reproduction.py @@ -28,18 +28,19 @@ def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100, use_gymnasium=args.use_gymnasium) evaluator = EpisodicEvaluator(run_per_evaluation=10) - evaluation_hook = H.EvaluationHook(eval_env, - evaluator, - timing=args.eval_timing, - writer=W.FileWriter(outdir=outdir, - file_prefix='evaluation_result')) + evaluation_hook = H.EvaluationHook( + eval_env, + evaluator, + timing=args.eval_timing, + writer=W.FileWriter(outdir=outdir, file_prefix="evaluation_result"), + ) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) iteration_num_hook = H.IterationNumHook(timing=5000) @@ -60,37 +61,37 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError( - 'Please specify the snapshot dir for showcasing') + raise ValueError("Please specify the snapshot dir for showcasing") config = A.ATRPOConfig(gpu_id=args.gpu) - eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render, - use_gymnasium=args.use_gymnasium) + eval_env = build_mujoco_env( + args.env, test=True, seed=args.seed + 200, render=args.render, use_gymnasium=args.use_gymnasium + ) atrpo = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(atrpo, A.ATRPO): - raise ValueError('Loaded snapshot is not trained with ATRPO!') + raise ValueError("Loaded snapshot is not trained with ATRPO!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) returns = evaluator(atrpo, eval_env) mean = np.mean(returns) std_dev = np.std(returns) median = np.median(returns) - logger.info('Evaluation results. mean: {} +/- std: {}, median: {}'.format(mean, std_dev, median)) + logger.info("Evaluation results. mean: {} +/- std: {}, median: {}".format(mean, std_dev, median)) def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='Ant-v3') - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--save-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=10000000) - parser.add_argument('--save_timing', type=int, default=1000000) - parser.add_argument('--eval_timing', type=int, default=50000) - parser.add_argument('--showcase_runs', type=int, default=10) - parser.add_argument('--use-gymnasium', action='store_true') + parser.add_argument("--env", type=str, default="Ant-v3") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--save-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=10000000) + parser.add_argument("--save_timing", type=int, default=1000000) + parser.add_argument("--eval_timing", type=int, default=50000) + parser.add_argument("--showcase_runs", type=int, default=10) + parser.add_argument("--use-gymnasium", action="store_true") args = parser.parse_args() @@ -100,5 +101,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/mujoco/bcq/bcq_reproduction.py b/reproductions/algorithms/mujoco/bcq/bcq_reproduction.py index 0c15edae..b64d544f 100644 --- a/reproductions/algorithms/mujoco/bcq/bcq_reproduction.py +++ b/reproductions/algorithms/mujoco/bcq/bcq_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,18 +27,19 @@ def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100) evaluator = EpisodicEvaluator(run_per_evaluation=10) - evaluation_hook = H.EvaluationHook(eval_env, - evaluator, - timing=args.eval_timing, - writer=W.FileWriter(outdir=outdir, - file_prefix='evaluation_result')) + evaluation_hook = H.EvaluationHook( + eval_env, + evaluator, + timing=args.eval_timing, + writer=W.FileWriter(outdir=outdir, file_prefix="evaluation_result"), + ) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) iteration_num_hook = H.IterationNumHook(timing=100) @@ -64,13 +65,12 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError( - 'Please specify the snapshot dir for showcasing') + raise ValueError("Please specify the snapshot dir for showcasing") eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render) config = A.BCQConfig(gpu_id=args.gpu) bcq = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(bcq, A.BCQ): - raise ValueError('Loaded snapshot is not trained with BCQ!') + raise ValueError("Loaded snapshot is not trained with BCQ!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) evaluator(bcq, eval_env) @@ -78,17 +78,17 @@ def run_showcase(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='ant-expert-v0') - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--save-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=1000000) - parser.add_argument('--save_timing', type=int, default=5000) - parser.add_argument('--eval_timing', type=int, default=5000) - parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument("--env", type=str, default="ant-expert-v0") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--save-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=1000000) + parser.add_argument("--save_timing", type=int, default=5000) + parser.add_argument("--eval_timing", type=int, default=5000) + parser.add_argument("--showcase_runs", type=int, default=10) args = parser.parse_args() @@ -98,5 +98,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/mujoco/bear/bear_reproduction.py b/reproductions/algorithms/mujoco/bear/bear_reproduction.py index f9bfb679..19e7b2fb 100644 --- a/reproductions/algorithms/mujoco/bear/bear_reproduction.py +++ b/reproductions/algorithms/mujoco/bear/bear_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,29 +27,30 @@ def select_mmd_sigma(env_name, mmd_kernel): - if mmd_kernel == 'gaussian': + if mmd_kernel == "gaussian": mmd_sigma = 20.0 - elif mmd_kernel == 'laplacian': - mmd_sigma = 20.0 if 'walker2d' in env_name else 10.0 + elif mmd_kernel == "laplacian": + mmd_sigma = 20.0 if "walker2d" in env_name else 10.0 else: - raise ValueError(f'Unknown mmd kernel: {mmd_kernel}') - print(f'selected mmd sigma: {mmd_sigma}') + raise ValueError(f"Unknown mmd kernel: {mmd_kernel}") + print(f"selected mmd sigma: {mmd_sigma}") return mmd_sigma def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100) evaluator = EpisodicEvaluator(run_per_evaluation=10) - evaluation_hook = H.EvaluationHook(eval_env, - evaluator, - timing=args.eval_timing, - writer=W.FileWriter(outdir=outdir, - file_prefix='evaluation_result')) + evaluation_hook = H.EvaluationHook( + eval_env, + evaluator, + timing=args.eval_timing, + writer=W.FileWriter(outdir=outdir, file_prefix="evaluation_result"), + ) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) iteration_num_hook = H.IterationNumHook(timing=100) @@ -66,8 +67,7 @@ def run_training(args): config = A.BEARConfig(gpu_id=args.gpu, mmd_sigma=mmd_sigma, mmd_type=args.mmd_kernel) bear = A.BEAR(train_env, config=config) - hooks = [save_snapshot_hook, evaluation_hook, - iteration_num_hook, iteration_state_hook] + hooks = [save_snapshot_hook, evaluation_hook, iteration_num_hook, iteration_state_hook] bear.set_hooks(hooks) bear.train_offline(buffer, total_iterations=args.total_iterations) @@ -78,13 +78,12 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError( - 'Please specify the snapshot dir for showcasing') + raise ValueError("Please specify the snapshot dir for showcasing") eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render) config = A.BEARConfig(gpu_id=args.gpu) bear = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(bear, A.BEAR): - raise ValueError('Loaded snapshot is not trained with BEAR!') + raise ValueError("Loaded snapshot is not trained with BEAR!") evaluator = EpisodicEvaluator(args.showcase_runs) evaluator(bear, eval_env) @@ -92,19 +91,18 @@ def run_showcase(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='ant-expert-v0') - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--mmd-kernel', type=str, - default="gaussian", choices=["laplacian", "gaussian"]) - parser.add_argument('--save-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=1000000) - parser.add_argument('--save_timing', type=int, default=5000) - parser.add_argument('--eval_timing', type=int, default=5000) - parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument("--env", type=str, default="ant-expert-v0") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--mmd-kernel", type=str, default="gaussian", choices=["laplacian", "gaussian"]) + parser.add_argument("--save-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=1000000) + parser.add_argument("--save_timing", type=int, default=5000) + parser.add_argument("--eval_timing", type=int, default=5000) + parser.add_argument("--showcase_runs", type=int, default=10) args = parser.parse_args() @@ -114,5 +112,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/mujoco/ddpg/ddpg_reproduction.py b/reproductions/algorithms/mujoco/ddpg/ddpg_reproduction.py index 1f8c326f..336cfcb9 100644 --- a/reproductions/algorithms/mujoco/ddpg/ddpg_reproduction.py +++ b/reproductions/algorithms/mujoco/ddpg/ddpg_reproduction.py @@ -24,26 +24,28 @@ def select_start_timesteps(env_name): - if env_name in ['Ant-v2', 'HalfCheetah-v2']: + if env_name in ["Ant-v2", "HalfCheetah-v2"]: timesteps = 10000 else: timesteps = 1000 - print(f'Selected start timesteps: {timesteps}') + print(f"Selected start timesteps: {timesteps}") return timesteps def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100, use_gymnasium=args.use_gymnasium) evaluator = EpisodicEvaluator(run_per_evaluation=10) - evaluation_hook = H.EvaluationHook(eval_env, - evaluator, - timing=args.eval_timing, - writer=W.FileWriter(outdir=outdir, file_prefix='evaluation_result')) + evaluation_hook = H.EvaluationHook( + eval_env, + evaluator, + timing=args.eval_timing, + writer=W.FileWriter(outdir=outdir, file_prefix="evaluation_result"), + ) iteration_num_hook = H.IterationNumHook(timing=100) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) @@ -64,15 +66,15 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError( - 'Please specify the snapshot dir for showcasing') + raise ValueError("Please specify the snapshot dir for showcasing") - eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, - render=args.render, use_gymnasium=args.use_gymnasium) + eval_env = build_mujoco_env( + args.env, test=True, seed=args.seed + 200, render=args.render, use_gymnasium=args.use_gymnasium + ) config = A.DDPGConfig(gpu_id=args.gpu) ddpg = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(ddpg, A.DDPG): - raise ValueError('Loaded snapshot is not trained with DDPG!') + raise ValueError("Loaded snapshot is not trained with DDPG!") evaluator = EpisodicEvaluator(args.showcase_runs) evaluator(ddpg, eval_env) @@ -80,18 +82,18 @@ def run_showcase(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='Ant-v2') - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--save-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=1000000) - parser.add_argument('--save_timing', type=int, default=5000) - parser.add_argument('--eval_timing', type=int, default=5000) - parser.add_argument('--showcase_runs', type=int, default=10) - parser.add_argument('--use-gymnasium', action='store_true') + parser.add_argument("--env", type=str, default="Ant-v2") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--save-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=1000000) + parser.add_argument("--save_timing", type=int, default=5000) + parser.add_argument("--eval_timing", type=int, default=5000) + parser.add_argument("--showcase_runs", type=int, default=10) + parser.add_argument("--use-gymnasium", action="store_true") args = parser.parse_args() @@ -101,5 +103,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/mujoco/decision_transformer/compute_expert_normalized_score.py b/reproductions/algorithms/mujoco/decision_transformer/compute_expert_normalized_score.py index dfe7b2c4..25792a3b 100644 --- a/reproductions/algorithms/mujoco/decision_transformer/compute_expert_normalized_score.py +++ b/reproductions/algorithms/mujoco/decision_transformer/compute_expert_normalized_score.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,7 +26,7 @@ def load_histogram_data(path, dtype=float): histogram = [] with open(path) as f: - tsv_reader = reader(f, delimiter='\t') + tsv_reader = reader(f, delimiter="\t") for i, row in enumerate(tsv_reader): if i == 0: continue @@ -55,8 +55,8 @@ def extract_iteration_num_and_returns(histogram_data): returns = [] for i in range(len(histogram_data)): data_row = histogram_data[i] - if 'returns' in data_row[0]: - iteration_nums.append(int(data_row[0].split(' ')[0])) + if "returns" in data_row[0]: + iteration_nums.append(int(data_row[0].split(" ")[0])) scores = data_row[1][0:].astype(float) returns.append(scores) @@ -83,21 +83,21 @@ def create_expert_normalized_score_file(histograms, file_outdir, d4rl_env_name): std_dev = np.std(normalized_r) * 100 scalar_results = {} - scalar_results['mean'] = mean - scalar_results['std_dev'] = std_dev + scalar_results["mean"] = mean + scalar_results["std_dev"] = std_dev writer.write_scalar(i, scalar_results) def to_d4rl_env_name(env_name): - if 'HalfCheetah' in env_name: - task_name = 'halfcheetah' - elif 'Hopper' in env_name: - task_name = 'hopper' - elif 'Walker2d' in env_name: - task_name = 'walker2d' + if "HalfCheetah" in env_name: + task_name = "halfcheetah" + elif "Hopper" in env_name: + task_name = "hopper" + elif "Walker2d" in env_name: + task_name = "walker2d" # fix to medium dataset - return f'{task_name}-medium-v2' + return f"{task_name}-medium-v2" def compile_results(args): @@ -105,12 +105,12 @@ def compile_results(args): histograms = {} histogram_directories = list_all_directory_with(rootdir, args.eval_histogram_filename) - print(f'files: {histogram_directories}') + print(f"files: {histogram_directories}") for directory in histogram_directories: if args.resultdir not in str(directory): continue relative_dir = directory.relative_to(rootdir) - env_name = str(relative_dir).split('/')[1] + env_name = str(relative_dir).split("/")[1] histogram_file = directory / args.eval_histogram_filename print(f"found histogram file of env: {env_name} at: {histogram_file}") if histogram_file.exists(): @@ -125,15 +125,17 @@ def compile_results(args): create_expert_normalized_score_file(histograms, file_outdir, d4rl_env_name) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--outdir', type=str, required=True, help='output directory') - parser.add_argument('--resultdir', type=str, required=True, help='result directory') - parser.add_argument('--eval-histogram-filename', - type=str, - default="evaluation_result_histogram.tsv", - help='eval result(histogram) filename') + parser.add_argument("--outdir", type=str, required=True, help="output directory") + parser.add_argument("--resultdir", type=str, required=True, help="result directory") + parser.add_argument( + "--eval-histogram-filename", + type=str, + default="evaluation_result_histogram.tsv", + help="eval result(histogram) filename", + ) args = parser.parse_args() diff --git a/reproductions/algorithms/mujoco/decision_transformer/decision_transformer_reproduction.py b/reproductions/algorithms/mujoco/decision_transformer/decision_transformer_reproduction.py index e34e45c5..43b49bdd 100755 --- a/reproductions/algorithms/mujoco/decision_transformer/decision_transformer_reproduction.py +++ b/reproductions/algorithms/mujoco/decision_transformer/decision_transformer_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -78,11 +78,10 @@ def build_scheduler(self, env_info, algorithm_config, **kwargs) -> BaseLearningR class MujocoDecisionTransformerBuilder(ModelBuilder): def build_model(self, scope_name, env_info, algorithm_config, **kwargs): - max_timesteps = cast(int, kwargs['max_timesteps']) - return MujocoDecisionTransformer(scope_name, - env_info.action_dim, - max_timestep=max_timesteps, - context_length=algorithm_config.context_length) + max_timesteps = cast(int, kwargs["max_timesteps"]) + return MujocoDecisionTransformer( + scope_name, env_info.action_dim, max_timestep=max_timesteps, context_length=algorithm_config.context_length + ) class MujocoSolverBuilder(SolverBuilder): @@ -95,13 +94,13 @@ def build_solver(self, env_info, algorithm_config, **kwargs) -> nn.solver.Solver def load_d4rl_dataset(env_name, dataset_type): - if 'HalfCheetah' in env_name: - task_name = 'halfcheetah' - elif 'Hopper' in env_name: - task_name = 'hopper' - elif 'Walker2d' in env_name: - task_name = 'walker2d' - d4rl_name = f'{task_name}-{dataset_type}-v2' + if "HalfCheetah" in env_name: + task_name = "halfcheetah" + elif "Hopper" in env_name: + task_name = "hopper" + elif "Walker2d" in env_name: + task_name = "walker2d" + d4rl_name = f"{task_name}-{dataset_type}-v2" d4rl_env = gym.make(d4rl_name) return d4rl_env.get_dataset() @@ -111,18 +110,18 @@ def load_dataset_from_path(dataset_dir): import pathlib def load_data_from_gz(gzfile): - with gzip.open(gzfile, mode='rb') as f: + with gzip.open(gzfile, mode="rb") as f: data = np.load(f, allow_pickle=False) return data dataset = {} dataset_dir = pathlib.Path(dataset_dir) - observation_file = dataset_dir / '$store$_observation_ckpt.0.gz' - action_file = dataset_dir / '$store$_action_ckpt.0.gz' - reward_file = dataset_dir / '$store$_reward_ckpt.0.gz' - terminal_file = dataset_dir / '$store$_terminal_ckpt.0.gz' - next_observation_file = dataset_dir / '$store$_next_observation_ckpt.0.gz' + observation_file = dataset_dir / "$store$_observation_ckpt.0.gz" + action_file = dataset_dir / "$store$_action_ckpt.0.gz" + reward_file = dataset_dir / "$store$_reward_ckpt.0.gz" + terminal_file = dataset_dir / "$store$_terminal_ckpt.0.gz" + next_observation_file = dataset_dir / "$store$_next_observation_ckpt.0.gz" observations = load_data_from_gz(observation_file) actions = load_data_from_gz(action_file) @@ -130,52 +129,52 @@ def load_data_from_gz(gzfile): terminals = load_data_from_gz(terminal_file) next_observations = load_data_from_gz(next_observation_file) - dataset['observations'] = observations - dataset['actions'] = actions - dataset['rewards'] = rewards - dataset['terminals'] = terminals - dataset['next_observations'] = next_observations + dataset["observations"] = observations + dataset["actions"] = actions + dataset["rewards"] = rewards + dataset["terminals"] = terminals + dataset["next_observations"] = next_observations return dataset def compute_state_mean_and_std(d4rl_dataset): - state_mean = np.mean(d4rl_dataset['observations'], axis=0) - state_std = np.std(d4rl_dataset['observations'], axis=0) + 1e-6 + state_mean = np.mean(d4rl_dataset["observations"], axis=0) + state_std = np.std(d4rl_dataset["observations"], axis=0) + 1e-6 return state_mean, state_std def load_dataset(d4rl_dataset, buffer_size, context_length, reward_scale): - use_timeouts = 'timeouts' in d4rl_dataset + use_timeouts = "timeouts" in d4rl_dataset max_possible_trajectories = buffer_size // context_length buffer = TrajectoryReplayBuffer(num_trajectories=max_possible_trajectories) - dataset_size = d4rl_dataset['rewards'].shape[0] + dataset_size = d4rl_dataset["rewards"].shape[0] max_timesteps = 1 episode_step = 0 start_index = 0 state_mean, state_std = compute_state_mean_and_std(d4rl_dataset) for i in range(dataset_size): - done = bool(d4rl_dataset['terminals'][i]) + done = bool(d4rl_dataset["terminals"][i]) episode_step = i - start_index - final_timestep = d4rl_dataset['timeouts'][i] if use_timeouts else (episode_step == 1000 - 1) + final_timestep = d4rl_dataset["timeouts"][i] if use_timeouts else (episode_step == 1000 - 1) if done or final_timestep: end_index = i - states = (d4rl_dataset['observations'][start_index:end_index+1] - state_mean) / state_std - actions = d4rl_dataset['actions'][start_index:end_index+1] - rewards = d4rl_dataset['rewards'][start_index:end_index+1] * reward_scale - non_terminals = 1.0 - d4rl_dataset['terminals'][start_index:end_index+1] - next_states = (d4rl_dataset['next_observations'][start_index:end_index+1] - state_mean) / state_std + states = (d4rl_dataset["observations"][start_index : end_index + 1] - state_mean) / state_std + actions = d4rl_dataset["actions"][start_index : end_index + 1] + rewards = d4rl_dataset["rewards"][start_index : end_index + 1] * reward_scale + non_terminals = 1.0 - d4rl_dataset["terminals"][start_index : end_index + 1] + next_states = (d4rl_dataset["next_observations"][start_index : end_index + 1] - state_mean) / state_std start_index = end_index + 1 info = [{} for _ in range(len(states))] for timestep, d in enumerate(info): - d['rtg'] = np.sum(rewards[timestep:]) - d['timesteps'] = timestep + d["rtg"] = np.sum(rewards[timestep:]) + d["timesteps"] = timestep assert all([len(data) == len(states) for data in (actions, rewards, non_terminals, next_states, info)]) timesteps = len(info) - 1 trajectory = list(zip(states, actions, rewards, non_terminals, next_states, info)) @@ -186,22 +185,22 @@ def load_dataset(d4rl_dataset, buffer_size, context_length, reward_scale): def get_target_return(env_name): - if 'HalfCheetah' in env_name: + if "HalfCheetah" in env_name: return 6000 - if 'Hopper' in env_name: + if "Hopper" in env_name: return 3600 - if 'Walker' in env_name: + if "Walker" in env_name: return 5000 - raise NotImplementedError(f'No target_return is defined for: {env_name}') + raise NotImplementedError(f"No target_return is defined for: {env_name}") def get_reward_scale(env_name): - if 'HalfCheetah' in env_name: - return 1/1000 - if 'Hopper' in env_name: - return 1/1000 - if 'Walker' in env_name: - return 1/1000 + if "HalfCheetah" in env_name: + return 1 / 1000 + if "Hopper" in env_name: + return 1 / 1000 + if "Walker" in env_name: + return 1 / 1000 return 1.0 @@ -210,7 +209,7 @@ def get_context_length(env_name): def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) @@ -235,15 +234,17 @@ def run_training(args): save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) target_return = args.target_return if args.target_return is not None else get_target_return(args.env) - config = A.DecisionTransformerConfig(gpu_id=args.gpu, - context_length=context_length, - max_timesteps=max_timesteps, - batch_size=args.batch_size, - target_return=target_return, - grad_clip_norm=0.25, - learning_rate=1.0e-4, - weight_decay=1.0e-4, - reward_scale=reward_scale) + config = A.DecisionTransformerConfig( + gpu_id=args.gpu, + context_length=context_length, + max_timesteps=max_timesteps, + batch_size=args.batch_size, + target_return=target_return, + grad_clip_norm=0.25, + learning_rate=1.0e-4, + weight_decay=1.0e-4, + reward_scale=reward_scale, + ) env_info = EnvironmentInfo.from_env(eval_env) decision_transformer = A.DecisionTransformer( env_info, @@ -251,10 +252,11 @@ def run_training(args): transformer_builder=MujocoDecisionTransformerBuilder(), transformer_solver_builder=MujocoSolverBuilder(), transformer_wd_solver_builder=None, - lr_scheduler_builder=MujocoLearningRateSchedulerBuilder(args.warmup_steps)) + lr_scheduler_builder=MujocoLearningRateSchedulerBuilder(args.warmup_steps), + ) decision_transformer.set_hooks(hooks=[epoch_num_hook, save_snapshot_hook, evaluation_hook]) - print(f'total epochs: {args.total_epochs}') + print(f"total epochs: {args.total_epochs}") # decision transformer runs 1 epoch per iteration decision_transformer.train(dataset, total_iterations=args.total_epochs) @@ -263,7 +265,7 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError('Please specify the snapshot dir for showcasing') + raise ValueError("Please specify the snapshot dir for showcasing") if args.dataset_path is None: dataset = load_d4rl_dataset(args.env, args.dataset_type) else: @@ -272,44 +274,45 @@ def run_showcase(args): eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render) eval_env = StateNormalizationWrapper(eval_env, state_mean, state_std) - config = {'gpu_id': args.gpu} + config = {"gpu_id": args.gpu} decision_transformer = serializers.load_snapshot( args.snapshot_dir, eval_env, - algorithm_kwargs={"config": config, "transformer_builder": MujocoDecisionTransformerBuilder()}) + algorithm_kwargs={"config": config, "transformer_builder": MujocoDecisionTransformerBuilder()}, + ) if not isinstance(decision_transformer, A.DecisionTransformer): - raise ValueError('Loaded snapshot is not trained with DecisionTransformer!') + raise ValueError("Loaded snapshot is not trained with DecisionTransformer!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) returns = evaluator(decision_transformer, eval_env) mean = np.mean(returns) std_dev = np.std(returns) median = np.median(returns) - logger.info('Evaluation results. mean: {} +/- std: {}, median: {}'.format(mean, std_dev, median)) + logger.info("Evaluation results. mean: {} +/- std: {}, median: {}".format(mean, std_dev, median)) def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='HalfCheetah-v3') - parser.add_argument('--dataset-path', type=str, default=None) - parser.add_argument('--dataset-type', type=str, default='medium', choices=['medium', 'expert']) - parser.add_argument('--save-dir', type=str, default="") - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--total-epochs', type=int, default=5) - parser.add_argument('--trajectories-per-buffer', type=int, default=10) - parser.add_argument('--warmup-steps', type=int, default=10000) - parser.add_argument('--buffer-size', type=int, default=500000) - parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--context-length', type=int, default=None) - parser.add_argument('--save_timing', type=int, default=1) - parser.add_argument('--eval_timing', type=int, default=1) - parser.add_argument('--showcase_runs', type=int, default=10) - parser.add_argument('--target-return', type=int, default=None) - parser.add_argument('--reward-scale', type=float, default=None) + parser.add_argument("--env", type=str, default="HalfCheetah-v3") + parser.add_argument("--dataset-path", type=str, default=None) + parser.add_argument("--dataset-type", type=str, default="medium", choices=["medium", "expert"]) + parser.add_argument("--save-dir", type=str, default="") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--total-epochs", type=int, default=5) + parser.add_argument("--trajectories-per-buffer", type=int, default=10) + parser.add_argument("--warmup-steps", type=int, default=10000) + parser.add_argument("--buffer-size", type=int, default=500000) + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--context-length", type=int, default=None) + parser.add_argument("--save_timing", type=int, default=1) + parser.add_argument("--eval_timing", type=int, default=1) + parser.add_argument("--showcase_runs", type=int, default=10) + parser.add_argument("--target-return", type=int, default=None) + parser.add_argument("--reward-scale", type=float, default=None) args = parser.parse_args() @@ -319,5 +322,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/mujoco/gail/gail_reproduction.py b/reproductions/algorithms/mujoco/gail/gail_reproduction.py index c3c36e23..c6186a14 100644 --- a/reproductions/algorithms/mujoco/gail/gail_reproduction.py +++ b/reproductions/algorithms/mujoco/gail/gail_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,18 +25,19 @@ def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100) evaluator = EpisodicEvaluator(run_per_evaluation=10) - evaluation_hook = H.EvaluationHook(eval_env, - evaluator, - timing=args.eval_timing, - writer=W.FileWriter(outdir=outdir, - file_prefix='evaluation_result')) + evaluation_hook = H.EvaluationHook( + eval_env, + evaluator, + timing=args.eval_timing, + writer=W.FileWriter(outdir=outdir, file_prefix="evaluation_result"), + ) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) iteration_num_hook = H.IterationNumHook(timing=50000) @@ -62,15 +63,16 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError('Please specify the snapshot dir for showcasing') + raise ValueError("Please specify the snapshot dir for showcasing") config = A.GAILConfig(gpu_id=args.gpu) eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render) - gail = serializers.load_snapshot(args.snapshot_dir, - eval_env, - algorithm_kwargs={"config": config, - "expert_buffer": ReplacementSamplingReplayBuffer()}) + gail = serializers.load_snapshot( + args.snapshot_dir, + eval_env, + algorithm_kwargs={"config": config, "expert_buffer": ReplacementSamplingReplayBuffer()}, + ) if not isinstance(gail, A.GAIL): - raise ValueError('Loaded snapshot is not trained with GAIL!') + raise ValueError("Loaded snapshot is not trained with GAIL!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) evaluator(gail, eval_env) @@ -78,18 +80,18 @@ def run_showcase(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='halfcheetah-medium-v2') - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--datasetsize', type=int, default=4000) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--save-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=25000000) - parser.add_argument('--save_timing', type=int, default=50000) - parser.add_argument('--eval_timing', type=int, default=50000) - parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument("--env", type=str, default="halfcheetah-medium-v2") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--datasetsize", type=int, default=4000) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--save-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=25000000) + parser.add_argument("--save_timing", type=int, default=50000) + parser.add_argument("--eval_timing", type=int, default=50000) + parser.add_argument("--showcase_runs", type=int, default=10) args = parser.parse_args() @@ -99,5 +101,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/mujoco/her/her_reproduction.py b/reproductions/algorithms/mujoco/her/her_reproduction.py index caf61d14..9f606f2e 100644 --- a/reproductions/algorithms/mujoco/her/her_reproduction.py +++ b/reproductions/algorithms/mujoco/her/her_reproduction.py @@ -40,17 +40,17 @@ def select_n_cycles(env_name): - if env_name in ['FetchReach-v1']: + if env_name in ["FetchReach-v1"]: n_cycles = 10 else: n_cycles = 50 - print(f'Selected start n_cycles: {n_cycles}') + print(f"Selected start n_cycles: {n_cycles}") return n_cycles def check_success(experiences: List[Experience]) -> bool: last_info = experiences[-1][-1] - if last_info['is_success'] == 1.0: + if last_info["is_success"] == 1.0: return True else: return False @@ -98,43 +98,50 @@ def build_mujoco_goal_conditioned_env(id_or_env, test=False, seed=None, render=F def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) n_cycles = select_n_cycles(env_name=args.env) - train_env = build_mujoco_goal_conditioned_env(args.env, seed=args.seed, render=args.render, - use_gymnasium=args.use_gymnasium) - eval_env = build_mujoco_goal_conditioned_env(args.env, test=True, seed=args.seed + 100, render=args.render, - use_gymnasium=args.use_gymnasium) + train_env = build_mujoco_goal_conditioned_env( + args.env, seed=args.seed, render=args.render, use_gymnasium=args.use_gymnasium + ) + eval_env = build_mujoco_goal_conditioned_env( + args.env, test=True, seed=args.seed + 100, render=args.render, use_gymnasium=args.use_gymnasium + ) max_timesteps = train_env.spec.max_episode_steps iteration_per_epoch = n_cycles * n_update * max_timesteps evaluator = EpisodicSuccessEvaluator(check_success=check_success, run_per_evaluation=10) - evaluation_hook = H.EvaluationHook(eval_env, evaluator, - timing=iteration_per_epoch, - writer=W.FileWriter(outdir=outdir, file_prefix='evaluation_result')) + evaluation_hook = H.EvaluationHook( + eval_env, + evaluator, + timing=iteration_per_epoch, + writer=W.FileWriter(outdir=outdir, file_prefix="evaluation_result"), + ) epoch_num_hook = H.EpochNumHook(iteration_per_epoch=iteration_per_epoch) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) - return_clip_min = -1. / (1. - gamma) + return_clip_min = -1.0 / (1.0 - gamma) return_clip_max = 0.0 return_clip = (return_clip_min, return_clip_max) - config = A.HERConfig(gpu_id=args.gpu, - gamma=gamma, - tau=tau, - exploration_noise_sigma=exploration_noise_sigma, - max_timesteps=max_timesteps, - n_cycles=n_cycles, - n_rollout=n_rollout, - n_update=n_update, - start_timesteps=start_timesteps, - batch_size=batch_size, - exploration_epsilon=exploration_epsilon, - return_clip=return_clip) + config = A.HERConfig( + gpu_id=args.gpu, + gamma=gamma, + tau=tau, + exploration_noise_sigma=exploration_noise_sigma, + max_timesteps=max_timesteps, + n_cycles=n_cycles, + n_rollout=n_rollout, + n_update=n_update, + start_timesteps=start_timesteps, + batch_size=batch_size, + exploration_epsilon=exploration_epsilon, + return_clip=return_clip, + ) her = A.HER(train_env, config=config) hooks = [epoch_num_hook, save_snapshot_hook, evaluation_hook] @@ -148,31 +155,31 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError('Please specify the snapshot dir for showcasing') - eval_env = build_mujoco_goal_conditioned_env(args.env, test=True, seed=args.seed + 200, render=args.render, - use_gymnasium=args.use_gymnasium) + raise ValueError("Please specify the snapshot dir for showcasing") + eval_env = build_mujoco_goal_conditioned_env( + args.env, test=True, seed=args.seed + 200, render=args.render, use_gymnasium=args.use_gymnasium + ) config = A.HERConfig(gpu_id=args.gpu) her = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(her, A.HER): - raise ValueError('Loaded snapshot is not trained with HER!') - evaluator = EpisodicSuccessEvaluator(check_success=check_success, - run_per_evaluation=args.showcase_runs) + raise ValueError("Loaded snapshot is not trained with HER!") + evaluator = EpisodicSuccessEvaluator(check_success=check_success, run_per_evaluation=args.showcase_runs) evaluator(her, eval_env) def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='FetchPush-v1') - parser.add_argument('--gpu', type=int, default=-1) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--save_timing', type=int, default=100000) - parser.add_argument('--save-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=20000000) - parser.add_argument('--showcase_runs', type=int, default=10) - parser.add_argument('--use-gymnasium', action='store_true') + parser.add_argument("--env", type=str, default="FetchPush-v1") + parser.add_argument("--gpu", type=int, default=-1) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--save_timing", type=int, default=100000) + parser.add_argument("--save-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=20000000) + parser.add_argument("--showcase_runs", type=int, default=10) + parser.add_argument("--use-gymnasium", action="store_true") args = parser.parse_args() @@ -182,5 +189,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/mujoco/icml2015trpo/icml2015trpo_reproduction.py b/reproductions/algorithms/mujoco/icml2015trpo/icml2015trpo_reproduction.py index 1aec7e3b..a9420e2d 100644 --- a/reproductions/algorithms/mujoco/icml2015trpo/icml2015trpo_reproduction.py +++ b/reproductions/algorithms/mujoco/icml2015trpo/icml2015trpo_reproduction.py @@ -24,27 +24,27 @@ def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100, use_gymnasium=args.use_gymnasium) evaluator = EpisodicEvaluator(run_per_evaluation=10) - evaluation_hook = H.EvaluationHook(eval_env, - evaluator, - timing=args.eval_timing, - writer=W.FileWriter(outdir=outdir, - file_prefix='evaluation_result')) + evaluation_hook = H.EvaluationHook( + eval_env, + evaluator, + timing=args.eval_timing, + writer=W.FileWriter(outdir=outdir, file_prefix="evaluation_result"), + ) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) iteration_num_hook = H.IterationNumHook(timing=1000) train_env = build_mujoco_env(args.env, seed=args.seed, render=args.render, use_gymnasium=args.use_gymnasium) - config = A.ICML2015TRPOConfig(gpu_id=args.gpu, - num_steps_per_iteration=1000000, - batch_size=1000000, - gpu_batch_size=100000) + config = A.ICML2015TRPOConfig( + gpu_id=args.gpu, num_steps_per_iteration=1000000, batch_size=1000000, gpu_batch_size=100000 + ) trpo = A.ICML2015TRPO(train_env, config=config) hooks = [iteration_num_hook, save_snapshot_hook, evaluation_hook] @@ -58,14 +58,14 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError( - 'Please specify the snapshot dir for showcasing') - eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render, - use_gymnasium=args.use_gymnasium) + raise ValueError("Please specify the snapshot dir for showcasing") + eval_env = build_mujoco_env( + args.env, test=True, seed=args.seed + 200, render=args.render, use_gymnasium=args.use_gymnasium + ) config = A.ICML2015TRPOConfig(gpu_id=args.gpu) trpo = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(trpo, A.ICML2015TRPO): - raise ValueError('Loaded snapshot is not trained with ICML2015TRPO!') + raise ValueError("Loaded snapshot is not trained with ICML2015TRPO!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) evaluator(trpo, eval_env) @@ -73,18 +73,18 @@ def run_showcase(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='Hopper-v2') - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--save-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=200000000) - parser.add_argument('--save_timing', type=int, default=1000000) - parser.add_argument('--eval_timing', type=int, default=1000000) - parser.add_argument('--showcase_runs', type=int, default=10) - parser.add_argument('--use-gymnasium', action='store_true') + parser.add_argument("--env", type=str, default="Hopper-v2") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--save-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=200000000) + parser.add_argument("--save_timing", type=int, default=1000000) + parser.add_argument("--eval_timing", type=int, default=1000000) + parser.add_argument("--showcase_runs", type=int, default=10) + parser.add_argument("--use-gymnasium", action="store_true") args = parser.parse_args() @@ -94,5 +94,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/mujoco/icml2018sac/icml2018sac_reproduction.py b/reproductions/algorithms/mujoco/icml2018sac/icml2018sac_reproduction.py index aee0350f..7a89b618 100644 --- a/reproductions/algorithms/mujoco/icml2018sac/icml2018sac_reproduction.py +++ b/reproductions/algorithms/mujoco/icml2018sac/icml2018sac_reproduction.py @@ -24,47 +24,48 @@ def select_start_timesteps(env_name): - if env_name in ['Ant-v2', 'HalfCheetah-v2']: + if env_name in ["Ant-v2", "HalfCheetah-v2"]: timesteps = 10000 else: timesteps = 1000 - print(f'Selected start timesteps: {timesteps}') + print(f"Selected start timesteps: {timesteps}") return timesteps def select_total_iterations(env_name): - if env_name in ['Hopper-v2', 'Walker2d-v2']: + if env_name in ["Hopper-v2", "Walker2d-v2"]: total_iterations = 1000000 - elif env_name in ['Humanoid-v2']: + elif env_name in ["Humanoid-v2"]: total_iterations = 10000000 else: total_iterations = 3000000 - print(f'Selected total iterations: {total_iterations}') + print(f"Selected total iterations: {total_iterations}") return total_iterations def select_reward_scalar(env_name): - if env_name in ['Humanoid-v2']: + if env_name in ["Humanoid-v2"]: scalar = 20.0 else: scalar = 5.0 - print(f'Selected reward scalar: {scalar}') + print(f"Selected reward scalar: {scalar}") return scalar def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100, use_gymnasium=args.use_gymnasium) evaluator = EpisodicEvaluator(run_per_evaluation=10) - evaluation_hook = H.EvaluationHook(eval_env, - evaluator, - timing=args.eval_timing, - writer=W.FileWriter(outdir=outdir, - file_prefix='evaluation_result')) + evaluation_hook = H.EvaluationHook( + eval_env, + evaluator, + timing=args.eval_timing, + writer=W.FileWriter(outdir=outdir, file_prefix="evaluation_result"), + ) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) iteration_num_hook = H.IterationNumHook(timing=100) @@ -72,9 +73,7 @@ def run_training(args): train_env = build_mujoco_env(args.env, seed=args.seed, render=args.render, use_gymnasium=args.use_gymnasium) timesteps = select_start_timesteps(args.env) reward_scalar = select_reward_scalar(args.env) - config = A.ICML2018SACConfig(gpu_id=args.gpu, - start_timesteps=timesteps, - reward_scalar=reward_scalar) + config = A.ICML2018SACConfig(gpu_id=args.gpu, start_timesteps=timesteps, reward_scalar=reward_scalar) icml2018sac = A.ICML2018SAC(train_env, config=config) hooks = [iteration_num_hook, save_snapshot_hook, evaluation_hook] @@ -89,14 +88,14 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError( - 'Please specify the snapshot dir for showcasing') - eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render, - use_gymnasium=args.use_gymnasium) + raise ValueError("Please specify the snapshot dir for showcasing") + eval_env = build_mujoco_env( + args.env, test=True, seed=args.seed + 200, render=args.render, use_gymnasium=args.use_gymnasium + ) config = A.ICML2018SACConfig(gpu_id=args.gpu) icml2018sac = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(icml2018sac, A.ICML2018SAC): - raise ValueError('Loaded snapshot is not trained with ICML2018SAC!') + raise ValueError("Loaded snapshot is not trained with ICML2018SAC!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) evaluator(icml2018sac, eval_env) @@ -104,18 +103,18 @@ def run_showcase(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='Ant-v2') - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--save-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=None) - parser.add_argument('--save_timing', type=int, default=5000) - parser.add_argument('--eval_timing', type=int, default=5000) - parser.add_argument('--showcase_runs', type=int, default=10) - parser.add_argument('--use-gymnasium', action='store_true') + parser.add_argument("--env", type=str, default="Ant-v2") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--save-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=None) + parser.add_argument("--save_timing", type=int, default=5000) + parser.add_argument("--eval_timing", type=int, default=5000) + parser.add_argument("--showcase_runs", type=int, default=10) + parser.add_argument("--use-gymnasium", action="store_true") args = parser.parse_args() @@ -125,5 +124,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/mujoco/ppo/ppo_reproduction.py b/reproductions/algorithms/mujoco/ppo/ppo_reproduction.py index 6637491e..e86d6f0b 100644 --- a/reproductions/algorithms/mujoco/ppo/ppo_reproduction.py +++ b/reproductions/algorithms/mujoco/ppo/ppo_reproduction.py @@ -24,7 +24,7 @@ def select_timelimit_as_terminal(env_name): - if 'Swimmer' in env_name: + if "Swimmer" in env_name: timelimit_as_terminal = True else: timelimit_as_terminal = False @@ -32,35 +32,38 @@ def select_timelimit_as_terminal(env_name): def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100, use_gymnasium=args.use_gymnasium) evaluator = EpisodicEvaluator(run_per_evaluation=10) - evaluation_hook = H.EvaluationHook(eval_env, - evaluator, - timing=args.eval_timing, - writer=W.FileWriter(outdir=outdir, - file_prefix='evaluation_result')) + evaluation_hook = H.EvaluationHook( + eval_env, + evaluator, + timing=args.eval_timing, + writer=W.FileWriter(outdir=outdir, file_prefix="evaluation_result"), + ) iteration_num_hook = H.IterationNumHook(timing=1000) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) train_env = build_mujoco_env(args.env, seed=args.seed, render=args.render, use_gymnasium=args.use_gymnasium) timelimit_as_terminal = select_timelimit_as_terminal(args.env) - config = A.PPOConfig(gpu_id=args.gpu, - epsilon=0.2, - entropy_coefficient=0.0, - actor_timesteps=2048, - epochs=10, - batch_size=64, - learning_rate=3.0*1e-4, - actor_num=1, - decrease_alpha=False, - timelimit_as_terminal=timelimit_as_terminal, - seed=args.seed) + config = A.PPOConfig( + gpu_id=args.gpu, + epsilon=0.2, + entropy_coefficient=0.0, + actor_timesteps=2048, + epochs=10, + batch_size=64, + learning_rate=3.0 * 1e-4, + actor_num=1, + decrease_alpha=False, + timelimit_as_terminal=timelimit_as_terminal, + seed=args.seed, + ) ppo = A.PPO(train_env, config=config) hooks = [iteration_num_hook, save_snapshot_hook, evaluation_hook] @@ -74,14 +77,14 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError( - 'Please specify the snapshot dir for showcasing') - eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render, - use_gymnasium=args.use_gymnasium) + raise ValueError("Please specify the snapshot dir for showcasing") + eval_env = build_mujoco_env( + args.env, test=True, seed=args.seed + 200, render=args.render, use_gymnasium=args.use_gymnasium + ) config = A.PPOConfig(gpu_id=args.gpu) ppo = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(ppo, A.PPO): - raise ValueError('Loaded snapshot is not trained with PPO!') + raise ValueError("Loaded snapshot is not trained with PPO!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) evaluator(ppo, eval_env) @@ -89,18 +92,18 @@ def run_showcase(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='Ant-v2') - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--save-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=1000000) - parser.add_argument('--save_timing', type=int, default=5000) - parser.add_argument('--eval_timing', type=int, default=5000) - parser.add_argument('--showcase_runs', type=int, default=10) - parser.add_argument('--use-gymnasium', action='store_true') + parser.add_argument("--env", type=str, default="Ant-v2") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--save-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=1000000) + parser.add_argument("--save_timing", type=int, default=5000) + parser.add_argument("--eval_timing", type=int, default=5000) + parser.add_argument("--showcase_runs", type=int, default=10) + parser.add_argument("--use-gymnasium", action="store_true") args = parser.parse_args() @@ -110,5 +113,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/mujoco/qrsac/qrsac_reproduction.py b/reproductions/algorithms/mujoco/qrsac/qrsac_reproduction.py index b2b8ba86..82b3e545 100644 --- a/reproductions/algorithms/mujoco/qrsac/qrsac_reproduction.py +++ b/reproductions/algorithms/mujoco/qrsac/qrsac_reproduction.py @@ -24,37 +24,41 @@ def select_total_iterations(env_name): - if env_name in ['Ant-v2', 'HalfCheetah-v2', 'Walker2d-v2']: + if env_name in ["Ant-v2", "HalfCheetah-v2", "Walker2d-v2"]: total_iterations = 3000000 - elif env_name in ['Humanoid-v2']: + elif env_name in ["Humanoid-v2"]: total_iterations = 10000000 else: total_iterations = 1000000 - print(f'Selected total iterations: {total_iterations}') + print(f"Selected total iterations: {total_iterations}") return total_iterations def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100, use_gymnasium=args.use_gymnasium) evaluator = EpisodicEvaluator(run_per_evaluation=10) - evaluation_hook = H.EvaluationHook(eval_env, - evaluator, - timing=args.eval_timing, - writer=W.FileWriter(outdir=outdir, file_prefix='evaluation_result')) + evaluation_hook = H.EvaluationHook( + eval_env, + evaluator, + timing=args.eval_timing, + writer=W.FileWriter(outdir=outdir, file_prefix="evaluation_result"), + ) iteration_num_hook = H.IterationNumHook(timing=100) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) train_env = build_mujoco_env(args.env, seed=args.seed, render=args.render, use_gymnasium=args.use_gymnasium) - config = A.QRSACConfig(gpu_id=args.gpu, - fix_temperature=args.fix_temperature, - initial_temperature=args.initial_temperature, - num_steps=args.num_steps) + config = A.QRSACConfig( + gpu_id=args.gpu, + fix_temperature=args.fix_temperature, + initial_temperature=args.initial_temperature, + num_steps=args.num_steps, + ) sac = A.QRSAC(train_env, config=config) hooks = [iteration_num_hook, save_snapshot_hook, evaluation_hook] @@ -69,14 +73,14 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError( - 'Please specify the snapshot dir for showcasing') - eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render, - use_gymnasium=args.use_gymnasium) + raise ValueError("Please specify the snapshot dir for showcasing") + eval_env = build_mujoco_env( + args.env, test=True, seed=args.seed + 200, render=args.render, use_gymnasium=args.use_gymnasium + ) config = A.QRSACConfig(gpu_id=args.gpu) sac = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(sac, A.QRSAC): - raise ValueError('Loaded snapshot is not trained with QRSAC!') + raise ValueError("Loaded snapshot is not trained with QRSAC!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) evaluator(sac, eval_env) @@ -84,23 +88,23 @@ def run_showcase(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='Ant-v2') - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--save-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=None) - parser.add_argument('--save_timing', type=int, default=5000) - parser.add_argument('--eval_timing', type=int, default=5000) - parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument("--env", type=str, default="Ant-v2") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--save-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=None) + parser.add_argument("--save_timing", type=int, default=5000) + parser.add_argument("--eval_timing", type=int, default=5000) + parser.add_argument("--showcase_runs", type=int, default=10) # QRSAC algorithm config - parser.add_argument('--fix-temperature', action='store_true') - parser.add_argument('--initial-temperature', type=float, default=None) - parser.add_argument('--num-steps', type=int, default=1, help='number of steps for n-step Q target') - parser.add_argument('--use-gymnasium', action='store_true') + parser.add_argument("--fix-temperature", action="store_true") + parser.add_argument("--initial-temperature", type=float, default=None) + parser.add_argument("--num-steps", type=int, default=1, help="number of steps for n-step Q target") + parser.add_argument("--use-gymnasium", action="store_true") args = parser.parse_args() @@ -110,5 +114,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/mujoco/redq/redq_reproduction.py b/reproductions/algorithms/mujoco/redq/redq_reproduction.py index f0c7e78d..70b5ce0f 100644 --- a/reproductions/algorithms/mujoco/redq/redq_reproduction.py +++ b/reproductions/algorithms/mujoco/redq/redq_reproduction.py @@ -24,26 +24,28 @@ def select_total_iterations(env_name): - if env_name in ['Hopper-v2']: + if env_name in ["Hopper-v2"]: total_iterations = 120000 else: total_iterations = 300000 - print(f'Selected total iterations: {total_iterations}') + print(f"Selected total iterations: {total_iterations}") return total_iterations def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100, use_gymnasium=args.use_gymnasium) evaluator = EpisodicEvaluator(run_per_evaluation=10) - evaluation_hook = H.EvaluationHook(eval_env, - evaluator, - timing=args.eval_timing, - writer=W.FileWriter(outdir=outdir, file_prefix='evaluation_result')) + evaluation_hook = H.EvaluationHook( + eval_env, + evaluator, + timing=args.eval_timing, + writer=W.FileWriter(outdir=outdir, file_prefix="evaluation_result"), + ) iteration_num_hook = H.IterationNumHook(timing=100) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) @@ -64,14 +66,14 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError( - 'Please specify the snapshot dir for showcasing') - eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render, - use_gymnasium=args.use_gymnasium) + raise ValueError("Please specify the snapshot dir for showcasing") + eval_env = build_mujoco_env( + args.env, test=True, seed=args.seed + 200, render=args.render, use_gymnasium=args.use_gymnasium + ) config = A.REDQConfig(gpu_id=args.gpu) redq = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(redq, A.REDQ): - raise ValueError('Loaded snapshot is not trained with REDQ!') + raise ValueError("Loaded snapshot is not trained with REDQ!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) evaluator(redq, eval_env) @@ -79,21 +81,21 @@ def run_showcase(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='Ant-v2') - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--save-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=None) - parser.add_argument('--save_timing', type=int, default=5000) - parser.add_argument('--eval_timing', type=int, default=5000) - parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument("--env", type=str, default="Ant-v2") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--save-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=None) + parser.add_argument("--save_timing", type=int, default=5000) + parser.add_argument("--eval_timing", type=int, default=5000) + parser.add_argument("--showcase_runs", type=int, default=10) # REDQ algorithm config - parser.add_argument('--fix-temperature', action='store_true') - parser.add_argument('--use-gymnasium', action='store_true') + parser.add_argument("--fix-temperature", action="store_true") + parser.add_argument("--use-gymnasium", action="store_true") args = parser.parse_args() @@ -103,5 +105,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/mujoco/sac/sac_reproduction.py b/reproductions/algorithms/mujoco/sac/sac_reproduction.py index f5751ddb..693e957b 100644 --- a/reproductions/algorithms/mujoco/sac/sac_reproduction.py +++ b/reproductions/algorithms/mujoco/sac/sac_reproduction.py @@ -24,28 +24,30 @@ def select_total_iterations(env_name): - if env_name in ['Ant-v2', 'HalfCheetah-v2', 'Walker2d-v2']: + if env_name in ["Ant-v2", "HalfCheetah-v2", "Walker2d-v2"]: total_iterations = 3000000 - elif env_name in ['Humanoid-v2']: + elif env_name in ["Humanoid-v2"]: total_iterations = 10000000 else: total_iterations = 1000000 - print(f'Selected total iterations: {total_iterations}') + print(f"Selected total iterations: {total_iterations}") return total_iterations def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100, use_gymnasium=args.use_gymnasium) evaluator = EpisodicEvaluator(run_per_evaluation=10) - evaluation_hook = H.EvaluationHook(eval_env, - evaluator, - timing=args.eval_timing, - writer=W.FileWriter(outdir=outdir, file_prefix='evaluation_result')) + evaluation_hook = H.EvaluationHook( + eval_env, + evaluator, + timing=args.eval_timing, + writer=W.FileWriter(outdir=outdir, file_prefix="evaluation_result"), + ) iteration_num_hook = H.IterationNumHook(timing=100) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) @@ -66,14 +68,14 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError( - 'Please specify the snapshot dir for showcasing') - eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render, - use_gymnasium=args.use_gymnasium) + raise ValueError("Please specify the snapshot dir for showcasing") + eval_env = build_mujoco_env( + args.env, test=True, seed=args.seed + 200, render=args.render, use_gymnasium=args.use_gymnasium + ) config = A.SACConfig(gpu_id=args.gpu) sac = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(sac, A.SAC): - raise ValueError('Loaded snapshot is not trained with SAC!') + raise ValueError("Loaded snapshot is not trained with SAC!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) evaluator(sac, eval_env) @@ -81,21 +83,21 @@ def run_showcase(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='Ant-v2') - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--save-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=None) - parser.add_argument('--save_timing', type=int, default=5000) - parser.add_argument('--eval_timing', type=int, default=5000) - parser.add_argument('--showcase_runs', type=int, default=10) - parser.add_argument('--use-gymnasium', action='store_true') + parser.add_argument("--env", type=str, default="Ant-v2") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--save-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=None) + parser.add_argument("--save_timing", type=int, default=5000) + parser.add_argument("--eval_timing", type=int, default=5000) + parser.add_argument("--showcase_runs", type=int, default=10) + parser.add_argument("--use-gymnasium", action="store_true") # SAC algorithm config - parser.add_argument('--fix-temperature', action='store_true') + parser.add_argument("--fix-temperature", action="store_true") args = parser.parse_args() @@ -105,5 +107,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/mujoco/td3/td3_reproduction.py b/reproductions/algorithms/mujoco/td3/td3_reproduction.py index b5d44272..9d44c387 100644 --- a/reproductions/algorithms/mujoco/td3/td3_reproduction.py +++ b/reproductions/algorithms/mujoco/td3/td3_reproduction.py @@ -24,26 +24,28 @@ def select_start_timesteps(env_name): - if env_name in ['Ant-v2', 'HalfCheetah-v2']: + if env_name in ["Ant-v2", "HalfCheetah-v2"]: timesteps = 10000 else: timesteps = 1000 - print(f'Selected start timesteps: {timesteps}') + print(f"Selected start timesteps: {timesteps}") return timesteps def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100, use_gymnasium=args.use_gymnasium) evaluator = EpisodicEvaluator(run_per_evaluation=10) - evaluation_hook = H.EvaluationHook(eval_env, - evaluator, - timing=args.eval_timing, - writer=W.FileWriter(outdir=outdir, file_prefix='evaluation_result')) + evaluation_hook = H.EvaluationHook( + eval_env, + evaluator, + timing=args.eval_timing, + writer=W.FileWriter(outdir=outdir, file_prefix="evaluation_result"), + ) iteration_num_hook = H.IterationNumHook(timing=100) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) @@ -63,14 +65,14 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError( - 'Please specify the snapshot dir for showcasing') - eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render, - use_gymnasium=args.use_gymnasium) + raise ValueError("Please specify the snapshot dir for showcasing") + eval_env = build_mujoco_env( + args.env, test=True, seed=args.seed + 200, render=args.render, use_gymnasium=args.use_gymnasium + ) config = A.TD3Config(gpu_id=args.gpu) td3 = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(td3, A.TD3): - raise ValueError('Loaded snapshot is not trained with TD3!') + raise ValueError("Loaded snapshot is not trained with TD3!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) evaluator(td3, eval_env) @@ -78,18 +80,18 @@ def run_showcase(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='Ant-v2') - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--save-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=1000000) - parser.add_argument('--save_timing', type=int, default=5000) - parser.add_argument('--eval_timing', type=int, default=5000) - parser.add_argument('--showcase_runs', type=int, default=10) - parser.add_argument('--use-gymnasium', action='store_true') + parser.add_argument("--env", type=str, default="Ant-v2") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--save-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=1000000) + parser.add_argument("--save_timing", type=int, default=5000) + parser.add_argument("--eval_timing", type=int, default=5000) + parser.add_argument("--showcase_runs", type=int, default=10) + parser.add_argument("--use-gymnasium", action="store_true") args = parser.parse_args() @@ -99,5 +101,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/mujoco/trpo/trpo_reproduction.py b/reproductions/algorithms/mujoco/trpo/trpo_reproduction.py index c698821f..c01f77cb 100644 --- a/reproductions/algorithms/mujoco/trpo/trpo_reproduction.py +++ b/reproductions/algorithms/mujoco/trpo/trpo_reproduction.py @@ -27,18 +27,19 @@ def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100, use_gymnasium=args.use_gymnasium) evaluator = EpisodicEvaluator(run_per_evaluation=10) - evaluation_hook = H.EvaluationHook(eval_env, - evaluator, - timing=args.eval_timing, - writer=W.FileWriter(outdir=outdir, - file_prefix='evaluation_result')) + evaluation_hook = H.EvaluationHook( + eval_env, + evaluator, + timing=args.eval_timing, + writer=W.FileWriter(outdir=outdir, file_prefix="evaluation_result"), + ) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) iteration_num_hook = H.IterationNumHook(timing=5000) @@ -58,37 +59,37 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError( - 'Please specify the snapshot dir for showcasing') - eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render, - use_gymnasium=args.use_gymnasium) + raise ValueError("Please specify the snapshot dir for showcasing") + eval_env = build_mujoco_env( + args.env, test=True, seed=args.seed + 200, render=args.render, use_gymnasium=args.use_gymnasium + ) config = A.TRPOConfig(gpu_id=args.gpu) trpo = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(trpo, A.TRPO): - raise ValueError('Loaded snapshot is not trained with TRPO!') + raise ValueError("Loaded snapshot is not trained with TRPO!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) returns = evaluator(trpo, eval_env) mean = np.mean(returns) std_dev = np.std(returns) median = np.median(returns) - logger.info('Evaluation results. mean: {} +/- std: {}, median: {}'.format(mean, std_dev, median)) + logger.info("Evaluation results. mean: {} +/- std: {}, median: {}".format(mean, std_dev, median)) def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='Hopper-v2') - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--save-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=1000000) - parser.add_argument('--save_timing', type=int, default=10000) - parser.add_argument('--eval_timing', type=int, default=10000) - parser.add_argument('--showcase_runs', type=int, default=10) - parser.add_argument('--use-gymnasium', action='store_true') + parser.add_argument("--env", type=str, default="Hopper-v2") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--save-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=1000000) + parser.add_argument("--save_timing", type=int, default=10000) + parser.add_argument("--eval_timing", type=int, default=10000) + parser.add_argument("--showcase_runs", type=int, default=10) + parser.add_argument("--use-gymnasium", action="store_true") args = parser.parse_args() @@ -98,5 +99,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/mujoco/xql/compute_expert_normalized_score.py b/reproductions/algorithms/mujoco/xql/compute_expert_normalized_score.py index f0d8153a..cd535819 100755 --- a/reproductions/algorithms/mujoco/xql/compute_expert_normalized_score.py +++ b/reproductions/algorithms/mujoco/xql/compute_expert_normalized_score.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,7 +26,7 @@ def load_histogram_data(path, dtype=float): histogram = [] with open(path) as f: - tsv_reader = reader(f, delimiter='\t') + tsv_reader = reader(f, delimiter="\t") for i, row in enumerate(tsv_reader): if i == 0: continue @@ -55,8 +55,8 @@ def extract_iteration_num_and_returns(histogram_data): returns = [] for i in range(len(histogram_data)): data_row = histogram_data[i] - if 'returns' in data_row[0]: - iteration_nums.append(int(data_row[0].split(' ')[0])) + if "returns" in data_row[0]: + iteration_nums.append(int(data_row[0].split(" ")[0])) scores = data_row[1][0:].astype(float) returns.append(scores) @@ -83,8 +83,8 @@ def create_expert_normalized_score_file(histograms, file_outdir, d4rl_env_name): std_dev = np.std(normalized_r) * 100 scalar_results = {} - scalar_results['mean'] = mean - scalar_results['std_dev'] = std_dev + scalar_results["mean"] = mean + scalar_results["std_dev"] = std_dev writer.write_scalar(i, scalar_results) @@ -94,12 +94,12 @@ def compile_results(args): histograms = {} histogram_directories = list_all_directory_with(rootdir, args.eval_histogram_filename) - print(f'files: {histogram_directories}') + print(f"files: {histogram_directories}") for directory in histogram_directories: if args.resultdir not in str(directory): continue relative_dir = directory.relative_to(rootdir) - env_name = str(relative_dir).split('/')[1] + env_name = str(relative_dir).split("/")[1] histogram_file = directory / args.eval_histogram_filename print(f"found histogram file of env: {env_name} at: {histogram_file}") if histogram_file.exists(): @@ -110,19 +110,21 @@ def compile_results(args): for env_name, histograms in histograms.items(): file_outdir = pathlib.Path(args.outdir) / pathlib.Path(env_name) - d4rl_env_name = env_name.replace('_results', '') + d4rl_env_name = env_name.replace("_results", "") create_expert_normalized_score_file(histograms, file_outdir, d4rl_env_name) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--outdir', type=str, required=True, help='output directory') - parser.add_argument('--resultdir', type=str, required=True, help='result directory') - parser.add_argument('--eval-histogram-filename', - type=str, - default="evaluation_result_histogram.tsv", - help='eval result(histogram) filename') + parser.add_argument("--outdir", type=str, required=True, help="output directory") + parser.add_argument("--resultdir", type=str, required=True, help="result directory") + parser.add_argument( + "--eval-histogram-filename", + type=str, + default="evaluation_result_histogram.tsv", + help="eval result(histogram) filename", + ) args = parser.parse_args() diff --git a/reproductions/algorithms/mujoco/xql/xql_reproduction.py b/reproductions/algorithms/mujoco/xql/xql_reproduction.py index 60a783f8..82091d6e 100755 --- a/reproductions/algorithms/mujoco/xql/xql_reproduction.py +++ b/reproductions/algorithms/mujoco/xql/xql_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -53,17 +53,17 @@ def build_solver(self, env_info, algorithm_config, **kwargs): def clip_actions_in_dataset(dataset): eps = 1e-5 lim = 1.0 - eps - dataset['actions'] = np.clip(dataset['actions'], -lim, lim) + dataset["actions"] = np.clip(dataset["actions"], -lim, lim) return dataset def normalize_dataset_score(dataset): def dataset_to_trajectories(dataset): - states = dataset['observations'] - actions = dataset['actions'] - rewards = dataset['rewards'] - terminals = dataset['terminals'] - timeouts = dataset['timeouts'] + states = dataset["observations"] + actions = dataset["actions"] + rewards = dataset["rewards"] + terminals = dataset["terminals"] + timeouts = dataset["timeouts"] trajectories = [] trajectory = [] @@ -89,11 +89,11 @@ def max_min_returns(trajectories): trajectories = dataset_to_trajectories(dataset) max_return, min_return = max_min_returns(trajectories) - print(f'len trajectories: {len(trajectories)}') - print(f'max return: {max_return}, min return: {min_return}') + print(f"len trajectories: {len(trajectories)}") + print(f"max return: {max_return}, min return: {min_return}") - dataset['rewards'] /= (max_return - min_return) - dataset['rewards'] *= 1000.0 + dataset["rewards"] /= max_return - min_return + dataset["rewards"] *= 1000.0 return dataset @@ -107,7 +107,7 @@ def build_env_and_dataset(env_name, seed=None): def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) @@ -126,10 +126,7 @@ def run_training(args): # Author's code sets different temperature for value/policy training. # If we set policy_temperature=value_temperature, the performace have slightly decreased. - config = A.XQLConfig(gpu_id=args.gpu, - batch_size=args.batch_size, - value_temperature=2.0, - policy_temperature=1/3.0) + config = A.XQLConfig(gpu_id=args.gpu, batch_size=args.batch_size, value_temperature=2.0, policy_temperature=1 / 3.0) env_info = EnvironmentInfo.from_env(eval_env) xql = A.XQL(env_info, config=config, policy_solver_builder=CosineDecayPolicySolverBuilder(args.total_iterations)) xql.set_hooks(hooks=[iteration_num_hook, save_snapshot_hook, evaluation_hook]) @@ -141,12 +138,12 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError('Please specify the snapshot dir for showcasing') + raise ValueError("Please specify the snapshot dir for showcasing") eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render) config = A.XQLConfig(gpu_id=args.gpu) xql = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(xql, A.XQL): - raise ValueError('Loaded snapshot is not trained with XQL!') + raise ValueError("Loaded snapshot is not trained with XQL!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) evaluator(xql, eval_env) @@ -154,18 +151,18 @@ def run_showcase(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='halfcheetah-expert-v2') - parser.add_argument('--save-dir', type=str, default="") - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=1000000) - parser.add_argument('--batch-size', type=int, default=256) - parser.add_argument('--save_timing', type=int, default=5000) - parser.add_argument('--eval_timing', type=int, default=5000) - parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument("--env", type=str, default="halfcheetah-expert-v2") + parser.add_argument("--save-dir", type=str, default="") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=1000000) + parser.add_argument("--batch-size", type=int, default=256) + parser.add_argument("--save_timing", type=int, default=5000) + parser.add_argument("--eval_timing", type=int, default=5000) + parser.add_argument("--showcase_runs", type=int, default=10) args = parser.parse_args() @@ -175,5 +172,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/pybullet/amp/amp_reproduction.py b/reproductions/algorithms/pybullet/amp/amp_reproduction.py index f54dac22..1609804e 100644 --- a/reproductions/algorithms/pybullet/amp/amp_reproduction.py +++ b/reproductions/algorithms/pybullet/amp/amp_reproduction.py @@ -23,8 +23,10 @@ from deepmimic_utils.deepmimic_env import DeepMimicEnv, DeepMimicGoalEnv, DeepMimicWindowViewer from deepmimic_utils.deepmimic_evaluator import DeepMimicEpisodicEvaluator from deepmimic_utils.deepmimic_explorer import DeepMimicExplorer -from deepmimic_utils.deepmimic_normalizer import (DeepMimicGoalTupleRunningMeanNormalizer, - DeepMimicTupleRunningMeanNormalizer) +from deepmimic_utils.deepmimic_normalizer import ( + DeepMimicGoalTupleRunningMeanNormalizer, + DeepMimicTupleRunningMeanNormalizer, +) import nnabla_rl.algorithms as A import nnabla_rl.environment_explorers as EE @@ -41,8 +43,13 @@ class DeepMimicTupleStatePreprocessorBuilder(PreprocessorBuilder): - def build_preprocessor(self, scope_name: str, env_info: EnvironmentInfo, # type: ignore[override] - algorithm_config: A.AMPConfig, **kwargs) -> Preprocessor: + def build_preprocessor( # type: ignore[override] + self, + scope_name: str, + env_info: EnvironmentInfo, + algorithm_config: A.AMPConfig, + **kwargs, + ) -> Preprocessor: assert algorithm_config.state_mean_initializer is not None assert algorithm_config.state_var_initializer is not None @@ -59,7 +66,8 @@ def build_preprocessor(self, scope_name: str, env_info: EnvironmentInfo, # type goal_state_mean_initializer=np.array(algorithm_config.state_mean_initializer[3], dtype=np.float32), goal_state_var_initializer=np.array(algorithm_config.state_var_initializer[3], dtype=np.float32), epsilon=0.02, - mode_for_floating_point_error="max") + mode_for_floating_point_error="max", + ) else: return DeepMimicTupleRunningMeanNormalizer( scope_name, @@ -70,54 +78,69 @@ def build_preprocessor(self, scope_name: str, env_info: EnvironmentInfo, # type reward_state_mean_initializer=np.array(algorithm_config.state_mean_initializer[1], dtype=np.float32), reward_state_var_initializer=np.array(algorithm_config.state_var_initializer[1], dtype=np.float32), epsilon=0.02, - mode_for_floating_point_error="max") + mode_for_floating_point_error="max", + ) class DeepMimicExplorerBuilder(ExplorerBuilder): - def build_explorer(self, env_info: EnvironmentInfo, algorithm_config: A.AMPConfig, # type: ignore[override] - algorithm: A.AMP, **kwargs) -> EnvironmentExplorer: + def build_explorer( # type: ignore[override] + self, + env_info: EnvironmentInfo, + algorithm_config: A.AMPConfig, + algorithm: A.AMP, + **kwargs, + ) -> EnvironmentExplorer: explorer_config = EE.LinearDecayEpsilonGreedyExplorerConfig( initial_step_num=0, timelimit_as_terminal=algorithm_config.timelimit_as_terminal, initial_epsilon=1.0, final_epsilon=algorithm_config.final_explore_rate, max_explore_steps=algorithm_config.max_explore_steps, - append_explorer_info=True) - explorer = DeepMimicExplorer(greedy_action_selector=kwargs["greedy_action_selector"], - random_action_selector=kwargs["random_action_selector"], - env_info=env_info, - config=explorer_config) + append_explorer_info=True, + ) + explorer = DeepMimicExplorer( + greedy_action_selector=kwargs["greedy_action_selector"], + random_action_selector=kwargs["random_action_selector"], + env_info=env_info, + config=explorer_config, + ) return explorer class DeepMimicReplayBufferBuilder(ReplayBufferBuilder): - def build_replay_buffer( - self, env_info: EnvironmentInfo, algorithm_config: A.AMPConfig, **kwargs # type: ignore[override] + def build_replay_buffer( # type: ignore[override] + self, env_info: EnvironmentInfo, algorithm_config: A.AMPConfig, **kwargs ) -> RandomRemovalReplayBuffer: return RandomRemovalReplayBuffer( capacity=int(np.ceil(algorithm_config.discriminator_agent_replay_buffer_size / algorithm_config.actor_num)) ) -def build_deepmimic_env(args_file_path: str, - goal_conditioned_env: bool, - seed: int, - eval_mode: bool, - print_env: bool, - num_processes: int, - render_env: bool) -> gym.Env: +def build_deepmimic_env( + args_file_path: str, + goal_conditioned_env: bool, + seed: int, + eval_mode: bool, + print_env: bool, + num_processes: int, + render_env: bool, +) -> gym.Env: env: gym.Env if args_file_path == "FakeAMPNNablaRL-v1": # NOTE: FakeAMPNNablaRL-v1 is for the algorithm test. env = gym.make(args_file_path) else: if goal_conditioned_env: - env = GoalConditionedTupleObservationEnv(DeepMimicGoalEnv( - args_file_path, eval_mode, num_processes=num_processes, step_until_action_needed=not render_env)) + env = GoalConditionedTupleObservationEnv( + DeepMimicGoalEnv( + args_file_path, eval_mode, num_processes=num_processes, step_until_action_needed=not render_env + ) + ) env = FlattenNestedTupleStateWrapper(env) else: - env = DeepMimicEnv(args_file_path, eval_mode, num_processes=num_processes, - step_until_action_needed=not render_env) + env = DeepMimicEnv( + args_file_path, eval_mode, num_processes=num_processes, step_until_action_needed=not render_env + ) # dummy reset for generating core env.reset() @@ -139,80 +162,95 @@ def build_config(args, train_env: Union[DeepMimicEnv, DeepMimicGoalEnv]): achieved_goal_mean = tuple([mean.tolist() for mean in train_env.unwrapped.observation_mean["achieved_goal"]]) achieved_goal_var = tuple([var.tolist() for var in train_env.unwrapped.observation_var["achieved_goal"]]) - config = A.AMPConfig(gpu_id=args.gpu, - seed=args.seed, - normalize_action=True, - preprocess_state=True, - use_reward_from_env=True, - gamma=0.99, - action_mean=tuple(train_env.unwrapped.action_mean.tolist()), - action_var=tuple(train_env.unwrapped.action_var.tolist()), - state_mean_initializer=tuple([*observation_mean, *desired_goal_mean, *achieved_goal_mean]), - state_var_initializer=tuple([*observation_var, *desired_goal_var, *achieved_goal_var]), - value_at_task_fail=train_env.unwrapped.reward_at_task_fail / (1.0 - 0.99), - value_at_task_success=train_env.unwrapped.reward_at_task_success / (1.0 - 0.99), - target_value_clip=(train_env.unwrapped.reward_range[0] / (1.0 - 0.99), - train_env.unwrapped.reward_range[1] / (1.0 - 0.99)), - v_function_learning_rate=2e-05, - policy_learning_rate=4e-06, - actor_num=args.actor_num, - actor_timesteps=4096 // args.actor_num, - max_explore_steps=200000000 // args.actor_num) + config = A.AMPConfig( + gpu_id=args.gpu, + seed=args.seed, + normalize_action=True, + preprocess_state=True, + use_reward_from_env=True, + gamma=0.99, + action_mean=tuple(train_env.unwrapped.action_mean.tolist()), + action_var=tuple(train_env.unwrapped.action_var.tolist()), + state_mean_initializer=tuple([*observation_mean, *desired_goal_mean, *achieved_goal_mean]), + state_var_initializer=tuple([*observation_var, *desired_goal_var, *achieved_goal_var]), + value_at_task_fail=train_env.unwrapped.reward_at_task_fail / (1.0 - 0.99), + value_at_task_success=train_env.unwrapped.reward_at_task_success / (1.0 - 0.99), + target_value_clip=( + train_env.unwrapped.reward_range[0] / (1.0 - 0.99), + train_env.unwrapped.reward_range[1] / (1.0 - 0.99), + ), + v_function_learning_rate=2e-05, + policy_learning_rate=4e-06, + actor_num=args.actor_num, + actor_timesteps=4096 // args.actor_num, + max_explore_steps=200000000 // args.actor_num, + ) else: - config = A.AMPConfig(gpu_id=args.gpu, - seed=args.seed, - normalize_action=True, - preprocess_state=True, - gamma=0.95, - action_mean=tuple(train_env.unwrapped.action_mean.tolist()), - action_var=tuple(train_env.unwrapped.action_var.tolist()), - state_mean_initializer=tuple([mean.tolist() - for mean in train_env.unwrapped.observation_mean]), - state_var_initializer=tuple([var.tolist() for var in train_env.unwrapped.observation_var]), - value_at_task_fail=train_env.unwrapped.reward_at_task_fail / (1.0 - 0.95), - value_at_task_success=train_env.unwrapped.reward_at_task_success / (1.0 - 0.95), - target_value_clip=(train_env.unwrapped.reward_range[0] / (1.0 - 0.95), - train_env.unwrapped.reward_range[1] / (1.0 - 0.95)), - actor_num=args.actor_num, - actor_timesteps=4096 // args.actor_num, - max_explore_steps=200000000 // args.actor_num) + config = A.AMPConfig( + gpu_id=args.gpu, + seed=args.seed, + normalize_action=True, + preprocess_state=True, + gamma=0.95, + action_mean=tuple(train_env.unwrapped.action_mean.tolist()), + action_var=tuple(train_env.unwrapped.action_var.tolist()), + state_mean_initializer=tuple([mean.tolist() for mean in train_env.unwrapped.observation_mean]), + state_var_initializer=tuple([var.tolist() for var in train_env.unwrapped.observation_var]), + value_at_task_fail=train_env.unwrapped.reward_at_task_fail / (1.0 - 0.95), + value_at_task_success=train_env.unwrapped.reward_at_task_success / (1.0 - 0.95), + target_value_clip=( + train_env.unwrapped.reward_range[0] / (1.0 - 0.95), + train_env.unwrapped.reward_range[1] / (1.0 - 0.95), + ), + actor_num=args.actor_num, + actor_timesteps=4096 // args.actor_num, + max_explore_steps=200000000 // args.actor_num, + ) return config def run_training(args): - env_name = str(pathlib.Path(args.args_file_path).name).replace('_args.txt', '').replace('train_amp_', '') + env_name = str(pathlib.Path(args.args_file_path).name).replace("_args.txt", "").replace("train_amp_", "") outdir = f"{env_name}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) - train_env = build_deepmimic_env(args.args_file_path, - goal_conditioned_env=args.goal_conditioned_env, - seed=args.seed, - eval_mode=False, - print_env=True, - num_processes=args.actor_num, - render_env=False) + train_env = build_deepmimic_env( + args.args_file_path, + goal_conditioned_env=args.goal_conditioned_env, + seed=args.seed, + eval_mode=False, + print_env=True, + num_processes=args.actor_num, + render_env=False, + ) config = build_config(args, train_env) - amp = A.AMP(train_env, - config=config, - env_explorer_builder=DeepMimicExplorerBuilder(), - state_preprocessor_builder=DeepMimicTupleStatePreprocessorBuilder(), - discriminator_replay_buffer_builder=DeepMimicReplayBufferBuilder()) - - eval_env = build_deepmimic_env(args.args_file_path, - goal_conditioned_env=args.goal_conditioned_env, - seed=args.seed + 100, - eval_mode=True, - print_env=False, - num_processes=1, - render_env=False) + amp = A.AMP( + train_env, + config=config, + env_explorer_builder=DeepMimicExplorerBuilder(), + state_preprocessor_builder=DeepMimicTupleStatePreprocessorBuilder(), + discriminator_replay_buffer_builder=DeepMimicReplayBufferBuilder(), + ) + + eval_env = build_deepmimic_env( + args.args_file_path, + goal_conditioned_env=args.goal_conditioned_env, + seed=args.seed + 100, + eval_mode=True, + print_env=False, + num_processes=1, + render_env=False, + ) evaluator = DeepMimicEpisodicEvaluator(run_per_evaluation=32) - evaluation_hook = H.EvaluationHook(eval_env, - evaluator, - timing=args.eval_timing, - writer=W.FileWriter(outdir=outdir, file_prefix="evaluation_result")) + evaluation_hook = H.EvaluationHook( + eval_env, + evaluator, + timing=args.eval_timing, + writer=W.FileWriter(outdir=outdir, file_prefix="evaluation_result"), + ) iteration_state_hook = H.IterationStateHook( writer=W.FileWriter(outdir=outdir, file_prefix="iteration_state"), timing=args.iteration_state_timing ) @@ -231,19 +269,21 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: raise ValueError("Please specify the snapshot dir for showcasing") - eval_env = build_deepmimic_env(args.args_file_path, - goal_conditioned_env=args.goal_conditioned_env, - seed=args.seed + 200, - eval_mode=True, - print_env=True, - num_processes=1, - render_env=args.render_in_showcase) + eval_env = build_deepmimic_env( + args.args_file_path, + goal_conditioned_env=args.goal_conditioned_env, + seed=args.seed + 200, + eval_mode=True, + print_env=True, + num_processes=1, + render_env=args.render_in_showcase, + ) config = build_config(args, eval_env) amp = serializers.load_snapshot( args.snapshot_dir, eval_env, - algorithm_kwargs={"config": config, - "state_preprocessor_builder": DeepMimicTupleStatePreprocessorBuilder()}) + algorithm_kwargs={"config": config, "state_preprocessor_builder": DeepMimicTupleStatePreprocessorBuilder()}, + ) if not isinstance(amp, A.AMP): raise ValueError("Loaded snapshot is not trained with AMP!") diff --git a/reproductions/algorithms/pybullet/amp/deepmimic_utils/deepmimic_env.py b/reproductions/algorithms/pybullet/amp/deepmimic_utils/deepmimic_env.py index 03d4e8f2..4a7067c9 100644 --- a/reproductions/algorithms/pybullet/amp/deepmimic_utils/deepmimic_env.py +++ b/reproductions/algorithms/pybullet/amp/deepmimic_utils/deepmimic_env.py @@ -23,35 +23,70 @@ from nnabla_rl.typing import Experience import sys # noqa + sys.path.append(str(pathlib.Path(__file__).parent)) # noqa -from deepmimic_env_utils import (update_core_for_num_substeps, initialize_env, compile_observation, # noqa - generate_dummy_goal_env_state, compile_goal_env_observation, - load_goal_env_observation_and_action_space, load_observation_and_action_space, - record_invalid_or_valid_state, label_task_result) +from deepmimic_env_utils import ( # noqa + update_core_for_num_substeps, + initialize_env, + compile_observation, + generate_dummy_goal_env_state, + compile_goal_env_observation, + load_goal_env_observation_and_action_space, + load_observation_and_action_space, + record_invalid_or_valid_state, + label_task_result, +) try: sys.path.append(str(pathlib.Path(__file__).parent.parent / "DeepMimic")) # noqa from DeepMimicCore import DeepMimicCore # noqa except ModuleNotFoundError: from nnabla_rl.logger import logger + logger.info("No DeepMimicCore file. Please build the DeepMimic environment and generate the python file.") try: - from OpenGL.GLUT import (GLUT_DEPTH, GLUT_DOUBLE, GLUT_ELAPSED_TIME, GLUT_RGBA, glutCreateWindow, glutDisplayFunc, - glutGet, glutInit, glutInitDisplayMode, glutInitWindowSize, glutKeyboardFunc, - glutLeaveMainLoop, glutMainLoop, glutMotionFunc, glutMouseFunc, glutPostRedisplay, - glutReshapeFunc, glutSwapBuffers, glutTimerFunc) + from OpenGL.GLUT import ( + GLUT_DEPTH, + GLUT_DOUBLE, + GLUT_ELAPSED_TIME, + GLUT_RGBA, + glutCreateWindow, + glutDisplayFunc, + glutGet, + glutInit, + glutInitDisplayMode, + glutInitWindowSize, + glutKeyboardFunc, + glutLeaveMainLoop, + glutMainLoop, + glutMotionFunc, + glutMouseFunc, + glutPostRedisplay, + glutReshapeFunc, + glutSwapBuffers, + glutTimerFunc, + ) except ModuleNotFoundError: from nnabla_rl.logger import logger - logger.info("No OpenGL lib. Please build the DeepMimic environment and generate the python file, " - "OpenGL lib is installed automatically.") + + logger.info( + "No OpenGL lib. Please build the DeepMimic environment and generate the python file, " + "OpenGL lib is installed automatically." + ) class DeepMimicEnv(AMPEnv): unwrapped: "DeepMimicEnv" - def __init__(self, args_file: str, eval_mode: bool, - fps: int = 60, num_processes: int = 1, step_until_action_needed: bool = True) -> None: + def __init__( + self, + args_file: str, + eval_mode: bool, + fps: int = 60, + num_processes: int = 1, + step_until_action_needed: bool = True, + ) -> None: assert fps > 0 assert num_processes > 0 @@ -66,16 +101,18 @@ def __init__(self, args_file: str, eval_mode: bool, self._initialized = False self._action_needed = True - (self.reward_range, - self.observation_space, - self.observation_mean, - self.observation_var, - self.action_space, - self.action_mean, - self.action_var, - self.reward_at_task_fail, - self.reward_at_task_success) = self._reward_range_state_and_action_space() - self.spec = EnvSpec(pathlib.Path(args_file).name.replace('train_amp_', '').replace('_args.txt', '-v0')) + ( + self.reward_range, + self.observation_space, + self.observation_mean, + self.observation_var, + self.action_space, + self.action_mean, + self.action_var, + self.reward_at_task_fail, + self.reward_at_task_success, + ) = self._reward_range_state_and_action_space() + self.spec = EnvSpec(pathlib.Path(args_file).name.replace("train_amp_", "").replace("_args.txt", "-v0")) super().__init__() @@ -137,9 +174,12 @@ def _step(self, action): if self._action_needed: self._core.SetAction(self._agent_id, np.array(action, dtype=np.float32).tolist()) - done, info = update_core_for_num_substeps(until_action_needed=self._step_until_action_needed, - core=self._core, update_timesteps=self._update_timesteps, - agent_id=self._agent_id) + done, info = update_core_for_num_substeps( + until_action_needed=self._step_until_action_needed, + core=self._core, + update_timesteps=self._update_timesteps, + agent_id=self._agent_id, + ) self._action_needed = info["action_needed"] next_state = compile_observation(self._core, self._num_timesteps, self._agent_id) @@ -156,32 +196,42 @@ def _reward_range_state_and_action_space(self): dummy_core.Init() assert dummy_core.GetGoalSize(self._agent_id) == 0, "This env has a goal! Use DeepMimicGoalEnv." - (observation_space, - observation_mean, - observation_var, - action_space, - action_mean, - action_var, - reward_at_task_fail, - reward_at_task_success) = load_observation_and_action_space(dummy_core, self._agent_id) + ( + observation_space, + observation_mean, + observation_var, + action_space, + action_mean, + action_var, + reward_at_task_fail, + reward_at_task_success, + ) = load_observation_and_action_space(dummy_core, self._agent_id) reward_range = (dummy_core.GetRewardMin(self._agent_id), dummy_core.GetRewardMax(self._agent_id)) - return (reward_range, - observation_space, - observation_mean, - observation_var, - action_space, - action_mean, - action_var, - reward_at_task_fail, - reward_at_task_success) + return ( + reward_range, + observation_space, + observation_mean, + observation_var, + action_space, + action_mean, + action_var, + reward_at_task_fail, + reward_at_task_success, + ) class DeepMimicGoalEnv(AMPGoalEnv): unwrapped: "DeepMimicGoalEnv" - def __init__(self, args_file: str, eval_mode: bool, - fps: int = 60, num_processes: int = 1, step_until_action_needed: bool = True) -> None: + def __init__( + self, + args_file: str, + eval_mode: bool, + fps: int = 60, + num_processes: int = 1, + step_until_action_needed: bool = True, + ) -> None: assert fps > 0 assert num_processes > 0 @@ -196,17 +246,19 @@ def __init__(self, args_file: str, eval_mode: bool, self._initialized = False self._action_needed = True - (self.reward_range, - self.observation_space, - self.observation_mean, - self.observation_var, - self.action_space, - self.action_mean, - self.action_var, - self.reward_at_task_fail, - self.reward_at_task_success) = self._reward_range_state_and_action_space() + ( + self.reward_range, + self.observation_space, + self.observation_mean, + self.observation_var, + self.action_space, + self.action_mean, + self.action_var, + self.reward_at_task_fail, + self.reward_at_task_success, + ) = self._reward_range_state_and_action_space() assert self.reward_at_task_fail < self.reward_at_task_success - self.spec = EnvSpec(pathlib.Path(args_file).name.replace('train_amp_', '').replace('_args.txt', '-v0')) + self.spec = EnvSpec(pathlib.Path(args_file).name.replace("train_amp_", "").replace("_args.txt", "-v0")) super().__init__() @@ -274,9 +326,12 @@ def _step(self, action): if self._action_needed: self._core.SetAction(self._agent_id, np.array(action, dtype=np.float32).tolist()) - done, info = update_core_for_num_substeps(until_action_needed=self._step_until_action_needed, - core=self._core, update_timesteps=self._update_timesteps, - agent_id=self._agent_id) + done, info = update_core_for_num_substeps( + until_action_needed=self._step_until_action_needed, + core=self._core, + update_timesteps=self._update_timesteps, + agent_id=self._agent_id, + ) self._action_needed = info["action_needed"] next_state = compile_goal_env_observation(self._core, self._num_timesteps, self._agent_id) @@ -293,34 +348,40 @@ def _reward_range_state_and_action_space(self): dummy_core.Init() assert dummy_core.GetGoalSize(self._agent_id) != 0, "This env does not have a goal! Use DeepMimicEnv." - (observation_space, - observation_mean, - observation_var, - action_space, - action_mean, - action_var, - reward_at_task_fail, - reward_at_task_success) = load_goal_env_observation_and_action_space(dummy_core, self._agent_id) + ( + observation_space, + observation_mean, + observation_var, + action_space, + action_mean, + action_var, + reward_at_task_fail, + reward_at_task_success, + ) = load_goal_env_observation_and_action_space(dummy_core, self._agent_id) reward_range = (dummy_core.GetRewardMin(self._agent_id), dummy_core.GetRewardMax(self._agent_id)) - return (reward_range, - observation_space, - observation_mean, - observation_var, - action_space, - action_mean, - action_var, - reward_at_task_fail, - reward_at_task_success) + return ( + reward_range, + observation_space, + observation_mean, + observation_var, + action_space, + action_mean, + action_var, + reward_at_task_fail, + reward_at_task_success, + ) class DeepMimicWindowViewer: - def __init__(self, - env: gym.Env, - policy_callback_function: Optional[Callable[[np.ndarray, np.ndarray], np.ndarray]] = None, - width: int = 800, - height: int = 450, - playback_speed: int = 1) -> None: + def __init__( + self, + env: gym.Env, + policy_callback_function: Optional[Callable[[np.ndarray, np.ndarray], np.ndarray]] = None, + width: int = 800, + height: int = 450, + playback_speed: int = 1, + ) -> None: self._env = env assert isinstance(env.unwrapped, DeepMimicEnv) or isinstance(env.unwrapped, DeepMimicGoalEnv) self._env_unwrapped = env.unwrapped @@ -339,8 +400,12 @@ def __init__(self, # NOTE: Create window first, then rendering should be enabled. self._initialize_window() - self._core = initialize_env(self._env_unwrapped._args_file, self._env_unwrapped._agent_id, - seed=self._env_unwrapped._seed, enable_window_view=True) + self._core = initialize_env( + self._env_unwrapped._args_file, + self._env_unwrapped._agent_id, + seed=self._env_unwrapped._seed, + enable_window_view=True, + ) self._env_unwrapped._core = self._core # Force to overwrite core self._setup_draw() @@ -391,7 +456,7 @@ def _keyboard(self, key, x, y): key_val = int.from_bytes(key, byteorder="big") self._core.Keyboard(key_val, x, y) - if (key == b"r"): + if key == b"r": self._state = self._env.reset() self._initial_step = False diff --git a/reproductions/algorithms/pybullet/amp/deepmimic_utils/deepmimic_env_utils.py b/reproductions/algorithms/pybullet/amp/deepmimic_utils/deepmimic_env_utils.py index a9ea5784..8aa54590 100644 --- a/reproductions/algorithms/pybullet/amp/deepmimic_utils/deepmimic_env_utils.py +++ b/reproductions/algorithms/pybullet/amp/deepmimic_utils/deepmimic_env_utils.py @@ -24,16 +24,19 @@ from nnabla_rl.typing import State import sys # noqa + try: sys.path.append(str(pathlib.Path(__file__).parent.parent / "DeepMimic")) # noqa from DeepMimicCore import DeepMimicCore # noqa except ModuleNotFoundError: from nnabla_rl.logger import logger + logger.info("No DeepMimicCore file. Please build the DeepMimic environment and generate the python file.") -def update_core(core: "DeepMimicCore.cDeepMimicCore", update_timesteps: int, agent_id: int - ) -> Tuple[bool, bool, Dict[str, bool]]: +def update_core( + core: "DeepMimicCore.cDeepMimicCore", update_timesteps: int, agent_id: int +) -> Tuple[bool, bool, Dict[str, bool]]: num_substeps = core.GetNumUpdateSubsteps() timestep = float(update_timesteps) / float(num_substeps) num_steps = 0 @@ -49,9 +52,15 @@ def update_core(core: "DeepMimicCore.cDeepMimicCore", update_timesteps: int, age terminate = core.CheckTerminate(agent_id) # 0 is Null, 1 is Fail and 2 is Success # See: https://github.com/xbpeng/DeepMimic/blob/70e7c6b22b775bb9342d4e15e6ef0bd91a55c6c0/env/env.py#L7 - return False, True, {"_valid_episode": valid_episode, - "_task_fail": True if terminate == 1 else False, - "_task_success": True if terminate == 2 else False} + return ( + False, + True, + { + "_valid_episode": valid_episode, + "_task_fail": True if terminate == 1 else False, + "_task_success": True if terminate == 2 else False, + }, + ) if core.NeedNewAction(agent_id): assert num_steps >= num_substeps @@ -60,8 +69,9 @@ def update_core(core: "DeepMimicCore.cDeepMimicCore", update_timesteps: int, age return False, False, {"_valid_episode": valid_episode, "_task_fail": False, "_task_success": False} -def update_core_for_num_substeps(until_action_needed: bool, core: "DeepMimicCore.cDeepMimicCore", - update_timesteps: int, agent_id: int) -> Tuple[bool, Dict[str, Any]]: +def update_core_for_num_substeps( + until_action_needed: bool, core: "DeepMimicCore.cDeepMimicCore", update_timesteps: int, agent_id: int +) -> Tuple[bool, Dict[str, Any]]: if until_action_needed: action_needed = False done = False @@ -82,24 +92,32 @@ def compile_observation(core: "DeepMimicCore.cDeepMimicCore", num_timesteps: int if num_timesteps == 0: # In an initial step, a dummy state given as the concatenated st and st+1. dummy_state = np.zeros(core.GetAMPObsSize(), dtype=np.float32) - return (np.array(core.RecordState(agent_id), dtype=np.float32), - dummy_state, - record_invalid_or_valid_state(num_timesteps)) + return ( + np.array(core.RecordState(agent_id), dtype=np.float32), + dummy_state, + record_invalid_or_valid_state(num_timesteps), + ) else: - return (np.array(core.RecordState(agent_id), dtype=np.float32), - np.array(core.RecordAMPObsAgent(agent_id), dtype=np.float32), - record_invalid_or_valid_state(num_timesteps)) + return ( + np.array(core.RecordState(agent_id), dtype=np.float32), + np.array(core.RecordAMPObsAgent(agent_id), dtype=np.float32), + record_invalid_or_valid_state(num_timesteps), + ) -def compile_goal_env_observation(core: "DeepMimicCore.cDeepMimicCore", num_timesteps: int, agent_id: int - ) -> Dict[str, State]: +def compile_goal_env_observation( + core: "DeepMimicCore.cDeepMimicCore", num_timesteps: int, agent_id: int +) -> Dict[str, State]: observation = compile_observation(core, num_timesteps, agent_id) - goal_env_observation = {"observation": observation, - "desired_goal": (np.array(core.RecordGoal(agent_id), dtype=np.float32), - np.ones((1,), dtype=np.float32)), - # not use an achieved goal - "achieved_goal": (np.array(core.RecordGoal(agent_id), dtype=np.float32) * 0.0, - np.zeros((1,), dtype=np.float32))} + goal_env_observation = { + "observation": observation, + "desired_goal": (np.array(core.RecordGoal(agent_id), dtype=np.float32), np.ones((1,), dtype=np.float32)), + # not use an achieved goal + "achieved_goal": ( + np.array(core.RecordGoal(agent_id), dtype=np.float32) * 0.0, + np.zeros((1,), dtype=np.float32), + ), + } return goal_env_observation @@ -128,8 +146,9 @@ def label_task_result(state, reward, done, info) -> TaskResult: return TaskResult.UNKNOWN -def initialize_env(args_file: str, agent_id: int, - enable_window_view: bool = False, seed: Optional[int] = None) -> "DeepMimicCore.cDeepMimicCore": +def initialize_env( + args_file: str, agent_id: int, enable_window_view: bool = False, seed: Optional[int] = None +) -> "DeepMimicCore.cDeepMimicCore": core = DeepMimicCore.cDeepMimicCore(enable_window_view) if seed is not None: core.SeedRand(seed) @@ -155,10 +174,9 @@ def initialize_env(args_file: str, agent_id: int, return core -def load_observation_and_action_space(dummy_core: "DeepMimicCore.cDeepMimicCore", - agent_id: int) -> Tuple[gym.Space, Tuple[np.ndarray, ...], Tuple[np.ndarray, ...], - gym.Space, np.ndarray, - np.ndarray, float, float]: +def load_observation_and_action_space( + dummy_core: "DeepMimicCore.cDeepMimicCore", agent_id: int +) -> Tuple[gym.Space, Tuple[np.ndarray, ...], Tuple[np.ndarray, ...], gym.Space, np.ndarray, np.ndarray, float, float]: action_space = spaces.Box( low=np.array(dummy_core.BuildActionBoundMin(agent_id), dtype=np.float32), high=np.array(dummy_core.BuildActionBoundMax(agent_id), dtype=np.float32), @@ -169,18 +187,15 @@ def load_observation_and_action_space(dummy_core: "DeepMimicCore.cDeepMimicCore" # observation for discriminator # DeepMimic env returns the concatenated st and st+1. # valid or invalid - observation_space = spaces.Tuple([spaces.Box(low=-np.inf, # type: ignore - high=np.inf, - shape=(dummy_core.GetStateSize(agent_id),), - dtype=np.float32), - spaces.Box(low=-np.inf, - high=np.inf, - shape=(dummy_core.GetAMPObsSize(),), - dtype=np.float32), - spaces.Box(low=0.0, - high=1.0, - shape=(1,), - dtype=np.float32)]) + observation_space = spaces.Tuple( + [ + spaces.Box( + low=-np.inf, high=np.inf, shape=(dummy_core.GetStateSize(agent_id),), dtype=np.float32 # type: ignore + ), + spaces.Box(low=-np.inf, high=np.inf, shape=(dummy_core.GetAMPObsSize(),), dtype=np.float32), + spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32), + ] + ) # Offset means default + offset, so mean is negative of the offset. obs_for_policy_mean = -1.0 * np.array(dummy_core.BuildStateOffset(agent_id), dtype=np.float32) obs_for_policy_var = (1.0 / np.array(dummy_core.BuildStateScale(agent_id), dtype=np.float32)) ** 2 @@ -192,60 +207,82 @@ def load_observation_and_action_space(dummy_core: "DeepMimicCore.cDeepMimicCore" action_var = (1.0 / np.array(dummy_core.BuildActionScale(agent_id), dtype=np.float32)) ** 2 reward_at_task_fail = float(dummy_core.GetRewardFail(agent_id)) reward_at_task_success = float(dummy_core.GetRewardSucc(agent_id)) - return (observation_space, - observation_mean, - observation_var, - action_space, - action_mean, - action_var, - reward_at_task_fail, - reward_at_task_success) - - -def load_goal_env_observation_and_action_space(dummy_core: "DeepMimicCore.cDeepMimicCore", agent_id: int - ) -> Tuple[gym.Space, Dict[str, Tuple[np.ndarray, ...]], - Dict[str, Tuple[np.ndarray, ...]], - gym.Space, np.ndarray, np.ndarray, - float, float]: + return ( + observation_space, + observation_mean, + observation_var, + action_space, + action_mean, + action_var, + reward_at_task_fail, + reward_at_task_success, + ) + + +def load_goal_env_observation_and_action_space(dummy_core: "DeepMimicCore.cDeepMimicCore", agent_id: int) -> Tuple[ + gym.Space, + Dict[str, Tuple[np.ndarray, ...]], + Dict[str, Tuple[np.ndarray, ...]], + gym.Space, + np.ndarray, + np.ndarray, + float, + float, +]: assert dummy_core.GetGoalSize(agent_id) != 0, "This env does not have a goal! Use DeepMimicEnv." - (observation_space, - observation_mean, - observation_var, - action_space, - action_mean, - action_var, - reward_at_task_fail, - reward_at_task_success) = load_observation_and_action_space(dummy_core, agent_id) + ( + observation_space, + observation_mean, + observation_var, + action_space, + action_mean, + action_var, + reward_at_task_fail, + reward_at_task_success, + ) = load_observation_and_action_space(dummy_core, agent_id) # Add goal state to observation space - goal_state_space = spaces.Tuple([spaces.Box(low=-np.inf, # type: ignore[operator] - high=np.inf, - shape=(dummy_core.GetGoalSize(agent_id),), - dtype=np.float32), - # valid or invalid - spaces.Box(low=0.0, - high=1.0, - shape=(1,), - dtype=np.float32)]) - goal_env_observation_space = spaces.Dict({"observation": observation_space, - "desired_goal": goal_state_space, - "achieved_goal": goal_state_space}) + goal_state_space = spaces.Tuple( + [ + spaces.Box( + low=-np.inf, # type: ignore[operator] + high=np.inf, + shape=(dummy_core.GetGoalSize(agent_id),), + dtype=np.float32, + ), + # valid or invalid + spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32), + ] + ) + goal_env_observation_space = spaces.Dict( + {"observation": observation_space, "desired_goal": goal_state_space, "achieved_goal": goal_state_space} + ) # Offset means default + offset, so mean is negative of the offset. - obs_for_goal_mean = (-1.0 * np.array(dummy_core.BuildGoalOffset(agent_id), dtype=np.float32), - np.zeros((1,), dtype=np.float32)) - obs_for_goal_var = ((1.0 / np.array(dummy_core.BuildGoalScale(agent_id), dtype=np.float32)) ** 2, - np.ones((1,), dtype=np.float32)) - - goal_env_observation_mean = {"observation": observation_mean, - "desired_goal": obs_for_goal_mean, - "achieved_goal": obs_for_goal_mean} - goal_env_observation_var = {"observation": observation_var, - "desired_goal": obs_for_goal_var, - "achieved_goal": obs_for_goal_var} - return (goal_env_observation_space, - goal_env_observation_mean, - goal_env_observation_var, - action_space, - action_mean, - action_var, - reward_at_task_fail, - reward_at_task_success) + obs_for_goal_mean = ( + -1.0 * np.array(dummy_core.BuildGoalOffset(agent_id), dtype=np.float32), + np.zeros((1,), dtype=np.float32), + ) + obs_for_goal_var = ( + (1.0 / np.array(dummy_core.BuildGoalScale(agent_id), dtype=np.float32)) ** 2, + np.ones((1,), dtype=np.float32), + ) + + goal_env_observation_mean = { + "observation": observation_mean, + "desired_goal": obs_for_goal_mean, + "achieved_goal": obs_for_goal_mean, + } + goal_env_observation_var = { + "observation": observation_var, + "desired_goal": obs_for_goal_var, + "achieved_goal": obs_for_goal_var, + } + return ( + goal_env_observation_space, + goal_env_observation_mean, + goal_env_observation_var, + action_space, + action_mean, + action_var, + reward_at_task_fail, + reward_at_task_success, + ) diff --git a/reproductions/algorithms/pybullet/amp/deepmimic_utils/deepmimic_evaluator.py b/reproductions/algorithms/pybullet/amp/deepmimic_utils/deepmimic_evaluator.py index 03f64bfa..f3c066a7 100644 --- a/reproductions/algorithms/pybullet/amp/deepmimic_utils/deepmimic_evaluator.py +++ b/reproductions/algorithms/pybullet/amp/deepmimic_utils/deepmimic_evaluator.py @@ -35,8 +35,11 @@ def __call__(self, algorithm, env: Union[AMPEnv, AMPGoalEnv]) -> List[float]: if info["valid_episode"]: returns.append(reward_sum) - logger.info("Finished evaluation run: #{} out of {}. Total reward: {}".format( - len(returns), self._num_episodes, reward_sum)) + logger.info( + "Finished evaluation run: #{} out of {}. Total reward: {}".format( + len(returns), self._num_episodes, reward_sum + ) + ) else: logger.info("Invalid episode. Skip to add this episode to evaluation") return returns diff --git a/reproductions/algorithms/pybullet/amp/deepmimic_utils/deepmimic_explorer.py b/reproductions/algorithms/pybullet/amp/deepmimic_utils/deepmimic_explorer.py index bdb96cd6..32034c2b 100644 --- a/reproductions/algorithms/pybullet/amp/deepmimic_utils/deepmimic_explorer.py +++ b/reproductions/algorithms/pybullet/amp/deepmimic_utils/deepmimic_explorer.py @@ -17,8 +17,10 @@ import gym from nnabla_rl.environment_explorer import EnvironmentExplorer, EnvironmentExplorerConfig -from nnabla_rl.environment_explorers.epsilon_greedy_explorer import (LinearDecayEpsilonGreedyExplorer, - LinearDecayEpsilonGreedyExplorerConfig) +from nnabla_rl.environment_explorers.epsilon_greedy_explorer import ( + LinearDecayEpsilonGreedyExplorer, + LinearDecayEpsilonGreedyExplorerConfig, +) from nnabla_rl.environments.amp_env import AMPEnv, AMPGoalEnv from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.typing import ActionSelector, Experience, State @@ -60,11 +62,13 @@ def step(self, env: gym.Env, n: int = 1, break_if_done: bool = False) -> List[Ex class DeepMimicExplorer(LinearDecayEpsilonGreedyExplorer, ExploreUntilValidEnvironmentExplorer): - def __init__(self, - greedy_action_selector: ActionSelector, - random_action_selector: ActionSelector, - env_info: EnvironmentInfo, - config: LinearDecayEpsilonGreedyExplorerConfig = LinearDecayEpsilonGreedyExplorerConfig()): + def __init__( + self, + greedy_action_selector: ActionSelector, + random_action_selector: ActionSelector, + env_info: EnvironmentInfo, + config: LinearDecayEpsilonGreedyExplorerConfig = LinearDecayEpsilonGreedyExplorerConfig(), + ): super().__init__( env_info=env_info, config=config, diff --git a/reproductions/algorithms/pybullet/amp/deepmimic_utils/deepmimic_normalizer.py b/reproductions/algorithms/pybullet/amp/deepmimic_utils/deepmimic_normalizer.py index 68b66471..b4ebd21d 100644 --- a/reproductions/algorithms/pybullet/amp/deepmimic_utils/deepmimic_normalizer.py +++ b/reproductions/algorithms/pybullet/amp/deepmimic_utils/deepmimic_normalizer.py @@ -24,31 +24,37 @@ class DeepMimicTupleRunningMeanNormalizer(Preprocessor, Model): _normalizers: List[RunningMeanNormalizer] - def __init__(self, - scope_name: str, - policy_state_shape: Tuple[Tuple[int, ...], ...], - reward_state_shape: Tuple[Tuple[int, ...], ...], - policy_state_mean_initializer: np.ndarray, - policy_state_var_initializer: np.ndarray, - reward_state_mean_initializer: np.ndarray, - reward_state_var_initializer: np.ndarray, - epsilon: float = 1e-2, - mode_for_floating_point_error: str = "max"): + def __init__( + self, + scope_name: str, + policy_state_shape: Tuple[Tuple[int, ...], ...], + reward_state_shape: Tuple[Tuple[int, ...], ...], + policy_state_mean_initializer: np.ndarray, + policy_state_var_initializer: np.ndarray, + reward_state_mean_initializer: np.ndarray, + reward_state_var_initializer: np.ndarray, + epsilon: float = 1e-2, + mode_for_floating_point_error: str = "max", + ): super(DeepMimicTupleRunningMeanNormalizer, self).__init__(scope_name) self._normalizers = [] - policy_state_normalizer = RunningMeanNormalizer(scope_name + "/policy_state", - shape=policy_state_shape, - epsilon=epsilon, - mode_for_floating_point_error=mode_for_floating_point_error, - mean_initializer=policy_state_mean_initializer, - var_initializer=policy_state_var_initializer) - reward_state_normalizer = RunningMeanNormalizer(scope_name + "/reward_state", - shape=reward_state_shape, - epsilon=epsilon, - mode_for_floating_point_error=mode_for_floating_point_error, - mean_initializer=reward_state_mean_initializer, - var_initializer=reward_state_var_initializer) + policy_state_normalizer = RunningMeanNormalizer( + scope_name + "/policy_state", + shape=policy_state_shape, + epsilon=epsilon, + mode_for_floating_point_error=mode_for_floating_point_error, + mean_initializer=policy_state_mean_initializer, + var_initializer=policy_state_var_initializer, + ) + reward_state_normalizer = RunningMeanNormalizer( + scope_name + "/reward_state", + shape=reward_state_shape, + epsilon=epsilon, + mode_for_floating_point_error=mode_for_floating_point_error, + mean_initializer=reward_state_mean_initializer, + var_initializer=reward_state_var_initializer, + ) self._normalizers = [policy_state_normalizer, reward_state_normalizer] def process(self, x): @@ -68,34 +74,40 @@ def update(self, data): class DeepMimicGoalTupleRunningMeanNormalizer(DeepMimicTupleRunningMeanNormalizer): _normalizers: List[RunningMeanNormalizer] - def __init__(self, - scope_name: str, - policy_state_shape: Tuple[Tuple[int, ...], ...], - reward_state_shape: Tuple[Tuple[int, ...], ...], - goal_state_shape: Tuple[Tuple[int, ...], ...], - policy_state_mean_initializer: np.ndarray, - policy_state_var_initializer: np.ndarray, - reward_state_mean_initializer: np.ndarray, - reward_state_var_initializer: np.ndarray, - goal_state_mean_initializer: np.ndarray, - goal_state_var_initializer: np.ndarray, - epsilon: float = 1e-2, - mode_for_floating_point_error: str = "max"): - super().__init__(scope_name, - policy_state_shape, - reward_state_shape, - policy_state_mean_initializer, - policy_state_var_initializer, - reward_state_mean_initializer, - reward_state_var_initializer, - epsilon, - mode_for_floating_point_error) - goal_state_normalizer = RunningMeanNormalizer(scope_name + "/goal_state", - shape=goal_state_shape, - epsilon=epsilon, - mode_for_floating_point_error=mode_for_floating_point_error, - mean_initializer=goal_state_mean_initializer, - var_initializer=goal_state_var_initializer) + def __init__( + self, + scope_name: str, + policy_state_shape: Tuple[Tuple[int, ...], ...], + reward_state_shape: Tuple[Tuple[int, ...], ...], + goal_state_shape: Tuple[Tuple[int, ...], ...], + policy_state_mean_initializer: np.ndarray, + policy_state_var_initializer: np.ndarray, + reward_state_mean_initializer: np.ndarray, + reward_state_var_initializer: np.ndarray, + goal_state_mean_initializer: np.ndarray, + goal_state_var_initializer: np.ndarray, + epsilon: float = 1e-2, + mode_for_floating_point_error: str = "max", + ): + super().__init__( + scope_name, + policy_state_shape, + reward_state_shape, + policy_state_mean_initializer, + policy_state_var_initializer, + reward_state_mean_initializer, + reward_state_var_initializer, + epsilon, + mode_for_floating_point_error, + ) + goal_state_normalizer = RunningMeanNormalizer( + scope_name + "/goal_state", + shape=goal_state_shape, + epsilon=epsilon, + mode_for_floating_point_error=mode_for_floating_point_error, + mean_initializer=goal_state_mean_initializer, + var_initializer=goal_state_var_initializer, + ) self._normalizers.append(goal_state_normalizer) def process(self, x): diff --git a/reproductions/algorithms/pybullet/icra2018qtopt/external_grasping_env/kuka.py b/reproductions/algorithms/pybullet/icra2018qtopt/external_grasping_env/kuka.py index 4d3beb00..86d05805 100644 --- a/reproductions/algorithms/pybullet/icra2018qtopt/external_grasping_env/kuka.py +++ b/reproductions/algorithms/pybullet/icra2018qtopt/external_grasping_env/kuka.py @@ -30,8 +30,7 @@ class Kuka: - def __init__(self, urdfRootPath='', timeStep=0.01, clientId=0, ikFix=False, - returnPos=True): + def __init__(self, urdfRootPath="", timeStep=0.01, clientId=0, ikFix=False, returnPos=True): """Creates a Kuka robot. Args: @@ -50,7 +49,7 @@ def __init__(self, urdfRootPath='', timeStep=0.01, clientId=0, ikFix=False, self.ikFix = ikFix self.returnPos = returnPos - self.maxForce = 200. + self.maxForce = 200.0 self.fingerAForce = 6 self.fingerBForce = 5.5 self.fingerTipForce = 6 @@ -60,27 +59,24 @@ def __init__(self, urdfRootPath='', timeStep=0.01, clientId=0, ikFix=False, self.useOrientation = 1 self.kukaEndEffectorIndex = 6 # lower limits for null space - self.ll = [-.967, -2, -2.96, 0.19, -2.96, -2.09, -3.05] + self.ll = [-0.967, -2, -2.96, 0.19, -2.96, -2.09, -3.05] # upper limits for null space - self.ul = [.967, 2, 2.96, 2.29, 2.96, 2.09, 3.05] + self.ul = [0.967, 2, 2.96, 2.29, 2.96, 2.09, 3.05] # joint ranges for null space self.jr = [5.8, 4, 5.8, 4, 5.8, 4, 6] # restposes for null space - self.rp = [0, 0, 0, 0.5*math.pi, 0, -math.pi*0.5*0.66, 0] + self.rp = [0, 0, 0, 0.5 * math.pi, 0, -math.pi * 0.5 * 0.66, 0] # joint damping coefficents - self.jd = [.1] * 12 + self.jd = [0.1] * 12 kuka_path = os.path.join(urdfRootPath, "kuka_iiwa/kuka_with_gripper2.sdf") self.kukaUid = pybullet.loadSDF(kuka_path, physicsClientId=self.cid)[0] tray_path = os.path.join(urdfRootPath, "tray/tray.urdf") - self.trayUid = pybullet.loadURDF(tray_path, - [0.64, 0.075, -0.19], - [0.0, 0.0, 1.0, 0.0], - physicsClientId=self.cid) + self.trayUid = pybullet.loadURDF( + tray_path, [0.64, 0.075, -0.19], [0.0, 0.0, 1.0, 0.0], physicsClientId=self.cid + ) self.reset() - def reset(self, - base_pos=None, - endeffector_pos=None): + def reset(self, base_pos=None, endeffector_pos=None): """Resets the kuka base and joint positions. Args: @@ -94,26 +90,39 @@ def reset(self, if endeffector_pos is None: endeffector_pos = [0.537, 0.0, 0.5] - pybullet.resetBasePositionAndOrientation(self.kukaUid, - base_pos, - [0.000000, 0.000000, 0.000000, 1.000000], - physicsClientId=self.cid) - self.jointPositions = [0.006418, 0.413184, -0.011401, -1.589317, 0.005379, - 1.137684, -0.006539, 0.000048, -0.299912, 0.000000, - -0.000043, 0.299960, 0.000000, -0.000200] + pybullet.resetBasePositionAndOrientation( + self.kukaUid, base_pos, [0.000000, 0.000000, 0.000000, 1.000000], physicsClientId=self.cid + ) + self.jointPositions = [ + 0.006418, + 0.413184, + -0.011401, + -1.589317, + 0.005379, + 1.137684, + -0.006539, + 0.000048, + -0.299912, + 0.000000, + -0.000043, + 0.299960, + 0.000000, + -0.000200, + ] self.numJoints = pybullet.getNumJoints(self.kukaUid, physicsClientId=self.cid) for jointIndex in range(self.numJoints): - pybullet.resetJointState(self.kukaUid, - jointIndex, - self.jointPositions[jointIndex], - physicsClientId=self.cid) + pybullet.resetJointState( + self.kukaUid, jointIndex, self.jointPositions[jointIndex], physicsClientId=self.cid + ) if self.useSimulation: - pybullet.setJointMotorControl2(self.kukaUid, - jointIndex, - pybullet.POSITION_CONTROL, - targetPosition=self.jointPositions[jointIndex], - force=self.maxForce, - physicsClientId=self.cid) + pybullet.setJointMotorControl2( + self.kukaUid, + jointIndex, + pybullet.POSITION_CONTROL, + targetPosition=self.jointPositions[jointIndex], + force=self.maxForce, + physicsClientId=self.cid, + ) # Set the endeffector height to endEffectorPos. self.endEffectorPos = endeffector_pos @@ -131,7 +140,7 @@ def reset(self, self.motorIndices.append(i) def getActionDimension(self): - if (self.useInverseKinematics): + if self.useInverseKinematics: return len(self.motorIndices) return 6 # Position x,y,z and roll/pitch/yaw euler angles of end effector. @@ -140,8 +149,7 @@ def getObservationDimension(self): def getObservation(self): observation = [] - state = pybullet.getLinkState( - self.kukaUid, self.kukaEndEffectorIndex, physicsClientId=self.cid) + state = pybullet.getLinkState(self.kukaUid, self.kukaEndEffectorIndex, physicsClientId=self.cid) if self.ikFix: # state[0] is the linkWorldPosition, the center of mass of the link. # However, the IK solver uses localInertialFrameOrientation, the inertial @@ -160,28 +168,50 @@ def applyFingerAngle(self, fingerAngle): # TODO(ejang) - replace with pybullet.setJointMotorControlArray (more # efficient). pybullet.setJointMotorControl2( - self.kukaUid, 7, pybullet.POSITION_CONTROL, - targetPosition=self.endEffectorAngle, force=self.maxForce, - physicsClientId=self.cid) + self.kukaUid, + 7, + pybullet.POSITION_CONTROL, + targetPosition=self.endEffectorAngle, + force=self.maxForce, + physicsClientId=self.cid, + ) pybullet.setJointMotorControl2( - self.kukaUid, 8, pybullet.POSITION_CONTROL, - targetPosition=-fingerAngle, force=self.fingerAForce, - physicsClientId=self.cid) + self.kukaUid, + 8, + pybullet.POSITION_CONTROL, + targetPosition=-fingerAngle, + force=self.fingerAForce, + physicsClientId=self.cid, + ) pybullet.setJointMotorControl2( - self.kukaUid, 11, pybullet.POSITION_CONTROL, - targetPosition=fingerAngle, force=self.fingerBForce, - physicsClientId=self.cid) + self.kukaUid, + 11, + pybullet.POSITION_CONTROL, + targetPosition=fingerAngle, + force=self.fingerBForce, + physicsClientId=self.cid, + ) pybullet.setJointMotorControl2( - self.kukaUid, 10, pybullet.POSITION_CONTROL, targetPosition=0, - force=self.fingerTipForce, physicsClientId=self.cid) + self.kukaUid, + 10, + pybullet.POSITION_CONTROL, + targetPosition=0, + force=self.fingerTipForce, + physicsClientId=self.cid, + ) pybullet.setJointMotorControl2( - self.kukaUid, 13, pybullet.POSITION_CONTROL, targetPosition=0, - force=self.fingerTipForce, physicsClientId=self.cid) + self.kukaUid, + 13, + pybullet.POSITION_CONTROL, + targetPosition=0, + force=self.fingerTipForce, + physicsClientId=self.cid, + ) def applyAction(self, motorCommands): pos = None - if (self.useInverseKinematics): + if self.useInverseKinematics: dx = motorCommands[0] dy = motorCommands[1] @@ -189,71 +219,99 @@ def applyAction(self, motorCommands): da = motorCommands[3] fingerAngle = motorCommands[4] - state = pybullet.getLinkState( - self.kukaUid, self.kukaEndEffectorIndex, physicsClientId=self.cid) + state = pybullet.getLinkState(self.kukaUid, self.kukaEndEffectorIndex, physicsClientId=self.cid) if self.ikFix: actualEndEffectorPos = state[4] self.endEffectorPos = list(actualEndEffectorPos) else: actualEndEffectorPos = state[0] - self.endEffectorPos[0] = self.endEffectorPos[0]+dx - if (self.endEffectorPos[0] > 0.75): + self.endEffectorPos[0] = self.endEffectorPos[0] + dx + if self.endEffectorPos[0] > 0.75: self.endEffectorPos[0] = 0.75 - if (self.endEffectorPos[0] < 0.45): + if self.endEffectorPos[0] < 0.45: self.endEffectorPos[0] = 0.45 - self.endEffectorPos[1] = self.endEffectorPos[1]+dy - if (self.endEffectorPos[1] < -0.22): + self.endEffectorPos[1] = self.endEffectorPos[1] + dy + if self.endEffectorPos[1] < -0.22: self.endEffectorPos[1] = -0.22 - if (self.endEffectorPos[1] > 0.22): + if self.endEffectorPos[1] > 0.22: self.endEffectorPos[1] = 0.22 - if (dz > 0 or actualEndEffectorPos[2] > 0.10): - self.endEffectorPos[2] = self.endEffectorPos[2]+dz - if (actualEndEffectorPos[2] < 0.10): - self.endEffectorPos[2] = self.endEffectorPos[2]+0.0001 + if dz > 0 or actualEndEffectorPos[2] > 0.10: + self.endEffectorPos[2] = self.endEffectorPos[2] + dz + if actualEndEffectorPos[2] < 0.10: + self.endEffectorPos[2] = self.endEffectorPos[2] + 0.0001 self.endEffectorAngle = self.endEffectorAngle + da pos = self.endEffectorPos orn = pybullet.getQuaternionFromEuler([0, -math.pi, 0]) # -math.pi,yaw]) - if (self.useNullSpace == 1): - if (self.useOrientation == 1): + if self.useNullSpace == 1: + if self.useOrientation == 1: jointPoses = pybullet.calculateInverseKinematics( - self.kukaUid, self.kukaEndEffectorIndex, pos, - orn, self.ll, self.ul, self.jr, self.rp, - maxNumIterations=1, physicsClientId=self.cid) + self.kukaUid, + self.kukaEndEffectorIndex, + pos, + orn, + self.ll, + self.ul, + self.jr, + self.rp, + maxNumIterations=1, + physicsClientId=self.cid, + ) else: jointPoses = pybullet.calculateInverseKinematics( - self.kukaUid, self.kukaEndEffectorIndex, pos, lowerLimits=self.ll, - upperLimits=self.ul, jointRanges=self.jr, - restPoses=self.rp, maxNumIterations=1, - physicsClientId=self.cid) + self.kukaUid, + self.kukaEndEffectorIndex, + pos, + lowerLimits=self.ll, + upperLimits=self.ul, + jointRanges=self.jr, + restPoses=self.rp, + maxNumIterations=1, + physicsClientId=self.cid, + ) else: - if (self.useOrientation == 1): + if self.useOrientation == 1: if self.ikFix: jointPoses = pybullet.calculateInverseKinematics( - self.kukaUid, self.kukaEndEffectorIndex, - pos, orn, jointDamping=self.jd, maxNumIterations=50, - residualThreshold=.001, - physicsClientId=self.cid) + self.kukaUid, + self.kukaEndEffectorIndex, + pos, + orn, + jointDamping=self.jd, + maxNumIterations=50, + residualThreshold=0.001, + physicsClientId=self.cid, + ) else: jointPoses = pybullet.calculateInverseKinematics( - self.kukaUid, self.kukaEndEffectorIndex, - pos, orn, jointDamping=self.jd, maxNumIterations=1, - physicsClientId=self.cid) + self.kukaUid, + self.kukaEndEffectorIndex, + pos, + orn, + jointDamping=self.jd, + maxNumIterations=1, + physicsClientId=self.cid, + ) else: jointPoses = pybullet.calculateInverseKinematics( - self.kukaUid, self.kukaEndEffectorIndex, pos, - maxNumIterations=1, physicsClientId=self.cid) - if (self.useSimulation): - for i in range(self.kukaEndEffectorIndex+1): + self.kukaUid, self.kukaEndEffectorIndex, pos, maxNumIterations=1, physicsClientId=self.cid + ) + if self.useSimulation: + for i in range(self.kukaEndEffectorIndex + 1): # print(i) pybullet.setJointMotorControl2( - bodyIndex=self.kukaUid, jointIndex=i, + bodyIndex=self.kukaUid, + jointIndex=i, controlMode=pybullet.POSITION_CONTROL, targetPosition=jointPoses[i], - targetVelocity=0, force=self.maxForce, positionGain=0.03, - velocityGain=1, physicsClientId=self.cid) + targetVelocity=0, + force=self.maxForce, + positionGain=0.03, + velocityGain=1, + physicsClientId=self.cid, + ) else: # Reset the joint state (ignoring all dynamics, not recommended to use # during simulation). @@ -261,8 +319,7 @@ def applyAction(self, motorCommands): # TODO(b/72742371) Figure out why if useSimulation = 0, # len(jointPoses) = 12 and self.numJoints = 14. for i in range(len(jointPoses)): - pybullet.resetJointState(self.kukaUid, i, jointPoses[i], - physicsClientId=self.cid) + pybullet.resetJointState(self.kukaUid, i, jointPoses[i], physicsClientId=self.cid) # Move fingers. self.applyFingerAngle(fingerAngle) @@ -270,9 +327,13 @@ def applyAction(self, motorCommands): for action in range(len(motorCommands)): motor = self.motorIndices[action] pybullet.setJointMotorControl2( - self.kukaUid, motor, pybullet.POSITION_CONTROL, - targetPosition=motorCommands[action], force=self.maxForce, - physicsClientId=self.cid) + self.kukaUid, + motor, + pybullet.POSITION_CONTROL, + targetPosition=motorCommands[action], + force=self.maxForce, + physicsClientId=self.cid, + ) if self.returnPos: # Return the target position for metrics later. return pos diff --git a/reproductions/algorithms/pybullet/icra2018qtopt/external_grasping_env/kuka_grasping_procedural_env.py b/reproductions/algorithms/pybullet/icra2018qtopt/external_grasping_env/kuka_grasping_procedural_env.py index 74f62862..53e35602 100644 --- a/reproductions/algorithms/pybullet/icra2018qtopt/external_grasping_env/kuka_grasping_procedural_env.py +++ b/reproductions/algorithms/pybullet/icra2018qtopt/external_grasping_env/kuka_grasping_procedural_env.py @@ -36,9 +36,10 @@ INTERNAL_BULLET_ROOT = None if INTERNAL_BULLET_ROOT is None: import pybullet_data + OSS_DATA_ROOT = pybullet_data.getDataPath() else: - OSS_DATA_ROOT = '' + OSS_DATA_ROOT = "" # pylint: enable=bad-import-order # pylint: enable=g-import-not-at-top @@ -47,28 +48,29 @@ class KukaGraspingProceduralEnv(gym.Env): """Simplified grasping environment with discrete and continuous actions.""" def __init__( - self, - block_random=0.3, - camera_random=0, - simple_observations=False, - continuous=False, - remove_height_hack=False, - urdf_list=None, - render_mode='GUI', - num_objects=5, - dv=0.06, - target=False, - target_filenames=None, - non_target_filenames=None, - num_resets_per_setup=1, - render_width=128, - render_height=128, - downsample_width=64, - downsample_height=64, - test=False, - allow_duplicate_objects=True, - max_num_training_models=900, - max_num_test_models=100): + self, + block_random=0.3, + camera_random=0, + simple_observations=False, + continuous=False, + remove_height_hack=False, + urdf_list=None, + render_mode="GUI", + num_objects=5, + dv=0.06, + target=False, + target_filenames=None, + non_target_filenames=None, + num_resets_per_setup=1, + render_width=128, + render_height=128, + downsample_width=64, + downsample_height=64, + test=False, + allow_duplicate_objects=True, + max_num_training_models=900, + max_num_test_models=100, + ): """Creates a KukaGraspingEnv. Args: @@ -99,16 +101,16 @@ def __init__( max_num_test_models: The number of distinct models to choose from when selecting the num_objects placed in the tray for testing. """ - self._time_step = 1. / 200. + self._time_step = 1.0 / 200.0 self._max_steps = 15 # Open-source search paths. self._urdf_root = OSS_DATA_ROOT - self._models_dir = os.path.join(self._urdf_root, 'random_urdfs') + self._models_dir = os.path.join(self._urdf_root, "random_urdfs") self._action_repeat = 200 self._env_step = 0 - self._renders = render_mode in ['GUI', 'TCP'] + self._renders = render_mode in ["GUI", "TCP"] # Size we render at. self._width = render_width self._height = render_height @@ -122,10 +124,8 @@ def __init__( if target_filenames: target_filenames = [self._get_urdf_path(f) for f in target_filenames] if non_target_filenames: - non_target_filenames = [ - self._get_urdf_path(f) for f in non_target_filenames] - self._object_filenames = (target_filenames or []) + ( - non_target_filenames or []) + non_target_filenames = [self._get_urdf_path(f) for f in non_target_filenames] + self._object_filenames = (target_filenames or []) + (non_target_filenames or []) self._target_filenames = target_filenames or [] self._block_random = block_random self._cam_random = camera_random @@ -139,13 +139,13 @@ def __init__( self._max_num_training_models = max_num_training_models self._max_num_test_models = max_num_test_models - if render_mode == 'GUI': + if render_mode == "GUI": self.cid = pybullet.connect(pybullet.GUI) pybullet.resetDebugVisualizerCamera(1.3, 180, -41, [0.52, -0.2, -0.33]) - elif render_mode == 'DIRECT': + elif render_mode == "DIRECT": self.cid = pybullet.connect(pybullet.DIRECT) - elif render_mode == 'TCP': - self.cid = pybullet.connect(pybullet.TCP, 'localhost', 6667) + elif render_mode == "TCP": + self.cid = pybullet.connect(pybullet.TCP, "localhost", 6667) self.setup() if self._continuous: @@ -165,10 +165,7 @@ def __init__( self.observation_space = spaces.Box(low=-100, high=100, shape=(14,)) else: # image (self._height, self._width, 3) x position of the gripper (3,) - img_space = spaces.Box( - low=0, - high=255, - shape=(self._downsample_height, self._downsample_width, 3)) + img_space = spaces.Box(low=0, high=255, shape=(self._downsample_height, self._downsample_width, 3)) # self.observation_space = spaces.Tuple((img_space, pos_space)) self.observation_space = img_space self.viewer = None @@ -185,29 +182,23 @@ def setup(self): ) self._urdf_list = self._object_filenames pybullet.resetSimulation(physicsClientId=self.cid) - pybullet.setPhysicsEngineParameter( - numSolverIterations=150, physicsClientId=self.cid) + pybullet.setPhysicsEngineParameter(numSolverIterations=150, physicsClientId=self.cid) pybullet.setTimeStep(self._time_step, physicsClientId=self.cid) pybullet.setGravity(0, 0, -10, physicsClientId=self.cid) - plane_path = os.path.join(self._urdf_root, 'plane.urdf') + plane_path = os.path.join(self._urdf_root, "plane.urdf") pybullet.loadURDF(plane_path, [0, 0, -1], physicsClientId=self.cid) - table_path = os.path.join(self._urdf_root, 'table/table.urdf') - pybullet.loadURDF( - table_path, [0.5, 0.0, -.82], [0., 0., 0., 1.], - physicsClientId=self.cid) - self._kuka = kuka.Kuka( - urdfRootPath=self._urdf_root, - timeStep=self._time_step, - clientId=self.cid) + table_path = os.path.join(self._urdf_root, "table/table.urdf") + pybullet.loadURDF(table_path, [0.5, 0.0, -0.82], [0.0, 0.0, 0.0, 1.0], physicsClientId=self.cid) + self._kuka = kuka.Kuka(urdfRootPath=self._urdf_root, timeStep=self._time_step, clientId=self.cid) self._block_uids = [] for urdf_name in self._urdf_list: xpos = 0.4 + self._block_random * random.random() - ypos = self._block_random * (random.random() - .5) + ypos = self._block_random * (random.random() - 0.5) angle = np.pi / 2 + self._block_random * np.pi * random.random() ori = pybullet.getQuaternionFromEuler([0, 0, angle]) uid = pybullet.loadURDF( - urdf_name, [xpos, ypos, .15], [ori[0], ori[1], ori[2], ori[3]], - physicsClientId=self.cid) + urdf_name, [xpos, ypos, 0.15], [ori[0], ori[1], ori[2], ori[3]], physicsClientId=self.cid + ) self._block_uids.append(uid) for _ in range(500): pybullet.stepSimulation(physicsClientId=self.cid) @@ -220,30 +211,27 @@ def reset(self): self._attempted_grasp = False look = [0.23, 0.2, 0.54] - distance = 1. + distance = 1.0 pitch = -56 + self._cam_random * np.random.uniform(-3, 3) yaw = 245 + self._cam_random * np.random.uniform(-3, 3) roll = 0 - self._view_matrix = pybullet.computeViewMatrixFromYawPitchRoll( - look, distance, yaw, pitch, roll, 2) - fov = 20. + self._cam_random * np.random.uniform(-2, 2) + self._view_matrix = pybullet.computeViewMatrixFromYawPitchRoll(look, distance, yaw, pitch, roll, 2) + fov = 20.0 + self._cam_random * np.random.uniform(-2, 2) aspect = self._width / self._height near = 0.1 far = 10 - self._proj_matrix = pybullet.computeProjectionMatrixFOV( - fov, aspect, near, far) + self._proj_matrix = pybullet.computeProjectionMatrixFOV(fov, aspect, near, far) self._env_step = 0 for i in range(len(self._urdf_list)): xpos = 0.4 + self._block_random * random.random() - ypos = self._block_random * (random.random() - .5) + ypos = self._block_random * (random.random() - 0.5) # random angle angle = np.pi / 2 + self._block_random * np.pi * random.random() ori = pybullet.getQuaternionFromEuler([0, 0, angle]) pybullet.resetBasePositionAndOrientation( - self._block_uids[i], [xpos, ypos, .15], - [ori[0], ori[1], ori[2], ori[3]], - physicsClientId=self.cid) + self._block_uids[i], [xpos, ypos, 0.15], [ori[0], ori[1], ori[2], ori[3]], physicsClientId=self.cid + ) # Let each object fall to the tray individual, to prevent object # intersection. for _ in range(500): @@ -267,11 +255,13 @@ def _get_observation(self): return self._get_image_observation() def _get_image_observation(self): - results = pybullet.getCameraImage(width=self._width, - height=self._height, - viewMatrix=self._view_matrix, - projectionMatrix=self._proj_matrix, - physicsClientId=self.cid) + results = pybullet.getCameraImage( + width=self._width, + height=self._height, + viewMatrix=self._view_matrix, + projectionMatrix=self._proj_matrix, + physicsClientId=self.cid, + ) rgba = results[2] np_img_arr = np.reshape(rgba, (self._height, self._width, 4)) # Extract RGB components only. @@ -288,17 +278,14 @@ def _get_simple_observation(self): Numpy array containing location and orientation of nearest block and location of end-effector. """ - state = pybullet.getLinkState( - self._kuka.kukaUid, self._kuka.kukaEndEffectorIndex, - physicsClientId=self.cid) + state = pybullet.getLinkState(self._kuka.kukaUid, self._kuka.kukaEndEffectorIndex, physicsClientId=self.cid) end_effector_pos = np.array(state[0]) end_effector_ori = np.array(state[1]) distances = [] pos_and_ori = [] for uid in self._block_uids: - pos, ori = pybullet.getBasePositionAndOrientation( - uid, physicsClientId=self.cid) + pos, ori = pybullet.getBasePositionAndOrientation(uid, physicsClientId=self.cid) pos, ori = np.array(pos), np.array(ori) pos_and_ori.append((pos, ori)) distances.append(np.linalg.norm(end_effector_pos - pos)) @@ -354,9 +341,7 @@ def _step_continuous(self, action): break # If we are close to the bin, attempt grasp. - state = pybullet.getLinkState(self._kuka.kukaUid, - self._kuka.kukaEndEffectorIndex, - physicsClientId=self.cid) + state = pybullet.getLinkState(self._kuka.kukaUid, self._kuka.kukaEndEffectorIndex, physicsClientId=self.cid) end_effector_pos = state[0] if end_effector_pos[2] <= 0.1: finger_angle = 0.3 @@ -364,7 +349,7 @@ def _step_continuous(self, action): grasp_action = [0, 0, 0.001, 0, finger_angle] self._kuka.applyAction(grasp_action) pybullet.stepSimulation(physicsClientId=self.cid) - finger_angle -= 0.3/100. + finger_angle -= 0.3 / 100.0 if finger_angle < 0: finger_angle = 0 self._attempted_grasp = True @@ -372,12 +357,10 @@ def _step_continuous(self, action): done = self._termination() reward = self._reward() - debug = { - 'grasp_success': self._grasp_success - } + debug = {"grasp_success": self._grasp_success} return observation, reward, done, debug - def _render(self, mode='human'): + def _render(self, mode="human"): return def _termination(self): @@ -388,13 +371,12 @@ def _reward(self): self._grasp_success = 0 if self._target: - target_uids = self._block_uids[0:len(self._target_filenames)] + target_uids = self._block_uids[0 : len(self._target_filenames)] else: target_uids = self._block_uids for uid in target_uids: - pos, _ = pybullet.getBasePositionAndOrientation( - uid, physicsClientId=self.cid) + pos, _ = pybullet.getBasePositionAndOrientation(uid, physicsClientId=self.cid) # If any block is above height, provide reward. if pos[2] > 0.2: self._grasp_success = 1 @@ -423,17 +405,16 @@ def _get_random_objects(self, num_objects, test, replace=True): A list of urdf filenames. """ if test: - urdf_pattern = os.path.join(self._models_dir, '*0/*.urdf') + urdf_pattern = os.path.join(self._models_dir, "*0/*.urdf") max_num_objects = self._max_num_test_models else: - urdf_pattern = os.path.join(self._models_dir, '*[^0]/*.urdf') + urdf_pattern = os.path.join(self._models_dir, "*[^0]/*.urdf") max_num_objects = self._max_num_training_models found_object_directories = glob.glob(urdf_pattern) total_num_objects = len(found_object_directories) if total_num_objects > max_num_objects: total_num_objects = max_num_objects - selected_objects = np.random.choice( - np.arange(total_num_objects), num_objects, replace=replace) + selected_objects = np.random.choice(np.arange(total_num_objects), num_objects, replace=replace) selected_objects_filenames = [] for object_index in selected_objects: selected_objects_filenames += [found_object_directories[object_index]] diff --git a/reproductions/algorithms/pybullet/icra2018qtopt/icra2018qtopt_reproduction.py b/reproductions/algorithms/pybullet/icra2018qtopt/icra2018qtopt_reproduction.py index 60524600..0dcf562b 100644 --- a/reproductions/algorithms/pybullet/icra2018qtopt/icra2018qtopt_reproduction.py +++ b/reproductions/algorithms/pybullet/icra2018qtopt/icra2018qtopt_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -65,7 +65,7 @@ def build_kuka_grasping_procedural_env(test: bool = False, render: bool = False) num_objects=5, max_num_training_models=100 if test else 900, target=False, - test=test + test=test, ) env = Float32RewardEnv(env) env = HWCToCHWEnv(env) @@ -123,9 +123,7 @@ def collect_data( def run_training(args): - outdir = ( - f"KukaGraspingProceduralEnv_{args.num_collection_episodes}_results/seed-{args.seed}" - ) + outdir = f"KukaGraspingProceduralEnv_{args.num_collection_episodes}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) @@ -159,9 +157,7 @@ def run_training(args): # NOTE: Downbiased for z-axis # See: https://github.com/google-research/google-research/blob/master/dql_grasping/policies.py#L298 - config = A.ICRA2018QtOptConfig( - gpu_id=args.gpu, cem_initial_mean=(0.0, 0.0, -1.0, 0.0), batch_size=args.batch_size - ) + config = A.ICRA2018QtOptConfig(gpu_id=args.gpu, cem_initial_mean=(0.0, 0.0, -1.0, 0.0), batch_size=args.batch_size) icra2018qtopt = A.ICRA2018QtOpt(train_env, config=config) hooks = [iteration_num_hook, save_snapshot_hook, evaluation_hook, latest_state_hook] @@ -178,9 +174,7 @@ def run_showcase(args): eval_env = build_kuka_grasping_procedural_env(test=True, render=args.render) config = A.ICRA2018QtOptConfig(gpu_id=args.gpu) - icra2018qtopt = serializers.load_snapshot( - args.snapshot_dir, eval_env, algorithm_kwargs={"config": config} - ) + icra2018qtopt = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(icra2018qtopt, A.ICRA2018QtOpt): raise ValueError("Loaded snapshot is not trained with ICRA2018QtOpt!") @@ -204,9 +198,7 @@ def main(): parser.add_argument("--num_collection_episodes", type=int, default=1000000) parser.add_argument("--multi_process", action="store_true") parser.add_argument("--ncpu", type=int, default=None) - parser.add_argument( - "--replay_buffer_file_path", type=str, default="./replay_buffer_1m_data.pkl" - ) + parser.add_argument("--replay_buffer_file_path", type=str, default="./replay_buffer_1m_data.pkl") args = parser.parse_args() diff --git a/reproductions/algorithms/sparse_mujoco/demme_sac/demme_sac_reproduction.py b/reproductions/algorithms/sparse_mujoco/demme_sac/demme_sac_reproduction.py index 0e365994..aa77e571 100644 --- a/reproductions/algorithms/sparse_mujoco/demme_sac/demme_sac_reproduction.py +++ b/reproductions/algorithms/sparse_mujoco/demme_sac/demme_sac_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,66 +26,67 @@ def select_start_timesteps(env_name): - if env_name in ['SparseAnt-v1', 'SparseHalfCheetah-v1']: + if env_name in ["SparseAnt-v1", "SparseHalfCheetah-v1"]: timesteps = 10000 else: timesteps = 1000 - print(f'Selected start timesteps: {timesteps}') + print(f"Selected start timesteps: {timesteps}") return timesteps def select_total_iterations(env_name): - if env_name in ['SparseAnt-v1']: + if env_name in ["SparseAnt-v1"]: total_iterations = 3000000 else: total_iterations = 1000000 - print(f'Selected total iterations: {total_iterations}') + print(f"Selected total iterations: {total_iterations}") return total_iterations def select_alpha_pi(env_name): - if env_name in ['SparseHopper-v1']: + if env_name in ["SparseHopper-v1"]: alpha_pi = 0.04 - elif env_name in ['SparseHalfCheetah-v1']: + elif env_name in ["SparseHalfCheetah-v1"]: alpha_pi = 0.02 - elif env_name in ['SparseWalker2d-v1']: + elif env_name in ["SparseWalker2d-v1"]: alpha_pi = 0.02 - elif env_name in ['SparseAnt-v1']: + elif env_name in ["SparseAnt-v1"]: alpha_pi = 0.01 else: alpha_pi = 1.0 - print(f'Selected alpha_pi: {alpha_pi}') + print(f"Selected alpha_pi: {alpha_pi}") return alpha_pi def select_alpha_q(env_name): - if env_name in ['SparseHopper-v1']: + if env_name in ["SparseHopper-v1"]: alpha_q = 2.0 - elif env_name in ['SparseHalfCheetah-v1']: + elif env_name in ["SparseHalfCheetah-v1"]: alpha_q = 2.0 - elif env_name in ['SparseWalker2d-v1']: + elif env_name in ["SparseWalker2d-v1"]: alpha_q = 2.0 - elif env_name in ['SparseAnt-v1']: + elif env_name in ["SparseAnt-v1"]: alpha_q = 0.1 else: alpha_q = 1.0 - print(f'Selected alpha_q: {alpha_q}') + print(f"Selected alpha_q: {alpha_q}") return alpha_q def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100) evaluator = EpisodicEvaluator(run_per_evaluation=10) - evaluation_hook = H.EvaluationHook(eval_env, - evaluator, - timing=args.eval_timing, - writer=W.FileWriter(outdir=outdir, - file_prefix='evaluation_result')) + evaluation_hook = H.EvaluationHook( + eval_env, + evaluator, + timing=args.eval_timing, + writer=W.FileWriter(outdir=outdir, file_prefix="evaluation_result"), + ) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) iteration_num_hook = H.IterationNumHook(timing=100) @@ -94,10 +95,7 @@ def run_training(args): timesteps = select_start_timesteps(args.env) alpha_pi = select_alpha_pi(args.env) alpha_q = select_alpha_q(args.env) - config = A.DEMMESACConfig(gpu_id=args.gpu, - start_timesteps=timesteps, - alpha_pi=alpha_pi, - alpha_q=alpha_q) + config = A.DEMMESACConfig(gpu_id=args.gpu, start_timesteps=timesteps, alpha_pi=alpha_pi, alpha_q=alpha_q) demme_sac = A.DEMMESAC(train_env, config=config) hooks = [iteration_num_hook, save_snapshot_hook, evaluation_hook] @@ -112,13 +110,12 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError( - 'Please specify the snapshot dir for showcasing') + raise ValueError("Please specify the snapshot dir for showcasing") eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render) config = A.DEMMESACConfig(gpu_id=args.gpu) demme_sac = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(demme_sac, A.DEMMESAC): - raise ValueError('Loaded snapshot is not trained with DEMMESAC!') + raise ValueError("Loaded snapshot is not trained with DEMMESAC!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) evaluator(demme_sac, eval_env) @@ -126,17 +123,17 @@ def run_showcase(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='SparseAnt-v1') - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--save-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=None) - parser.add_argument('--save_timing', type=int, default=5000) - parser.add_argument('--eval_timing', type=int, default=5000) - parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument("--env", type=str, default="SparseAnt-v1") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--save-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=None) + parser.add_argument("--save_timing", type=int, default=5000) + parser.add_argument("--eval_timing", type=int, default=5000) + parser.add_argument("--showcase_runs", type=int, default=10) args = parser.parse_args() @@ -146,5 +143,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/sparse_mujoco/environment/setup.py b/reproductions/algorithms/sparse_mujoco/environment/setup.py index beb1efaa..1b2f40e7 100644 --- a/reproductions/algorithms/sparse_mujoco/environment/setup.py +++ b/reproductions/algorithms/sparse_mujoco/environment/setup.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,4 +14,4 @@ from setuptools import setup -setup(name='sparse_mujoco', version='0.0.1', install_requires=['gym', 'mujoco-py']) +setup(name="sparse_mujoco", version="0.0.1", install_requires=["gym", "mujoco-py"]) diff --git a/reproductions/algorithms/sparse_mujoco/environment/sparse_mujoco/__init__.py b/reproductions/algorithms/sparse_mujoco/environment/sparse_mujoco/__init__.py index b4844819..c63f50ee 100644 --- a/reproductions/algorithms/sparse_mujoco/environment/sparse_mujoco/__init__.py +++ b/reproductions/algorithms/sparse_mujoco/environment/sparse_mujoco/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,28 +15,28 @@ from gym.envs.registration import register register( - id='SparseHalfCheetah-v1', - entry_point='sparse_mujoco.sparse_half_cheetah:SparseHalfCheetahEnv', + id="SparseHalfCheetah-v1", + entry_point="sparse_mujoco.sparse_half_cheetah:SparseHalfCheetahEnv", max_episode_steps=1000, reward_threshold=4800.0, ) register( - id='SparseHopper-v1', - entry_point='sparse_mujoco.sparse_hopper:SparseHopperEnv', + id="SparseHopper-v1", + entry_point="sparse_mujoco.sparse_hopper:SparseHopperEnv", max_episode_steps=1000, reward_threshold=3800.0, ) register( - id='SparseWalker2d-v1', + id="SparseWalker2d-v1", max_episode_steps=1000, - entry_point='sparse_mujoco.sparse_walker2d:SparseWalker2dEnv', + entry_point="sparse_mujoco.sparse_walker2d:SparseWalker2dEnv", ) register( - id='SparseAnt-v1', - entry_point='sparse_mujoco.sparse_ant:SparseAntEnv', + id="SparseAnt-v1", + entry_point="sparse_mujoco.sparse_ant:SparseAntEnv", max_episode_steps=1000, reward_threshold=6000.0, ) diff --git a/reproductions/algorithms/sparse_mujoco/environment/sparse_mujoco/sparse_hopper.py b/reproductions/algorithms/sparse_mujoco/environment/sparse_mujoco/sparse_hopper.py index 621fac14..34c5d6f7 100644 --- a/reproductions/algorithms/sparse_mujoco/environment/sparse_mujoco/sparse_hopper.py +++ b/reproductions/algorithms/sparse_mujoco/environment/sparse_mujoco/sparse_hopper.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,14 +21,9 @@ def step(self, a): self.do_simulation(a, self.frame_skip) posafter, height, ang = self.sim.data.qpos[0:3] - reward = 1.0 if posafter - self.init_qpos[0] > 1. else 0.0 + reward = 1.0 if posafter - self.init_qpos[0] > 1.0 else 0.0 s = self.state_vector() - done = not ( - np.isfinite(s).all() - and (np.abs(s[2:]) < 100).all() - and (height > 0.7) - and (abs(ang) < 0.2) - ) + done = not (np.isfinite(s).all() and (np.abs(s[2:]) < 100).all() and (height > 0.7) and (abs(ang) < 0.2)) ob = self._get_obs() return ob, reward, done, {} diff --git a/reproductions/algorithms/sparse_mujoco/environment/sparse_mujoco/sparse_walker2d.py b/reproductions/algorithms/sparse_mujoco/environment/sparse_mujoco/sparse_walker2d.py index f2aeb586..5bad9048 100644 --- a/reproductions/algorithms/sparse_mujoco/environment/sparse_mujoco/sparse_walker2d.py +++ b/reproductions/algorithms/sparse_mujoco/environment/sparse_mujoco/sparse_walker2d.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ def step(self, a): self.do_simulation(a, self.frame_skip) posafter, height, ang = self.sim.data.qpos[0:3] - reward = 1.0 if posafter - self.init_qpos[0] > 1. else 0.0 + reward = 1.0 if posafter - self.init_qpos[0] > 1.0 else 0.0 done = not (height > 0.8 and height < 2.0 and ang > -1.0 and ang < 1.0) ob = self._get_obs() diff --git a/reproductions/algorithms/sparse_mujoco/icml2018sac/icml2018sac_reproduction.py b/reproductions/algorithms/sparse_mujoco/icml2018sac/icml2018sac_reproduction.py index 1fab0698..97314d3d 100644 --- a/reproductions/algorithms/sparse_mujoco/icml2018sac/icml2018sac_reproduction.py +++ b/reproductions/algorithms/sparse_mujoco/icml2018sac/icml2018sac_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,47 +26,48 @@ def select_start_timesteps(env_name): - if env_name in ['SparseAnt-v1', 'SparseHalfCheetah-v1']: + if env_name in ["SparseAnt-v1", "SparseHalfCheetah-v1"]: timesteps = 10000 else: timesteps = 1000 - print(f'Selected start timesteps: {timesteps}') + print(f"Selected start timesteps: {timesteps}") return timesteps def select_total_iterations(env_name): - if env_name in ['SparseAnt-v1']: + if env_name in ["SparseAnt-v1"]: total_iterations = 3000000 else: total_iterations = 1000000 - print(f'Selected total iterations: {total_iterations}') + print(f"Selected total iterations: {total_iterations}") return total_iterations def select_reward_scalar(env_name): - if env_name in ['SparseAnt-v1']: + if env_name in ["SparseAnt-v1"]: scalar = 100.0 - elif env_name in ['SparseHopper-v1', 'SparseHalfCheetah-v1']: + elif env_name in ["SparseHopper-v1", "SparseHalfCheetah-v1"]: scalar = 50.0 else: scalar = 25.0 - print(f'Selected reward scalar: {scalar}') + print(f"Selected reward scalar: {scalar}") return scalar def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100) evaluator = EpisodicEvaluator(run_per_evaluation=10) - evaluation_hook = H.EvaluationHook(eval_env, - evaluator, - timing=args.eval_timing, - writer=W.FileWriter(outdir=outdir, - file_prefix='evaluation_result')) + evaluation_hook = H.EvaluationHook( + eval_env, + evaluator, + timing=args.eval_timing, + writer=W.FileWriter(outdir=outdir, file_prefix="evaluation_result"), + ) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) iteration_num_hook = H.IterationNumHook(timing=100) @@ -74,9 +75,7 @@ def run_training(args): train_env = build_mujoco_env(args.env, seed=args.seed, render=args.render) timesteps = select_start_timesteps(args.env) reward_scalar = select_reward_scalar(args.env) - config = A.ICML2018SACConfig(gpu_id=args.gpu, - start_timesteps=timesteps, - reward_scalar=reward_scalar) + config = A.ICML2018SACConfig(gpu_id=args.gpu, start_timesteps=timesteps, reward_scalar=reward_scalar) icml2018sac = A.ICML2018SAC(train_env, config=config) hooks = [iteration_num_hook, save_snapshot_hook, evaluation_hook] @@ -91,13 +90,12 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError( - 'Please specify the snapshot dir for showcasing') + raise ValueError("Please specify the snapshot dir for showcasing") eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render) config = A.ICML2018SACConfig(gpu_id=args.gpu) icml2018sac = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(icml2018sac, A.ICML2018SAC): - raise ValueError('Loaded snapshot is not trained with ICML2018SAC!') + raise ValueError("Loaded snapshot is not trained with ICML2018SAC!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) evaluator(icml2018sac, eval_env) @@ -105,17 +103,17 @@ def run_showcase(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='SparseAnt-v1') - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--save-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=None) - parser.add_argument('--save_timing', type=int, default=5000) - parser.add_argument('--eval_timing', type=int, default=5000) - parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument("--env", type=str, default="SparseAnt-v1") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--save-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=None) + parser.add_argument("--save_timing", type=int, default=5000) + parser.add_argument("--eval_timing", type=int, default=5000) + parser.add_argument("--showcase_runs", type=int, default=10) args = parser.parse_args() @@ -125,5 +123,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/reproductions/algorithms/sparse_mujoco/mme_sac/mme_sac_reproduction.py b/reproductions/algorithms/sparse_mujoco/mme_sac/mme_sac_reproduction.py index 14d5a590..17686419 100644 --- a/reproductions/algorithms/sparse_mujoco/mme_sac/mme_sac_reproduction.py +++ b/reproductions/algorithms/sparse_mujoco/mme_sac/mme_sac_reproduction.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,66 +26,67 @@ def select_start_timesteps(env_name): - if env_name in ['SparseAnt-v1', 'SparseHalfCheetah-v1']: + if env_name in ["SparseAnt-v1", "SparseHalfCheetah-v1"]: timesteps = 10000 else: timesteps = 1000 - print(f'Selected start timesteps: {timesteps}') + print(f"Selected start timesteps: {timesteps}") return timesteps def select_total_iterations(env_name): - if env_name in ['SparseAnt-v1']: + if env_name in ["SparseAnt-v1"]: total_iterations = 3000000 else: total_iterations = 1000000 - print(f'Selected total iterations: {total_iterations}') + print(f"Selected total iterations: {total_iterations}") return total_iterations def select_alpha_pi(env_name): - if env_name in ['SparseHopper-v1']: + if env_name in ["SparseHopper-v1"]: alpha_pi = 0.04 - elif env_name in ['SparseHalfCheetah-v1']: + elif env_name in ["SparseHalfCheetah-v1"]: alpha_pi = 0.02 - elif env_name in ['SparseWalker2d-v1']: + elif env_name in ["SparseWalker2d-v1"]: alpha_pi = 0.02 - elif env_name in ['SparseAnt-v1']: + elif env_name in ["SparseAnt-v1"]: alpha_pi = 0.01 else: alpha_pi = 1.0 - print(f'Selected alpha_pi: {alpha_pi}') + print(f"Selected alpha_pi: {alpha_pi}") return alpha_pi def select_alpha_q(env_name): - if env_name in ['SparseHopper-v1']: + if env_name in ["SparseHopper-v1"]: alpha_q = 1.0 - elif env_name in ['SparseHalfCheetah-v1']: + elif env_name in ["SparseHalfCheetah-v1"]: alpha_q = 2.0 - elif env_name in ['SparseWalker2d-v1']: + elif env_name in ["SparseWalker2d-v1"]: alpha_q = 0.5 - elif env_name in ['SparseAnt-v1']: + elif env_name in ["SparseAnt-v1"]: alpha_q = 0.2 else: alpha_q = 1.0 - print(f'Selected alpha_q: {alpha_q}') + print(f"Selected alpha_q: {alpha_q}") return alpha_q def run_training(args): - outdir = f'{args.env}_results/seed-{args.seed}' + outdir = f"{args.env}_results/seed-{args.seed}" if args.save_dir: outdir = os.path.join(os.path.abspath(args.save_dir), outdir) set_global_seed(args.seed) eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 100) evaluator = EpisodicEvaluator(run_per_evaluation=10) - evaluation_hook = H.EvaluationHook(eval_env, - evaluator, - timing=args.eval_timing, - writer=W.FileWriter(outdir=outdir, - file_prefix='evaluation_result')) + evaluation_hook = H.EvaluationHook( + eval_env, + evaluator, + timing=args.eval_timing, + writer=W.FileWriter(outdir=outdir, file_prefix="evaluation_result"), + ) save_snapshot_hook = H.SaveSnapshotHook(outdir, timing=args.save_timing) iteration_num_hook = H.IterationNumHook(timing=100) @@ -94,10 +95,7 @@ def run_training(args): timesteps = select_start_timesteps(args.env) alpha_pi = select_alpha_pi(args.env) alpha_q = select_alpha_q(args.env) - config = A.MMESACConfig(gpu_id=args.gpu, - start_timesteps=timesteps, - alpha_pi=alpha_pi, - alpha_q=alpha_q) + config = A.MMESACConfig(gpu_id=args.gpu, start_timesteps=timesteps, alpha_pi=alpha_pi, alpha_q=alpha_q) mme_sac = A.MMESAC(train_env, config=config) hooks = [iteration_num_hook, save_snapshot_hook, evaluation_hook] @@ -112,13 +110,12 @@ def run_training(args): def run_showcase(args): if args.snapshot_dir is None: - raise ValueError( - 'Please specify the snapshot dir for showcasing') + raise ValueError("Please specify the snapshot dir for showcasing") eval_env = build_mujoco_env(args.env, test=True, seed=args.seed + 200, render=args.render) config = A.MMESACConfig(gpu_id=args.gpu) mme_sac = serializers.load_snapshot(args.snapshot_dir, eval_env, algorithm_kwargs={"config": config}) if not isinstance(mme_sac, A.MMESAC): - raise ValueError('Loaded snapshot is not trained with MMESAC!') + raise ValueError("Loaded snapshot is not trained with MMESAC!") evaluator = EpisodicEvaluator(run_per_evaluation=args.showcase_runs) evaluator(mme_sac, eval_env) @@ -126,17 +123,17 @@ def run_showcase(args): def main(): parser = argparse.ArgumentParser() - parser.add_argument('--env', type=str, default='SparseAnt-v1') - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--render', action='store_true') - parser.add_argument('--showcase', action='store_true') - parser.add_argument('--snapshot-dir', type=str, default=None) - parser.add_argument('--save-dir', type=str, default=None) - parser.add_argument('--total_iterations', type=int, default=None) - parser.add_argument('--save_timing', type=int, default=5000) - parser.add_argument('--eval_timing', type=int, default=5000) - parser.add_argument('--showcase_runs', type=int, default=10) + parser.add_argument("--env", type=str, default="SparseAnt-v1") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--render", action="store_true") + parser.add_argument("--showcase", action="store_true") + parser.add_argument("--snapshot-dir", type=str, default=None) + parser.add_argument("--save-dir", type=str, default=None) + parser.add_argument("--total_iterations", type=int, default=None) + parser.add_argument("--save_timing", type=int, default=5000) + parser.add_argument("--eval_timing", type=int, default=5000) + parser.add_argument("--showcase_runs", type=int, default=10) args = parser.parse_args() @@ -146,5 +143,5 @@ def main(): run_training(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/requirements.txt b/requirements.txt index 74875d4a..81577db5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,8 +10,8 @@ pytest pytest-cov mypy!=1.11.0 typing-extensions -isort -autopep8 +isort>5.0.0 +black packaging docformatter gymnasium \ No newline at end of file diff --git a/test_resources/reproductions/atari-dataset/dummy_dataset_generator.py b/test_resources/reproductions/atari-dataset/dummy_dataset_generator.py index b23e4896..fd65433d 100644 --- a/test_resources/reproductions/atari-dataset/dummy_dataset_generator.py +++ b/test_resources/reproductions/atari-dataset/dummy_dataset_generator.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,12 +20,12 @@ def save_dataset(filepath, data): - with gzip.GzipFile(filepath, 'w') as f: + with gzip.GzipFile(filepath, "w") as f: np.save(f, data, allow_pickle=False) def main(): - fake_env = build_atari_env('FakeAtariNNablaRLNoFrameskip-v1', test=True) + fake_env = build_atari_env("FakeAtariNNablaRLNoFrameskip-v1", test=True) dataset_size = 20 @@ -47,11 +47,11 @@ def main(): state = fake_env.reset() if done else next_state - save_dataset('./$store$_observation_ckpt.0.gz', np.asarray(states)) - save_dataset('./$store$_action_ckpt.0.gz', np.asarray(actions)) - save_dataset('./$store$_reward_ckpt.0.gz', np.asarray(rewards)) - save_dataset('./$store$_terminal_ckpt.0.gz', np.asarray(dones)) + save_dataset("./$store$_observation_ckpt.0.gz", np.asarray(states)) + save_dataset("./$store$_action_ckpt.0.gz", np.asarray(actions)) + save_dataset("./$store$_reward_ckpt.0.gz", np.asarray(rewards)) + save_dataset("./$store$_terminal_ckpt.0.gz", np.asarray(dones)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/test_resources/reproductions/mujoco-dataset/dummy_dataset_generator.py b/test_resources/reproductions/mujoco-dataset/dummy_dataset_generator.py index 9eaa5bac..1d3736cb 100644 --- a/test_resources/reproductions/mujoco-dataset/dummy_dataset_generator.py +++ b/test_resources/reproductions/mujoco-dataset/dummy_dataset_generator.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,12 +20,12 @@ def save_dataset(filepath, data): - with gzip.GzipFile(filepath, 'w') as f: + with gzip.GzipFile(filepath, "w") as f: np.save(f, data, allow_pickle=False) def main(): - fake_env = build_mujoco_env('FakeMujocoNNablaRL-v1', test=True) + fake_env = build_mujoco_env("FakeMujocoNNablaRL-v1", test=True) dataset_size = 20 @@ -49,12 +49,12 @@ def main(): state = fake_env.reset() if done else next_state - save_dataset('./$store$_observation_ckpt.0.gz', np.asarray(states)) - save_dataset('./$store$_action_ckpt.0.gz', np.asarray(actions)) - save_dataset('./$store$_reward_ckpt.0.gz', np.asarray(rewards)) - save_dataset('./$store$_terminal_ckpt.0.gz', np.asarray(dones)) - save_dataset('./$store$_next_observation_ckpt.0.gz', np.asarray(next_states)) + save_dataset("./$store$_observation_ckpt.0.gz", np.asarray(states)) + save_dataset("./$store$_action_ckpt.0.gz", np.asarray(actions)) + save_dataset("./$store$_reward_ckpt.0.gz", np.asarray(rewards)) + save_dataset("./$store$_terminal_ckpt.0.gz", np.asarray(dones)) + save_dataset("./$store$_next_observation_ckpt.0.gz", np.asarray(next_states)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tests/algorithms/test_a2c.py b/tests/algorithms/test_a2c.py index 44893899..8067c22f 100644 --- a/tests/algorithms/test_a2c.py +++ b/tests/algorithms/test_a2c.py @@ -1,5 +1,5 @@ # Copyright 2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -28,7 +28,7 @@ def test_algorithm_name(self): dummy_env = E.DummyDiscreteImg() a2c = A.A2C(dummy_env) - assert a2c.__name__ == 'A2C' + assert a2c.__name__ == "A2C" def test_continuous_action_env_unsupported(self): """Check that error occurs when training on continuous action env.""" @@ -47,7 +47,7 @@ def test_run_online_discrete_env_training(self): config = A.A2CConfig(n_steps=n_steps, actor_num=actor_num) a2c = A.A2C(dummy_env, config=config) - a2c.train_online(dummy_env, total_iterations=n_steps*actor_num) + a2c.train_online(dummy_env, total_iterations=n_steps * actor_num) def test_run_offline_training(self): """Check that no error occurs when calling offline training.""" @@ -80,14 +80,14 @@ def test_latest_iteration_state(self): dummy_env = E.DummyDiscreteImg() a2c = A.A2C(dummy_env) - a2c._policy_trainer_state = {'pi_loss': 0.} - a2c._v_function_trainer_state = {'v_loss': 1.} + a2c._policy_trainer_state = {"pi_loss": 0.0} + a2c._v_function_trainer_state = {"v_loss": 1.0} latest_iteration_state = a2c.latest_iteration_state - assert 'pi_loss' in latest_iteration_state['scalar'] - assert 'v_loss' in latest_iteration_state['scalar'] - assert latest_iteration_state['scalar']['pi_loss'] == 0. - assert latest_iteration_state['scalar']['v_loss'] == 1. + assert "pi_loss" in latest_iteration_state["scalar"] + assert "v_loss" in latest_iteration_state["scalar"] + assert latest_iteration_state["scalar"]["pi_loss"] == 0.0 + assert latest_iteration_state["scalar"]["v_loss"] == 1.0 if __name__ == "__main__": diff --git a/tests/algorithms/test_amp.py b/tests/algorithms/test_amp.py index a20e94c6..79246b63 100644 --- a/tests/algorithms/test_amp.py +++ b/tests/algorithms/test_amp.py @@ -20,9 +20,13 @@ import nnabla as nn import nnabla_rl.algorithms as A import nnabla_rl.environments as E -from nnabla_rl.algorithms.amp import (_compute_v_target_and_advantage_with_clipping_and_overwriting, _concatenate_state, - _copy_np_array_to_mp_array, _EquallySampleBufferIterator, - _sample_experiences_from_buffers) +from nnabla_rl.algorithms.amp import ( + _compute_v_target_and_advantage_with_clipping_and_overwriting, + _concatenate_state, + _copy_np_array_to_mp_array, + _EquallySampleBufferIterator, + _sample_experiences_from_buffers, +) from nnabla_rl.environments.amp_env import TaskResult from nnabla_rl.environments.wrappers.common import FlattenNestedTupleStateWrapper from nnabla_rl.environments.wrappers.goal_conditioned import GoalConditionedTupleObservationEnv @@ -38,9 +42,9 @@ def __init__(self): def v(self, s): with nn.parameter_scope(self.scope_name): if isinstance(s, tuple): - h = s[0] * 2. + s[1] * 2. + h = s[0] * 2.0 + s[1] * 2.0 else: - h = s * 2. + h = s * 2.0 return h @@ -52,7 +56,7 @@ def test_algorithm_name(self): dummy_env = E.DummyAMPEnv() amp = A.AMP(dummy_env) - assert amp.__name__ == 'AMP' + assert amp.__name__ == "AMP" def test_run_online_amp_env_training(self): """Check that no error occurs when calling online training (amp env)""" @@ -63,7 +67,7 @@ def test_run_online_amp_env_training(self): config = A.AMPConfig(batch_size=5, actor_timesteps=actor_timesteps, actor_num=actor_num) amp = A.AMP(dummy_env, config=config) - amp.train_online(dummy_env, total_iterations=actor_timesteps*actor_num) + amp.train_online(dummy_env, total_iterations=actor_timesteps * actor_num) def test_run_online_amp_goal_env_training(self): """Check that no error occurs when calling online training (emp goal @@ -74,11 +78,12 @@ def test_run_online_amp_goal_env_training(self): dummy_env = FlattenNestedTupleStateWrapper(dummy_env) actor_timesteps = 10 actor_num = 2 - config = A.AMPConfig(batch_size=5, actor_timesteps=actor_timesteps, - actor_num=actor_num, use_reward_from_env=True) + config = A.AMPConfig( + batch_size=5, actor_timesteps=actor_timesteps, actor_num=actor_num, use_reward_from_env=True + ) amp = A.AMP(dummy_env, config=config) - amp.train_online(dummy_env, total_iterations=actor_timesteps*actor_num) + amp.train_online(dummy_env, total_iterations=actor_timesteps * actor_num) def test_run_online_with_invalid_env_trainig(self): """Check that error occurs when calling online training (invalid env, @@ -181,17 +186,17 @@ def test_latest_iteration_state(self): dummy_env = E.DummyAMPEnv() amp = A.AMP(dummy_env) - amp._policy_trainer_state = {'pi_loss': 0.} - amp._v_function_trainer_state = {'v_loss': 1.} - amp._discriminator_trainer_state = {'reward_loss': 2.} + amp._policy_trainer_state = {"pi_loss": 0.0} + amp._v_function_trainer_state = {"v_loss": 1.0} + amp._discriminator_trainer_state = {"reward_loss": 2.0} latest_iteration_state = amp.latest_iteration_state - assert 'pi_loss' in latest_iteration_state['scalar'] - assert 'v_loss' in latest_iteration_state['scalar'] - assert 'reward_loss' in latest_iteration_state['scalar'] - assert latest_iteration_state['scalar']['pi_loss'] == 0. - assert latest_iteration_state['scalar']['v_loss'] == 1. - assert latest_iteration_state['scalar']['reward_loss'] == 2. + assert "pi_loss" in latest_iteration_state["scalar"] + assert "v_loss" in latest_iteration_state["scalar"] + assert "reward_loss" in latest_iteration_state["scalar"] + assert latest_iteration_state["scalar"]["pi_loss"] == 0.0 + assert latest_iteration_state["scalar"]["v_loss"] == 1.0 + assert latest_iteration_state["scalar"]["reward_loss"] == 2.0 def test_copy_np_array_to_mp_array(self): shape = (10, 9, 8, 7) @@ -254,9 +259,14 @@ def test_concatenate_state(self): e_s_next = (np.random.rand(3), np.random.rand(3)) v_target = (np.random.rand(1), np.random.rand(1)) advantage = (np.random.rand(1), np.random.rand(1)) - dummy_experiences_per_agent = [[tuple(experience) for experience in - zip(s, a, r, non_terminal, n_s, log_prob, - non_greedy, e_s, e_a, e_s_next, v_target, advantage)]] + dummy_experiences_per_agent = [ + [ + tuple(experience) + for experience in zip( + s, a, r, non_terminal, n_s, log_prob, non_greedy, e_s, e_a, e_s_next, v_target, advantage + ) + ] + ] actual_concat_s, actual_concat_e_s = _concatenate_state(dummy_experiences_per_agent) assert np.allclose(actual_concat_s, np.stack(s, axis=0)) assert np.allclose(actual_concat_e_s, np.stack(e_s, axis=0)) @@ -264,7 +274,7 @@ def test_concatenate_state(self): def test_sample_experiences_from_buffers(self): buffers = [ReplayBuffer() for _ in range(2)] for buffer in buffers: - buffer.sample = MagicMock(return_value=(((1, ), (2, ), (3, )), {})) + buffer.sample = MagicMock(return_value=(((1,), (2,), (3,)), {})) _sample_experiences_from_buffers(buffers=buffers, batch_size=6) @@ -272,98 +282,117 @@ def test_sample_experiences_from_buffers(self): for buffer in buffers: buffer.sample.assert_called_once_with(num_samples=3) - @pytest.mark.parametrize("gamma, lmb, value_at_task_fail, value_at_task_success," - "value_clip, expected_adv, expected_vtarg", - [[1., 0., 0., 1., None, np.array([[1.], [1.], [1.]]), np.array([[3.], [3.], [3.]])], - [1., 1., 0., 1., None, np.array([[3.], [2.], [1.]]), np.array([[5.], [4.], [3.]])], - [0.9, 0.7, 0., 1., None, np.array([[1.62152], [1.304], [0.8]]), - np.array([[3.62152], [3.304], [2.8]])], - [1., 1., 0., 1., (-1.2, 1.2), np.array([[3.], [2.], [1.]]), - np.array([[4.2], [3.2], [2.2]])] - ]) + @pytest.mark.parametrize( + "gamma, lmb, value_at_task_fail, value_at_task_success," "value_clip, expected_adv, expected_vtarg", + [ + [1.0, 0.0, 0.0, 1.0, None, np.array([[1.0], [1.0], [1.0]]), np.array([[3.0], [3.0], [3.0]])], + [1.0, 1.0, 0.0, 1.0, None, np.array([[3.0], [2.0], [1.0]]), np.array([[5.0], [4.0], [3.0]])], + [0.9, 0.7, 0.0, 1.0, None, np.array([[1.62152], [1.304], [0.8]]), np.array([[3.62152], [3.304], [2.8]])], + [1.0, 1.0, 0.0, 1.0, (-1.2, 1.2), np.array([[3.0], [2.0], [1.0]]), np.array([[4.2], [3.2], [2.2]])], + ], + ) def test_compute_v_target_and_advantage_with_clipping_and_overwriting_unknown_task_result( - self, gamma, lmb, value_at_task_fail, value_at_task_success, - value_clip, expected_adv, expected_vtarg): + self, gamma, lmb, value_at_task_fail, value_at_task_success, value_clip, expected_adv, expected_vtarg + ): dummy_v_function = DummyVFunction() dummy_experience = self._collect_dummy_experience_unknown_task_result() r = np.ones(3) actual_vtarg, actual_adv = _compute_v_target_and_advantage_with_clipping_and_overwriting( - dummy_v_function, dummy_experience, r, gamma, lmb, value_at_task_fail, value_at_task_success, value_clip) + dummy_v_function, dummy_experience, r, gamma, lmb, value_at_task_fail, value_at_task_success, value_clip + ) assert np.allclose(actual_adv, expected_adv) assert np.allclose(actual_vtarg, expected_vtarg) - @pytest.mark.parametrize("gamma, lmb, value_at_task_fail, value_at_task_success," - "value_clip, expected_adv, expected_vtarg", - [[1., 0., -1., 1., None, np.array([[1.], [1.], [-2.]]), np.array([[3.], [3.], [0.]])], - [1., 1., -1., 1., None, np.array([[0.], [-1.], [-2.]]), np.array([[2.], [1.], [0.]])], - [1., 1., -1., 1., (-1.2, 1.2), np.array([[0.8], [-0.2], [-1.2]]), - np.array([[2.], [1.], [0.]])] - ]) + @pytest.mark.parametrize( + "gamma, lmb, value_at_task_fail, value_at_task_success," "value_clip, expected_adv, expected_vtarg", + [ + [1.0, 0.0, -1.0, 1.0, None, np.array([[1.0], [1.0], [-2.0]]), np.array([[3.0], [3.0], [0.0]])], + [1.0, 1.0, -1.0, 1.0, None, np.array([[0.0], [-1.0], [-2.0]]), np.array([[2.0], [1.0], [0.0]])], + [1.0, 1.0, -1.0, 1.0, (-1.2, 1.2), np.array([[0.8], [-0.2], [-1.2]]), np.array([[2.0], [1.0], [0.0]])], + ], + ) def test_compute_v_target_and_advantage_with_clipping_and_overwriting_unknown_task_fail( - self, gamma, lmb, value_at_task_fail, value_at_task_success, - value_clip, expected_adv, expected_vtarg): + self, gamma, lmb, value_at_task_fail, value_at_task_success, value_clip, expected_adv, expected_vtarg + ): dummy_v_function = DummyVFunction() dummy_experience = self._collect_dummy_experience_unknown_task_result( - task_result=TaskResult(TaskResult.FAIL.value)) + task_result=TaskResult(TaskResult.FAIL.value) + ) r = np.ones(3) actual_vtarg, actual_adv = _compute_v_target_and_advantage_with_clipping_and_overwriting( - dummy_v_function, dummy_experience, r, gamma, lmb, value_at_task_fail, value_at_task_success, value_clip) + dummy_v_function, dummy_experience, r, gamma, lmb, value_at_task_fail, value_at_task_success, value_clip + ) assert np.allclose(actual_adv, expected_adv, atol=1e-6) assert np.allclose(actual_vtarg, expected_vtarg, atol=1e-6) - @pytest.mark.parametrize("gamma, lmb, value_at_task_fail, value_at_task_success," - "value_clip, expected_adv, expected_vtarg", - [[1., 0., -1., 5., None, np.array([[1.], [1.], [4.]]), np.array([[3.], [3.], [6.]])], - [1., 1., -1., 5., None, np.array([[6.], [5.], [4.]]), np.array([[8.], [7.], [6.]])], - [1., 1., -1., 5., (-1.2, 1.2), np.array([[6.8], [5.8], [4.8]]), - np.array([[8.], [7.], [6.]])] - ]) + @pytest.mark.parametrize( + "gamma, lmb, value_at_task_fail, value_at_task_success," "value_clip, expected_adv, expected_vtarg", + [ + [1.0, 0.0, -1.0, 5.0, None, np.array([[1.0], [1.0], [4.0]]), np.array([[3.0], [3.0], [6.0]])], + [1.0, 1.0, -1.0, 5.0, None, np.array([[6.0], [5.0], [4.0]]), np.array([[8.0], [7.0], [6.0]])], + [1.0, 1.0, -1.0, 5.0, (-1.2, 1.2), np.array([[6.8], [5.8], [4.8]]), np.array([[8.0], [7.0], [6.0]])], + ], + ) def test_compute_v_target_and_advantage_with_clipping_and_overwriting_unknown_task_success( - self, gamma, lmb, value_at_task_fail, value_at_task_success, - value_clip, expected_adv, expected_vtarg): + self, gamma, lmb, value_at_task_fail, value_at_task_success, value_clip, expected_adv, expected_vtarg + ): dummy_v_function = DummyVFunction() dummy_experience = self._collect_dummy_experience_unknown_task_result( - task_result=TaskResult(TaskResult.SUCCESS.value)) + task_result=TaskResult(TaskResult.SUCCESS.value) + ) r = np.ones(3) actual_vtarg, actual_adv = _compute_v_target_and_advantage_with_clipping_and_overwriting( - dummy_v_function, dummy_experience, r, gamma, lmb, value_at_task_fail, value_at_task_success, value_clip) + dummy_v_function, dummy_experience, r, gamma, lmb, value_at_task_fail, value_at_task_success, value_clip + ) assert np.allclose(actual_adv, expected_adv, atol=1e-6) assert np.allclose(actual_vtarg, expected_vtarg, atol=1e-6) - def _collect_dummy_experience_unknown_task_result(self, - num_episodes=1, episode_length=3, - task_result=TaskResult(TaskResult.UNKNOWN.value)): + def _collect_dummy_experience_unknown_task_result( + self, num_episodes=1, episode_length=3, task_result=TaskResult(TaskResult.UNKNOWN.value) + ): experience = [] for _ in range(num_episodes): for i in range(episode_length): - s_current = np.ones(1, ) - a = np.ones(1, ) - s_next = np.ones(1, ) - r = np.ones(1, ) - non_terminal = np.ones(1, ) + s_current = np.ones( + 1, + ) + a = np.ones( + 1, + ) + s_next = np.ones( + 1, + ) + r = np.ones( + 1, + ) + non_terminal = np.ones( + 1, + ) info = {"task_result": TaskResult(0)} - if i == episode_length-1: - non_terminal = np.zeros(1, ) + if i == episode_length - 1: + non_terminal = np.zeros( + 1, + ) info = {"task_result": task_result} experience.append((s_current, a, r, non_terminal, s_next, info)) return experience -class TestEquallySampleBufferIterator(): +class TestEquallySampleBufferIterator: def test_equally_sample_buffer_iterator_iterates_correct_number_of_times(self): buffer_size = 5 buffers = [ReplayBuffer(buffer_size) for _ in range(2)] for i, buffer in enumerate(buffers): - buffer.append_all(np.arange(buffer_size) * (i+1)) + buffer.append_all(np.arange(buffer_size) * (i + 1)) batch_size = 6 total_num_iterations = 10 @@ -381,7 +410,7 @@ def test_equally_sample_buffer_iterator_iterates_correct_data(self): buffers = [ReplayBuffer(buffer_size) for _ in range(2)] for i, buffer in enumerate(buffers): - dummy_experience = [((j+1)*(i+1), ) for j in range(buffer_size)] + dummy_experience = [((j + 1) * (i + 1),) for j in range(buffer_size)] buffer.append_all(dummy_experience) batch_size = 4 diff --git a/tests/algorithms/test_atrpo.py b/tests/algorithms/test_atrpo.py index e47e068f..525229c3 100644 --- a/tests/algorithms/test_atrpo.py +++ b/tests/algorithms/test_atrpo.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ import nnabla_rl.environments as E -class TestTRPO(): +class TestTRPO: def setup_method(self): nn.clear_parameters() @@ -28,7 +28,7 @@ def test_algorithm_name(self): dummy_env = E.DummyContinuous() atrpo = A.ATRPO(dummy_env) - assert atrpo.__name__ == 'ATRPO' + assert atrpo.__name__ == "ATRPO" def test_discrete_action_env_unsupported(self): """Check that error occurs when training on discrete action env.""" @@ -41,12 +41,14 @@ def test_run_online_training(self): dummy_env = E.DummyContinuous() dummy_env = EpisodicEnv(dummy_env, min_episode_length=3) - config = A.ATRPOConfig(num_steps_per_iteration=5, - gpu_batch_size=5, - pi_batch_size=5, - vf_batch_size=2, - sigma_kl_divergence_constraint=10.0, - maximum_backtrack_numbers=50) + config = A.ATRPOConfig( + num_steps_per_iteration=5, + gpu_batch_size=5, + pi_batch_size=5, + vf_batch_size=2, + sigma_kl_divergence_constraint=10.0, + maximum_backtrack_numbers=50, + ) atrpo = A.ATRPO(dummy_env, config=config) atrpo.train_online(dummy_env, total_iterations=5) @@ -102,15 +104,16 @@ def test_latest_iteration_state(self): dummy_env = E.DummyContinuous() atrpo = A.ATRPO(dummy_env) - atrpo._v_function_trainer_state = {'v_loss': 0.} + atrpo._v_function_trainer_state = {"v_loss": 0.0} latest_iteration_state = atrpo.latest_iteration_state - assert 'v_loss' in latest_iteration_state['scalar'] - assert latest_iteration_state['scalar']['v_loss'] == 0. + assert "v_loss" in latest_iteration_state["scalar"] + assert latest_iteration_state["scalar"]["v_loss"] == 0.0 if __name__ == "__main__": from testing_utils import EpisodicEnv + pytest.main() else: from ..testing_utils import EpisodicEnv diff --git a/tests/algorithms/test_bcq.py b/tests/algorithms/test_bcq.py index c41f855b..8e4a4ef5 100644 --- a/tests/algorithms/test_bcq.py +++ b/tests/algorithms/test_bcq.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ def test_algorithm_name(self): dummy_env = E.DummyContinuous() bcq = A.BCQ(dummy_env) - assert bcq.__name__ == 'BCQ' + assert bcq.__name__ == "BCQ" def test_run_online_training(self): """Check that error occurs when calling online training.""" @@ -94,23 +94,24 @@ def test_latest_iteration_state(self): dummy_env = E.DummyContinuous() bcq = A.BCQ(dummy_env) - bcq._encoder_trainer_state = {'encoder_loss': 0.} - bcq._q_function_trainer_state = {'q_loss': 1., 'td_errors': np.array([0., 1.])} - bcq._perturbator_trainer_state = {'perturbator_loss': 2.} + bcq._encoder_trainer_state = {"encoder_loss": 0.0} + bcq._q_function_trainer_state = {"q_loss": 1.0, "td_errors": np.array([0.0, 1.0])} + bcq._perturbator_trainer_state = {"perturbator_loss": 2.0} latest_iteration_state = bcq.latest_iteration_state - assert 'encoder_loss' in latest_iteration_state['scalar'] - assert 'q_loss' in latest_iteration_state['scalar'] - assert 'perturbator_loss' in latest_iteration_state['scalar'] - assert 'td_errors' in latest_iteration_state['histogram'] - assert latest_iteration_state['scalar']['encoder_loss'] == 0. - assert latest_iteration_state['scalar']['q_loss'] == 1. - assert latest_iteration_state['scalar']['perturbator_loss'] == 2. - assert np.allclose(latest_iteration_state['histogram']['td_errors'], np.array([0., 1.])) + assert "encoder_loss" in latest_iteration_state["scalar"] + assert "q_loss" in latest_iteration_state["scalar"] + assert "perturbator_loss" in latest_iteration_state["scalar"] + assert "td_errors" in latest_iteration_state["histogram"] + assert latest_iteration_state["scalar"]["encoder_loss"] == 0.0 + assert latest_iteration_state["scalar"]["q_loss"] == 1.0 + assert latest_iteration_state["scalar"]["perturbator_loss"] == 2.0 + assert np.allclose(latest_iteration_state["histogram"]["td_errors"], np.array([0.0, 1.0])) if __name__ == "__main__": from testing_utils import generate_dummy_experiences + pytest.main() else: from ..testing_utils import generate_dummy_experiences diff --git a/tests/algorithms/test_bear.py b/tests/algorithms/test_bear.py index afbf079f..eeae3abc 100644 --- a/tests/algorithms/test_bear.py +++ b/tests/algorithms/test_bear.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ def test_algorithm_name(self): dummy_env = E.DummyContinuous() bear = A.BEAR(dummy_env) - assert bear.__name__ == 'BEAR' + assert bear.__name__ == "BEAR" def test_discrete_action_env_unsupported(self): """Check that error occurs when training on discrete action env.""" @@ -94,23 +94,24 @@ def test_latest_iteration_state(self): dummy_env = E.DummyContinuous() bear = A.BEAR(dummy_env) - bear._encoder_trainer_state = {'encoder_loss': 0.} - bear._q_function_trainer_state = {'q_loss': 1., 'td_errors': np.array([0., 1.])} - bear._policy_trainer_state = {'pi_loss': 2.} + bear._encoder_trainer_state = {"encoder_loss": 0.0} + bear._q_function_trainer_state = {"q_loss": 1.0, "td_errors": np.array([0.0, 1.0])} + bear._policy_trainer_state = {"pi_loss": 2.0} latest_iteration_state = bear.latest_iteration_state - assert 'encoder_loss' in latest_iteration_state['scalar'] - assert 'q_loss' in latest_iteration_state['scalar'] - assert 'pi_loss' in latest_iteration_state['scalar'] - assert 'td_errors' in latest_iteration_state['histogram'] - assert latest_iteration_state['scalar']['encoder_loss'] == 0. - assert latest_iteration_state['scalar']['q_loss'] == 1. - assert latest_iteration_state['scalar']['pi_loss'] == 2. - assert np.allclose(latest_iteration_state['histogram']['td_errors'], np.array([0., 1.])) + assert "encoder_loss" in latest_iteration_state["scalar"] + assert "q_loss" in latest_iteration_state["scalar"] + assert "pi_loss" in latest_iteration_state["scalar"] + assert "td_errors" in latest_iteration_state["histogram"] + assert latest_iteration_state["scalar"]["encoder_loss"] == 0.0 + assert latest_iteration_state["scalar"]["q_loss"] == 1.0 + assert latest_iteration_state["scalar"]["pi_loss"] == 2.0 + assert np.allclose(latest_iteration_state["histogram"]["td_errors"], np.array([0.0, 1.0])) if __name__ == "__main__": from testing_utils import generate_dummy_experiences + pytest.main() else: from ..testing_utils import generate_dummy_experiences diff --git a/tests/algorithms/test_categorical_ddqn.py b/tests/algorithms/test_categorical_ddqn.py index d263de23..b56c3a12 100644 --- a/tests/algorithms/test_categorical_ddqn.py +++ b/tests/algorithms/test_categorical_ddqn.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ def test_algorithm_name(self): dummy_env = E.DummyDiscreteImg() categorical_dqn = A.CategoricalDDQN(dummy_env) - assert categorical_dqn.__name__ == 'CategoricalDDQN' + assert categorical_dqn.__name__ == "CategoricalDDQN" def test_continuous_action_env_unsupported(self): """Check that error occurs when training on continuous action env.""" @@ -70,7 +70,7 @@ def test_compute_eval_action(self): state = np.float32(state) action = categorical_dqn.compute_eval_action(state) - assert action.shape == (1, ) + assert action.shape == (1,) def test_latest_iteration_state(self): """Check that latest iteration state has the keys and values we @@ -79,17 +79,18 @@ def test_latest_iteration_state(self): dummy_env = E.DummyDiscreteImg() categorical_dqn = A.CategoricalDDQN(dummy_env) - categorical_dqn._model_trainer_state = {'cross_entropy_loss': 0., 'td_errors': np.array([0., 1.])} + categorical_dqn._model_trainer_state = {"cross_entropy_loss": 0.0, "td_errors": np.array([0.0, 1.0])} latest_iteration_state = categorical_dqn.latest_iteration_state - assert 'cross_entropy_loss' in latest_iteration_state['scalar'] - assert 'td_errors' in latest_iteration_state['histogram'] - assert latest_iteration_state['scalar']['cross_entropy_loss'] == 0. - assert np.allclose(latest_iteration_state['histogram']['td_errors'], np.array([0., 1.])) + assert "cross_entropy_loss" in latest_iteration_state["scalar"] + assert "td_errors" in latest_iteration_state["histogram"] + assert latest_iteration_state["scalar"]["cross_entropy_loss"] == 0.0 + assert np.allclose(latest_iteration_state["histogram"]["td_errors"], np.array([0.0, 1.0])) if __name__ == "__main__": from testing_utils import generate_dummy_experiences + pytest.main() else: from ..testing_utils import generate_dummy_experiences diff --git a/tests/algorithms/test_categorical_dqn.py b/tests/algorithms/test_categorical_dqn.py index 89d4eda8..d632d23a 100644 --- a/tests/algorithms/test_categorical_dqn.py +++ b/tests/algorithms/test_categorical_dqn.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -58,8 +58,7 @@ def all_probs(self, s: nn.Variable) -> nn.Variable: h = self._h h = NF.relu(x=h) with nn.parameter_scope("affine2"): - h = NPF.affine( - h, n_outmaps=self._n_action * self._n_atom) + h = NPF.affine(h, n_outmaps=self._n_action * self._n_atom) h = NF.reshape(h, (-1, self._n_action, self._n_atom)) assert h.shape == (batch_size, self._n_action, self._n_atom) return NF.softmax(h, axis=2) @@ -69,14 +68,14 @@ def is_recurrent(self) -> bool: def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: shapes: Dict[str, nn.Variable] = {} - shapes['lstm_hidden'] = (self._lstm_state_size, ) - shapes['lstm_cell'] = (self._lstm_state_size, ) + shapes["lstm_hidden"] = (self._lstm_state_size,) + shapes["lstm_cell"] = (self._lstm_state_size,) return shapes def get_internal_states(self) -> Dict[str, nn.Variable]: states: Dict[str, nn.Variable] = {} - states['lstm_hidden'] = self._h - states['lstm_cell'] = self._c + states["lstm_hidden"] = self._h + states["lstm_cell"] = self._c return states def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): @@ -86,8 +85,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): if self._c is not None: self._c.data.zero() else: - self._h = states['lstm_hidden'] - self._c = states['lstm_cell'] + self._h = states["lstm_hidden"] + self._c = states["lstm_cell"] def _create_internal_states(self, batch_size): self._h = nn.Variable((batch_size, self._lstm_state_size)) @@ -108,7 +107,7 @@ def test_algorithm_name(self): dummy_env = E.DummyDiscreteImg() categorical_dqn = A.CategoricalDQN(dummy_env) - assert categorical_dqn.__name__ == 'CategoricalDQN' + assert categorical_dqn.__name__ == "CategoricalDQN" def test_continuous_action_env_unsupported(self): """Check that error occurs when training on continuous action env.""" @@ -145,6 +144,7 @@ def test_run_online_training_multistep(self): def test_run_online_rnn_training(self): """Check that no error occurs when calling online training with RNN model.""" + class RNNModelBuilder(ModelBuilder[ValueDistributionFunction]): def build_model(self, scope_name: str, env_info, algorithm_config, **kwargs): n_action = env_info.action_dim @@ -152,6 +152,7 @@ def build_model(self, scope_name: str, env_info, algorithm_config, **kwargs): v_min = algorithm_config.v_min v_max = algorithm_config.v_max return RNNValueDistributionFunction(scope_name, n_action, n_atom, v_min, v_max) + dummy_env = E.DummyDiscreteImg() config = A.CategoricalDQNConfig() config.num_steps = 2 @@ -185,7 +186,7 @@ def test_compute_eval_action(self): state = np.float32(state) action = categorical_dqn.compute_eval_action(state) - assert action.shape == (1, ) + assert action.shape == (1,) def test_latest_iteration_state(self): """Check that latest iteration state has the keys and values we @@ -194,17 +195,18 @@ def test_latest_iteration_state(self): dummy_env = E.DummyDiscreteImg() categorical_dqn = A.CategoricalDQN(dummy_env) - categorical_dqn._model_trainer_state = {'cross_entropy_loss': 0., 'td_errors': np.array([0., 1.])} + categorical_dqn._model_trainer_state = {"cross_entropy_loss": 0.0, "td_errors": np.array([0.0, 1.0])} latest_iteration_state = categorical_dqn.latest_iteration_state - assert 'cross_entropy_loss' in latest_iteration_state['scalar'] - assert 'td_errors' in latest_iteration_state['histogram'] - assert latest_iteration_state['scalar']['cross_entropy_loss'] == 0. - assert np.allclose(latest_iteration_state['histogram']['td_errors'], np.array([0., 1.])) + assert "cross_entropy_loss" in latest_iteration_state["scalar"] + assert "td_errors" in latest_iteration_state["histogram"] + assert latest_iteration_state["scalar"]["cross_entropy_loss"] == 0.0 + assert np.allclose(latest_iteration_state["histogram"]["td_errors"], np.array([0.0, 1.0])) if __name__ == "__main__": from testing_utils import generate_dummy_experiences + pytest.main() else: from ..testing_utils import generate_dummy_experiences diff --git a/tests/algorithms/test_common_utils.py b/tests/algorithms/test_common_utils.py index a403f174..330b54d9 100644 --- a/tests/algorithms/test_common_utils.py +++ b/tests/algorithms/test_common_utils.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,13 +17,24 @@ import pytest import nnabla as nn -from nnabla_rl.algorithms.common_utils import (_DeterministicPolicyActionSelector, _InfluenceMetricsEvaluator, - _StatePreprocessedDeterministicPolicy, - _StatePreprocessedStochasticPolicy, _StatePreprocessedVFunction, - compute_average_v_target_and_advantage, compute_v_target_and_advantage, - has_batch_dimension) -from nnabla_rl.environments.dummy import (DummyContinuous, DummyContinuousActionGoalEnv, DummyDiscrete, - DummyFactoredContinuous, DummyTupleContinuous, DummyTupleDiscrete) +from nnabla_rl.algorithms.common_utils import ( + _DeterministicPolicyActionSelector, + _InfluenceMetricsEvaluator, + _StatePreprocessedDeterministicPolicy, + _StatePreprocessedStochasticPolicy, + _StatePreprocessedVFunction, + compute_average_v_target_and_advantage, + compute_v_target_and_advantage, + has_batch_dimension, +) +from nnabla_rl.environments.dummy import ( + DummyContinuous, + DummyContinuousActionGoalEnv, + DummyDiscrete, + DummyFactoredContinuous, + DummyTupleContinuous, + DummyTupleDiscrete, +) from nnabla_rl.environments.environment_info import EnvironmentInfo from nnabla_rl.models import VFunction @@ -35,13 +46,13 @@ def __init__(self): def v(self, s): with nn.parameter_scope(self.scope_name): if isinstance(s, tuple): - h = s[0] * 2. + s[1] * 2. + h = s[0] * 2.0 + s[1] * 2.0 else: - h = s * 2. + h = s * 2.0 return h -class TestCommonUtils(): +class TestCommonUtils: def setup_method(self, method): nn.clear_parameters() @@ -49,13 +60,47 @@ def _collect_dummy_experience(self, num_episodes=1, episode_length=3, tupled_sta experience = [] for _ in range(num_episodes): for i in range(episode_length): - s_current = (np.ones(1, ), np.ones(1, )) if tupled_state else np.ones(1, ) - a = np.ones(1, ) - s_next = (np.ones(1, ), np.ones(1, )) if tupled_state else np.ones(1, ) - r = np.ones(1, ) - non_terminal = np.ones(1, ) - if i == episode_length-1: - non_terminal = np.zeros(1, ) + s_current = ( + ( + np.ones( + 1, + ), + np.ones( + 1, + ), + ) + if tupled_state + else np.ones( + 1, + ) + ) + a = np.ones( + 1, + ) + s_next = ( + ( + np.ones( + 1, + ), + np.ones( + 1, + ), + ) + if tupled_state + else np.ones( + 1, + ) + ) + r = np.ones( + 1, + ) + non_terminal = np.ones( + 1, + ) + if i == episode_length - 1: + non_terminal = np.zeros( + 1, + ) experience.append((s_current, a, r, non_terminal, s_next)) return experience @@ -107,50 +152,55 @@ def test_has_batch_dimension_non_tupled_discrete_state(self): assert has_batch_dimension(batched_state, env_info) assert not has_batch_dimension(non_batched_state, env_info) - @pytest.mark.parametrize("gamma, lmb, expected_adv, expected_vtarg, tupled_state", - [[1., 0., np.array([[1.], [1.], [-1.]]), np.array([[3.], [3.], [1.]]), False], - [1., 1., np.array([[1.], [0.], [-1.]]), np.array([[3.], [2.], [1.]]), False], - [0.9, 0.7, np.array([[0.9071], [0.17], [-1.]]), - np.array([[2.9071], [2.17], [1.]]), False], - [1., 0., np.array([[1.], [1.], [-3.]]), np.array([[5.], [5.], [1.]]), True], - [1., 1., np.array([[-1.], [-2.], [-3.]]), np.array([[3.], [2.], [1.]]), True], - ]) + @pytest.mark.parametrize( + "gamma, lmb, expected_adv, expected_vtarg, tupled_state", + [ + [1.0, 0.0, np.array([[1.0], [1.0], [-1.0]]), np.array([[3.0], [3.0], [1.0]]), False], + [1.0, 1.0, np.array([[1.0], [0.0], [-1.0]]), np.array([[3.0], [2.0], [1.0]]), False], + [0.9, 0.7, np.array([[0.9071], [0.17], [-1.0]]), np.array([[2.9071], [2.17], [1.0]]), False], + [1.0, 0.0, np.array([[1.0], [1.0], [-3.0]]), np.array([[5.0], [5.0], [1.0]]), True], + [1.0, 1.0, np.array([[-1.0], [-2.0], [-3.0]]), np.array([[3.0], [2.0], [1.0]]), True], + ], + ) def test_compute_v_target_and_advantage(self, gamma, lmb, expected_adv, expected_vtarg, tupled_state): dummy_v_function = DummyVFunction() dummy_experience = self._collect_dummy_experience(tupled_state=tupled_state) - actual_vtarg, actual_adv = compute_v_target_and_advantage( - dummy_v_function, dummy_experience, gamma, lmb) + actual_vtarg, actual_adv = compute_v_target_and_advantage(dummy_v_function, dummy_experience, gamma, lmb) assert np.allclose(actual_adv, expected_adv) assert np.allclose(actual_vtarg, expected_vtarg) - @pytest.mark.parametrize("lmb, expected_adv, expected_vtarg, tupled_state", - [[0., np.array([[0.], [0.], [-2.]]), np.array([[2.], [2.], [0.]]), False], - [1., np.array([[-2.], [-2.], [-2.]]), np.array([[0.], [0.], [0.]]), False], - [0.7, np.array([[-0.98], [-1.4], [-2.]]), np.array([[1.02], [0.6], [0.]]), False], - [0., np.array([[0.], [0.], [-4.]]), np.array([[4.], [4.], [0.]]), True], - [1., np.array([[-4.], [-4.], [-4.]]), np.array([[0.], [0.], [0.]]), True], - ]) + @pytest.mark.parametrize( + "lmb, expected_adv, expected_vtarg, tupled_state", + [ + [0.0, np.array([[0.0], [0.0], [-2.0]]), np.array([[2.0], [2.0], [0.0]]), False], + [1.0, np.array([[-2.0], [-2.0], [-2.0]]), np.array([[0.0], [0.0], [0.0]]), False], + [0.7, np.array([[-0.98], [-1.4], [-2.0]]), np.array([[1.02], [0.6], [0.0]]), False], + [0.0, np.array([[0.0], [0.0], [-4.0]]), np.array([[4.0], [4.0], [0.0]]), True], + [1.0, np.array([[-4.0], [-4.0], [-4.0]]), np.array([[0.0], [0.0], [0.0]]), True], + ], + ) def test_compute_average_v_target_and_advantage(self, lmb, expected_adv, expected_vtarg, tupled_state): dummy_v_function = DummyVFunction() dummy_experience = self._collect_dummy_experience(tupled_state=tupled_state) - actual_vtarg, actual_adv = compute_average_v_target_and_advantage( - dummy_v_function, dummy_experience, lmb) + actual_vtarg, actual_adv = compute_average_v_target_and_advantage(dummy_v_function, dummy_experience, lmb) assert np.allclose(actual_adv, expected_adv) assert np.allclose(actual_vtarg, expected_vtarg) def test_state_preprocessed_v_function(self): - state_shape = (5, ) + state_shape = (5,) from nnabla_rl.models import TRPOVFunction - v_scope_name = 'old_v' + + v_scope_name = "old_v" v_function = TRPOVFunction(v_scope_name) import nnabla_rl.preprocessors as RP - preprocessor_scope_name = 'test_preprocessor' + + preprocessor_scope_name = "test_preprocessor" preprocessor = RP.RunningMeanNormalizer(preprocessor_scope_name, shape=state_shape) v_function_old = _StatePreprocessedVFunction(v_function=v_function, preprocessor=preprocessor) @@ -158,22 +208,24 @@ def test_state_preprocessed_v_function(self): s = nn.Variable.from_numpy_array(np.empty(shape=(1, *state_shape))) _ = v_function_old.v(s) - v_new_scope_name = 'new_v' + v_new_scope_name = "new_v" v_function_new = v_function_old.deepcopy(v_new_scope_name) assert v_function_old.scope_name != v_function_new.scope_name assert v_function_old._preprocessor.scope_name == v_function_new._preprocessor.scope_name def test_state_preprocessed_stochastic_policy(self): - state_shape = (5, ) + state_shape = (5,) action_dim = 10 from nnabla_rl.models import TRPOPolicy - pi_scope_name = 'old_pi' + + pi_scope_name = "old_pi" pi = TRPOPolicy(pi_scope_name, action_dim=action_dim) import nnabla_rl.preprocessors as RP - preprocessor_scope_name = 'test_preprocessor' + + preprocessor_scope_name = "test_preprocessor" preprocessor = RP.RunningMeanNormalizer(preprocessor_scope_name, shape=state_shape) pi_old = _StatePreprocessedStochasticPolicy(policy=pi, preprocessor=preprocessor) @@ -181,22 +233,24 @@ def test_state_preprocessed_stochastic_policy(self): s = nn.Variable.from_numpy_array(np.empty(shape=(1, *state_shape))) _ = pi_old.pi(s) - pi_new_scope_name = 'new_pi' + pi_new_scope_name = "new_pi" pi_new = pi_old.deepcopy(pi_new_scope_name) assert pi_old.scope_name != pi_new.scope_name assert pi_old._preprocessor.scope_name == pi_new._preprocessor.scope_name def test_state_preprocessed_deterministic_policy(self): - state_shape = (5, ) + state_shape = (5,) action_dim = 10 from nnabla_rl.models import TD3Policy - pi_scope_name = 'old_pi' + + pi_scope_name = "old_pi" pi = TD3Policy(pi_scope_name, action_dim=action_dim, max_action_value=1.0) import nnabla_rl.preprocessors as RP - preprocessor_scope_name = 'test_preprocessor' + + preprocessor_scope_name = "test_preprocessor" preprocessor = RP.RunningMeanNormalizer(preprocessor_scope_name, shape=state_shape) pi_old = _StatePreprocessedDeterministicPolicy(policy=pi, preprocessor=preprocessor) @@ -204,7 +258,7 @@ def test_state_preprocessed_deterministic_policy(self): s = nn.Variable.from_numpy_array(np.empty(shape=(1, *state_shape))) _ = pi_old.pi(s) - pi_new_scope_name = 'new_pi' + pi_new_scope_name = "new_pi" pi_new = pi_old.deepcopy(pi_new_scope_name) assert pi_old.scope_name != pi_new.scope_name @@ -220,7 +274,8 @@ def test_action_selector_tupled_state(self): action_dim = env_info.action_dim from nnabla_rl.models import HERPolicy - pi_scope_name = 'pi' + + pi_scope_name = "pi" pi = HERPolicy(pi_scope_name, action_dim=action_dim, max_action_value=1.0) selector = _DeterministicPolicyActionSelector(env_info, pi) @@ -245,7 +300,8 @@ def test_action_selector_non_tupled_state(self): action_dim = env_info.action_dim from nnabla_rl.models import TD3Policy - pi_scope_name = 'pi' + + pi_scope_name = "pi" pi = TD3Policy(pi_scope_name, action_dim=action_dim, max_action_value=1.0) selector = _DeterministicPolicyActionSelector(env_info, pi) @@ -270,6 +326,7 @@ def test_influence_metrics_evaluator(self): env_info = EnvironmentInfo.from_env(env) from nnabla_rl.models import SACDQFunction + q_scope_name = "q" q_function = SACDQFunction(q_scope_name, num_factors=num_factors) diff --git a/tests/algorithms/test_ddp.py b/tests/algorithms/test_ddp.py index b9975e5e..96cd642a 100644 --- a/tests/algorithms/test_ddp.py +++ b/tests/algorithms/test_ddp.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,16 +32,17 @@ def __init__(self, dt=0.01): # input changes the velocity self._B = np.array([[0, 0], [0, dt]]) - def next_state(self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False) \ - -> Tuple[np.ndarray, Dict[str, Any]]: + def next_state( + self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False + ) -> Tuple[np.ndarray, Dict[str, Any]]: return self._A.dot(x) + self._B.dot(u), {} - def gradient(self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False) \ - -> Tuple[np.ndarray, np.ndarray]: + def gradient(self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False) -> Tuple[np.ndarray, np.ndarray]: return self._A, self._B - def hessian(self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False) \ - -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + def hessian( + self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: state_dim = self.state_dim() action_dim = self.action_dim() Fxx = np.zeros(shape=(state_dim, state_dim, state_dim)) @@ -71,7 +72,7 @@ def evaluate( return x.T.dot(self._Q).dot(x) else: # Assuming that target state is zero - return x.T.dot(self._Q).dot(x) + 2.0*x.T.dot(self._F).dot(u) + u.T.dot(self._R).dot(u) + return x.T.dot(self._Q).dot(x) + 2.0 * x.T.dot(self._F).dot(u) + u.T.dot(self._R).dot(u) def gradient( self, x: np.ndarray, u: Optional[np.ndarray], t: int, final_state: bool = False, batched: bool = False @@ -92,18 +93,18 @@ def hessian( class TestDDP(object): def test_algorithm_name(self): - env = E.DummyContinuous(observation_shape=(2, ), action_shape=(2, )) + env = E.DummyContinuous(observation_shape=(2,), action_shape=(2,)) dynamics = LinearDynamics() cost_function = QuadraticCostFunction() ddp = A.DDP(env, dynamics=dynamics, cost_function=cost_function) - assert ddp.__name__ == 'DDP' + assert ddp.__name__ == "DDP" def test_continuous_action_env_supported(self): """Check that no error occurs when training on continuous action env.""" - env = E.DummyContinuous(observation_shape=(2, ), action_shape=(2, )) + env = E.DummyContinuous(observation_shape=(2,), action_shape=(2,)) dynamics = LinearDynamics() cost_function = QuadraticCostFunction() @@ -119,7 +120,7 @@ def test_discrete_action_env_not_supported(self): A.DDP(env, dynamics=dynamics, cost_function=cost_function) def test_compute_eval_action(self): - env = E.DummyContinuous(observation_shape=(2, ), action_shape=(2, )) + env = E.DummyContinuous(observation_shape=(2,), action_shape=(2,)) dynamics = LinearDynamics() cost_function = QuadraticCostFunction() @@ -138,7 +139,7 @@ def test_compute_eval_action(self): np.testing.assert_almost_equal(lqr_action, ddp_action, decimal=3) def test_compute_trajectory(self): - env = E.DummyContinuous(observation_shape=(2, ), action_shape=(2, )) + env = E.DummyContinuous(observation_shape=(2,), action_shape=(2,)) dynamics = LinearDynamics() cost_function = QuadraticCostFunction() @@ -155,13 +156,13 @@ def test_compute_trajectory(self): lqr_trajectory, _ = lqr.compute_trajectory(initial_trajectory) ddp_trajectory, _ = ddp.compute_trajectory(initial_trajectory) - for (lqr_state, ddp_state) in zip(lqr_trajectory[:-1], ddp_trajectory[:-1]): + for lqr_state, ddp_state in zip(lqr_trajectory[:-1], ddp_trajectory[:-1]): np.testing.assert_almost_equal(lqr_state[0], ddp_state[0], decimal=3) np.testing.assert_almost_equal(lqr_state[1], ddp_state[1], decimal=3) def test_run_online_training(self): """Check that error occurs when calling online training.""" - env = E.DummyContinuous(observation_shape=(2, ), action_shape=(2, )) + env = E.DummyContinuous(observation_shape=(2,), action_shape=(2,)) dynamics = LinearDynamics() cost_function = QuadraticCostFunction() @@ -171,7 +172,7 @@ def test_run_online_training(self): def test_run_offline_training(self): """Check that error occurs when calling offline training.""" - env = E.DummyContinuous(observation_shape=(2, ), action_shape=(2, )) + env = E.DummyContinuous(observation_shape=(2,), action_shape=(2,)) dynamics = LinearDynamics() cost_function = QuadraticCostFunction() diff --git a/tests/algorithms/test_ddpg.py b/tests/algorithms/test_ddpg.py index faa93f62..834f92f2 100644 --- a/tests/algorithms/test_ddpg.py +++ b/tests/algorithms/test_ddpg.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -64,14 +64,14 @@ def is_recurrent(self) -> bool: def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: shapes: Dict[str, nn.Variable] = {} - shapes['lstm_hidden'] = (self._lstm_state_size, ) - shapes['lstm_cell'] = (self._lstm_state_size, ) + shapes["lstm_hidden"] = (self._lstm_state_size,) + shapes["lstm_cell"] = (self._lstm_state_size,) return shapes def get_internal_states(self) -> Dict[str, nn.Variable]: states: Dict[str, nn.Variable] = {} - states['lstm_hidden'] = self._h - states['lstm_cell'] = self._c + states["lstm_hidden"] = self._h + states["lstm_cell"] = self._c return states def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): @@ -81,8 +81,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): if self._c is not None: self._c.data.zero() else: - self._h = states['lstm_hidden'] - self._c = states['lstm_cell'] + self._h = states["lstm_hidden"] + self._c = states["lstm_cell"] def _create_internal_states(self, batch_size): self._h = nn.Variable((batch_size, self._lstm_state_size)) @@ -129,14 +129,14 @@ def is_recurrent(self) -> bool: def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: shapes: Dict[str, nn.Variable] = {} - shapes['lstm_hidden'] = (self._lstm_state_size, ) - shapes['lstm_cell'] = (self._lstm_state_size, ) + shapes["lstm_hidden"] = (self._lstm_state_size,) + shapes["lstm_cell"] = (self._lstm_state_size,) return shapes def get_internal_states(self) -> Dict[str, nn.Variable]: states: Dict[str, nn.Variable] = {} - states['lstm_hidden'] = self._h - states['lstm_cell'] = self._c + states["lstm_hidden"] = self._h + states["lstm_cell"] = self._c return states def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): @@ -146,8 +146,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): if self._c is not None: self._c.data.zero() else: - self._h = states['lstm_hidden'] - self._c = states['lstm_cell'] + self._h = states["lstm_hidden"] + self._c = states["lstm_cell"] def _create_internal_states(self, batch_size): self._h = nn.Variable((batch_size, self._lstm_state_size)) @@ -168,7 +168,7 @@ def test_algorithm_name(self): dummy_env = E.DummyContinuous() ddpg = A.DDPG(dummy_env) - assert ddpg.__name__ == 'DDPG' + assert ddpg.__name__ == "DDPG" def test_discrete_action_env_unsupported(self): """Check that error occurs when training on discrete action env.""" @@ -189,6 +189,7 @@ def test_run_online_training(self): def test_run_online_rnn_training(self): """Check that no error occurs when calling online training with RNN model.""" + class RNNActorBuilder(ModelBuilder[DeterministicPolicy]): def build_model(self, scope_name: str, env_info, algorithm_config, **kwargs): return RNNActorFunction(scope_name, action_dim=env_info.action_dim) @@ -238,20 +239,21 @@ def test_latest_iteration_state(self): dummy_env = E.DummyContinuous() ddpg = A.DDPG(dummy_env) - ddpg._q_function_trainer_state = {'q_loss': 0., 'td_errors': np.array([0., 1.])} - ddpg._policy_trainer_state = {'pi_loss': 1.} + ddpg._q_function_trainer_state = {"q_loss": 0.0, "td_errors": np.array([0.0, 1.0])} + ddpg._policy_trainer_state = {"pi_loss": 1.0} latest_iteration_state = ddpg.latest_iteration_state - assert 'q_loss' in latest_iteration_state['scalar'] - assert 'pi_loss' in latest_iteration_state['scalar'] - assert 'td_errors' in latest_iteration_state['histogram'] - assert latest_iteration_state['scalar']['q_loss'] == 0. - assert latest_iteration_state['scalar']['pi_loss'] == 1. - assert np.allclose(latest_iteration_state['histogram']['td_errors'], np.array([0., 1.])) + assert "q_loss" in latest_iteration_state["scalar"] + assert "pi_loss" in latest_iteration_state["scalar"] + assert "td_errors" in latest_iteration_state["histogram"] + assert latest_iteration_state["scalar"]["q_loss"] == 0.0 + assert latest_iteration_state["scalar"]["pi_loss"] == 1.0 + assert np.allclose(latest_iteration_state["histogram"]["td_errors"], np.array([0.0, 1.0])) if __name__ == "__main__": from testing_utils import generate_dummy_experiences + pytest.main() else: from ..testing_utils import generate_dummy_experiences diff --git a/tests/algorithms/test_ddqn.py b/tests/algorithms/test_ddqn.py index 92745651..8ef14de6 100644 --- a/tests/algorithms/test_ddqn.py +++ b/tests/algorithms/test_ddqn.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ def test_algorithm_name(self): dummy_env = E.DummyDiscreteImg() ddqn = A.DDQN(dummy_env) - assert ddqn.__name__ == 'DDQN' + assert ddqn.__name__ == "DDQN" def test_continuous_action_env_unsupported(self): """Check that error occurs when training on continuous action env.""" @@ -107,17 +107,18 @@ def test_latest_iteration_state(self): dummy_env = E.DummyDiscreteImg() ddqn = A.DDQN(dummy_env) - ddqn._q_function_trainer_state = {'q_loss': 0., 'td_errors': np.array([0., 1.])} + ddqn._q_function_trainer_state = {"q_loss": 0.0, "td_errors": np.array([0.0, 1.0])} latest_iteration_state = ddqn.latest_iteration_state - assert 'q_loss' in latest_iteration_state['scalar'] - assert 'td_errors' in latest_iteration_state['histogram'] - assert latest_iteration_state['scalar']['q_loss'] == 0. - assert np.allclose(latest_iteration_state['histogram']['td_errors'], np.array([0., 1.])) + assert "q_loss" in latest_iteration_state["scalar"] + assert "td_errors" in latest_iteration_state["histogram"] + assert latest_iteration_state["scalar"]["q_loss"] == 0.0 + assert np.allclose(latest_iteration_state["histogram"]["td_errors"], np.array([0.0, 1.0])) if __name__ == "__main__": from testing_utils import generate_dummy_experiences + pytest.main() else: from ..testing_utils import generate_dummy_experiences diff --git a/tests/algorithms/test_decision_transformer.py b/tests/algorithms/test_decision_transformer.py index 9971bc04..2ba732f6 100644 --- a/tests/algorithms/test_decision_transformer.py +++ b/tests/algorithms/test_decision_transformer.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ def test_algorithm_name(self): dummy_env = E.DummyDiscreteImg(max_episode_steps=10) decision_transformer = A.DecisionTransformer(dummy_env) - assert decision_transformer.__name__ == 'DecisionTransformer' + assert decision_transformer.__name__ == "DecisionTransformer" def test_run_online_training(self): """Check that error occurs when calling online training.""" @@ -53,8 +53,9 @@ def test_run_offline_training(self): for _ in range(trajectory_num): trajectory = generate_dummy_trajectory(dummy_env, trajectory_length) # Add info required by decision transformer - trajectory = tuple((s, a, r, done, s_next, {'rtg': 1, 'timesteps': 1}) - for (s, a, r, done, s_next, *_) in trajectory) + trajectory = tuple( + (s, a, r, done, s_next, {"rtg": 1, "timesteps": 1}) for (s, a, r, done, s_next, *_) in trajectory + ) buffer.append_trajectory(trajectory) batch_size = 3 @@ -73,12 +74,12 @@ def test_compute_eval_action(self): state = dummy_env.reset() state = np.float32(state) - extra_info = {'reward': 0.0} + extra_info = {"reward": 0.0} action = decision_transformer.compute_eval_action(state, extra_info=extra_info, begin_of_episode=True) - assert action.shape == (1, ) + assert action.shape == (1,) action = decision_transformer.compute_eval_action(state, extra_info=extra_info, begin_of_episode=False) - assert action.shape == (1, ) + assert action.shape == (1,) def test_parameter_range(self): with pytest.raises(ValueError): @@ -115,15 +116,16 @@ def test_latest_iteration_state(self): config = A.DecisionTransformerConfig(batch_size=batch_size, max_timesteps=trajectory_length, context_length=10) decision_transformer = A.DecisionTransformer(dummy_env, config=config) - decision_transformer._decision_transformer_trainer_state = {'loss': 0.} + decision_transformer._decision_transformer_trainer_state = {"loss": 0.0} latest_iteration_state = decision_transformer.latest_iteration_state - assert 'loss' in latest_iteration_state['scalar'] - assert latest_iteration_state['scalar']['loss'] == 0. + assert "loss" in latest_iteration_state["scalar"] + assert latest_iteration_state["scalar"]["loss"] == 0.0 if __name__ == "__main__": from testing_utils import generate_dummy_trajectory + pytest.main() else: from ..testing_utils import generate_dummy_trajectory diff --git a/tests/algorithms/test_demme_sac.py b/tests/algorithms/test_demme_sac.py index 8de38376..3c78be88 100644 --- a/tests/algorithms/test_demme_sac.py +++ b/tests/algorithms/test_demme_sac.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ class RNNPolicyFunction(StochasticPolicy): def __init__(self, scope_name: str, action_dim: int): super(RNNPolicyFunction, self).__init__(scope_name) self._action_dim = action_dim - self._lstm_state_size = action_dim*2 + self._lstm_state_size = action_dim * 2 self._h = None self._c = None @@ -70,14 +70,14 @@ def is_recurrent(self) -> bool: def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: shapes: Dict[str, nn.Variable] = {} - shapes['lstm_hidden'] = (self._lstm_state_size, ) - shapes['lstm_cell'] = (self._lstm_state_size, ) + shapes["lstm_hidden"] = (self._lstm_state_size,) + shapes["lstm_cell"] = (self._lstm_state_size,) return shapes def get_internal_states(self) -> Dict[str, nn.Variable]: states: Dict[str, nn.Variable] = {} - states['lstm_hidden'] = self._h - states['lstm_cell'] = self._c + states["lstm_hidden"] = self._h + states["lstm_cell"] = self._c return states def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): @@ -87,8 +87,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): if self._c is not None: self._c.data.zero() else: - self._h = states['lstm_hidden'] - self._c = states['lstm_cell'] + self._h = states["lstm_hidden"] + self._c = states["lstm_cell"] def _create_internal_states(self, batch_size): self._h = nn.Variable((batch_size, self._lstm_state_size)) @@ -135,14 +135,14 @@ def is_recurrent(self) -> bool: def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: shapes: Dict[str, nn.Variable] = {} - shapes['lstm_hidden'] = (self._lstm_state_size, ) - shapes['lstm_cell'] = (self._lstm_state_size, ) + shapes["lstm_hidden"] = (self._lstm_state_size,) + shapes["lstm_cell"] = (self._lstm_state_size,) return shapes def get_internal_states(self) -> Dict[str, nn.Variable]: states: Dict[str, nn.Variable] = {} - states['lstm_hidden'] = self._h - states['lstm_cell'] = self._c + states["lstm_hidden"] = self._h + states["lstm_cell"] = self._c return states def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): @@ -152,8 +152,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): if self._c is not None: self._c.data.zero() else: - self._h = states['lstm_hidden'] - self._c = states['lstm_cell'] + self._h = states["lstm_hidden"] + self._c = states["lstm_cell"] def _create_internal_states(self, batch_size): self._h = nn.Variable((batch_size, self._lstm_state_size)) @@ -199,14 +199,14 @@ def is_recurrent(self) -> bool: def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: shapes: Dict[str, nn.Variable] = {} - shapes['lstm_hidden'] = (self._lstm_state_size, ) - shapes['lstm_cell'] = (self._lstm_state_size, ) + shapes["lstm_hidden"] = (self._lstm_state_size,) + shapes["lstm_cell"] = (self._lstm_state_size,) return shapes def get_internal_states(self) -> Dict[str, nn.Variable]: states: Dict[str, nn.Variable] = {} - states['lstm_hidden'] = self._h - states['lstm_cell'] = self._c + states["lstm_hidden"] = self._h + states["lstm_cell"] = self._c return states def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): @@ -216,8 +216,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): if self._c is not None: self._c.data.zero() else: - self._h = states['lstm_hidden'] - self._c = states['lstm_cell'] + self._h = states["lstm_hidden"] + self._c = states["lstm_cell"] def _create_internal_states(self, batch_size): self._h = nn.Variable((batch_size, self._lstm_state_size)) @@ -238,7 +238,7 @@ def test_algorithm_name(self): dummy_env = E.DummyContinuous() sac = A.DEMMESAC(dummy_env) - assert sac.__name__ == 'DEMMESAC' + assert sac.__name__ == "DEMMESAC" def test_discrete_action_env_unsupported(self): """Check that error occurs when training on discrete action env.""" @@ -256,6 +256,7 @@ def test_run_online_training(self): def test_run_online_rnn_training(self): """Check that no error occurs when calling online training with RNN model.""" + class RNNPolicyBuilder(ModelBuilder[StochasticPolicy]): def build_model(self, scope_name: str, env_info, algorithm_config, **kwargs): return RNNPolicyFunction(scope_name, action_dim=env_info.action_dim) @@ -286,13 +287,16 @@ def build_model(self, scope_name: str, env_info, algorithm_config, **kwargs): config.v_re_burn_in_steps = 2 config.start_timesteps = 7 config.batch_size = 2 - sac = A.DEMMESAC(dummy_env, config=config, - pi_t_builder=RNNPolicyBuilder(), - pi_e_builder=RNNPolicyBuilder(), - q_rr_function_builder=RNNQFunctionBuilder(), - q_re_function_builder=RNNQFunctionBuilder(), - v_rr_function_builder=RNNVFunctionBuilder(), - v_re_function_builder=RNNVFunctionBuilder()) + sac = A.DEMMESAC( + dummy_env, + config=config, + pi_t_builder=RNNPolicyBuilder(), + pi_e_builder=RNNPolicyBuilder(), + q_rr_function_builder=RNNQFunctionBuilder(), + q_re_function_builder=RNNQFunctionBuilder(), + v_rr_function_builder=RNNVFunctionBuilder(), + v_re_function_builder=RNNVFunctionBuilder(), + ) sac.train_online(dummy_env, total_iterations=10) @@ -323,10 +327,8 @@ def test_target_network_initialization(self): sac = A.DEMMESAC(dummy_env) # Should be initialized to same parameters - assert self._has_same_parameters( - sac._v_rr.get_parameters(), sac._target_v_rr.get_parameters()) - assert self._has_same_parameters( - sac._v_re.get_parameters(), sac._target_v_re.get_parameters()) + assert self._has_same_parameters(sac._v_rr.get_parameters(), sac._target_v_rr.get_parameters()) + assert self._has_same_parameters(sac._v_re.get_parameters(), sac._target_v_re.get_parameters()) def test_parameter_range(self): with pytest.raises(ValueError): @@ -359,28 +361,29 @@ def test_latest_iteration_state(self): dummy_env = E.DummyContinuous() sac = A.DEMMESAC(dummy_env) - sac._q_rr_trainer_state = {'q_loss': 0., 'td_errors': np.array([0., 1.])} - sac._q_re_trainer_state = {'q_loss': 1., 'td_errors': np.array([1., 2.])} - sac._v_rr_trainer_state = {'v_loss': 2.} - sac._v_re_trainer_state = {'v_loss': 3.} - sac._pi_t_trainer_state = {'pi_loss': 4.} - sac._pi_e_trainer_state = {'pi_loss': 5.} + sac._q_rr_trainer_state = {"q_loss": 0.0, "td_errors": np.array([0.0, 1.0])} + sac._q_re_trainer_state = {"q_loss": 1.0, "td_errors": np.array([1.0, 2.0])} + sac._v_rr_trainer_state = {"v_loss": 2.0} + sac._v_re_trainer_state = {"v_loss": 3.0} + sac._pi_t_trainer_state = {"pi_loss": 4.0} + sac._pi_e_trainer_state = {"pi_loss": 5.0} latest_iteration_state = sac.latest_iteration_state - for loss in ['q_rr_loss', 'q_re_loss', 'pi_t_loss', 'pi_e_loss', 'v_rr_loss', 'v_re_loss']: - assert loss in latest_iteration_state['scalar'] - assert latest_iteration_state['scalar']['q_rr_loss'] == 0. - assert latest_iteration_state['scalar']['q_re_loss'] == 1. - assert latest_iteration_state['scalar']['v_rr_loss'] == 2. - assert latest_iteration_state['scalar']['v_re_loss'] == 3. - assert latest_iteration_state['scalar']['pi_t_loss'] == 4. - assert latest_iteration_state['scalar']['pi_e_loss'] == 5. - assert np.allclose(latest_iteration_state['histogram']['q_rr_td_errors'], np.array([0., 1.])) - assert np.allclose(latest_iteration_state['histogram']['q_re_td_errors'], np.array([1., 2.])) + for loss in ["q_rr_loss", "q_re_loss", "pi_t_loss", "pi_e_loss", "v_rr_loss", "v_re_loss"]: + assert loss in latest_iteration_state["scalar"] + assert latest_iteration_state["scalar"]["q_rr_loss"] == 0.0 + assert latest_iteration_state["scalar"]["q_re_loss"] == 1.0 + assert latest_iteration_state["scalar"]["v_rr_loss"] == 2.0 + assert latest_iteration_state["scalar"]["v_re_loss"] == 3.0 + assert latest_iteration_state["scalar"]["pi_t_loss"] == 4.0 + assert latest_iteration_state["scalar"]["pi_e_loss"] == 5.0 + assert np.allclose(latest_iteration_state["histogram"]["q_rr_td_errors"], np.array([0.0, 1.0])) + assert np.allclose(latest_iteration_state["histogram"]["q_re_td_errors"], np.array([1.0, 2.0])) if __name__ == "__main__": from tests.testing_utils import generate_dummy_experiences + pytest.main() else: from ..testing_utils import generate_dummy_experiences diff --git a/tests/algorithms/test_dqn.py b/tests/algorithms/test_dqn.py index 35b33a16..194951a7 100644 --- a/tests/algorithms/test_dqn.py +++ b/tests/algorithms/test_dqn.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ def test_algorithm_name(self): dummy_env = E.DummyDiscreteImg() dqn = A.DQN(dummy_env) - assert dqn.__name__ == 'DQN' + assert dqn.__name__ == "DQN" def test_continuous_action_env_unsupported(self): """Check that error occurs when training on continuous action env.""" @@ -70,9 +70,11 @@ def test_run_online_training_multistep(self): def test_run_online_rnn_training(self): """Check that no error occurs when calling online training with RNN model.""" + class RNNModelBuilder(ModelBuilder[QFunction]): def build_model(self, scope_name: str, env_info, algorithm_config, **kwargs): return DRQNQFunction(scope_name, env_info.action_dim) + dummy_env = E.DummyDiscreteImg() config = A.DQNConfig() config.num_steps = 2 @@ -143,17 +145,18 @@ def test_latest_iteration_state(self): dummy_env = E.DummyDiscreteImg() dqn = A.DQN(dummy_env) - dqn._q_function_trainer_state = {'q_loss': 0., 'td_errors': np.array([0., 1.])} + dqn._q_function_trainer_state = {"q_loss": 0.0, "td_errors": np.array([0.0, 1.0])} latest_iteration_state = dqn.latest_iteration_state - assert 'q_loss' in latest_iteration_state['scalar'] - assert 'td_errors' in latest_iteration_state['histogram'] - assert latest_iteration_state['scalar']['q_loss'] == 0. - assert np.allclose(latest_iteration_state['histogram']['td_errors'], np.array([0., 1.])) + assert "q_loss" in latest_iteration_state["scalar"] + assert "td_errors" in latest_iteration_state["histogram"] + assert latest_iteration_state["scalar"]["q_loss"] == 0.0 + assert np.allclose(latest_iteration_state["histogram"]["td_errors"], np.array([0.0, 1.0])) if __name__ == "__main__": from testing_utils import generate_dummy_experiences + pytest.main() else: from ..testing_utils import generate_dummy_experiences diff --git a/tests/algorithms/test_drqn.py b/tests/algorithms/test_drqn.py index 1e57aabc..a8daa7c8 100644 --- a/tests/algorithms/test_drqn.py +++ b/tests/algorithms/test_drqn.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ def test_algorithm_name(self): dummy_env = E.DummyDiscreteImg() drqn = A.DRQN(dummy_env) - assert drqn.__name__ == 'DRQN' + assert drqn.__name__ == "DRQN" def test_continuous_action_env_unsupported(self): """Check that error occurs when training on continuous action env.""" @@ -109,17 +109,18 @@ def test_latest_iteration_state(self): dummy_env = E.DummyDiscreteImg() drqn = A.DRQN(dummy_env) - drqn._q_function_trainer_state = {'q_loss': 0., 'td_errors': np.array([0., 1.])} + drqn._q_function_trainer_state = {"q_loss": 0.0, "td_errors": np.array([0.0, 1.0])} latest_iteration_state = drqn.latest_iteration_state - assert 'q_loss' in latest_iteration_state['scalar'] - assert 'td_errors' in latest_iteration_state['histogram'] - assert latest_iteration_state['scalar']['q_loss'] == 0. - assert np.allclose(latest_iteration_state['histogram']['td_errors'], np.array([0., 1.])) + assert "q_loss" in latest_iteration_state["scalar"] + assert "td_errors" in latest_iteration_state["histogram"] + assert latest_iteration_state["scalar"]["q_loss"] == 0.0 + assert np.allclose(latest_iteration_state["histogram"]["td_errors"], np.array([0.0, 1.0])) if __name__ == "__main__": from testing_utils import generate_dummy_experiences + pytest.main() else: from ..testing_utils import generate_dummy_experiences diff --git a/tests/algorithms/test_dummy.py b/tests/algorithms/test_dummy.py index 6b72c1d6..cce8cfff 100644 --- a/tests/algorithms/test_dummy.py +++ b/tests/algorithms/test_dummy.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ def test_algorithm_name(self): dummy_env = E.DummyDiscrete() dummy = A.Dummy(dummy_env) - assert dummy.__name__ == 'Dummy' + assert dummy.__name__ == "Dummy" def test_run_online_training(self): """Check that no error occurs when calling online training.""" @@ -44,19 +44,16 @@ def test_run_offline_training(self): dummy = A.Dummy(dummy_env) experience_num = 100 - fake_states = np.empty(shape=(experience_num, ) + - dummy_env.observation_space.shape) - fake_actions = np.empty(shape=(experience_num, ) + - dummy_env.action_space.shape) + fake_states = np.empty(shape=(experience_num,) + dummy_env.observation_space.shape) + fake_actions = np.empty(shape=(experience_num,) + dummy_env.action_space.shape) fake_rewards = np.empty(shape=(experience_num, 1)) fake_non_terminals = np.empty(shape=(experience_num, 1)) - fake_next_states = np.empty(shape=(experience_num, ) + - dummy_env.observation_space.shape) - fake_next_actions = np.empty(shape=(experience_num, ) + - dummy_env.action_space.shape) + fake_next_states = np.empty(shape=(experience_num,) + dummy_env.observation_space.shape) + fake_next_actions = np.empty(shape=(experience_num,) + dummy_env.action_space.shape) - fake_experiences = zip(fake_states, fake_actions, fake_rewards, - fake_non_terminals, fake_next_states, fake_next_actions) + fake_experiences = zip( + fake_states, fake_actions, fake_rewards, fake_non_terminals, fake_next_states, fake_next_actions + ) dummy.train_offline(fake_experiences, total_iterations=10) def test_compute_eval_action(self): diff --git a/tests/algorithms/test_gail.py b/tests/algorithms/test_gail.py index 1bc286dd..e6b5bbd8 100644 --- a/tests/algorithms/test_gail.py +++ b/tests/algorithms/test_gail.py @@ -1,5 +1,5 @@ # Copyright 2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ from nnabla_rl.replay_buffer import ReplayBuffer -class TestGAIL(): +class TestGAIL: def setup_method(self): nn.clear_parameters() @@ -37,7 +37,7 @@ def test_algorithm_name(self): dummy_buffer = self._create_dummy_buffer(dummy_env) gail = A.GAIL(dummy_env, dummy_buffer) - assert gail.__name__ == 'GAIL' + assert gail.__name__ == "GAIL" def test_discrete_action_env_unsupported(self): """Check that error occurs when training on discrete action env.""" @@ -54,12 +54,14 @@ def test_run_online_training(self): dummy_env = EpisodicEnv(dummy_env, min_episode_length=3) dummy_buffer = self._create_dummy_buffer(dummy_env, batch_size=15) - config = A.GAILConfig(num_steps_per_iteration=5, - pi_batch_size=5, - vf_batch_size=2, - discriminator_batch_size=2, - sigma_kl_divergence_constraint=10.0, - maximum_backtrack_numbers=50) + config = A.GAILConfig( + num_steps_per_iteration=5, + pi_batch_size=5, + vf_batch_size=2, + discriminator_batch_size=2, + sigma_kl_divergence_constraint=10.0, + maximum_backtrack_numbers=50, + ) gail = A.GAIL(dummy_env, dummy_buffer, config=config) gail.train_online(dummy_env, total_iterations=5) @@ -121,18 +123,19 @@ def test_latest_iteration_state(self): dummy_buffer = self._create_dummy_buffer(dummy_env) gail = A.GAIL(dummy_env, dummy_buffer) - gail._v_function_trainer_state = {'v_loss': 0.} - gail._discriminator_trainer_state = {'reward_loss': 1.} + gail._v_function_trainer_state = {"v_loss": 0.0} + gail._discriminator_trainer_state = {"reward_loss": 1.0} latest_iteration_state = gail.latest_iteration_state - assert 'v_loss' in latest_iteration_state['scalar'] - assert 'reward_loss' in latest_iteration_state['scalar'] - assert latest_iteration_state['scalar']['v_loss'] == 0. - assert latest_iteration_state['scalar']['reward_loss'] == 1. + assert "v_loss" in latest_iteration_state["scalar"] + assert "reward_loss" in latest_iteration_state["scalar"] + assert latest_iteration_state["scalar"]["v_loss"] == 0.0 + assert latest_iteration_state["scalar"]["reward_loss"] == 1.0 if __name__ == "__main__": from testing_utils import EpisodicEnv, generate_dummy_experiences + pytest.main() else: from ..testing_utils import EpisodicEnv, generate_dummy_experiences diff --git a/tests/algorithms/test_her.py b/tests/algorithms/test_her.py index cca0bc1a..759c656d 100644 --- a/tests/algorithms/test_her.py +++ b/tests/algorithms/test_her.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -70,14 +70,14 @@ def is_recurrent(self) -> bool: def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: shapes: Dict[str, nn.Variable] = {} - shapes['lstm_hidden'] = (self._lstm_state_size, ) - shapes['lstm_cell'] = (self._lstm_state_size, ) + shapes["lstm_hidden"] = (self._lstm_state_size,) + shapes["lstm_cell"] = (self._lstm_state_size,) return shapes def get_internal_states(self) -> Dict[str, nn.Variable]: states: Dict[str, nn.Variable] = {} - states['lstm_hidden'] = self._h - states['lstm_cell'] = self._c + states["lstm_hidden"] = self._h + states["lstm_cell"] = self._c return states def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): @@ -87,8 +87,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): if self._c is not None: self._c.data.zero() else: - self._h = states['lstm_hidden'] - self._c = states['lstm_cell'] + self._h = states["lstm_hidden"] + self._c = states["lstm_cell"] def _create_internal_states(self, batch_size): self._h = nn.Variable((batch_size, self._lstm_state_size)) @@ -136,14 +136,14 @@ def is_recurrent(self) -> bool: def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: shapes: Dict[str, nn.Variable] = {} - shapes['lstm_hidden'] = (self._lstm_state_size, ) - shapes['lstm_cell'] = (self._lstm_state_size, ) + shapes["lstm_hidden"] = (self._lstm_state_size,) + shapes["lstm_cell"] = (self._lstm_state_size,) return shapes def get_internal_states(self) -> Dict[str, nn.Variable]: states: Dict[str, nn.Variable] = {} - states['lstm_hidden'] = self._h - states['lstm_cell'] = self._c + states["lstm_hidden"] = self._h + states["lstm_cell"] = self._c return states def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): @@ -153,8 +153,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): if self._c is not None: self._c.data.zero() else: - self._h = states['lstm_hidden'] - self._c = states['lstm_cell'] + self._h = states["lstm_hidden"] + self._c = states["lstm_cell"] def _create_internal_states(self, batch_size): self._h = nn.Variable((batch_size, self._lstm_state_size)) @@ -176,7 +176,7 @@ def test_algorithm_name(self): dummy_env = GoalConditionedTupleObservationEnv(dummy_env) her = A.HER(dummy_env) - assert her.__name__ == 'HER' + assert her.__name__ == "HER" def test_not_goal_conditioned_env_unsupported(self): """Check that error occurs when training on not goal-conditioned @@ -210,6 +210,7 @@ def test_run_online_training(self): def test_run_online_rnn_training(self): """Check that no error occurs when calling online training with RNN model.""" + class RNNActorBuilder(ModelBuilder[DeterministicPolicy]): def build_model(self, scope_name: str, env_info, algorithm_config, **kwargs): return RNNActorFunction(scope_name, action_dim=env_info.action_dim) @@ -262,20 +263,21 @@ def test_latest_iteration_state(self): dummy_env = GoalConditionedTupleObservationEnv(dummy_env) her = A.HER(dummy_env) - her._q_function_trainer_state = {'q_loss': 0., 'td_errors': np.array([0., 1.])} - her._policy_trainer_state = {'pi_loss': 1.} + her._q_function_trainer_state = {"q_loss": 0.0, "td_errors": np.array([0.0, 1.0])} + her._policy_trainer_state = {"pi_loss": 1.0} latest_iteration_state = her.latest_iteration_state - assert 'q_loss' in latest_iteration_state['scalar'] - assert 'pi_loss' in latest_iteration_state['scalar'] - assert 'td_errors' in latest_iteration_state['histogram'] - assert latest_iteration_state['scalar']['q_loss'] == 0. - assert latest_iteration_state['scalar']['pi_loss'] == 1. - assert np.allclose(latest_iteration_state['histogram']['td_errors'], np.array([0., 1.])) + assert "q_loss" in latest_iteration_state["scalar"] + assert "pi_loss" in latest_iteration_state["scalar"] + assert "td_errors" in latest_iteration_state["histogram"] + assert latest_iteration_state["scalar"]["q_loss"] == 0.0 + assert latest_iteration_state["scalar"]["pi_loss"] == 1.0 + assert np.allclose(latest_iteration_state["histogram"]["td_errors"], np.array([0.0, 1.0])) if __name__ == "__main__": from testing_utils import generate_dummy_experiences + pytest.main() else: from ..testing_utils import generate_dummy_experiences diff --git a/tests/algorithms/test_hyar.py b/tests/algorithms/test_hyar.py index 6b3d571e..a04e2e85 100644 --- a/tests/algorithms/test_hyar.py +++ b/tests/algorithms/test_hyar.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ def test_algorithm_name(self): dummy_env = E.DummyHybridEnv() hyar = A.HyAR(dummy_env) - assert hyar.__name__ == 'HyAR' + assert hyar.__name__ == "HyAR" def test_discrete_action_env_unsupported(self): """Check that error occurs when training on discrete action env.""" @@ -46,9 +46,7 @@ def test_continuous_action_env_unsupported(self): def test_run_online_training(self): """Check that no error occurs when calling online training.""" dummy_env = E.DummyHybridEnv(max_episode_steps=10) - config = A.HyARConfig(start_timesteps=1, - vae_pretrain_episodes=1, - vae_pretrain_times=1) + config = A.HyARConfig(start_timesteps=1, vae_pretrain_episodes=1, vae_pretrain_times=1) hyar = A.HyAR(dummy_env, config=config) hyar.train_online(dummy_env, total_iterations=10) @@ -57,9 +55,7 @@ def test_run_offline_training(self): """Check that no error occurs when calling offline training.""" batch_size = 5 dummy_env = E.DummyHybridEnv(max_episode_steps=10) - config = A.HyARConfig(start_timesteps=1, - vae_pretrain_episodes=1, - vae_pretrain_times=1) + config = A.HyARConfig(start_timesteps=1, vae_pretrain_episodes=1, vae_pretrain_times=1) hyar = A.HyAR(dummy_env, config=config) experiences = generate_dummy_experiences(dummy_env, batch_size) @@ -117,29 +113,30 @@ def test_latest_iteration_state(self): dummy_env = E.DummyHybridEnv() hyar = A.HyAR(dummy_env) - hyar._q_function_trainer_state = {'q_loss': 0., 'td_errors': np.array([0., 1.])} - hyar._policy_trainer_state = {'pi_loss': 1.} - hyar._vae_trainer_state = {'encoder_loss': 1., 'kl_loss': 2., 'reconstruction_loss': 3., 'dyn_loss': 4.} + hyar._q_function_trainer_state = {"q_loss": 0.0, "td_errors": np.array([0.0, 1.0])} + hyar._policy_trainer_state = {"pi_loss": 1.0} + hyar._vae_trainer_state = {"encoder_loss": 1.0, "kl_loss": 2.0, "reconstruction_loss": 3.0, "dyn_loss": 4.0} latest_iteration_state = hyar.latest_iteration_state - assert 'q_loss' in latest_iteration_state['scalar'] - assert 'pi_loss' in latest_iteration_state['scalar'] - assert 'encoder_loss' in latest_iteration_state['scalar'] - assert 'kl_loss' in latest_iteration_state['scalar'] - assert 'reconstruction_loss' in latest_iteration_state['scalar'] - assert 'dyn_loss' in latest_iteration_state['scalar'] - assert 'td_errors' in latest_iteration_state['histogram'] - assert latest_iteration_state['scalar']['q_loss'] == 0. - assert latest_iteration_state['scalar']['pi_loss'] == 1. - assert latest_iteration_state['scalar']['encoder_loss'] == 1. - assert latest_iteration_state['scalar']['kl_loss'] == 2. - assert latest_iteration_state['scalar']['reconstruction_loss'] == 3. - assert latest_iteration_state['scalar']['dyn_loss'] == 4. - assert np.allclose(latest_iteration_state['histogram']['td_errors'], np.array([0., 1.])) + assert "q_loss" in latest_iteration_state["scalar"] + assert "pi_loss" in latest_iteration_state["scalar"] + assert "encoder_loss" in latest_iteration_state["scalar"] + assert "kl_loss" in latest_iteration_state["scalar"] + assert "reconstruction_loss" in latest_iteration_state["scalar"] + assert "dyn_loss" in latest_iteration_state["scalar"] + assert "td_errors" in latest_iteration_state["histogram"] + assert latest_iteration_state["scalar"]["q_loss"] == 0.0 + assert latest_iteration_state["scalar"]["pi_loss"] == 1.0 + assert latest_iteration_state["scalar"]["encoder_loss"] == 1.0 + assert latest_iteration_state["scalar"]["kl_loss"] == 2.0 + assert latest_iteration_state["scalar"]["reconstruction_loss"] == 3.0 + assert latest_iteration_state["scalar"]["dyn_loss"] == 4.0 + assert np.allclose(latest_iteration_state["histogram"]["td_errors"], np.array([0.0, 1.0])) if __name__ == "__main__": from testing_utils import generate_dummy_experiences + pytest.main() else: from ..testing_utils import generate_dummy_experiences diff --git a/tests/algorithms/test_icml2015_trpo.py b/tests/algorithms/test_icml2015_trpo.py index 8b79cccb..96d7ed37 100644 --- a/tests/algorithms/test_icml2015_trpo.py +++ b/tests/algorithms/test_icml2015_trpo.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,17 +29,19 @@ def test_algorithm_name(self): dummy_env = E.DummyDiscreteImg() trpo = A.ICML2015TRPO(dummy_env) - assert trpo.__name__ == 'ICML2015TRPO' + assert trpo.__name__ == "ICML2015TRPO" def test_run_online_training(self): """Check that no error occurs when calling online training.""" dummy_env = E.DummyDiscreteImg() dummy_env = EpisodicEnv(dummy_env, min_episode_length=3) - config = A.ICML2015TRPOConfig(batch_size=5, - gpu_batch_size=2, - num_steps_per_iteration=5, - sigma_kl_divergence_constraint=10.0, - maximum_backtrack_numbers=2) + config = A.ICML2015TRPOConfig( + batch_size=5, + gpu_batch_size=2, + num_steps_per_iteration=5, + sigma_kl_divergence_constraint=10.0, + maximum_backtrack_numbers=2, + ) trpo = A.ICML2015TRPO(dummy_env, config=config) trpo.train_online(dummy_env, total_iterations=1) @@ -79,20 +81,18 @@ def test_compute_accumulated_reward(self): gamma = 0.99 episode_length = 3 reward_sequence = np.arange(episode_length) - gamma_seq = np.array( - [gamma**i for i in range(episode_length)]) + gamma_seq = np.array([gamma**i for i in range(episode_length)]) gamma_seqs = np.zeros((episode_length, episode_length)) gamma_seqs[0] = gamma_seq for i in range(1, episode_length): gamma_seqs[i, i:] = gamma_seq[:-i] - expect = np.sum(reward_sequence*gamma_seqs, axis=1) + expect = np.sum(reward_sequence * gamma_seqs, axis=1) dummy_envinfo = E.DummyContinuous() icml2015_trpo = A.ICML2015TRPO(dummy_envinfo) - accumulated_reward = icml2015_trpo._compute_accumulated_reward( - reward_sequence, gamma) + accumulated_reward = icml2015_trpo._compute_accumulated_reward(reward_sequence, gamma) assert expect == pytest.approx(accumulated_reward.flatten()) @@ -110,6 +110,7 @@ def test_compute_accumulated_reward_raise_value_error(self): if __name__ == "__main__": from testing_utils import EpisodicEnv + pytest.main() else: from ..testing_utils import EpisodicEnv diff --git a/tests/algorithms/test_icml2018_sac.py b/tests/algorithms/test_icml2018_sac.py index f123ade3..91fd69c7 100644 --- a/tests/algorithms/test_icml2018_sac.py +++ b/tests/algorithms/test_icml2018_sac.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -39,7 +39,7 @@ class RNNPolicyFunction(StochasticPolicy): def __init__(self, scope_name: str, action_dim: int): super(RNNPolicyFunction, self).__init__(scope_name) self._action_dim = action_dim - self._lstm_state_size = action_dim*2 + self._lstm_state_size = action_dim * 2 self._h = None self._c = None @@ -71,14 +71,14 @@ def is_recurrent(self) -> bool: def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: shapes: Dict[str, nn.Variable] = {} - shapes['lstm_hidden'] = (self._lstm_state_size, ) - shapes['lstm_cell'] = (self._lstm_state_size, ) + shapes["lstm_hidden"] = (self._lstm_state_size,) + shapes["lstm_cell"] = (self._lstm_state_size,) return shapes def get_internal_states(self) -> Dict[str, nn.Variable]: states: Dict[str, nn.Variable] = {} - states['lstm_hidden'] = self._h - states['lstm_cell'] = self._c + states["lstm_hidden"] = self._h + states["lstm_cell"] = self._c return states def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): @@ -88,8 +88,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): if self._c is not None: self._c.data.zero() else: - self._h = states['lstm_hidden'] - self._c = states['lstm_cell'] + self._h = states["lstm_hidden"] + self._c = states["lstm_cell"] def _create_internal_states(self, batch_size): self._h = nn.Variable((batch_size, self._lstm_state_size)) @@ -136,14 +136,14 @@ def is_recurrent(self) -> bool: def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: shapes: Dict[str, nn.Variable] = {} - shapes['lstm_hidden'] = (self._lstm_state_size, ) - shapes['lstm_cell'] = (self._lstm_state_size, ) + shapes["lstm_hidden"] = (self._lstm_state_size,) + shapes["lstm_cell"] = (self._lstm_state_size,) return shapes def get_internal_states(self) -> Dict[str, nn.Variable]: states: Dict[str, nn.Variable] = {} - states['lstm_hidden'] = self._h - states['lstm_cell'] = self._c + states["lstm_hidden"] = self._h + states["lstm_cell"] = self._c return states def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): @@ -153,8 +153,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): if self._c is not None: self._c.data.zero() else: - self._h = states['lstm_hidden'] - self._c = states['lstm_cell'] + self._h = states["lstm_hidden"] + self._c = states["lstm_cell"] def _create_internal_states(self, batch_size): self._h = nn.Variable((batch_size, self._lstm_state_size)) @@ -200,14 +200,14 @@ def is_recurrent(self) -> bool: def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: shapes: Dict[str, nn.Variable] = {} - shapes['lstm_hidden'] = (self._lstm_state_size, ) - shapes['lstm_cell'] = (self._lstm_state_size, ) + shapes["lstm_hidden"] = (self._lstm_state_size,) + shapes["lstm_cell"] = (self._lstm_state_size,) return shapes def get_internal_states(self) -> Dict[str, nn.Variable]: states: Dict[str, nn.Variable] = {} - states['lstm_hidden'] = self._h - states['lstm_cell'] = self._c + states["lstm_hidden"] = self._h + states["lstm_cell"] = self._c return states def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): @@ -217,8 +217,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): if self._c is not None: self._c.data.zero() else: - self._h = states['lstm_hidden'] - self._c = states['lstm_cell'] + self._h = states["lstm_hidden"] + self._c = states["lstm_cell"] def _create_internal_states(self, batch_size): self._h = nn.Variable((batch_size, self._lstm_state_size)) @@ -239,7 +239,7 @@ def test_algorithm_name(self): dummy_env = E.DummyContinuous() sac = A.ICML2018SAC(dummy_env) - assert sac.__name__ == 'ICML2018SAC' + assert sac.__name__ == "ICML2018SAC" def test_discrete_action_env_unsupported(self): """Check that error occurs when training on discrete action env.""" @@ -257,6 +257,7 @@ def test_run_online_training(self): def test_run_online_rnn_training(self): """Check that no error occurs when calling online training with RNN model.""" + class RNNPolicyBuilder(ModelBuilder[StochasticPolicy]): def build_model(self, scope_name: str, env_info, algorithm_config, **kwargs): return RNNPolicyFunction(scope_name, action_dim=env_info.action_dim) @@ -280,10 +281,13 @@ def build_model(self, scope_name: str, env_info, algorithm_config, **kwargs): config.num_steps = 2 config.start_timesteps = 7 config.batch_size = 2 - sac = A.ICML2018SAC(dummy_env, config=config, - policy_builder=RNNPolicyBuilder(), - q_function_builder=RNNQFunctionBuilder(), - v_function_builder=RNNVFunctionBuilder()) + sac = A.ICML2018SAC( + dummy_env, + config=config, + policy_builder=RNNPolicyBuilder(), + q_function_builder=RNNQFunctionBuilder(), + v_function_builder=RNNVFunctionBuilder(), + ) sac.train_online(dummy_env, total_iterations=10) @@ -314,8 +318,7 @@ def test_target_network_initialization(self): sac = A.ICML2018SAC(dummy_env) # Should be initialized to same parameters - assert self._has_same_parameters( - sac._v.get_parameters(), sac._target_v.get_parameters()) + assert self._has_same_parameters(sac._v.get_parameters(), sac._target_v.get_parameters()) def test_parameter_range(self): with pytest.raises(ValueError): @@ -348,23 +351,24 @@ def test_latest_iteration_state(self): dummy_env = E.DummyContinuous() sac = A.ICML2018SAC(dummy_env) - sac._q_function_trainer_state = {'q_loss': 0., 'td_errors': np.array([0., 1.])} - sac._policy_trainer_state = {'pi_loss': 2.} - sac._v_function_trainer_state = {'v_loss': 1.} + sac._q_function_trainer_state = {"q_loss": 0.0, "td_errors": np.array([0.0, 1.0])} + sac._policy_trainer_state = {"pi_loss": 2.0} + sac._v_function_trainer_state = {"v_loss": 1.0} latest_iteration_state = sac.latest_iteration_state - assert 'q_loss' in latest_iteration_state['scalar'] - assert 'pi_loss' in latest_iteration_state['scalar'] - assert 'v_loss' in latest_iteration_state['scalar'] - assert 'td_errors' in latest_iteration_state['histogram'] - assert latest_iteration_state['scalar']['q_loss'] == 0. - assert latest_iteration_state['scalar']['v_loss'] == 1. - assert latest_iteration_state['scalar']['pi_loss'] == 2. - assert np.allclose(latest_iteration_state['histogram']['td_errors'], np.array([0., 1.])) + assert "q_loss" in latest_iteration_state["scalar"] + assert "pi_loss" in latest_iteration_state["scalar"] + assert "v_loss" in latest_iteration_state["scalar"] + assert "td_errors" in latest_iteration_state["histogram"] + assert latest_iteration_state["scalar"]["q_loss"] == 0.0 + assert latest_iteration_state["scalar"]["v_loss"] == 1.0 + assert latest_iteration_state["scalar"]["pi_loss"] == 2.0 + assert np.allclose(latest_iteration_state["histogram"]["td_errors"], np.array([0.0, 1.0])) if __name__ == "__main__": from tests.testing_utils import generate_dummy_experiences + pytest.main() else: from ..testing_utils import generate_dummy_experiences diff --git a/tests/algorithms/test_icra2018_qtopt.py b/tests/algorithms/test_icra2018_qtopt.py index ec9fef7e..0721e433 100644 --- a/tests/algorithms/test_icra2018_qtopt.py +++ b/tests/algorithms/test_icra2018_qtopt.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -28,7 +28,7 @@ class DummyQFunction(ContinuousQFunction): def __init__(self, action_high: np.ndarray, action_low: np.ndarray): - super(DummyQFunction, self).__init__('dummy') + super(DummyQFunction, self).__init__("dummy") self._random_sample_size = 16 self._action_high = action_high self._action_low = action_low @@ -37,28 +37,28 @@ def q(self, s: nn.Variable, a: nn.Variable) -> nn.Variable: batch_size = s.shape[0] with nn.parameter_scope(self.scope_name): - with nn.parameter_scope('state_conv1'): + with nn.parameter_scope("state_conv1"): h = NF.relu(NPF.convolution(s, 32, (3, 3), stride=(2, 2))) - with nn.parameter_scope('state_conv2'): + with nn.parameter_scope("state_conv2"): h = NF.relu(NPF.convolution(h, 32, (3, 3), stride=(2, 2))) - with nn.parameter_scope('state_conv3'): + with nn.parameter_scope("state_conv3"): encoded_state = NF.relu(NPF.convolution(h, 32, (3, 3), stride=(2, 2))) - with nn.parameter_scope('action_affine1'): + with nn.parameter_scope("action_affine1"): encoded_action = NF.relu(NPF.affine(a, 32)) encoded_action = NF.reshape(encoded_action, (batch_size, 32, 1, 1)) h = encoded_state + encoded_action - with nn.parameter_scope('affine1'): + with nn.parameter_scope("affine1"): h = NF.relu(NPF.affine(h, 32)) - with nn.parameter_scope('affine2'): + with nn.parameter_scope("affine2"): h = NF.relu(NPF.affine(h, 32)) - with nn.parameter_scope('affine3'): + with nn.parameter_scope("affine3"): q_value = NPF.affine(h, 1) return q_value @@ -73,7 +73,7 @@ def argmax_q(self, s: nn.Variable) -> nn.Variable: def objective_function(a): batch_size, sample_size, action_dim = a.shape - a = a.reshape((batch_size*sample_size, action_dim)) + a = a.reshape((batch_size * sample_size, action_dim)) q_value = self.q(tiled_s, a) q_value = q_value.reshape((batch_size, sample_size, 1)) return q_value @@ -81,18 +81,24 @@ def objective_function(a): upper_bound = np.tile(self._action_high, (batch_size, 1)) lower_bound = np.tile(self._action_low, (batch_size, 1)) optimized_action = RF.random_shooting_method( - objective_function, - upper_bound=upper_bound, - lower_bound=lower_bound, - sample_size=self._random_sample_size + objective_function, upper_bound=upper_bound, lower_bound=lower_bound, sample_size=self._random_sample_size ) return optimized_action def _tile_state(self, s, tile_size): - tile_reps = [tile_size, ] + [1, ] * len(s.shape) + tile_reps = [ + tile_size, + ] + [ + 1, + ] * len(s.shape) s = NF.tile(s, tile_reps) - transpose_reps = [1, 0, ] + list(range(len(s.shape)))[2:] + transpose_reps = [ + 1, + 0, + ] + list( + range(len(s.shape)) + )[2:] s = NF.transpose(s, transpose_reps) s = NF.reshape(s, (-1, *s.shape[2:])) return s @@ -111,7 +117,7 @@ def test_algorithm_name(self): dummy_env = E.DummyContinuousImg(image_shape=(3, 64, 64)) qtopt = A.ICRA2018QtOpt(dummy_env, q_func_builder=DummyQFunctionBuilder()) - assert qtopt.__name__ == 'ICRA2018QtOpt' + assert qtopt.__name__ == "ICRA2018QtOpt" def test_discrete_action_env_unsupported(self): """Check that error occurs when training on discrete action env.""" @@ -214,17 +220,18 @@ def test_latest_iteration_state(self): dummy_env = E.DummyContinuousImg() dqn = A.ICRA2018QtOpt(dummy_env, q_func_builder=DummyQFunctionBuilder()) - dqn._q_function_trainer_state = {'q_loss': 0., 'td_errors': np.array([0., 1.])} + dqn._q_function_trainer_state = {"q_loss": 0.0, "td_errors": np.array([0.0, 1.0])} latest_iteration_state = dqn.latest_iteration_state - assert 'q_loss' in latest_iteration_state['scalar'] - assert 'td_errors' in latest_iteration_state['histogram'] - assert latest_iteration_state['scalar']['q_loss'] == 0. - assert np.allclose(latest_iteration_state['histogram']['td_errors'], np.array([0., 1.])) + assert "q_loss" in latest_iteration_state["scalar"] + assert "td_errors" in latest_iteration_state["histogram"] + assert latest_iteration_state["scalar"]["q_loss"] == 0.0 + assert np.allclose(latest_iteration_state["histogram"]["td_errors"], np.array([0.0, 1.0])) if __name__ == "__main__": from testing_utils import generate_dummy_experiences + pytest.main() else: from ..testing_utils import generate_dummy_experiences diff --git a/tests/algorithms/test_ilqr.py b/tests/algorithms/test_ilqr.py index 4e1120fb..edc962df 100644 --- a/tests/algorithms/test_ilqr.py +++ b/tests/algorithms/test_ilqr.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,16 +32,17 @@ def __init__(self, dt=0.01): # input changes the velocity self._B = np.array([[0, 0], [0, dt]]) - def next_state(self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False) \ - -> Tuple[np.ndarray, Dict[str, Any]]: + def next_state( + self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False + ) -> Tuple[np.ndarray, Dict[str, Any]]: return self._A.dot(x) + self._B.dot(u), {} - def gradient(self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False) \ - -> Tuple[np.ndarray, np.ndarray]: + def gradient(self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False) -> Tuple[np.ndarray, np.ndarray]: return self._A, self._B - def hessian(self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False) \ - -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + def hessian( + self, x: np.ndarray, u: np.ndarray, t: int, batched: bool = False + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: raise NotImplementedError def state_dim(self) -> int: @@ -65,7 +66,7 @@ def evaluate( return x.T.dot(self._Q).dot(x) else: # Assuming that target state is zero - return x.T.dot(self._Q).dot(x) + 2.0*x.T.dot(self._F).dot(u) + u.T.dot(self._R).dot(u) + return x.T.dot(self._Q).dot(x) + 2.0 * x.T.dot(self._F).dot(u) + u.T.dot(self._R).dot(u) def gradient( self, x: np.ndarray, u: Optional[np.ndarray], t: int, final_state: bool = False, batched: bool = False @@ -86,18 +87,18 @@ def hessian( class TestiLQR(object): def test_algorithm_name(self): - env = E.DummyContinuous(observation_shape=(2, ), action_shape=(2, )) + env = E.DummyContinuous(observation_shape=(2,), action_shape=(2,)) dynamics = LinearDynamics() cost_function = QuadraticCostFunction() ilqr = A.iLQR(env, dynamics=dynamics, cost_function=cost_function) - assert ilqr.__name__ == 'iLQR' + assert ilqr.__name__ == "iLQR" def test_continuous_action_env_supported(self): """Check that no error occurs when training on continuous action env.""" - env = E.DummyContinuous(observation_shape=(2, ), action_shape=(2, )) + env = E.DummyContinuous(observation_shape=(2,), action_shape=(2,)) dynamics = LinearDynamics() cost_function = QuadraticCostFunction() @@ -113,7 +114,7 @@ def test_discrete_action_env_not_supported(self): A.iLQR(env, dynamics=dynamics, cost_function=cost_function) def test_compute_eval_action(self): - env = E.DummyContinuous(observation_shape=(2, ), action_shape=(2, )) + env = E.DummyContinuous(observation_shape=(2,), action_shape=(2,)) dynamics = LinearDynamics() cost_function = QuadraticCostFunction() @@ -132,7 +133,7 @@ def test_compute_eval_action(self): np.testing.assert_almost_equal(lqr_action, ilqr_action, decimal=3) def test_compute_trajectory(self): - env = E.DummyContinuous(observation_shape=(2, ), action_shape=(2, )) + env = E.DummyContinuous(observation_shape=(2,), action_shape=(2,)) dynamics = LinearDynamics() cost_function = QuadraticCostFunction() @@ -149,13 +150,13 @@ def test_compute_trajectory(self): lqr_trajectory, _ = lqr.compute_trajectory(initial_trajectory) ilqr_trajectory, _ = ilqr.compute_trajectory(initial_trajectory) - for (lqr_state, ilqr_state) in zip(lqr_trajectory[:-1], ilqr_trajectory[:-1]): + for lqr_state, ilqr_state in zip(lqr_trajectory[:-1], ilqr_trajectory[:-1]): np.testing.assert_almost_equal(lqr_state[0], ilqr_state[0], decimal=3) np.testing.assert_almost_equal(lqr_state[1], ilqr_state[1], decimal=3) def test_run_online_training(self): """Check that error occurs when calling online training.""" - env = E.DummyContinuous(observation_shape=(2, ), action_shape=(2, )) + env = E.DummyContinuous(observation_shape=(2,), action_shape=(2,)) dynamics = LinearDynamics() cost_function = QuadraticCostFunction() @@ -165,7 +166,7 @@ def test_run_online_training(self): def test_run_offline_training(self): """Check that error occurs when calling offline training.""" - env = E.DummyContinuous(observation_shape=(2, ), action_shape=(2, )) + env = E.DummyContinuous(observation_shape=(2,), action_shape=(2,)) dynamics = LinearDynamics() cost_function = QuadraticCostFunction() diff --git a/tests/algorithms/test_iqn.py b/tests/algorithms/test_iqn.py index a5d09e6f..063fe1fa 100644 --- a/tests/algorithms/test_iqn.py +++ b/tests/algorithms/test_iqn.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,12 +31,14 @@ class RNNStateActionQuantileFunction(IQNQuantileFunction): - def __init__(self, - scope_name: str, - n_action: int, - embedding_dim: int, - K: int, - risk_measure_function: Callable[[nn.Variable], nn.Variable]): + def __init__( + self, + scope_name: str, + n_action: int, + embedding_dim: int, + K: int, + risk_measure_function: Callable[[nn.Variable], nn.Variable], + ): super().__init__(scope_name, n_action, embedding_dim, K, risk_measure_function) self._h = None self._c = None @@ -62,7 +64,8 @@ def all_quantile_values(self, s: nn.Variable, tau: nn.Variable) -> nn.Variable: cell_state = RF.expand_dims(self._c, axis=1) cell_state = NF.broadcast(cell_state, shape=(self._c.shape[0], tau.shape[-1], self._c.shape[-1])) hidden_state, cell_state = RPF.lstm_cell( - h, hidden_state, cell_state, self._lstm_state_size, base_axis=2) + h, hidden_state, cell_state, self._lstm_state_size, base_axis=2 + ) h = hidden_state # Save only the state of first sample for the next timestep self._h, *_ = NF.split(hidden_state, axis=1) @@ -78,14 +81,14 @@ def is_recurrent(self) -> bool: def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: shapes: Dict[str, nn.Variable] = {} - shapes['lstm_hidden'] = (self._lstm_state_size, ) - shapes['lstm_cell'] = (self._lstm_state_size, ) + shapes["lstm_hidden"] = (self._lstm_state_size,) + shapes["lstm_cell"] = (self._lstm_state_size,) return shapes def get_internal_states(self) -> Dict[str, nn.Variable]: states: Dict[str, nn.Variable] = {} - states['lstm_hidden'] = self._h - states['lstm_cell'] = self._c + states["lstm_hidden"] = self._h + states["lstm_cell"] = self._c return states def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): @@ -95,8 +98,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): if self._c is not None: self._c.data.zero() else: - self._h = states['lstm_hidden'] - self._c = states['lstm_cell'] + self._h = states["lstm_hidden"] + self._c = states["lstm_cell"] def _create_internal_states(self, batch_size): self._h = nn.Variable((batch_size, self._lstm_state_size)) @@ -117,7 +120,7 @@ def test_algorithm_name(self): dummy_env = E.DummyDiscreteImg() iqn = A.IQN(dummy_env) - assert iqn.__name__ == 'IQN' + assert iqn.__name__ == "IQN" def test_continuous_action_env_unsupported(self): """Check that error occurs when training on continuous action env.""" @@ -154,14 +157,18 @@ def test_run_online_training_multistep(self): def test_run_online_rnn_training(self): """Check that no error occurs when calling online training with RNN model.""" + class RNNModelBuilder(ModelBuilder[StateActionQuantileFunction]): def build_model(self, scope_name: str, env_info, algorithm_config, **kwargs): - risk_measure_function = kwargs['risk_measure_function'] - return RNNStateActionQuantileFunction(scope_name, - env_info.action_dim, - algorithm_config.embedding_dim, - K=algorithm_config.K, - risk_measure_function=risk_measure_function) + risk_measure_function = kwargs["risk_measure_function"] + return RNNStateActionQuantileFunction( + scope_name, + env_info.action_dim, + algorithm_config.embedding_dim, + K=algorithm_config.K, + risk_measure_function=risk_measure_function, + ) + dummy_env = E.DummyDiscreteImg() config = A.IQNConfig() config.num_steps = 2 @@ -198,7 +205,7 @@ def test_compute_eval_action(self): state = np.float32(state) action = iqn.compute_eval_action(state) - assert action.shape == (1, ) + assert action.shape == (1,) def test_parameter_range(self): with pytest.raises(ValueError): @@ -239,15 +246,16 @@ def test_latest_iteration_state(self): dummy_env = E.DummyDiscreteImg() iqn = A.IQN(dummy_env) - iqn._quantile_function_trainer_state = {'q_loss': 0.} + iqn._quantile_function_trainer_state = {"q_loss": 0.0} latest_iteration_state = iqn.latest_iteration_state - assert 'q_loss' in latest_iteration_state['scalar'] - assert latest_iteration_state['scalar']['q_loss'] == 0. + assert "q_loss" in latest_iteration_state["scalar"] + assert latest_iteration_state["scalar"]["q_loss"] == 0.0 if __name__ == "__main__": from testing_utils import generate_dummy_experiences + pytest.main() else: from ..testing_utils import generate_dummy_experiences diff --git a/tests/algorithms/test_lqr.py b/tests/algorithms/test_lqr.py index af631269..79ea245b 100644 --- a/tests/algorithms/test_lqr.py +++ b/tests/algorithms/test_lqr.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -74,18 +74,18 @@ def hessian( class TestLQR(object): def test_algorithm_name(self): - env = E.DummyContinuous(observation_shape=(2, ), action_shape=(2, )) + env = E.DummyContinuous(observation_shape=(2,), action_shape=(2,)) dynamics = LinearDynamics() cost_function = QuadraticCostFunction() lqr = A.LQR(env, dynamics=dynamics, cost_function=cost_function) - assert lqr.__name__ == 'LQR' + assert lqr.__name__ == "LQR" def test_continuous_action_env_supported(self): """Check that no error occurs when training on continuous action env.""" - env = E.DummyContinuous(observation_shape=(2, ), action_shape=(2, )) + env = E.DummyContinuous(observation_shape=(2,), action_shape=(2,)) dynamics = LinearDynamics() cost_function = QuadraticCostFunction() @@ -101,7 +101,7 @@ def test_discrete_action_env_not_supported(self): A.LQR(env, dynamics=dynamics, cost_function=cost_function) def test_compute_eval_action(self): - env = E.DummyContinuous(observation_shape=(2, ), action_shape=(2, )) + env = E.DummyContinuous(observation_shape=(2,), action_shape=(2,)) dynamics = LinearDynamics() cost_function = QuadraticCostFunction() @@ -116,7 +116,7 @@ def test_compute_eval_action(self): assert lqr_action.shape == (*env.action_space.shape, 1) def test_compute_trajectory(self): - env = E.DummyContinuous(observation_shape=(2, ), action_shape=(2, )) + env = E.DummyContinuous(observation_shape=(2,), action_shape=(2,)) dynamics = LinearDynamics() cost_function = QuadraticCostFunction() @@ -136,7 +136,7 @@ def test_compute_trajectory(self): def test_run_online_training(self): """Check that error occurs when calling online training.""" - env = E.DummyContinuous(observation_shape=(2, ), action_shape=(2, )) + env = E.DummyContinuous(observation_shape=(2,), action_shape=(2,)) dynamics = LinearDynamics() cost_function = QuadraticCostFunction() @@ -146,7 +146,7 @@ def test_run_online_training(self): def test_run_offline_training(self): """Check that error occurs when calling offline training.""" - env = E.DummyContinuous(observation_shape=(2, ), action_shape=(2, )) + env = E.DummyContinuous(observation_shape=(2,), action_shape=(2,)) dynamics = LinearDynamics() cost_function = QuadraticCostFunction() diff --git a/tests/algorithms/test_mme_sac.py b/tests/algorithms/test_mme_sac.py index 9dbffc33..24f3ffd3 100644 --- a/tests/algorithms/test_mme_sac.py +++ b/tests/algorithms/test_mme_sac.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ class RNNPolicyFunction(StochasticPolicy): def __init__(self, scope_name: str, action_dim: int): super(RNNPolicyFunction, self).__init__(scope_name) self._action_dim = action_dim - self._lstm_state_size = action_dim*2 + self._lstm_state_size = action_dim * 2 self._h = None self._c = None @@ -70,14 +70,14 @@ def is_recurrent(self) -> bool: def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: shapes: Dict[str, nn.Variable] = {} - shapes['lstm_hidden'] = (self._lstm_state_size, ) - shapes['lstm_cell'] = (self._lstm_state_size, ) + shapes["lstm_hidden"] = (self._lstm_state_size,) + shapes["lstm_cell"] = (self._lstm_state_size,) return shapes def get_internal_states(self) -> Dict[str, nn.Variable]: states: Dict[str, nn.Variable] = {} - states['lstm_hidden'] = self._h - states['lstm_cell'] = self._c + states["lstm_hidden"] = self._h + states["lstm_cell"] = self._c return states def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): @@ -87,8 +87,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): if self._c is not None: self._c.data.zero() else: - self._h = states['lstm_hidden'] - self._c = states['lstm_cell'] + self._h = states["lstm_hidden"] + self._c = states["lstm_cell"] def _create_internal_states(self, batch_size): self._h = nn.Variable((batch_size, self._lstm_state_size)) @@ -135,14 +135,14 @@ def is_recurrent(self) -> bool: def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: shapes: Dict[str, nn.Variable] = {} - shapes['lstm_hidden'] = (self._lstm_state_size, ) - shapes['lstm_cell'] = (self._lstm_state_size, ) + shapes["lstm_hidden"] = (self._lstm_state_size,) + shapes["lstm_cell"] = (self._lstm_state_size,) return shapes def get_internal_states(self) -> Dict[str, nn.Variable]: states: Dict[str, nn.Variable] = {} - states['lstm_hidden'] = self._h - states['lstm_cell'] = self._c + states["lstm_hidden"] = self._h + states["lstm_cell"] = self._c return states def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): @@ -152,8 +152,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): if self._c is not None: self._c.data.zero() else: - self._h = states['lstm_hidden'] - self._c = states['lstm_cell'] + self._h = states["lstm_hidden"] + self._c = states["lstm_cell"] def _create_internal_states(self, batch_size): self._h = nn.Variable((batch_size, self._lstm_state_size)) @@ -199,14 +199,14 @@ def is_recurrent(self) -> bool: def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: shapes: Dict[str, nn.Variable] = {} - shapes['lstm_hidden'] = (self._lstm_state_size, ) - shapes['lstm_cell'] = (self._lstm_state_size, ) + shapes["lstm_hidden"] = (self._lstm_state_size,) + shapes["lstm_cell"] = (self._lstm_state_size,) return shapes def get_internal_states(self) -> Dict[str, nn.Variable]: states: Dict[str, nn.Variable] = {} - states['lstm_hidden'] = self._h - states['lstm_cell'] = self._c + states["lstm_hidden"] = self._h + states["lstm_cell"] = self._c return states def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): @@ -216,8 +216,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): if self._c is not None: self._c.data.zero() else: - self._h = states['lstm_hidden'] - self._c = states['lstm_cell'] + self._h = states["lstm_hidden"] + self._c = states["lstm_cell"] def _create_internal_states(self, batch_size): self._h = nn.Variable((batch_size, self._lstm_state_size)) @@ -238,7 +238,7 @@ def test_algorithm_name(self): dummy_env = E.DummyContinuous() sac = A.MMESAC(dummy_env) - assert sac.__name__ == 'MMESAC' + assert sac.__name__ == "MMESAC" def test_discrete_action_env_unsupported(self): """Check that error occurs when training on discrete action env.""" @@ -256,6 +256,7 @@ def test_run_online_training(self): def test_run_online_rnn_training(self): """Check that no error occurs when calling online training with RNN model.""" + class RNNPolicyBuilder(ModelBuilder[StochasticPolicy]): def build_model(self, scope_name: str, env_info, algorithm_config, **kwargs): return RNNPolicyFunction(scope_name, action_dim=env_info.action_dim) @@ -277,10 +278,13 @@ def build_model(self, scope_name: str, env_info, algorithm_config, **kwargs): config.critic_burn_in_steps = 2 config.start_timesteps = 7 config.batch_size = 2 - sac = A.MMESAC(dummy_env, config=config, - policy_builder=RNNPolicyBuilder(), - q_function_builder=RNNQFunctionBuilder(), - v_function_builder=RNNVFunctionBuilder()) + sac = A.MMESAC( + dummy_env, + config=config, + policy_builder=RNNPolicyBuilder(), + q_function_builder=RNNQFunctionBuilder(), + v_function_builder=RNNVFunctionBuilder(), + ) sac.train_online(dummy_env, total_iterations=10) @@ -311,8 +315,7 @@ def test_target_network_initialization(self): sac = A.MMESAC(dummy_env) # Should be initialized to same parameters - assert self._has_same_parameters( - sac._v.get_parameters(), sac._target_v.get_parameters()) + assert self._has_same_parameters(sac._v.get_parameters(), sac._target_v.get_parameters()) def test_parameter_range(self): with pytest.raises(ValueError): @@ -345,23 +348,24 @@ def test_latest_iteration_state(self): dummy_env = E.DummyContinuous() sac = A.MMESAC(dummy_env) - sac._q_function_trainer_state = {'q_loss': 0., 'td_errors': np.array([0., 1.])} - sac._policy_trainer_state = {'pi_loss': 2.} - sac._v_function_trainer_state = {'v_loss': 1.} + sac._q_function_trainer_state = {"q_loss": 0.0, "td_errors": np.array([0.0, 1.0])} + sac._policy_trainer_state = {"pi_loss": 2.0} + sac._v_function_trainer_state = {"v_loss": 1.0} latest_iteration_state = sac.latest_iteration_state - assert 'q_loss' in latest_iteration_state['scalar'] - assert 'pi_loss' in latest_iteration_state['scalar'] - assert 'v_loss' in latest_iteration_state['scalar'] - assert 'td_errors' in latest_iteration_state['histogram'] - assert latest_iteration_state['scalar']['q_loss'] == 0. - assert latest_iteration_state['scalar']['v_loss'] == 1. - assert latest_iteration_state['scalar']['pi_loss'] == 2. - assert np.allclose(latest_iteration_state['histogram']['td_errors'], np.array([0., 1.])) + assert "q_loss" in latest_iteration_state["scalar"] + assert "pi_loss" in latest_iteration_state["scalar"] + assert "v_loss" in latest_iteration_state["scalar"] + assert "td_errors" in latest_iteration_state["histogram"] + assert latest_iteration_state["scalar"]["q_loss"] == 0.0 + assert latest_iteration_state["scalar"]["v_loss"] == 1.0 + assert latest_iteration_state["scalar"]["pi_loss"] == 2.0 + assert np.allclose(latest_iteration_state["histogram"]["td_errors"], np.array([0.0, 1.0])) if __name__ == "__main__": from tests.testing_utils import generate_dummy_experiences + pytest.main() else: from ..testing_utils import generate_dummy_experiences diff --git a/tests/algorithms/test_mppi.py b/tests/algorithms/test_mppi.py index 171775b4..bcfe27f5 100644 --- a/tests/algorithms/test_mppi.py +++ b/tests/algorithms/test_mppi.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -73,17 +73,17 @@ def hessian( class TestMPPI(object): def test_algorithm_name(self): - env = E.DummyContinuous(observation_shape=(2, ), action_shape=(2, )) + env = E.DummyContinuous(observation_shape=(2,), action_shape=(2,)) cost_function = QuadraticCostFunction() mppi = A.MPPI(env, cost_function=cost_function) - assert mppi.__name__ == 'MPPI' + assert mppi.__name__ == "MPPI" def test_continuous_action_env_supported(self): """Check that no error occurs when training on continuous action env.""" - env = E.DummyContinuous(observation_shape=(2, ), action_shape=(2, )) + env = E.DummyContinuous(observation_shape=(2,), action_shape=(2,)) cost_function = QuadraticCostFunction() A.MPPI(env, cost_function=cost_function) @@ -97,7 +97,7 @@ def test_discrete_action_env_not_supported(self): A.MPPI(env, cost_function=cost_function) def test_compute_eval_action(self): - env = E.DummyContinuous(observation_shape=(2, ), action_shape=(2, )) + env = E.DummyContinuous(observation_shape=(2,), action_shape=(2,)) cost_function = QuadraticCostFunction() T = 20 @@ -108,20 +108,22 @@ def test_compute_eval_action(self): x0 = np.array([[2.0], [0.0]]) mppi_action = mppi.compute_eval_action(x0) - assert mppi_action.shape == (*env.action_space.shape, ) + assert mppi_action.shape == (*env.action_space.shape,) def test_compute_trajectory(self): - env = E.DummyContinuous(observation_shape=(2, ), action_shape=(2, )) + env = E.DummyContinuous(observation_shape=(2,), action_shape=(2,)) dynamics = LinearDynamics() cost_function = QuadraticCostFunction() T = 100 covariance = np.eye(N=2) * 0.3 - config = A.MPPIConfig(T=T, use_known_dynamics=True, dt=0.2, covariance=covariance,) - mppi = A.MPPI(env, - known_dynamics=dynamics, - cost_function=cost_function, - config=config) + config = A.MPPIConfig( + T=T, + use_known_dynamics=True, + dt=0.2, + covariance=covariance, + ) + mppi = A.MPPI(env, known_dynamics=dynamics, cost_function=cost_function, config=config) # initial pose x0 = np.array([[2.5], [0.0]]) @@ -140,7 +142,7 @@ def test_compute_trajectory(self): assert np.abs(pos) < 0.5 def test_run_online_training(self): - env = E.DummyContinuous(observation_shape=(2, ), action_shape=(2, ), max_episode_steps=5) + env = E.DummyContinuous(observation_shape=(2,), action_shape=(2,), max_episode_steps=5) cost_function = QuadraticCostFunction() config = A.MPPIConfig(batch_size=2, training_iterations=2) @@ -149,7 +151,7 @@ def test_run_online_training(self): def test_run_offline_training(self): """Check that error occurs when calling offline training.""" - env = E.DummyContinuous(observation_shape=(2, ), action_shape=(2, ), max_episode_steps=5) + env = E.DummyContinuous(observation_shape=(2,), action_shape=(2,), max_episode_steps=5) cost_function = QuadraticCostFunction() with pytest.raises(Exception): diff --git a/tests/algorithms/test_munchausen_dqn.py b/tests/algorithms/test_munchausen_dqn.py index 611a6c26..508e370d 100644 --- a/tests/algorithms/test_munchausen_dqn.py +++ b/tests/algorithms/test_munchausen_dqn.py @@ -1,5 +1,5 @@ # Copyright 2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ def test_algorithm_name(self): dummy_env = E.DummyDiscreteImg() dqn = A.MunchausenDQN(dummy_env) - assert dqn.__name__ == 'MunchausenDQN' + assert dqn.__name__ == "MunchausenDQN" def test_continuous_action_env_unsupported(self): """Check that error occurs when training on continuous action env.""" @@ -121,17 +121,18 @@ def test_latest_iteration_state(self): dummy_env = E.DummyDiscreteImg() m_dqn = A.MunchausenDQN(dummy_env) - m_dqn._q_function_trainer_state = {'q_loss': 0., 'td_errors': np.array([0., 1.])} + m_dqn._q_function_trainer_state = {"q_loss": 0.0, "td_errors": np.array([0.0, 1.0])} latest_iteration_state = m_dqn.latest_iteration_state - assert 'q_loss' in latest_iteration_state['scalar'] - assert 'td_errors' in latest_iteration_state['histogram'] - assert latest_iteration_state['scalar']['q_loss'] == 0. - assert np.allclose(latest_iteration_state['histogram']['td_errors'], np.array([0., 1.])) + assert "q_loss" in latest_iteration_state["scalar"] + assert "td_errors" in latest_iteration_state["histogram"] + assert latest_iteration_state["scalar"]["q_loss"] == 0.0 + assert np.allclose(latest_iteration_state["histogram"]["td_errors"], np.array([0.0, 1.0])) if __name__ == "__main__": from testing_utils import generate_dummy_experiences + pytest.main() else: from ..testing_utils import generate_dummy_experiences diff --git a/tests/algorithms/test_munchausen_iqn.py b/tests/algorithms/test_munchausen_iqn.py index f2fb63ef..e7768796 100644 --- a/tests/algorithms/test_munchausen_iqn.py +++ b/tests/algorithms/test_munchausen_iqn.py @@ -1,5 +1,5 @@ # Copyright 2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ def test_algorithm_name(self): dummy_env = E.DummyDiscreteImg() m_iqn = A.MunchausenIQN(dummy_env) - assert m_iqn.__name__ == 'MunchausenIQN' + assert m_iqn.__name__ == "MunchausenIQN" def test_continuous_action_env_unsupported(self): """Check that error occurs when training on continuous action env.""" @@ -87,7 +87,7 @@ def test_compute_eval_action(self): state = np.float32(state) action = m_iqn.compute_eval_action(state) - assert action.shape == (1, ) + assert action.shape == (1,) def test_parameter_range(self): with pytest.raises(ValueError): @@ -130,15 +130,16 @@ def test_latest_iteration_state(self): dummy_env = E.DummyDiscreteImg() m_iqn = A.MunchausenIQN(dummy_env) - m_iqn._quantile_function_trainer_state = {'q_loss': 0.} + m_iqn._quantile_function_trainer_state = {"q_loss": 0.0} latest_iteration_state = m_iqn.latest_iteration_state - assert 'q_loss' in latest_iteration_state['scalar'] - assert latest_iteration_state['scalar']['q_loss'] == 0. + assert "q_loss" in latest_iteration_state["scalar"] + assert latest_iteration_state["scalar"]["q_loss"] == 0.0 if __name__ == "__main__": from testing_utils import generate_dummy_experiences + pytest.main() else: from ..testing_utils import generate_dummy_experiences diff --git a/tests/algorithms/test_ppo.py b/tests/algorithms/test_ppo.py index 1ca82548..e88b9f27 100644 --- a/tests/algorithms/test_ppo.py +++ b/tests/algorithms/test_ppo.py @@ -36,8 +36,10 @@ def __init__(self, scope_name: str, action_dim: int): def pi(self, s: nn.Variable): s, *_ = s - return Gaussian(mean=nn.Variable.from_numpy_array(np.zeros(shape=(s.shape[0], self._action_dim))), - ln_var=nn.Variable.from_numpy_array(np.zeros(shape=(s.shape[0], self._action_dim)))) + return Gaussian( + mean=nn.Variable.from_numpy_array(np.zeros(shape=(s.shape[0], self._action_dim))), + ln_var=nn.Variable.from_numpy_array(np.zeros(shape=(s.shape[0], self._action_dim))), + ) class TupleStateActorBuilder(ModelBuilder[VFunction]): @@ -68,7 +70,7 @@ def test_algorithm_name(self): dummy_env = E.DummyDiscreteImg() ppo = A.PPO(dummy_env) - assert ppo.__name__ == 'PPO' + assert ppo.__name__ == "PPO" def test_run_online_discrete_env_training(self): """Check that no error occurs when calling online training (discrete @@ -80,7 +82,7 @@ def test_run_online_discrete_env_training(self): config = A.PPOConfig(batch_size=5, actor_timesteps=actor_timesteps, actor_num=actor_num) ppo = A.PPO(dummy_env, config=config) - ppo.train_online(dummy_env, total_iterations=actor_timesteps*actor_num) + ppo.train_online(dummy_env, total_iterations=actor_timesteps * actor_num) def test_run_online_continuous_env_training(self): """Check that no error occurs when calling online training (continuous @@ -102,11 +104,14 @@ def test_run_online_tuple_state_env_training(self): actor_timesteps = 10 actor_num = 2 config = A.PPOConfig(batch_size=5, actor_timesteps=actor_timesteps, actor_num=actor_num, preprocess_state=False) - ppo = A.PPO(dummy_env, config=config, - v_function_builder=TupleStateVFunctionBuilder(), - policy_builder=TupleStateActorBuilder()) + ppo = A.PPO( + dummy_env, + config=config, + v_function_builder=TupleStateVFunctionBuilder(), + policy_builder=TupleStateActorBuilder(), + ) - ppo.train_online(dummy_env, total_iterations=actor_timesteps*actor_num) + ppo.train_online(dummy_env, total_iterations=actor_timesteps * actor_num) def test_run_online_discrete_single_actor(self): """Check that no error occurs when calling online training (discrete @@ -118,7 +123,7 @@ def test_run_online_discrete_single_actor(self): config = A.PPOConfig(batch_size=5, actor_timesteps=actor_timesteps, actor_num=actor_num) ppo = A.PPO(dummy_env, config=config) - ppo.train_online(dummy_env, total_iterations=actor_timesteps*actor_num) + ppo.train_online(dummy_env, total_iterations=actor_timesteps * actor_num) def test_run_online_continuous_single_actor(self): """Check that no error occurs when calling online training (continuous @@ -161,14 +166,14 @@ def test_latest_iteration_state(self): dummy_env = E.DummyContinuous() ppo = A.PPO(dummy_env) - ppo._policy_trainer_state = {'pi_loss': 0.} - ppo._v_function_trainer_state = {'v_loss': 1.} + ppo._policy_trainer_state = {"pi_loss": 0.0} + ppo._v_function_trainer_state = {"v_loss": 1.0} latest_iteration_state = ppo.latest_iteration_state - assert 'pi_loss' in latest_iteration_state['scalar'] - assert 'v_loss' in latest_iteration_state['scalar'] - assert latest_iteration_state['scalar']['pi_loss'] == 0. - assert latest_iteration_state['scalar']['v_loss'] == 1. + assert "pi_loss" in latest_iteration_state["scalar"] + assert "v_loss" in latest_iteration_state["scalar"] + assert latest_iteration_state["scalar"]["pi_loss"] == 0.0 + assert latest_iteration_state["scalar"]["v_loss"] == 1.0 def test_copy_np_array_to_mp_array(self): shape = (10, 9, 8, 7) diff --git a/tests/algorithms/test_qrdqn.py b/tests/algorithms/test_qrdqn.py index dfecbf1b..03060c69 100644 --- a/tests/algorithms/test_qrdqn.py +++ b/tests/algorithms/test_qrdqn.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -59,8 +59,7 @@ def all_quantiles(self, s: nn.Variable) -> nn.Variable: h = NF.relu(x=h) with nn.parameter_scope("affine2"): h = NPF.affine(h, n_outmaps=self._n_action * self._n_quantile) - quantiles = NF.reshape( - h, (-1, self._n_action, self._n_quantile)) + quantiles = NF.reshape(h, (-1, self._n_action, self._n_quantile)) assert quantiles.shape == (batch_size, self._n_action, self._n_quantile) return quantiles @@ -69,14 +68,14 @@ def is_recurrent(self) -> bool: def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: shapes: Dict[str, nn.Variable] = {} - shapes['lstm_hidden'] = (self._lstm_state_size, ) - shapes['lstm_cell'] = (self._lstm_state_size, ) + shapes["lstm_hidden"] = (self._lstm_state_size,) + shapes["lstm_cell"] = (self._lstm_state_size,) return shapes def get_internal_states(self) -> Dict[str, nn.Variable]: states: Dict[str, nn.Variable] = {} - states['lstm_hidden'] = self._h - states['lstm_cell'] = self._c + states["lstm_hidden"] = self._h + states["lstm_cell"] = self._c return states def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): @@ -86,8 +85,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): if self._c is not None: self._c.data.zero() else: - self._h = states['lstm_hidden'] - self._c = states['lstm_cell'] + self._h = states["lstm_hidden"] + self._c = states["lstm_cell"] def _create_internal_states(self, batch_size): self._h = nn.Variable((batch_size, self._lstm_state_size)) @@ -108,7 +107,7 @@ def test_algorithm_name(self): dummy_env = E.DummyDiscreteImg() qrdqn = A.QRDQN(dummy_env) - assert qrdqn.__name__ == 'QRDQN' + assert qrdqn.__name__ == "QRDQN" def test_continuous_action_env_unsupported(self): """Check that error occurs when training on continuous action env.""" @@ -145,11 +144,13 @@ def test_run_online_training_multistep(self): def test_run_online_rnn_training(self): """Check that no error occurs when calling online training with RNN model.""" + class RNNModelBuilder(ModelBuilder[QuantileDistributionFunction]): def build_model(self, scope_name: str, env_info, algorithm_config, **kwargs): n_action = env_info.action_dim n_quantile = algorithm_config.num_quantiles return RNNQuantileDistributionFunction(scope_name, n_action, n_quantile) + dummy_env = E.DummyDiscreteImg() config = A.QRDQNConfig() config.num_steps = 2 @@ -187,7 +188,7 @@ def test_compute_eval_action(self): state = np.float32(state) action = qrdqn.compute_eval_action(state) - assert action.shape == (1, ) + assert action.shape == (1,) def test_parameter_range(self): with pytest.raises(ValueError): @@ -222,15 +223,16 @@ def test_latest_iteration_state(self): dummy_env = E.DummyDiscreteImg() qrdqn = A.QRDQN(dummy_env) - qrdqn._quantile_dist_trainer_state = {'q_loss': 0.} + qrdqn._quantile_dist_trainer_state = {"q_loss": 0.0} latest_iteration_state = qrdqn.latest_iteration_state - assert 'q_loss' in latest_iteration_state['scalar'] - assert latest_iteration_state['scalar']['q_loss'] == 0. + assert "q_loss" in latest_iteration_state["scalar"] + assert latest_iteration_state["scalar"]["q_loss"] == 0.0 if __name__ == "__main__": from testing_utils import generate_dummy_experiences + pytest.main() else: from ..testing_utils import generate_dummy_experiences diff --git a/tests/algorithms/test_qrsac.py b/tests/algorithms/test_qrsac.py index ea41a980..b0123527 100644 --- a/tests/algorithms/test_qrsac.py +++ b/tests/algorithms/test_qrsac.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ class RNNActorFunction(StochasticPolicy): def __init__(self, scope_name: str, action_dim: int): super(RNNActorFunction, self).__init__(scope_name) self._action_dim = action_dim - self._lstm_state_size = action_dim*2 + self._lstm_state_size = action_dim * 2 self._h = None self._c = None @@ -70,14 +70,14 @@ def is_recurrent(self) -> bool: def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: shapes: Dict[str, nn.Variable] = {} - shapes['lstm_hidden'] = (self._lstm_state_size, ) - shapes['lstm_cell'] = (self._lstm_state_size, ) + shapes["lstm_hidden"] = (self._lstm_state_size,) + shapes["lstm_cell"] = (self._lstm_state_size,) return shapes def get_internal_states(self) -> Dict[str, nn.Variable]: states: Dict[str, nn.Variable] = {} - states['lstm_hidden'] = self._h - states['lstm_cell'] = self._c + states["lstm_hidden"] = self._h + states["lstm_cell"] = self._c return states def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): @@ -87,8 +87,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): if self._c is not None: self._c.data.zero() else: - self._h = states['lstm_hidden'] - self._c = states['lstm_cell'] + self._h = states["lstm_hidden"] + self._c = states["lstm_cell"] def _create_internal_states(self, batch_size): self._h = nn.Variable((batch_size, self._lstm_state_size)) @@ -135,14 +135,14 @@ def is_recurrent(self) -> bool: def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: shapes: Dict[str, nn.Variable] = {} - shapes['lstm_hidden'] = (self._lstm_state_size, ) - shapes['lstm_cell'] = (self._lstm_state_size, ) + shapes["lstm_hidden"] = (self._lstm_state_size,) + shapes["lstm_cell"] = (self._lstm_state_size,) return shapes def get_internal_states(self) -> Dict[str, nn.Variable]: states: Dict[str, nn.Variable] = {} - states['lstm_hidden'] = self._h - states['lstm_cell'] = self._c + states["lstm_hidden"] = self._h + states["lstm_cell"] = self._c return states def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): @@ -152,8 +152,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): if self._c is not None: self._c.data.zero() else: - self._h = states['lstm_hidden'] - self._c = states['lstm_cell'] + self._h = states["lstm_hidden"] + self._c = states["lstm_cell"] def _create_internal_states(self, batch_size): self._h = nn.Variable((batch_size, self._lstm_state_size)) @@ -174,7 +174,7 @@ def test_algorithm_name(self): dummy_env = E.DummyContinuous() qrsac = A.QRSAC(dummy_env) - assert qrsac.__name__ == 'QRSAC' + assert qrsac.__name__ == "QRSAC" def test_discrete_action_env_unsupported(self): """Check that error occurs when training on discrete action env.""" @@ -204,6 +204,7 @@ def test_run_offline_training(self): def test_run_online_rnn_training(self): """Check that no error occurs when calling online training with RNN model.""" + class RNNActorBuilder(ModelBuilder[StochasticPolicy]): def build_model(self, scope_name: str, env_info, algorithm_config, **kwargs): return RNNActorFunction(scope_name, action_dim=env_info.action_dim) @@ -220,8 +221,9 @@ def build_model(self, scope_name: str, env_info, algorithm_config, **kwargs): config.critic_burn_in_steps = 2 config.start_timesteps = 7 config.batch_size = 2 - qrsac = A.QRSAC(dummy_env, config=config, quantile_function_builder=RNNCriticBuilder(), - policy_builder=RNNActorBuilder()) + qrsac = A.QRSAC( + dummy_env, config=config, quantile_function_builder=RNNCriticBuilder(), policy_builder=RNNActorBuilder() + ) qrsac.train_online(dummy_env, total_iterations=10) @@ -264,18 +266,19 @@ def test_latest_iteration_state(self): dummy_env = E.DummyContinuous() qrsac = A.QRSAC(dummy_env) - qrsac._quantile_function_trainer_state = {'q_loss': 0.} - qrsac._policy_trainer_state = {'pi_loss': 1.} + qrsac._quantile_function_trainer_state = {"q_loss": 0.0} + qrsac._policy_trainer_state = {"pi_loss": 1.0} latest_iteration_state = qrsac.latest_iteration_state - assert 'q_loss' in latest_iteration_state['scalar'] - assert 'pi_loss' in latest_iteration_state['scalar'] - assert latest_iteration_state['scalar']['q_loss'] == 0. - assert latest_iteration_state['scalar']['pi_loss'] == 1. + assert "q_loss" in latest_iteration_state["scalar"] + assert "pi_loss" in latest_iteration_state["scalar"] + assert latest_iteration_state["scalar"]["q_loss"] == 0.0 + assert latest_iteration_state["scalar"]["pi_loss"] == 1.0 if __name__ == "__main__": from testing_utils import generate_dummy_experiences + pytest.main() else: from ..testing_utils import generate_dummy_experiences diff --git a/tests/algorithms/test_rainbow.py b/tests/algorithms/test_rainbow.py index 68b0ecfa..70e3587a 100644 --- a/tests/algorithms/test_rainbow.py +++ b/tests/algorithms/test_rainbow.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,7 +31,7 @@ def test_algorithm_name(self): config.replay_buffer_size = 10 rainbow = A.Rainbow(dummy_env, config=config) - assert rainbow.__name__ == 'Rainbow' + assert rainbow.__name__ == "Rainbow" def test_continuous_action_env_unsupported(self): """Check that error occurs when training on continuous action env.""" @@ -79,7 +79,7 @@ def test_compute_eval_action(self): state = np.float32(state) action = rainbow.compute_eval_action(state) - assert action.shape == (1, ) + assert action.shape == (1,) def test_latest_iteration_state(self): """Check that latest iteration state has the keys and values we @@ -90,17 +90,18 @@ def test_latest_iteration_state(self): config.replay_buffer_size = 3 rainbow = A.Rainbow(dummy_env, config=config) - rainbow._model_trainer_state = {'cross_entropy_loss': 0., 'td_errors': np.array([0., 1.])} + rainbow._model_trainer_state = {"cross_entropy_loss": 0.0, "td_errors": np.array([0.0, 1.0])} latest_iteration_state = rainbow.latest_iteration_state - assert 'cross_entropy_loss' in latest_iteration_state['scalar'] - assert 'td_errors' in latest_iteration_state['histogram'] - assert latest_iteration_state['scalar']['cross_entropy_loss'] == 0. - assert np.allclose(latest_iteration_state['histogram']['td_errors'], np.array([0., 1.])) + assert "cross_entropy_loss" in latest_iteration_state["scalar"] + assert "td_errors" in latest_iteration_state["histogram"] + assert latest_iteration_state["scalar"]["cross_entropy_loss"] == 0.0 + assert np.allclose(latest_iteration_state["histogram"]["td_errors"], np.array([0.0, 1.0])) if __name__ == "__main__": from testing_utils import generate_dummy_experiences + pytest.main() else: from ..testing_utils import generate_dummy_experiences diff --git a/tests/algorithms/test_redq.py b/tests/algorithms/test_redq.py index 736900c4..c7c31dee 100644 --- a/tests/algorithms/test_redq.py +++ b/tests/algorithms/test_redq.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ class RNNActorFunction(StochasticPolicy): def __init__(self, scope_name: str, action_dim: int): super(RNNActorFunction, self).__init__(scope_name) self._action_dim = action_dim - self._lstm_state_size = action_dim*2 + self._lstm_state_size = action_dim * 2 self._h = None self._c = None @@ -70,14 +70,14 @@ def is_recurrent(self) -> bool: def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: shapes: Dict[str, nn.Variable] = {} - shapes['lstm_hidden'] = (self._lstm_state_size, ) - shapes['lstm_cell'] = (self._lstm_state_size, ) + shapes["lstm_hidden"] = (self._lstm_state_size,) + shapes["lstm_cell"] = (self._lstm_state_size,) return shapes def get_internal_states(self) -> Dict[str, nn.Variable]: states: Dict[str, nn.Variable] = {} - states['lstm_hidden'] = self._h - states['lstm_cell'] = self._c + states["lstm_hidden"] = self._h + states["lstm_cell"] = self._c return states def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): @@ -87,8 +87,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): if self._c is not None: self._c.data.zero() else: - self._h = states['lstm_hidden'] - self._c = states['lstm_cell'] + self._h = states["lstm_hidden"] + self._c = states["lstm_cell"] def _create_internal_states(self, batch_size): self._h = nn.Variable((batch_size, self._lstm_state_size)) @@ -135,14 +135,14 @@ def is_recurrent(self) -> bool: def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: shapes: Dict[str, nn.Variable] = {} - shapes['lstm_hidden'] = (self._lstm_state_size, ) - shapes['lstm_cell'] = (self._lstm_state_size, ) + shapes["lstm_hidden"] = (self._lstm_state_size,) + shapes["lstm_cell"] = (self._lstm_state_size,) return shapes def get_internal_states(self) -> Dict[str, nn.Variable]: states: Dict[str, nn.Variable] = {} - states['lstm_hidden'] = self._h - states['lstm_cell'] = self._c + states["lstm_hidden"] = self._h + states["lstm_cell"] = self._c return states def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): @@ -152,8 +152,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): if self._c is not None: self._c.data.zero() else: - self._h = states['lstm_hidden'] - self._c = states['lstm_cell'] + self._h = states["lstm_hidden"] + self._c = states["lstm_cell"] def _create_internal_states(self, batch_size): self._h = nn.Variable((batch_size, self._lstm_state_size)) @@ -174,7 +174,7 @@ def test_algorithm_name(self): dummy_env = E.DummyContinuous() redq = A.REDQ(dummy_env) - assert redq.__name__ == 'REDQ' + assert redq.__name__ == "REDQ" def test_discrete_action_env_unsupported(self): """Check that error occurs when training on discrete action env.""" @@ -205,6 +205,7 @@ def test_run_offline_training(self): def test_run_online_rnn_training(self): """Check that no error occurs when calling online training with RNN model.""" + class RNNActorBuilder(ModelBuilder[StochasticPolicy]): def build_model(self, scope_name: str, env_info, algorithm_config, **kwargs): return RNNActorFunction(scope_name, action_dim=env_info.action_dim) @@ -266,20 +267,21 @@ def test_latest_iteration_state(self): dummy_env = E.DummyContinuous() redq = A.REDQ(dummy_env) - redq._q_function_trainer_state = {'q_loss': 0., 'td_errors': np.array([0., 1.])} - redq._policy_trainer_state = {'pi_loss': 1.} + redq._q_function_trainer_state = {"q_loss": 0.0, "td_errors": np.array([0.0, 1.0])} + redq._policy_trainer_state = {"pi_loss": 1.0} latest_iteration_state = redq.latest_iteration_state - assert 'q_loss' in latest_iteration_state['scalar'] - assert 'pi_loss' in latest_iteration_state['scalar'] - assert 'td_errors' in latest_iteration_state['histogram'] - assert latest_iteration_state['scalar']['q_loss'] == 0. - assert latest_iteration_state['scalar']['pi_loss'] == 1. - assert np.allclose(latest_iteration_state['histogram']['td_errors'], np.array([0., 1.])) + assert "q_loss" in latest_iteration_state["scalar"] + assert "pi_loss" in latest_iteration_state["scalar"] + assert "td_errors" in latest_iteration_state["histogram"] + assert latest_iteration_state["scalar"]["q_loss"] == 0.0 + assert latest_iteration_state["scalar"]["pi_loss"] == 1.0 + assert np.allclose(latest_iteration_state["histogram"]["td_errors"], np.array([0.0, 1.0])) if __name__ == "__main__": from testing_utils import generate_dummy_experiences + pytest.main() else: from ..testing_utils import generate_dummy_experiences diff --git a/tests/algorithms/test_reinforce.py b/tests/algorithms/test_reinforce.py index b9589c46..6f516624 100644 --- a/tests/algorithms/test_reinforce.py +++ b/tests/algorithms/test_reinforce.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -28,7 +28,7 @@ def test_algorithm_name(self): dummy_env = E.DummyDiscrete() reinforce = A.REINFORCE(dummy_env) - assert reinforce.__name__ == 'REINFORCE' + assert reinforce.__name__ == "REINFORCE" def test_run_online_training(self): """Check that no error occurs when calling online training.""" @@ -62,15 +62,16 @@ def test_latest_iteration_state(self): dummy_env = E.DummyContinuous() reinforce = A.REINFORCE(dummy_env) - reinforce._policy_trainer_state = {'pi_loss': 0.} + reinforce._policy_trainer_state = {"pi_loss": 0.0} latest_iteration_state = reinforce.latest_iteration_state - assert 'pi_loss' in latest_iteration_state['scalar'] - assert latest_iteration_state['scalar']['pi_loss'] == 0. + assert "pi_loss" in latest_iteration_state["scalar"] + assert latest_iteration_state["scalar"]["pi_loss"] == 0.0 if __name__ == "__main__": from testing_utils import EpisodicEnv + pytest.main() else: from ..testing_utils import EpisodicEnv diff --git a/tests/algorithms/test_sac.py b/tests/algorithms/test_sac.py index 61c1bc2f..d440974e 100644 --- a/tests/algorithms/test_sac.py +++ b/tests/algorithms/test_sac.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -39,7 +39,7 @@ class RNNActorFunction(StochasticPolicy): def __init__(self, scope_name: str, action_dim: int): super(RNNActorFunction, self).__init__(scope_name) self._action_dim = action_dim - self._lstm_state_size = action_dim*2 + self._lstm_state_size = action_dim * 2 self._h = None self._c = None @@ -71,14 +71,14 @@ def is_recurrent(self) -> bool: def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: shapes: Dict[str, nn.Variable] = {} - shapes['lstm_hidden'] = (self._lstm_state_size, ) - shapes['lstm_cell'] = (self._lstm_state_size, ) + shapes["lstm_hidden"] = (self._lstm_state_size,) + shapes["lstm_cell"] = (self._lstm_state_size,) return shapes def get_internal_states(self) -> Dict[str, nn.Variable]: states: Dict[str, nn.Variable] = {} - states['lstm_hidden'] = self._h - states['lstm_cell'] = self._c + states["lstm_hidden"] = self._h + states["lstm_cell"] = self._c return states def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): @@ -88,8 +88,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): if self._c is not None: self._c.data.zero() else: - self._h = states['lstm_hidden'] - self._c = states['lstm_cell'] + self._h = states["lstm_hidden"] + self._c = states["lstm_cell"] def _create_internal_states(self, batch_size): self._h = nn.Variable((batch_size, self._lstm_state_size)) @@ -136,14 +136,14 @@ def is_recurrent(self) -> bool: def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: shapes: Dict[str, nn.Variable] = {} - shapes['lstm_hidden'] = (self._lstm_state_size, ) - shapes['lstm_cell'] = (self._lstm_state_size, ) + shapes["lstm_hidden"] = (self._lstm_state_size,) + shapes["lstm_cell"] = (self._lstm_state_size,) return shapes def get_internal_states(self) -> Dict[str, nn.Variable]: states: Dict[str, nn.Variable] = {} - states['lstm_hidden'] = self._h - states['lstm_cell'] = self._c + states["lstm_hidden"] = self._h + states["lstm_cell"] = self._c return states def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): @@ -153,8 +153,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): if self._c is not None: self._c.data.zero() else: - self._h = states['lstm_hidden'] - self._c = states['lstm_cell'] + self._h = states["lstm_hidden"] + self._c = states["lstm_cell"] def _create_internal_states(self, batch_size): self._h = nn.Variable((batch_size, self._lstm_state_size)) @@ -175,7 +175,7 @@ def test_algorithm_name(self): dummy_env = E.DummyContinuous() sac = A.SAC(dummy_env) - assert sac.__name__ == 'SAC' + assert sac.__name__ == "SAC" def test_discrete_action_env_unsupported(self): """Check that error occurs when training on discrete action env.""" @@ -205,6 +205,7 @@ def test_run_offline_training(self): def test_run_online_rnn_training(self): """Check that no error occurs when calling online training with RNN model.""" + class RNNActorBuilder(ModelBuilder[StochasticPolicy]): def build_model(self, scope_name: str, env_info, algorithm_config, **kwargs): return RNNActorFunction(scope_name, action_dim=env_info.action_dim) @@ -260,20 +261,21 @@ def test_latest_iteration_state(self): dummy_env = E.DummyContinuous() sac = A.SAC(dummy_env) - sac._q_function_trainer_state = {'q_loss': 0., 'td_errors': np.array([0., 1.])} - sac._policy_trainer_state = {'pi_loss': 1.} + sac._q_function_trainer_state = {"q_loss": 0.0, "td_errors": np.array([0.0, 1.0])} + sac._policy_trainer_state = {"pi_loss": 1.0} latest_iteration_state = sac.latest_iteration_state - assert 'q_loss' in latest_iteration_state['scalar'] - assert 'pi_loss' in latest_iteration_state['scalar'] - assert 'td_errors' in latest_iteration_state['histogram'] - assert latest_iteration_state['scalar']['q_loss'] == 0. - assert latest_iteration_state['scalar']['pi_loss'] == 1. - assert np.allclose(latest_iteration_state['histogram']['td_errors'], np.array([0., 1.])) + assert "q_loss" in latest_iteration_state["scalar"] + assert "pi_loss" in latest_iteration_state["scalar"] + assert "td_errors" in latest_iteration_state["histogram"] + assert latest_iteration_state["scalar"]["q_loss"] == 0.0 + assert latest_iteration_state["scalar"]["pi_loss"] == 1.0 + assert np.allclose(latest_iteration_state["histogram"]["td_errors"], np.array([0.0, 1.0])) if __name__ == "__main__": from testing_utils import generate_dummy_experiences + pytest.main() else: from ..testing_utils import generate_dummy_experiences diff --git a/tests/algorithms/test_sacd.py b/tests/algorithms/test_sacd.py index 929a0ae4..f3aa3b86 100644 --- a/tests/algorithms/test_sacd.py +++ b/tests/algorithms/test_sacd.py @@ -1,4 +1,4 @@ -# Copyright 2022,2023 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ class RNNActorFunction(StochasticPolicy): def __init__(self, scope_name: str, action_dim: int): super(RNNActorFunction, self).__init__(scope_name) self._action_dim = action_dim - self._lstm_state_size = action_dim*2 + self._lstm_state_size = action_dim * 2 self._h = None self._c = None @@ -70,14 +70,14 @@ def is_recurrent(self) -> bool: def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: shapes: Dict[str, nn.Variable] = {} - shapes['lstm_hidden'] = (self._lstm_state_size, ) - shapes['lstm_cell'] = (self._lstm_state_size, ) + shapes["lstm_hidden"] = (self._lstm_state_size,) + shapes["lstm_cell"] = (self._lstm_state_size,) return shapes def get_internal_states(self) -> Dict[str, nn.Variable]: states: Dict[str, nn.Variable] = {} - states['lstm_hidden'] = self._h - states['lstm_cell'] = self._c + states["lstm_hidden"] = self._h + states["lstm_cell"] = self._c return states def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): @@ -87,8 +87,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): if self._c is not None: self._c.data.zero() else: - self._h = states['lstm_hidden'] - self._c = states['lstm_cell'] + self._h = states["lstm_hidden"] + self._c = states["lstm_cell"] def _create_internal_states(self, batch_size): self._h = nn.Variable((batch_size, self._lstm_state_size)) @@ -138,14 +138,14 @@ def is_recurrent(self) -> bool: def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: shapes: Dict[str, nn.Variable] = {} - shapes['lstm_hidden'] = (self._lstm_state_size, ) - shapes['lstm_cell'] = (self._lstm_state_size, ) + shapes["lstm_hidden"] = (self._lstm_state_size,) + shapes["lstm_cell"] = (self._lstm_state_size,) return shapes def get_internal_states(self) -> Dict[str, nn.Variable]: states: Dict[str, nn.Variable] = {} - states['lstm_hidden'] = self._h - states['lstm_cell'] = self._c + states["lstm_hidden"] = self._h + states["lstm_cell"] = self._c return states def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): @@ -155,8 +155,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): if self._c is not None: self._c.data.zero() else: - self._h = states['lstm_hidden'] - self._c = states['lstm_cell'] + self._h = states["lstm_hidden"] + self._c = states["lstm_cell"] def _create_internal_states(self, batch_size): self._h = nn.Variable((batch_size, self._lstm_state_size)) @@ -181,7 +181,7 @@ def test_algorithm_name(self): dummy_env = E.DummyFactoredContinuous(reward_dimension=2) sacd = A.SACD(dummy_env) - assert sacd.__name__ == 'SACD' + assert sacd.__name__ == "SACD" def test_discrete_action_env_unsupported(self): """Check that error occurs when training on discrete action env.""" @@ -211,6 +211,7 @@ def test_run_offline_training(self): def test_run_online_rnn_training(self): """Check that no error occurs when calling online training with RNN model.""" + class RNNActorBuilder(ModelBuilder[StochasticPolicy]): def build_model(self, scope_name: str, env_info, algorithm_config, **kwargs): return RNNActorFunction(scope_name, action_dim=env_info.action_dim) @@ -266,20 +267,21 @@ def test_latest_iteration_state(self): dummy_env = E.DummyFactoredContinuous(reward_dimension=2) sacd = A.SACD(dummy_env, config=A.SACDConfig(reward_dimension=2)) - sacd._q_function_trainer_state = {'q_loss': 0., 'td_errors': np.array([0., 1.])} - sacd._policy_trainer_state = {'pi_loss': 1.} + sacd._q_function_trainer_state = {"q_loss": 0.0, "td_errors": np.array([0.0, 1.0])} + sacd._policy_trainer_state = {"pi_loss": 1.0} latest_iteration_state = sacd.latest_iteration_state - assert 'q_loss' in latest_iteration_state['scalar'] - assert 'pi_loss' in latest_iteration_state['scalar'] - assert 'td_errors' in latest_iteration_state['histogram'] - assert latest_iteration_state['scalar']['q_loss'] == 0. - assert latest_iteration_state['scalar']['pi_loss'] == 1. - assert np.allclose(latest_iteration_state['histogram']['td_errors'], np.array([0., 1.])) + assert "q_loss" in latest_iteration_state["scalar"] + assert "pi_loss" in latest_iteration_state["scalar"] + assert "td_errors" in latest_iteration_state["histogram"] + assert latest_iteration_state["scalar"]["q_loss"] == 0.0 + assert latest_iteration_state["scalar"]["pi_loss"] == 1.0 + assert np.allclose(latest_iteration_state["histogram"]["td_errors"], np.array([0.0, 1.0])) if __name__ == "__main__": from testing_utils import generate_dummy_experiences + pytest.main() else: from ..testing_utils import generate_dummy_experiences diff --git a/tests/algorithms/test_srsac.py b/tests/algorithms/test_srsac.py index f812d30b..4a6abbd0 100644 --- a/tests/algorithms/test_srsac.py +++ b/tests/algorithms/test_srsac.py @@ -38,7 +38,7 @@ class RNNActorFunction(StochasticPolicy): def __init__(self, scope_name: str, action_dim: int): super(RNNActorFunction, self).__init__(scope_name) self._action_dim = action_dim - self._lstm_state_size = action_dim*2 + self._lstm_state_size = action_dim * 2 self._h = None self._c = None @@ -70,14 +70,14 @@ def is_recurrent(self) -> bool: def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: shapes: Dict[str, nn.Variable] = {} - shapes['lstm_hidden'] = (self._lstm_state_size, ) - shapes['lstm_cell'] = (self._lstm_state_size, ) + shapes["lstm_hidden"] = (self._lstm_state_size,) + shapes["lstm_cell"] = (self._lstm_state_size,) return shapes def get_internal_states(self) -> Dict[str, nn.Variable]: states: Dict[str, nn.Variable] = {} - states['lstm_hidden'] = self._h - states['lstm_cell'] = self._c + states["lstm_hidden"] = self._h + states["lstm_cell"] = self._c return states def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): @@ -87,8 +87,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): if self._c is not None: self._c.data.zero() else: - self._h = states['lstm_hidden'] - self._c = states['lstm_cell'] + self._h = states["lstm_hidden"] + self._c = states["lstm_cell"] def _create_internal_states(self, batch_size): self._h = nn.Variable((batch_size, self._lstm_state_size)) @@ -135,14 +135,14 @@ def is_recurrent(self) -> bool: def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: shapes: Dict[str, nn.Variable] = {} - shapes['lstm_hidden'] = (self._lstm_state_size, ) - shapes['lstm_cell'] = (self._lstm_state_size, ) + shapes["lstm_hidden"] = (self._lstm_state_size,) + shapes["lstm_cell"] = (self._lstm_state_size,) return shapes def get_internal_states(self) -> Dict[str, nn.Variable]: states: Dict[str, nn.Variable] = {} - states['lstm_hidden'] = self._h - states['lstm_cell'] = self._c + states["lstm_hidden"] = self._h + states["lstm_cell"] = self._c return states def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): @@ -152,8 +152,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): if self._c is not None: self._c.data.zero() else: - self._h = states['lstm_hidden'] - self._c = states['lstm_cell'] + self._h = states["lstm_hidden"] + self._c = states["lstm_cell"] def _create_internal_states(self, batch_size): self._h = nn.Variable((batch_size, self._lstm_state_size)) @@ -174,7 +174,7 @@ def test_algorithm_name(self): dummy_env = E.DummyContinuous() srsac = A.SRSAC(dummy_env) - assert srsac.__name__ == 'SRSAC' + assert srsac.__name__ == "SRSAC" def test_discrete_action_env_unsupported(self): """Check that error occurs when training on discrete action env.""" @@ -204,6 +204,7 @@ def test_run_offline_training(self): def test_run_online_rnn_training(self): """Check that no error occurs when calling online training with RNN model.""" + class RNNActorBuilder(ModelBuilder[StochasticPolicy]): def build_model(self, scope_name: str, env_info, algorithm_config, **kwargs): return RNNActorFunction(scope_name, action_dim=env_info.action_dim) @@ -220,8 +221,9 @@ def build_model(self, scope_name: str, env_info, algorithm_config, **kwargs): config.critic_burn_in_steps = 2 config.start_timesteps = 7 config.batch_size = 2 - srsac = A.SRSAC(dummy_env, config=config, q_function_builder=RNNCriticBuilder(), - policy_builder=RNNActorBuilder()) + srsac = A.SRSAC( + dummy_env, config=config, q_function_builder=RNNCriticBuilder(), policy_builder=RNNActorBuilder() + ) srsac.train_online(dummy_env, total_iterations=10) @@ -264,19 +266,19 @@ def test_latest_iteration_state(self): dummy_env = E.DummyContinuous() srsac = A.SRSAC(dummy_env) - srsac._q_function_trainer_state = {'q_loss': 0., 'td_errors': np.array([0., 1.])} - srsac._policy_trainer_state = {'pi_loss': 1.} + srsac._q_function_trainer_state = {"q_loss": 0.0, "td_errors": np.array([0.0, 1.0])} + srsac._policy_trainer_state = {"pi_loss": 1.0} latest_iteration_state = srsac.latest_iteration_state - assert 'q_loss' in latest_iteration_state['scalar'] - assert 'pi_loss' in latest_iteration_state['scalar'] - assert 'td_errors' in latest_iteration_state['histogram'] - assert latest_iteration_state['scalar']['q_loss'] == 0. - assert latest_iteration_state['scalar']['pi_loss'] == 1. - assert np.allclose(latest_iteration_state['histogram']['td_errors'], np.array([0., 1.])) + assert "q_loss" in latest_iteration_state["scalar"] + assert "pi_loss" in latest_iteration_state["scalar"] + assert "td_errors" in latest_iteration_state["histogram"] + assert latest_iteration_state["scalar"]["q_loss"] == 0.0 + assert latest_iteration_state["scalar"]["pi_loss"] == 1.0 + assert np.allclose(latest_iteration_state["histogram"]["td_errors"], np.array([0.0, 1.0])) -class TestEfficientSRSAC(): +class TestEfficientSRSAC: def setup_method(self, method): nn.clear_parameters() @@ -284,7 +286,7 @@ def test_algorithm_name(self): dummy_env = E.DummyContinuous() srsac = A.EfficientSRSAC(dummy_env) - assert srsac.__name__ == 'EfficientSRSAC' + assert srsac.__name__ == "EfficientSRSAC" def test_discrete_action_env_unsupported(self): """Check that error occurs when training on discrete action env.""" @@ -314,6 +316,7 @@ def test_run_offline_training(self): def test_run_online_rnn_training(self): """Check that no error occurs when calling online training with RNN model.""" + class RNNActorBuilder(ModelBuilder[StochasticPolicy]): def build_model(self, scope_name: str, env_info, algorithm_config, **kwargs): return RNNActorFunction(scope_name, action_dim=env_info.action_dim) @@ -333,8 +336,9 @@ def build_model(self, scope_name: str, env_info, algorithm_config, **kwargs): config.batch_size = 2 with pytest.raises(RuntimeError): - srsac = A.EfficientSRSAC(dummy_env, config=config, q_function_builder=RNNCriticBuilder(), - policy_builder=RNNActorBuilder()) + srsac = A.EfficientSRSAC( + dummy_env, config=config, q_function_builder=RNNCriticBuilder(), policy_builder=RNNActorBuilder() + ) srsac.train_online(dummy_env, total_iterations=10) def test_compute_eval_action(self): @@ -372,19 +376,20 @@ def test_latest_iteration_state(self): dummy_env = E.DummyContinuous() srsac = A.EfficientSRSAC(dummy_env) - srsac._actor_critic_trainer_state = {'q_loss': 0., 'td_errors': np.array([0., 1.]), 'pi_loss': 1.} + srsac._actor_critic_trainer_state = {"q_loss": 0.0, "td_errors": np.array([0.0, 1.0]), "pi_loss": 1.0} latest_iteration_state = srsac.latest_iteration_state - assert 'q_loss' in latest_iteration_state['scalar'] - assert 'pi_loss' in latest_iteration_state['scalar'] - assert 'td_errors' in latest_iteration_state['histogram'] - assert latest_iteration_state['scalar']['q_loss'] == 0. - assert latest_iteration_state['scalar']['pi_loss'] == 1. - assert np.allclose(latest_iteration_state['histogram']['td_errors'], np.array([0., 1.])) + assert "q_loss" in latest_iteration_state["scalar"] + assert "pi_loss" in latest_iteration_state["scalar"] + assert "td_errors" in latest_iteration_state["histogram"] + assert latest_iteration_state["scalar"]["q_loss"] == 0.0 + assert latest_iteration_state["scalar"]["pi_loss"] == 1.0 + assert np.allclose(latest_iteration_state["histogram"]["td_errors"], np.array([0.0, 1.0])) if __name__ == "__main__": from testing_utils import generate_dummy_experiences + pytest.main() else: from ..testing_utils import generate_dummy_experiences diff --git a/tests/algorithms/test_td3.py b/tests/algorithms/test_td3.py index 078490e5..e6d5f931 100644 --- a/tests/algorithms/test_td3.py +++ b/tests/algorithms/test_td3.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -64,14 +64,14 @@ def is_recurrent(self) -> bool: def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: shapes: Dict[str, nn.Variable] = {} - shapes['lstm_hidden'] = (self._lstm_state_size, ) - shapes['lstm_cell'] = (self._lstm_state_size, ) + shapes["lstm_hidden"] = (self._lstm_state_size,) + shapes["lstm_cell"] = (self._lstm_state_size,) return shapes def get_internal_states(self) -> Dict[str, nn.Variable]: states: Dict[str, nn.Variable] = {} - states['lstm_hidden'] = self._h - states['lstm_cell'] = self._c + states["lstm_hidden"] = self._h + states["lstm_cell"] = self._c return states def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): @@ -81,8 +81,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): if self._c is not None: self._c.data.zero() else: - self._h = states['lstm_hidden'] - self._c = states['lstm_cell'] + self._h = states["lstm_hidden"] + self._c = states["lstm_cell"] def _create_internal_states(self, batch_size): self._h = nn.Variable((batch_size, self._lstm_state_size)) @@ -129,14 +129,14 @@ def is_recurrent(self) -> bool: def internal_state_shapes(self) -> Dict[str, Tuple[int, ...]]: shapes: Dict[str, nn.Variable] = {} - shapes['lstm_hidden'] = (self._lstm_state_size, ) - shapes['lstm_cell'] = (self._lstm_state_size, ) + shapes["lstm_hidden"] = (self._lstm_state_size,) + shapes["lstm_cell"] = (self._lstm_state_size,) return shapes def get_internal_states(self) -> Dict[str, nn.Variable]: states: Dict[str, nn.Variable] = {} - states['lstm_hidden'] = self._h - states['lstm_cell'] = self._c + states["lstm_hidden"] = self._h + states["lstm_cell"] = self._c return states def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): @@ -146,8 +146,8 @@ def set_internal_states(self, states: Optional[Dict[str, nn.Variable]] = None): if self._c is not None: self._c.data.zero() else: - self._h = states['lstm_hidden'] - self._c = states['lstm_cell'] + self._h = states["lstm_hidden"] + self._c = states["lstm_cell"] def _create_internal_states(self, batch_size): self._h = nn.Variable((batch_size, self._lstm_state_size)) @@ -168,7 +168,7 @@ def test_algorithm_name(self): dummy_env = E.DummyContinuous() td3 = A.TD3(dummy_env) - assert td3.__name__ == 'TD3' + assert td3.__name__ == "TD3" def test_discrete_action_env_unsupported(self): """Check that error occurs when training on discrete action env.""" @@ -188,6 +188,7 @@ def test_run_online_training(self): def test_run_online_rnn_training(self): """Check that no error occurs when calling online training with RNN model.""" + class RNNActorBuilder(ModelBuilder[DeterministicPolicy]): def build_model(self, scope_name: str, env_info, algorithm_config, **kwargs): return RNNActorFunction(scope_name, action_dim=env_info.action_dim) @@ -263,20 +264,21 @@ def test_latest_iteration_state(self): dummy_env = E.DummyContinuous() td3 = A.TD3(dummy_env) - td3._q_function_trainer_state = {'q_loss': 0., 'td_errors': np.array([0., 1.])} - td3._policy_trainer_state = {'pi_loss': 1.} + td3._q_function_trainer_state = {"q_loss": 0.0, "td_errors": np.array([0.0, 1.0])} + td3._policy_trainer_state = {"pi_loss": 1.0} latest_iteration_state = td3.latest_iteration_state - assert 'q_loss' in latest_iteration_state['scalar'] - assert 'pi_loss' in latest_iteration_state['scalar'] - assert 'td_errors' in latest_iteration_state['histogram'] - assert latest_iteration_state['scalar']['q_loss'] == 0. - assert latest_iteration_state['scalar']['pi_loss'] == 1. - assert np.allclose(latest_iteration_state['histogram']['td_errors'], np.array([0., 1.])) + assert "q_loss" in latest_iteration_state["scalar"] + assert "pi_loss" in latest_iteration_state["scalar"] + assert "td_errors" in latest_iteration_state["histogram"] + assert latest_iteration_state["scalar"]["q_loss"] == 0.0 + assert latest_iteration_state["scalar"]["pi_loss"] == 1.0 + assert np.allclose(latest_iteration_state["histogram"]["td_errors"], np.array([0.0, 1.0])) if __name__ == "__main__": from testing_utils import generate_dummy_experiences + pytest.main() else: from ..testing_utils import generate_dummy_experiences diff --git a/tests/algorithms/test_trpo.py b/tests/algorithms/test_trpo.py index 7dc9fa32..8a824717 100644 --- a/tests/algorithms/test_trpo.py +++ b/tests/algorithms/test_trpo.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ import nnabla_rl.environments as E -class TestTRPO(): +class TestTRPO: def setup_method(self): nn.clear_parameters() @@ -29,7 +29,7 @@ def test_algorithm_name(self): dummy_env = E.DummyContinuous() trpo = A.TRPO(dummy_env) - assert trpo.__name__ == 'TRPO' + assert trpo.__name__ == "TRPO" def test_discrete_action_env_unsupported(self): """Check that error occurs when training on discrete action env.""" @@ -42,12 +42,14 @@ def test_run_online_training(self): dummy_env = E.DummyContinuous() dummy_env = EpisodicEnv(dummy_env, min_episode_length=3) - config = A.TRPOConfig(num_steps_per_iteration=5, - gpu_batch_size=5, - pi_batch_size=5, - vf_batch_size=2, - sigma_kl_divergence_constraint=10.0, - maximum_backtrack_numbers=50) + config = A.TRPOConfig( + num_steps_per_iteration=5, + gpu_batch_size=5, + pi_batch_size=5, + vf_batch_size=2, + sigma_kl_divergence_constraint=10.0, + maximum_backtrack_numbers=50, + ) trpo = A.TRPO(dummy_env, config=config) trpo.train_online(dummy_env, total_iterations=5) @@ -97,15 +99,16 @@ def test_latest_iteration_state(self): dummy_env = E.DummyContinuous() trpo = A.TRPO(dummy_env) - trpo._v_function_trainer_state = {'v_loss': 0.} + trpo._v_function_trainer_state = {"v_loss": 0.0} latest_iteration_state = trpo.latest_iteration_state - assert 'v_loss' in latest_iteration_state['scalar'] - assert latest_iteration_state['scalar']['v_loss'] == 0. + assert "v_loss" in latest_iteration_state["scalar"] + assert latest_iteration_state["scalar"]["v_loss"] == 0.0 if __name__ == "__main__": from testing_utils import EpisodicEnv + pytest.main() else: from ..testing_utils import EpisodicEnv diff --git a/tests/algorithms/test_xql.py b/tests/algorithms/test_xql.py index 0e12bfbc..eeaaea2a 100644 --- a/tests/algorithms/test_xql.py +++ b/tests/algorithms/test_xql.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ def test_algorithm_name(self): dummy_env = E.DummyContinuous() xql = A.XQL(dummy_env) - assert xql.__name__ == 'XQL' + assert xql.__name__ == "XQL" def test_discrete_action_env_unsupported(self): """Check that error occurs when training on discrete action env.""" @@ -94,23 +94,24 @@ def test_latest_iteration_state(self): dummy_env = E.DummyContinuous() xql = A.XQL(dummy_env) - xql._q_function_trainer_state = {'q_loss': 1., 'td_errors': np.array([0., 1.])} - xql._v_function_trainer_state = {'v_loss': 1.} - xql._policy_trainer_state = {'pi_loss': 2.} + xql._q_function_trainer_state = {"q_loss": 1.0, "td_errors": np.array([0.0, 1.0])} + xql._v_function_trainer_state = {"v_loss": 1.0} + xql._policy_trainer_state = {"pi_loss": 2.0} latest_iteration_state = xql.latest_iteration_state - assert 'q_loss' in latest_iteration_state['scalar'] - assert 'v_loss' in latest_iteration_state['scalar'] - assert 'pi_loss' in latest_iteration_state['scalar'] - assert 'td_errors' in latest_iteration_state['histogram'] - assert latest_iteration_state['scalar']['q_loss'] == 1. - assert latest_iteration_state['scalar']['v_loss'] == 1. - assert latest_iteration_state['scalar']['pi_loss'] == 2. - assert np.allclose(latest_iteration_state['histogram']['td_errors'], np.array([0., 1.])) + assert "q_loss" in latest_iteration_state["scalar"] + assert "v_loss" in latest_iteration_state["scalar"] + assert "pi_loss" in latest_iteration_state["scalar"] + assert "td_errors" in latest_iteration_state["histogram"] + assert latest_iteration_state["scalar"]["q_loss"] == 1.0 + assert latest_iteration_state["scalar"]["v_loss"] == 1.0 + assert latest_iteration_state["scalar"]["pi_loss"] == 2.0 + assert np.allclose(latest_iteration_state["histogram"]["td_errors"], np.array([0.0, 1.0])) if __name__ == "__main__": from testing_utils import generate_dummy_experiences + pytest.main() else: from ..testing_utils import generate_dummy_experiences diff --git a/tests/distributions/test_bernoulli.py b/tests/distributions/test_bernoulli.py index ba116587..10ef4be4 100644 --- a/tests/distributions/test_bernoulli.py +++ b/tests/distributions/test_bernoulli.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -35,8 +35,7 @@ def test_sample(self): assert np.all(sampled.d == np.array([[1], [0], [0], [1]])) def test_sample_multi_dimensional(self): - z = np.array([[[1000.0], [-1000.0], [-1000.0], [1000.0]], - [[1000.0], [-1000.0], [1000.0], [-1000.0]]]) + z = np.array([[[1000.0], [-1000.0], [-1000.0], [1000.0]], [[1000.0], [-1000.0], [1000.0], [-1000.0]]]) assert z.shape == (2, 4, 1) batch_size = z.shape[0] category_size = z.shape[1] @@ -58,7 +57,7 @@ def test_log_prob(self): assert actual.shape == (batch_size, 1) p = self._sigmoid(z) - expected = np.where(classes == 1, np.log(p), np.log(1-p)) + expected = np.where(classes == 1, np.log(p), np.log(1 - p)) assert actual.shape == expected.shape assert np.allclose(actual.d, expected) @@ -74,7 +73,7 @@ def test_log_prob_multi_dimensional(self): assert actual.shape == (batch_size, category_num, 1) p = self._sigmoid(z) - expected = np.where(classes == 1, np.log(p), np.log(1-p)) + expected = np.where(classes == 1, np.log(p), np.log(1 - p)) assert actual.shape == expected.shape assert np.allclose(actual.d, expected) @@ -87,7 +86,7 @@ def test_entropy(self): assert actual.shape == (batch_size, 1) p = self._sigmoid(z) - probabilities = np.concatenate((p, 1-p), axis=-1) + probabilities = np.concatenate((p, 1 - p), axis=-1) expected = -np.sum(np.log(probabilities) * probabilities, axis=1, keepdims=True) assert actual.shape == expected.shape @@ -103,7 +102,7 @@ def test_entropy_multi_dimensional(self): assert actual.shape == (batch_size, category_num, 1) p = self._sigmoid(z) - probabilities = np.concatenate((p, 1-p), axis=-1) + probabilities = np.concatenate((p, 1 - p), axis=-1) expected = -np.sum(np.log(probabilities) * probabilities, axis=len(z.shape) - 1, keepdims=True) assert actual.shape == expected.shape @@ -113,12 +112,12 @@ def test_kl_divergence(self): batch_size = 10 z_p = np.random.normal(size=(batch_size, 1)) z_p_p = self._sigmoid(z_p) - z_p_dist = np.concatenate((z_p_p, 1-z_p_p), axis=-1) + z_p_dist = np.concatenate((z_p_p, 1 - z_p_p), axis=-1) distribution_p = D.Bernoulli(z=z_p) z_q = np.random.normal(size=(batch_size, 1)) z_q_p = self._sigmoid(z_q) - z_q_dist = np.concatenate((z_q_p, 1-z_q_p), axis=-1) + z_q_dist = np.concatenate((z_q_p, 1 - z_q_p), axis=-1) distribution_q = D.Bernoulli(z=z_q) actual = distribution_p.kl_divergence(distribution_q) @@ -135,12 +134,12 @@ def test_kl_divergence_multi_dimensional(self): category_num = 3 z_p = np.random.normal(size=(batch_size, category_num, 1)) z_p_p = self._sigmoid(z_p) - z_p_dist = np.concatenate((z_p_p, 1-z_p_p), axis=-1) + z_p_dist = np.concatenate((z_p_p, 1 - z_p_p), axis=-1) distribution_p = D.Bernoulli(z=z_p) z_q = np.random.normal(size=(batch_size, category_num, 1)) z_q_p = self._sigmoid(z_q) - z_q_dist = np.concatenate((z_q_p, 1-z_q_p), axis=-1) + z_q_dist = np.concatenate((z_q_p, 1 - z_q_p), axis=-1) distribution_q = D.Bernoulli(z=z_q) actual = distribution_p.kl_divergence(distribution_q) diff --git a/tests/distributions/test_common_utils.py b/tests/distributions/test_common_utils.py index ba48796a..eeaa5ff5 100644 --- a/tests/distributions/test_common_utils.py +++ b/tests/distributions/test_common_utils.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,10 +32,12 @@ def test_gaussian_log_prob(self): ln_var = np.random.randn(1, 4) * 5.0 var = np.exp(ln_var) - actual = common_utils.gaussian_log_prob(nn.Variable.from_numpy_array(x), - nn.Variable.from_numpy_array(mean), - nn.Variable.from_numpy_array(var), - nn.Variable.from_numpy_array(ln_var)) + actual = common_utils.gaussian_log_prob( + nn.Variable.from_numpy_array(x), + nn.Variable.from_numpy_array(mean), + nn.Variable.from_numpy_array(var), + nn.Variable.from_numpy_array(ln_var), + ) actual.forward() actual = actual.d diff --git a/tests/distributions/test_gaussian.py b/tests/distributions/test_gaussian.py index 13c06bd7..24abe1c7 100644 --- a/tests/distributions/test_gaussian.py +++ b/tests/distributions/test_gaussian.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,14 +23,14 @@ from nnabla_rl.distributions.gaussian import NnablaGaussian, NumpyGaussian -class TestGaussian(): +class TestGaussian: def _generate_dummy_mean_var(self): batch_size = 10 output_dim = 10 input_shape = (batch_size, output_dim) mean = np.zeros(shape=input_shape) - sigma = np.ones(shape=input_shape) * 5. - ln_var = np.log(sigma) * 2. + sigma = np.ones(shape=input_shape) * 5.0 + ln_var = np.log(sigma) * 2.0 return mean, ln_var def test_nnabla_constructor(self): @@ -73,12 +73,13 @@ def test_sample(self): input_shape = (batch_size, output_dim) mean = np.zeros(shape=input_shape) - sigma = np.ones(shape=input_shape) * 5. - ln_var = np.log(sigma) * 2. + sigma = np.ones(shape=input_shape) * 5.0 + ln_var = np.log(sigma) * 2.0 - with mock.patch('nnabla_rl.functions.sample_gaussian') as mock_sample_gaussian: - distribution = NnablaGaussian(mean=nn.Variable.from_numpy_array(mean), - ln_var=nn.Variable.from_numpy_array(ln_var)) + with mock.patch("nnabla_rl.functions.sample_gaussian") as mock_sample_gaussian: + distribution = NnablaGaussian( + mean=nn.Variable.from_numpy_array(mean), ln_var=nn.Variable.from_numpy_array(ln_var) + ) noise_clip = None sampled = distribution.sample(noise_clip=noise_clip) sampled.forward() @@ -101,8 +102,9 @@ def test_sample_and_compute_log_prob(self, mean, var): ln_var = np.ones(shape=input_shape) * np.log(var) var = np.exp(ln_var) - distribution = NnablaGaussian(mean=nn.Variable.from_numpy_array(mu), - ln_var=nn.Variable.from_numpy_array(ln_var)) + distribution = NnablaGaussian( + mean=nn.Variable.from_numpy_array(mu), ln_var=nn.Variable.from_numpy_array(ln_var) + ) sample, log_prob = distribution.sample_and_compute_log_prob() @@ -112,9 +114,7 @@ def test_sample_and_compute_log_prob(self, mean, var): nn.forward_all([sample, log_prob]) x = sample.d - gaussian_log_prob = -0.5 * np.log(2.0 * np.pi) \ - - 0.5 * ln_var \ - - (x - mu) ** 2 / (2.0 * var) + gaussian_log_prob = -0.5 * np.log(2.0 * np.pi) - 0.5 * ln_var - (x - mu) ** 2 / (2.0 * var) expected = np.sum(gaussian_log_prob, axis=-1, keepdims=True) actual = log_prob.d @@ -127,23 +127,22 @@ def test_sample_multiple(self): input_shape = (batch_size, output_dim) mean = np.zeros(shape=input_shape) - sigma = np.ones(shape=input_shape) * 5. - ln_var = np.log(sigma) * 2. + sigma = np.ones(shape=input_shape) * 5.0 + ln_var = np.log(sigma) * 2.0 - with mock.patch('nnabla_rl.functions.sample_gaussian_multiple') as mock_sample_multiple_gaussian: - distribution = NnablaGaussian(mean=nn.Variable.from_numpy_array(mean), - ln_var=nn.Variable.from_numpy_array(ln_var)) + with mock.patch("nnabla_rl.functions.sample_gaussian_multiple") as mock_sample_multiple_gaussian: + distribution = NnablaGaussian( + mean=nn.Variable.from_numpy_array(mean), ln_var=nn.Variable.from_numpy_array(ln_var) + ) noise_clip = None num_samples = 10 - sampled = distribution.sample_multiple( - num_samples, noise_clip=noise_clip) + sampled = distribution.sample_multiple(num_samples, noise_clip=noise_clip) sampled.forward() assert mock_sample_multiple_gaussian.call_count == 1 args, kwargs = mock_sample_multiple_gaussian.call_args - assert args == (distribution._mean, - distribution._ln_var, num_samples) + assert args == (distribution._mean, distribution._ln_var, num_samples) assert kwargs == {"noise_clip": noise_clip} @pytest.mark.parametrize("mean", np.arange(start=-1.0, stop=1.0, step=0.25)) @@ -157,11 +156,11 @@ def test_sample_multiple_and_compute_log_prob(self, mean, var): mu = np.ones(shape=input_shape) * mean ln_var = np.ones(shape=input_shape) * np.log(var) - distribution = NnablaGaussian(mean=nn.Variable.from_numpy_array(mu), - ln_var=nn.Variable.from_numpy_array(ln_var)) + distribution = NnablaGaussian( + mean=nn.Variable.from_numpy_array(mu), ln_var=nn.Variable.from_numpy_array(ln_var) + ) num_samples = 10 - samples, log_probs = distribution.sample_multiple_and_compute_log_prob( - num_samples=num_samples) + samples, log_probs = distribution.sample_multiple_and_compute_log_prob(num_samples=num_samples) # FIXME: if you enable clear_no_need_grad seems to compute something different # Do NOT use forward_all and no_need_grad flag at same time # nnabla's bug? @@ -169,22 +168,17 @@ def test_sample_multiple_and_compute_log_prob(self, mean, var): x = samples.d[:, 0, :] assert x.shape == (batch_size, output_dim) - gaussian_log_prob = -0.5 * np.log(2.0 * np.pi) \ - - 0.5 * ln_var \ - - (x - mu) ** 2 / (2.0 * var) + gaussian_log_prob = -0.5 * np.log(2.0 * np.pi) - 0.5 * ln_var - (x - mu) ** 2 / (2.0 * var) expected = np.sum(gaussian_log_prob, axis=-1, keepdims=True) actual = log_probs.d assert expected.shape == (batch_size, 1) assert np.allclose(expected, actual[:, 0, :]) mu = np.reshape(mu, newshape=(batch_size, 1, output_dim)) - ln_var = np.reshape(ln_var, newshape=( - batch_size, 1, output_dim)) + ln_var = np.reshape(ln_var, newshape=(batch_size, 1, output_dim)) x = samples.d - gaussian_log_prob = -0.5 * np.log(2.0 * np.pi) \ - - 0.5 * ln_var \ - - (x - mu) ** 2 / (2.0 * var) + gaussian_log_prob = -0.5 * np.log(2.0 * np.pi) - 0.5 * ln_var - (x - mu) ** 2 / (2.0 * var) expected = np.sum(gaussian_log_prob, axis=-1, keepdims=True) assert expected.shape == actual.shape @@ -196,14 +190,14 @@ def test_sample_multiple_and_compute_log_prob_shape(self): input_shape = (batch_size, output_dim) mean = np.zeros(shape=input_shape) - sigma = np.ones(shape=input_shape) * 5. - ln_var = np.log(sigma) * 2. + sigma = np.ones(shape=input_shape) * 5.0 + ln_var = np.log(sigma) * 2.0 - distribution = NnablaGaussian(mean=nn.Variable.from_numpy_array(mean), - ln_var=nn.Variable.from_numpy_array(ln_var)) + distribution = NnablaGaussian( + mean=nn.Variable.from_numpy_array(mean), ln_var=nn.Variable.from_numpy_array(ln_var) + ) num_samples = 10 - samples, log_probs = distribution.sample_multiple_and_compute_log_prob( - num_samples=num_samples) + samples, log_probs = distribution.sample_multiple_and_compute_log_prob(num_samples=num_samples) nn.forward_all([samples, log_probs]) assert samples.shape == (batch_size, num_samples, output_dim) @@ -215,25 +209,23 @@ def test_log_prob(self): input_shape = (batch_size, output_dim) mean = np.zeros(shape=input_shape) - sigma = np.ones(shape=input_shape) * 5. - ln_var = np.log(sigma) * 2. - dummy_input = nn.Variable.from_numpy_array( - np.random.randn(batch_size, output_dim)) - - with mock.patch('nnabla_rl.distributions.common_utils.gaussian_log_prob', - return_value=nn.Variable.from_numpy_array(np.empty(shape=input_shape))) \ - as mock_gaussian_log_prob: - distribution = NnablaGaussian(mean=nn.Variable.from_numpy_array(mean), - ln_var=nn.Variable.from_numpy_array(ln_var)) + sigma = np.ones(shape=input_shape) * 5.0 + ln_var = np.log(sigma) * 2.0 + dummy_input = nn.Variable.from_numpy_array(np.random.randn(batch_size, output_dim)) + + with mock.patch( + "nnabla_rl.distributions.common_utils.gaussian_log_prob", + return_value=nn.Variable.from_numpy_array(np.empty(shape=input_shape)), + ) as mock_gaussian_log_prob: + distribution = NnablaGaussian( + mean=nn.Variable.from_numpy_array(mean), ln_var=nn.Variable.from_numpy_array(ln_var) + ) distribution.log_prob(dummy_input) assert mock_gaussian_log_prob.call_count == 1 args, _ = mock_gaussian_log_prob.call_args - assert args == (dummy_input, - distribution._mean, - distribution._var, - distribution._ln_var) + assert args == (dummy_input, distribution._mean, distribution._var, distribution._ln_var) @pytest.mark.parametrize("batch_size", range(1, 10)) @pytest.mark.parametrize("output_dim", range(1, 10)) @@ -242,7 +234,7 @@ def test_entropy(self, batch_size, output_dim): mean = np.zeros(shape=input_shape) sigma = np.ones(shape=input_shape) - ln_var = np.log(sigma) * 2. + ln_var = np.log(sigma) * 2.0 distribution = NnablaGaussian(nn.Variable.from_numpy_array(mean), nn.Variable.from_numpy_array(ln_var)) actual = distribution.entropy() @@ -261,8 +253,8 @@ def test_kl_divergence(self): input_shape = (batch_size, output_dim) mean = np.zeros(shape=input_shape) - sigma = np.ones(shape=input_shape) * 5. - ln_var = np.log(sigma) * 2. + sigma = np.ones(shape=input_shape) * 5.0 + ln_var = np.log(sigma) * 2.0 distribution_p = NnablaGaussian(nn.Variable.from_numpy_array(mean), nn.Variable.from_numpy_array(ln_var)) distribution_q = NnablaGaussian(nn.Variable.from_numpy_array(mean), nn.Variable.from_numpy_array(ln_var)) @@ -283,10 +275,10 @@ def _gaussian_differential_entropy(self, covariance_matrix): return 0.5 * np.log(np.power(2.0 * np.pi * np.e, covariance_matrix.shape[0]) * determinant) -class TestNumpyGaussian(): - def _generate_dummy_mean_var(self, scale=5.): +class TestNumpyGaussian: + def _generate_dummy_mean_var(self, scale=5.0): gaussian_dim = 10 - mean_shape = (gaussian_dim, ) + mean_shape = (gaussian_dim,) mean = np.random.normal(size=mean_shape) sigma = np.diag(np.ones(shape=mean_shape) * scale) sigma_inv = np.diag(1.0 / np.diag(sigma)) @@ -337,10 +329,10 @@ def test_numpy_kl_divergence_identical_distribution(self): distribution_q = NumpyGaussian(mean, np.log(sigma)) actual = distribution_p.kl_divergence(distribution_q) - expected = np.zeros((1, )) + expected = np.zeros((1,)) assert expected == pytest.approx(actual, abs=1e-5) -if __name__ == '__main__': +if __name__ == "__main__": pytest.main() diff --git a/tests/distributions/test_gmm.py b/tests/distributions/test_gmm.py index 0f152712..8a71b9e4 100644 --- a/tests/distributions/test_gmm.py +++ b/tests/distributions/test_gmm.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,8 +19,14 @@ from scipy import stats import nnabla as nn -from nnabla_rl.distributions.gmm import (GMM, NumpyGMM, compute_mean_and_covariance, compute_mixing_coefficient, - compute_responsibility, inference_data_mean_and_covariance) +from nnabla_rl.distributions.gmm import ( + GMM, + NumpyGMM, + compute_mean_and_covariance, + compute_mixing_coefficient, + compute_responsibility, + inference_data_mean_and_covariance, +) def _generate_dummy_params(): @@ -32,7 +38,7 @@ def _generate_dummy_params(): return num_classes, dims, means, covariances, mixing_coefficients -class TestGMM(): +class TestGMM: def test_nnabla_constructor(self): _, _, means, covariances, mixing_coefficients = _generate_dummy_params() nnabla_means = nn.Variable.from_numpy_array(means) @@ -47,12 +53,12 @@ def test_numpy_constructor(self): assert isinstance(distribution._delegate, NumpyGMM) -class TestNumpyGMM(): +class TestNumpyGMM: def test_sample(self): _, dims, means, covariances, mixing_coefficients = _generate_dummy_params() gmm = NumpyGMM(means, covariances, mixing_coefficients) sample = gmm.sample() - assert sample.shape == (dims, ) + assert sample.shape == (dims,) with pytest.raises(NotImplementedError): gmm.sample(noise_clip=np.ones(means.shape[1:])) @@ -88,9 +94,7 @@ def test_numpy_compute_responsibility(self): def test_numpy_compute_mixing_coefficient(self): responsibility = np.random.randn(3, 1) + 2.0 - with mock.patch( - 'nnabla_rl.distributions.gmm.logsumexp', return_value=1.0 - ) as mock_logsumexp: + with mock.patch("nnabla_rl.distributions.gmm.logsumexp", return_value=1.0) as mock_logsumexp: result = compute_mixing_coefficient(responsibility) mock_logsumexp.assert_called_once() assert result == np.exp(1 - np.log(responsibility.shape[0])) @@ -110,13 +114,13 @@ def test_numpy_inference_data_mean_and_covariance(self): gmm = NumpyGMM(means, covariances, mixing_coefficients) with mock.patch( - 'nnabla_rl.distributions.gmm.compute_responsibility', return_value=(None, None) + "nnabla_rl.distributions.gmm.compute_responsibility", return_value=(None, None) ) as mock_compute_responsibility: with mock.patch( - 'nnabla_rl.distributions.gmm.compute_mixing_coefficient', return_value=None + "nnabla_rl.distributions.gmm.compute_mixing_coefficient", return_value=None ) as mock_compute_mixing_coefficient: with mock.patch( - 'nnabla_rl.distributions.gmm.compute_mean_and_covariance', + "nnabla_rl.distributions.gmm.compute_mean_and_covariance", return_value=(None, None), ) as mock_compute_mean_and_var: @@ -126,5 +130,5 @@ def test_numpy_inference_data_mean_and_covariance(self): mock_compute_mean_and_var.assert_called_once() -if __name__ == '__main__': +if __name__ == "__main__": pytest.main() diff --git a/tests/distributions/test_one_hot_softmax.py b/tests/distributions/test_one_hot_softmax.py index 5f9895f1..8e1012ab 100644 --- a/tests/distributions/test_one_hot_softmax.py +++ b/tests/distributions/test_one_hot_softmax.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,10 +27,7 @@ def setup_method(self, method): nn.clear_parameters() def test_sample(self): - z = np.array([[0, 0, 1000, 0], - [0, 1000, 0, 0], - [1000, 0, 0, 0], - [0, 0, 0, 1000]]) + z = np.array([[0, 0, 1000, 0], [0, 1000, 0, 0], [1000, 0, 0, 0], [0, 0, 0, 1000]]) batch_size = z.shape[0] distribution = D.OneHotSoftmax(z=z) @@ -41,14 +38,12 @@ def test_sample(self): assert np.all(sampled.d == np.array([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1]])) def test_sample_multi_dimensional(self): - z = np.array([[[1000, 0, 0, 0], - [0, 0, 0, 1000], - [0, 1000, 0, 0], - [0, 0, 1000, 0]], - [[0, 0, 1000, 0], - [0, 1000, 0, 0], - [1000, 0, 0, 0], - [0, 0, 0, 1000]]]) + z = np.array( + [ + [[1000, 0, 0, 0], [0, 0, 0, 1000], [0, 1000, 0, 0], [0, 0, 1000, 0]], + [[0, 0, 1000, 0], [0, 1000, 0, 0], [1000, 0, 0, 0], [0, 0, 0, 1000]], + ] + ) assert z.shape == (2, 4, 4) batch_size = z.shape[0] category_size = z.shape[1] @@ -57,18 +52,23 @@ def test_sample_multi_dimensional(self): sampled.forward() assert sampled.shape == (batch_size, category_size, 4) - assert np.all(sampled.d == np.array([[[1, 0, 0, 0], [0, 0, 0, 1], [0, 1, 0, 0], [0, 0, 1, 0]], - [[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1]]])) + assert np.all( + sampled.d + == np.array( + [ + [[1, 0, 0, 0], [0, 0, 0, 1], [0, 1, 0, 0], [0, 0, 1, 0]], + [[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1]], + ] + ) + ) def test_choose_probable(self): - z = np.array([[[1.0, 2.0, 3.0, 4.0], - [2.0, 3.0, 1.0, -1.0], - [-5.1, 11.2, 0.8, 0.7], - [-3.0, -2.1, -7.6, -5.4]], - [[0, 0, 1000, 0], - [0, 1000, 0, 0], - [1000, 0, 0, 0], - [0, 0, 0, 1000]]]) + z = np.array( + [ + [[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 1.0, -1.0], [-5.1, 11.2, 0.8, 0.7], [-3.0, -2.1, -7.6, -5.4]], + [[0, 0, 1000, 0], [0, 1000, 0, 0], [1000, 0, 0, 0], [0, 0, 0, 1000]], + ] + ) assert z.shape == (2, 4, 4) batch_size = z.shape[0] category_size = z.shape[1] @@ -77,8 +77,15 @@ def test_choose_probable(self): probable.forward() assert probable.shape == (batch_size, category_size, 4) - assert np.all(probable.d == np.array([[[0, 0, 0, 1], [0, 1, 0, 0], [0, 1, 0, 0], [0, 1, 0, 0]], - [[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1]]])) + assert np.all( + probable.d + == np.array( + [ + [[0, 0, 0, 1], [0, 1, 0, 0], [0, 1, 0, 0], [0, 1, 0, 0]], + [[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1]], + ] + ) + ) def test_backprop(self): class TestModel(StochasticPolicy): @@ -86,7 +93,8 @@ def pi(self, s: nn.Variable): with nn.parameter_scope(self.scope_name): z = NPF.affine(s, n_outmaps=5) return D.OneHotSoftmax(z=z) - model = TestModel('test') + + model = TestModel("test") batch_size = 5 data_dim = 10 diff --git a/tests/distributions/test_softmax.py b/tests/distributions/test_softmax.py index 4f8cdcd7..db5c1bb7 100644 --- a/tests/distributions/test_softmax.py +++ b/tests/distributions/test_softmax.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,10 +25,7 @@ def setup_method(self, method): nn.clear_parameters() def test_sample(self): - z = np.array([[0, 0, 1000, 0], - [0, 1000, 0, 0], - [1000, 0, 0, 0], - [0, 0, 0, 1000]]) + z = np.array([[0, 0, 1000, 0], [0, 1000, 0, 0], [1000, 0, 0, 0], [0, 0, 0, 1000]]) batch_size = z.shape[0] distribution = D.Softmax(z=z) @@ -39,14 +36,12 @@ def test_sample(self): assert np.all(sampled.d == np.array([[2], [1], [0], [3]])) def test_sample_multi_dimensional(self): - z = np.array([[[1000, 0, 0, 0], - [0, 0, 0, 1000], - [0, 1000, 0, 0], - [0, 0, 1000, 0]], - [[0, 0, 1000, 0], - [0, 1000, 0, 0], - [1000, 0, 0, 0], - [0, 0, 0, 1000]]]) + z = np.array( + [ + [[1000, 0, 0, 0], [0, 0, 0, 1000], [0, 1000, 0, 0], [0, 0, 1000, 0]], + [[0, 0, 1000, 0], [0, 1000, 0, 0], [1000, 0, 0, 0], [0, 0, 0, 1000]], + ] + ) assert z.shape == (2, 4, 4) batch_size = z.shape[0] category_size = z.shape[1] @@ -58,28 +53,23 @@ def test_sample_multi_dimensional(self): assert np.all(sampled.d == np.array([[[0], [3], [1], [2]], [[2], [1], [0], [3]]])) def test_choose_probable(self): - z = np.array([[1.0, 2.0, 3.0, 4.0], - [2.0, 3.0, 1.0, -1.0], - [-5.1, 11.2, 0.8, 0.7], - [-3.0, -2.1, -7.6, -5.4]]) + z = np.array([[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 1.0, -1.0], [-5.1, 11.2, 0.8, 0.7], [-3.0, -2.1, -7.6, -5.4]]) assert z.shape == (4, 4) batch_size = z.shape[0] distribution = D.Softmax(z=z) probable = distribution.choose_probable() probable.forward() - assert probable.shape == (batch_size, ) + assert probable.shape == (batch_size,) assert np.all(probable.d == np.array([3, 1, 1, 1])) def test_choose_probable_multi_dimensional(self): - z = np.array([[[1.0, 2.0, 3.0, 4.0], - [2.0, 3.0, 1.0, -1.0], - [-5.1, 11.2, 0.8, 0.7], - [-3.0, -2.1, -7.6, -5.4]], - [[0.9, 0.8, -7.1, 3.2], - [0.1, 0.2, 0.3, 0.4], - [-1.2, -2.7, -3.1, -4.3], - [1.0, 1.2, 1.3, -4.3]]]) + z = np.array( + [ + [[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 1.0, -1.0], [-5.1, 11.2, 0.8, 0.7], [-3.0, -2.1, -7.6, -5.4]], + [[0.9, 0.8, -7.1, 3.2], [0.1, 0.2, 0.3, 0.4], [-1.2, -2.7, -3.1, -4.3], [1.0, 1.2, 1.3, -4.3]], + ] + ) assert z.shape == (2, 4, 4) batch_size = z.shape[0] category_size = z.shape[1] @@ -103,7 +93,7 @@ def test_log_prob(self): probabilities = np.exp(z) / np.sum(np.exp(z), axis=1, keepdims=True) log_probabilities = np.log(probabilities) - indices = np.reshape(actions, newshape=(batch_size, )) + indices = np.reshape(actions, newshape=(batch_size,)) one_hot_action = self._to_one_hot_action(indices, action_num=action_num) expected = np.sum(log_probabilities * one_hot_action, axis=1, keepdims=True) @@ -125,7 +115,13 @@ def test_log_prob_multi_dimensional(self): probabilities = np.exp(z) / np.sum(np.exp(z), axis=len(z.shape) - 1, keepdims=True) log_probabilities = np.log(probabilities) - indices = np.reshape(actions, newshape=(batch_size, category_num, )) + indices = np.reshape( + actions, + newshape=( + batch_size, + category_num, + ), + ) one_hot_action = self._to_one_hot_action(indices, action_num=action_num) expected = np.sum(log_probabilities * one_hot_action, axis=len(z.shape) - 1, keepdims=True) @@ -178,8 +174,9 @@ def test_kl_divergence(self): assert actual.shape == (batch_size, 1) - expected = z_p_dist[0, 0] * np.log(z_p_dist[0, 0] / z_q_dist[0, 0]) + \ - z_p_dist[0, 1] * np.log(z_p_dist[0, 1] / z_q_dist[0, 1]) + expected = z_p_dist[0, 0] * np.log(z_p_dist[0, 0] / z_q_dist[0, 0]) + z_p_dist[0, 1] * np.log( + z_p_dist[0, 1] / z_q_dist[0, 1] + ) assert expected == pytest.approx(actual.d.flatten()[0], abs=1e-5) @@ -200,8 +197,9 @@ def test_kl_divergence_multi_dimensional(self): assert actual.shape == (batch_size, category_num, 1) for category in range(category_num): - expected = z_p_dist[0, category, 0] * np.log(z_p_dist[0, category, 0] / z_q_dist[0, category, 0]) + \ - z_p_dist[0, category, 1] * np.log(z_p_dist[0, category, 1] / z_q_dist[0, category, 1]) + expected = z_p_dist[0, category, 0] * np.log( + z_p_dist[0, category, 0] / z_q_dist[0, category, 0] + ) + z_p_dist[0, category, 1] * np.log(z_p_dist[0, category, 1] / z_q_dist[0, category, 1]) assert expected == pytest.approx(actual.d.flatten()[category], abs=1e-5) def _to_one_hot_action(self, a, action_num): diff --git a/tests/distributions/test_squashed_gaussian.py b/tests/distributions/test_squashed_gaussian.py index e74baafa..dd2fcafe 100644 --- a/tests/distributions/test_squashed_gaussian.py +++ b/tests/distributions/test_squashed_gaussian.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -81,12 +81,9 @@ def test_log_prob(self, x, mean, var): ln_var = np.array(np.log(var)).reshape((1, 1)) distribution = D.SquashedGaussian(mean=mean, ln_var=ln_var) ln_var = np.log(var) - gaussian_log_prob = -0.5 * \ - np.log(2.0 * np.pi) - 0.5 * ln_var - \ - (x - mean) ** 2 / (2.0 * var) + gaussian_log_prob = -0.5 * np.log(2.0 * np.pi) - 0.5 * ln_var - (x - mean) ** 2 / (2.0 * var) log_det_jacobian = np.log(1 - np.tanh(x) ** 2) - expected = np.sum(gaussian_log_prob - - log_det_jacobian, axis=-1, keepdims=True) + expected = np.sum(gaussian_log_prob - log_det_jacobian, axis=-1, keepdims=True) x_var = nn.Variable((1, 1)) x_var.d = np.tanh(x) @@ -110,9 +107,7 @@ def test_sample_and_compute_log_prob(self, mean, var): nn.forward_all([sample, actual]) x = np.arctanh(sample.data.data, dtype=np.float64) - gaussian_log_prob = -0.5 * \ - np.log(2.0 * np.pi) - 0.5 * ln_var - \ - (x - mean) ** 2 / (2.0 * var) + gaussian_log_prob = -0.5 * np.log(2.0 * np.pi) - 0.5 * ln_var - (x - mean) ** 2 / (2.0 * var) log_det_jacobian = np.log(1 - np.tanh(x) ** 2) expected = np.sum(gaussian_log_prob - log_det_jacobian, axis=-1, keepdims=True) @@ -147,8 +142,7 @@ def test_sample_multiple_and_compute_log_prob(self, mean, var): distribution = D.SquashedGaussian(mean=mu, ln_var=ln_var) num_samples = 10 - samples, log_probs = distribution.sample_multiple_and_compute_log_prob( - num_samples=num_samples) + samples, log_probs = distribution.sample_multiple_and_compute_log_prob(num_samples=num_samples) # FIXME: if you enable clear_no_need_grad seems to compute something different # Do NOT use forward_all and no_need_grad flag at same time # nnabla's bug? @@ -160,12 +154,9 @@ def test_sample_multiple_and_compute_log_prob(self, mean, var): # Check the first sample independently x = np.arctanh(samples.d[:, 0, :], dtype=np.float64) assert x.shape == (batch_size, output_dim) - gaussian_log_prob = -0.5 * np.log(2.0 * np.pi) - 0.5 * ln_var - \ - (x - mu) ** 2 / (2.0 * var) + gaussian_log_prob = -0.5 * np.log(2.0 * np.pi) - 0.5 * ln_var - (x - mu) ** 2 / (2.0 * var) log_det_jacobian = np.log(1 - np.tanh(x) ** 2) - expected = np.sum(gaussian_log_prob - log_det_jacobian, - axis=-1, - keepdims=True) + expected = np.sum(gaussian_log_prob - log_det_jacobian, axis=-1, keepdims=True) actual = log_probs.d assert expected.shape == (batch_size, 1) assert np.allclose(expected, actual[:, 0, :], atol=1e-3, rtol=1e-2) @@ -175,12 +166,9 @@ def test_sample_multiple_and_compute_log_prob(self, mean, var): ln_var = np.reshape(ln_var, newshape=(batch_size, 1, output_dim)) x = np.arctanh(samples.d, dtype=np.float64) - gaussian_log_prob = -0.5 * np.log(2.0 * np.pi) - 0.5 * ln_var - \ - (x - mu) ** 2 / (2.0 * var) + gaussian_log_prob = -0.5 * np.log(2.0 * np.pi) - 0.5 * ln_var - (x - mu) ** 2 / (2.0 * var) log_det_jacobian = np.log(1 - np.tanh(x) ** 2) - expected = np.sum(gaussian_log_prob - log_det_jacobian, - axis=-1, - keepdims=True) + expected = np.sum(gaussian_log_prob - log_det_jacobian, axis=-1, keepdims=True) actual = log_probs.d assert np.allclose(expected, actual, atol=1e-3, rtol=1e-2) @@ -190,13 +178,12 @@ def test_sample_multiple_and_compute_log_prob_shape(self): input_shape = (batch_size, output_dim) mean = np.zeros(shape=input_shape) - sigma = np.ones(shape=input_shape) * 5. - ln_var = np.log(sigma) * 2. + sigma = np.ones(shape=input_shape) * 5.0 + ln_var = np.log(sigma) * 2.0 distribution = D.SquashedGaussian(mean=mean, ln_var=ln_var) num_samples = 10 - samples, log_probs = distribution.sample_multiple_and_compute_log_prob( - num_samples=num_samples) + samples, log_probs = distribution.sample_multiple_and_compute_log_prob(num_samples=num_samples) nn.forward_all([samples, log_probs]) assert samples.shape == (batch_size, num_samples, output_dim) @@ -214,17 +201,13 @@ def test_log_prob_internal(self, x, mean, var): distribution = D.SquashedGaussian(mean=dummy_mean, ln_var=dummy_ln_var) ln_var = np.log(var) - gaussian_log_prob = -0.5 * \ - np.log(2.0 * np.pi) - 0.5 * ln_var - \ - (x - mean) ** 2 / (2.0 * var) + gaussian_log_prob = -0.5 * np.log(2.0 * np.pi) - 0.5 * ln_var - (x - mean) ** 2 / (2.0 * var) log_det_jacobian = np.log(1 - np.tanh(x) ** 2) - expected = np.sum(gaussian_log_prob - - log_det_jacobian, axis=-1, keepdims=True) + expected = np.sum(gaussian_log_prob - log_det_jacobian, axis=-1, keepdims=True) x_var = nn.Variable((1, 1)) x_var.d = x - actual = distribution._log_prob_internal( - x_var, mean=mean, var=var, ln_var=ln_var) + actual = distribution._log_prob_internal(x_var, mean=mean, var=var, ln_var=ln_var) actual.forward(clear_no_need_grad=True) actual = actual.d assert np.isclose(expected, actual) diff --git a/tests/environment_explorers/test_epsilon_greedy.py b/tests/environment_explorers/test_epsilon_greedy.py index 14baea59..7a0d4b65 100644 --- a/tests/environment_explorers/test_epsilon_greedy.py +++ b/tests/environment_explorers/test_epsilon_greedy.py @@ -18,9 +18,11 @@ import numpy as np import pytest -from nnabla_rl.environment_explorers.epsilon_greedy_explorer import (LinearDecayEpsilonGreedyExplorer, - LinearDecayEpsilonGreedyExplorerConfig, - epsilon_greedy_action_selection) +from nnabla_rl.environment_explorers.epsilon_greedy_explorer import ( + LinearDecayEpsilonGreedyExplorer, + LinearDecayEpsilonGreedyExplorerConfig, + epsilon_greedy_action_selection, +) class TestEpsilonGreedyActionStrategy(object): @@ -28,12 +30,10 @@ def test_epsilon_greedy_action_selection_always_greedy(self): greedy_selector_mock = mock.MagicMock(return_value=(1, {})) random_selector_mock = mock.MagicMock(return_value=(2, {})) - state = 'test' + state = "test" (should_be_greedy, _), is_greedy = epsilon_greedy_action_selection( - state, - greedy_selector_mock, - random_selector_mock, - epsilon=0.0) + state, greedy_selector_mock, random_selector_mock, epsilon=0.0 + ) assert should_be_greedy == 1 assert is_greedy is True greedy_selector_mock.assert_called_once() @@ -42,12 +42,10 @@ def test_epsilon_greedy_action_selection_always_random(self): greedy_selector_mock = mock.MagicMock(return_value=(1, {})) random_selector_mock = mock.MagicMock(return_value=(2, {})) - state = 'test' + state = "test" (should_be_random, _), is_greedy = epsilon_greedy_action_selection( - state, - greedy_selector_mock, - random_selector_mock, - epsilon=1.0) + state, greedy_selector_mock, random_selector_mock, epsilon=1.0 + ) assert should_be_random == 2 assert is_greedy is False random_selector_mock.assert_called_once() @@ -56,12 +54,10 @@ def test_epsilon_greedy_action_selection(self): greedy_selector_mock = mock.MagicMock(return_value=(1, {})) random_selector_mock = mock.MagicMock(return_value=(2, {})) - state = 'test' + state = "test" (action, _), is_greedy = epsilon_greedy_action_selection( - state, - greedy_selector_mock, - random_selector_mock, - epsilon=0.5) + state, greedy_selector_mock, random_selector_mock, epsilon=0.5 + ) if is_greedy: assert action == 1 greedy_selector_mock.assert_called_once() @@ -75,17 +71,15 @@ def test_compute_epsilon(self): max_explore_steps = 100 greedy_selector_mock = mock.MagicMock(return_value=(1, {})) random_selector_mock = mock.MagicMock(return_value=(2, {})) - config = LinearDecayEpsilonGreedyExplorerConfig(initial_epsilon=initial_epsilon, - final_epsilon=final_epsilon, - max_explore_steps=max_explore_steps) - explorer = LinearDecayEpsilonGreedyExplorer(greedy_selector_mock, - random_selector_mock, - env_info=None, - config=config) + config = LinearDecayEpsilonGreedyExplorerConfig( + initial_epsilon=initial_epsilon, final_epsilon=final_epsilon, max_explore_steps=max_explore_steps + ) + explorer = LinearDecayEpsilonGreedyExplorer( + greedy_selector_mock, random_selector_mock, env_info=None, config=config + ) def expected_epsilon(step): - epsilon = initial_epsilon - \ - (initial_epsilon - final_epsilon) / max_explore_steps * step + epsilon = initial_epsilon - (initial_epsilon - final_epsilon) / max_explore_steps * step return max(epsilon, final_epsilon) assert np.isclose(explorer._compute_epsilon(1), expected_epsilon(1)) @@ -94,20 +88,23 @@ def expected_epsilon(step): assert np.isclose(explorer._compute_epsilon(100), expected_epsilon(100)) assert explorer._compute_epsilon(200) == final_epsilon - def test_return_explorer_info(self, ): + def test_return_explorer_info( + self, + ): initial_epsilon = 1.0 final_epsilon = 0.0 max_explore_steps = 100 greedy_selector_mock = mock.MagicMock(return_value=(1, {})) random_selector_mock = mock.MagicMock(return_value=(2, {})) - config = LinearDecayEpsilonGreedyExplorerConfig(initial_epsilon=initial_epsilon, - final_epsilon=final_epsilon, - max_explore_steps=max_explore_steps, - append_explorer_info=True) - explorer = LinearDecayEpsilonGreedyExplorer(greedy_selector_mock, - random_selector_mock, - env_info=None, - config=config) + config = LinearDecayEpsilonGreedyExplorerConfig( + initial_epsilon=initial_epsilon, + final_epsilon=final_epsilon, + max_explore_steps=max_explore_steps, + append_explorer_info=True, + ) + explorer = LinearDecayEpsilonGreedyExplorer( + greedy_selector_mock, random_selector_mock, env_info=None, config=config + ) action, action_info = explorer.action(50, np.random.rand(5)) @@ -124,5 +121,5 @@ def test_return_explorer_info(self, ): assert action_info["explore_rate"] == 0.5 -if __name__ == '__main__': +if __name__ == "__main__": pytest.main() diff --git a/tests/environment_explorers/test_gaussian.py b/tests/environment_explorers/test_gaussian.py index 217041a8..6ef43cf7 100644 --- a/tests/environment_explorers/test_gaussian.py +++ b/tests/environment_explorers/test_gaussian.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,51 +20,38 @@ class TestRandomGaussianActionStrategy(object): - @pytest.mark.parametrize('clip_low', np.arange(start=-1.0, stop=0.0, step=0.25)) + @pytest.mark.parametrize("clip_low", np.arange(start=-1.0, stop=0.0, step=0.25)) @pytest.mark.parametrize("clip_high", np.arange(start=0.0, stop=1.0, step=0.25)) @pytest.mark.parametrize("sigma", np.arange(start=0.01, stop=5.0, step=1.0)) def test_random_gaussian_action_selection(self, clip_low, clip_high, sigma): - def policy_action_selector(state, *, begin_of_episode=False): - return np.zeros(shape=state.shape), {'test': 'success'} - config = GaussianExplorerConfig( - action_clip_low=clip_low, - action_clip_high=clip_high, - sigma=sigma - ) - explorer = GaussianExplorer( - env_info=None, - policy_action_selector=policy_action_selector, - config=config - ) + def policy_action_selector(state, *, begin_of_episode=False): + return np.zeros(shape=state.shape), {"test": "success"} + + config = GaussianExplorerConfig(action_clip_low=clip_low, action_clip_high=clip_high, sigma=sigma) + explorer = GaussianExplorer(env_info=None, policy_action_selector=policy_action_selector, config=config) steps = 1 state = np.empty(shape=(1, 4)) action, info = explorer.action(steps, state) assert np.all(clip_low <= action) and np.all(action <= clip_high) - assert info['test'] == 'success' + assert info["test"] == "success" @pytest.mark.parametrize("sigma", np.arange(start=0.01, stop=5.0, step=1.0)) def test_random_gaussian_without_clipping(self, sigma): - def policy_action_selector(state, *, begin_of_episode=False): - return np.zeros(shape=state.shape), {'test': 'success'} + def policy_action_selector(state, *, begin_of_episode=False): + return np.zeros(shape=state.shape), {"test": "success"} - config = GaussianExplorerConfig( - sigma=sigma - ) - explorer = GaussianExplorer( - env_info=None, - policy_action_selector=policy_action_selector, - config=config - ) + config = GaussianExplorerConfig(sigma=sigma) + explorer = GaussianExplorer(env_info=None, policy_action_selector=policy_action_selector, config=config) steps = 1 state = np.empty(shape=(1, 4)) action, info = explorer.action(steps, state) assert not np.allclose(action, np.zeros(action.shape)) - assert info['test'] == 'success' + assert info["test"] == "success" -if __name__ == '__main__': +if __name__ == "__main__": pytest.main() diff --git a/tests/environments/test_amp_env.py b/tests/environments/test_amp_env.py index e32dbdf0..b035e647 100644 --- a/tests/environments/test_amp_env.py +++ b/tests/environments/test_amp_env.py @@ -39,7 +39,7 @@ def expert_experience(self, state, reward, done, info): return self._dummy_expert_experience -class TestAMPEnv(): +class TestAMPEnv: def test_step(self): state = np.array([1.1]) reward = 5.0 @@ -55,8 +55,14 @@ def test_step(self): expert_reward = 0.5 expert_non_terminal = 1.0 expert_info = {} - dummy_expert_experience = (expert_state, expert_action, expert_reward, - expert_non_terminal, expert_next_state, expert_info) + dummy_expert_experience = ( + expert_state, + expert_action, + expert_reward, + expert_non_terminal, + expert_next_state, + expert_info, + ) env = DummyAMPEnv(dummy_transition, dummy_task_result, dummy_is_valid_episode, dummy_expert_experience) env.reset() @@ -101,7 +107,7 @@ def expert_experience(self, state, reward, done, info): return self._dummy_expert_experience -class TestAMPGoalEnv(): +class TestAMPGoalEnv: def test_step(self): state = np.array([1.1]) reward = 5.0 @@ -117,8 +123,14 @@ def test_step(self): expert_reward = 0.5 expert_non_terminal = 1.0 expert_info = {} - dummy_expert_experience = (expert_state, expert_action, expert_reward, - expert_non_terminal, expert_next_state, expert_info) + dummy_expert_experience = ( + expert_state, + expert_action, + expert_reward, + expert_non_terminal, + expert_next_state, + expert_info, + ) env = DummyAMPEnv(dummy_transition, dummy_task_result, dummy_is_valid_episode, dummy_expert_experience) env.reset() diff --git a/tests/environments/test_env_info.py b/tests/environments/test_env_info.py index 71477754..4506a03d 100644 --- a/tests/environments/test_env_info.py +++ b/tests/environments/test_env_info.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,13 +22,13 @@ class TestEnvInfo(object): - @pytest.mark.parametrize("max_episode_steps", [None, 100, 10000, float('inf')]) + @pytest.mark.parametrize("max_episode_steps", [None, 100, 10000, float("inf")]) def test_spec_max_episode_steps(self, max_episode_steps): dummy_env = E.DummyContinuous(max_episode_steps=max_episode_steps) env_info = EnvironmentInfo.from_env(dummy_env) if max_episode_steps is None: - assert env_info.max_episode_steps == float('inf') + assert env_info.max_episode_steps == float("inf") else: assert env_info.max_episode_steps == max_episode_steps @@ -102,7 +102,7 @@ def test_action_shape_discrete(self): dummy_env = E.DummyDiscreteImg() env_info = EnvironmentInfo.from_env(dummy_env) - assert env_info.action_shape == (1, ) + assert env_info.action_shape == (1,) def test_action_shape_tuple_continuous(self): dummy_env = E.DummyTupleContinuous() @@ -114,14 +114,15 @@ def test_action_shape_tuple_discrete(self): dummy_env = E.DummyTupleDiscrete() env_info = EnvironmentInfo.from_env(dummy_env) - assert env_info.action_shape == ((1, ), (1, )) + assert env_info.action_shape == ((1,), (1,)) def test_action_shape_tuple_mixed(self): dummy_env = E.DummyTupleMixed() env_info = EnvironmentInfo.from_env(dummy_env) - assert env_info.action_shape == tuple(space.shape if isinstance( - space, gym.spaces.Box) else (1, ) for space in dummy_env.action_space) + assert env_info.action_shape == tuple( + space.shape if isinstance(space, gym.spaces.Box) else (1,) for space in dummy_env.action_space + ) def test_action_dim_continuous(self): dummy_env = E.DummyContinuous() @@ -151,14 +152,15 @@ def test_action_dim_tuple_mixed(self): dummy_env = E.DummyTupleMixed() env_info = EnvironmentInfo.from_env(dummy_env) - assert env_info.action_dim == tuple(np.prod(space.shape) if isinstance( - space, gym.spaces.Box) else space.n for space in env_info.action_space) + assert env_info.action_dim == tuple( + np.prod(space.shape) if isinstance(space, gym.spaces.Box) else space.n for space in env_info.action_space + ) def test_state_shape_discrete(self): dummy_env = E.DummyDiscrete() env_info = EnvironmentInfo.from_env(dummy_env) - assert env_info.state_shape == (1, ) + assert env_info.state_shape == (1,) def test_state_shape_continuous(self): dummy_env = E.DummyContinuous() @@ -176,7 +178,7 @@ def test_state_shape_tuple_discrete(self): dummy_env = E.DummyTupleDiscrete() env_info = EnvironmentInfo.from_env(dummy_env) - assert env_info.state_shape == ((1, ), (1, )) + assert env_info.state_shape == ((1,), (1,)) def test_state_shape_tuple_continuous(self): dummy_env = E.DummyTupleContinuous() @@ -188,8 +190,9 @@ def test_state_shape_tuple_mixed(self): dummy_env = E.DummyTupleMixed() env_info = EnvironmentInfo.from_env(dummy_env) - assert env_info.state_shape == tuple(space.shape if isinstance( - space, gym.spaces.Box) else (1, ) for space in dummy_env.observation_space) + assert env_info.state_shape == tuple( + space.shape if isinstance(space, gym.spaces.Box) else (1,) for space in dummy_env.observation_space + ) def test_state_dim_discrete(self): dummy_env = E.DummyDiscrete() @@ -225,8 +228,10 @@ def test_state_dim_tuple_mixed(self): dummy_env = E.DummyTupleMixed() env_info = EnvironmentInfo.from_env(dummy_env) - assert env_info.state_dim == tuple(np.prod(space.shape) if isinstance( - space, gym.spaces.Box) else space.n for space in env_info.observation_space) + assert env_info.state_dim == tuple( + np.prod(space.shape) if isinstance(space, gym.spaces.Box) else space.n + for space in env_info.observation_space + ) if __name__ == "__main__": diff --git a/tests/environments/test_gym_utils.py b/tests/environments/test_gym_utils.py index 64370581..7daf663e 100644 --- a/tests/environments/test_gym_utils.py +++ b/tests/environments/test_gym_utils.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,35 +17,41 @@ import pytest import nnabla_rl.environments as E -from nnabla_rl.environments.gym_utils import (extract_max_episode_steps, get_space_dim, get_space_shape, - is_same_space_type) +from nnabla_rl.environments.gym_utils import ( + extract_max_episode_steps, + get_space_dim, + get_space_shape, + is_same_space_type, +) -class TestGymUtils(): +class TestGymUtils: def test_get_tuple_space_shape(self): - tuple_space = gym.spaces.Tuple((gym.spaces.Box(low=0.0, high=1.0, shape=(2, )), - gym.spaces.Box(low=0.0, high=1.0, shape=(3, )))) + tuple_space = gym.spaces.Tuple( + (gym.spaces.Box(low=0.0, high=1.0, shape=(2,)), gym.spaces.Box(low=0.0, high=1.0, shape=(3,))) + ) with pytest.raises(ValueError): get_space_shape(tuple_space) def test_get_box_space_shape(self): - shape = (5, ) + shape = (5,) box_space_shape = gym.spaces.Box(low=0.0, high=1.0, shape=shape) actual_shape = get_space_shape(box_space_shape) assert actual_shape == shape def test_get_discrete_space_shape(self): - shape = (1, ) + shape = (1,) discrete_space_shape = gym.spaces.Discrete(4) actual_shape = get_space_shape(discrete_space_shape) assert actual_shape == shape def test_get_tuple_space_dim(self): - tuple_space = gym.spaces.Tuple((gym.spaces.Box(low=0.0, high=1.0, shape=(2, )), - gym.spaces.Box(low=0.0, high=1.0, shape=(3, )))) + tuple_space = gym.spaces.Tuple( + (gym.spaces.Box(low=0.0, high=1.0, shape=(2,)), gym.spaces.Box(low=0.0, high=1.0, shape=(3,))) + ) with pytest.raises(ValueError): get_space_dim(tuple_space) @@ -62,7 +68,7 @@ def test_get_discrete_space_dim(self): discrete_space_shape = gym.spaces.Discrete(dim) actual_dim = get_space_shape(discrete_space_shape) - assert actual_dim == (1, ) + assert actual_dim == (1,) def test_extract_None_max_episode_steps(self): env = E.DummyContinuous(max_episode_steps=None) @@ -78,7 +84,7 @@ def test_extract_max_episode_steps(self): assert actual_max_episode_steps == max_episode_steps def test_is_same_space_type_box(self): - box_space = gym.spaces.Box(low=0.0, high=1.0, shape=(4, )) + box_space = gym.spaces.Box(low=0.0, high=1.0, shape=(4,)) tuple_box_space = gym.spaces.Tuple((box_space, box_space)) assert is_same_space_type(tuple_box_space, gym.spaces.Box) @@ -90,7 +96,7 @@ def test_is_same_space_type_discrete(self): assert is_same_space_type(tuple_discrete_space, gym.spaces.Discrete) def test_is_same_space_type_mixed(self): - box_space = gym.spaces.Box(low=0.0, high=1.0, shape=(4, )) + box_space = gym.spaces.Box(low=0.0, high=1.0, shape=(4,)) discrete_space = gym.spaces.Discrete(4) tuple_mixed_space = gym.spaces.Tuple((box_space, discrete_space)) diff --git a/tests/environments/wrappers/test_atari.py b/tests/environments/wrappers/test_atari.py index a8606889..7a930acd 100644 --- a/tests/environments/wrappers/test_atari.py +++ b/tests/environments/wrappers/test_atari.py @@ -1,5 +1,4 @@ - -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/environments/wrappers/test_common.py b/tests/environments/wrappers/test_common.py index 5cd0fc23..9b2f202c 100644 --- a/tests/environments/wrappers/test_common.py +++ b/tests/environments/wrappers/test_common.py @@ -24,7 +24,7 @@ class DummyNestedTupleStateEnv(gym.Env): def __init__(self, observation_space) -> None: super().__init__() - self.action_space = gym.spaces.Box(low=0.0, high=1.0, shape=(4, )) + self.action_space = gym.spaces.Box(low=0.0, high=1.0, shape=(4,)) self.observation_space = observation_space def reset(self): @@ -132,11 +132,11 @@ def test_timestep_as_state_env_discrete(self): assert next_state[1] == 2 def test_flatten_nested_tuple_state(self): - box_space_list = [gym.spaces.Box(low=0.0, high=1.0, shape=(i, )) for i in range(5)] + box_space_list = [gym.spaces.Box(low=0.0, high=1.0, shape=(i,)) for i in range(5)] observation_space = gym.spaces.Tuple( - [gym.spaces.Tuple(box_space_list[0:3]), - gym.spaces.Tuple(box_space_list[3:])]) + [gym.spaces.Tuple(box_space_list[0:3]), gym.spaces.Tuple(box_space_list[3:])] + ) env = DummyNestedTupleStateEnv(observation_space) env = FlattenNestedTupleStateWrapper(env) diff --git a/tests/hooks/test_computational_graph_hook.py b/tests/hooks/test_computational_graph_hook.py index d68a23f6..234a0c54 100644 --- a/tests/hooks/test_computational_graph_hook.py +++ b/tests/hooks/test_computational_graph_hook.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ from nnabla_rl.hooks import TrainingGraphHook -class TestComputationalGraphHook(): +class TestComputationalGraphHook: def setup_method(self, method): nn.clear_parameters() diff --git a/tests/hooks/test_evaluation_hook.py b/tests/hooks/test_evaluation_hook.py index e22b925c..8a6e32bb 100644 --- a/tests/hooks/test_evaluation_hook.py +++ b/tests/hooks/test_evaluation_hook.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ from nnabla_rl.hooks import EvaluationHook -class TestEvaluationHook(): +class TestEvaluationHook: def test_call(self): dummy_env = E.DummyContinuous() @@ -31,9 +31,7 @@ def test_call(self): mock_writer = mock.MagicMock() - hook = EvaluationHook(dummy_env, - evaluator=mock_evaluator, - writer=mock_writer) + hook = EvaluationHook(dummy_env, evaluator=mock_evaluator, writer=mock_writer) hook(dummy_algorithm) diff --git a/tests/hooks/test_iteration_num_hook.py b/tests/hooks/test_iteration_num_hook.py index 11da3cbe..cf86ec04 100644 --- a/tests/hooks/test_iteration_num_hook.py +++ b/tests/hooks/test_iteration_num_hook.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,13 +21,13 @@ from nnabla_rl.logger import logger -class TestIterationStateHook(): +class TestIterationStateHook: def test_call(self): dummy_algorithm = mock.MagicMock() hook = IterationNumHook(timing=1) - with mock.patch.object(logger, 'info') as mock_logger: + with mock.patch.object(logger, "info") as mock_logger: for i in range(10): dummy_algorithm.iteration_num = i hook(dummy_algorithm) diff --git a/tests/hooks/test_iteration_state_hook.py b/tests/hooks/test_iteration_state_hook.py index cd4623ff..7fd87c3a 100644 --- a/tests/hooks/test_iteration_state_hook.py +++ b/tests/hooks/test_iteration_state_hook.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,14 +19,14 @@ from nnabla_rl.writer import Writer -class TestIterationStateHook(): +class TestIterationStateHook: def test_call(self): dummy_algorithm = mock.MagicMock() test_latest_iteration_state = {} - test_latest_iteration_state['scalar'] = {} - test_latest_iteration_state['histogram'] = {} - test_latest_iteration_state['image'] = {} + test_latest_iteration_state["scalar"] = {} + test_latest_iteration_state["histogram"] = {} + test_latest_iteration_state["image"] = {} dummy_algorithm.iteration_num = 1 dummy_algorithm.latest_iteration_state = test_latest_iteration_state @@ -40,9 +40,10 @@ def test_call(self): hook(dummy_algorithm) - writer.write_scalar.assert_called_once_with(dummy_algorithm.iteration_num, - test_latest_iteration_state['scalar']) - writer.write_histogram.assert_called_once_with(dummy_algorithm.iteration_num, - test_latest_iteration_state['histogram']) - writer.write_image.assert_called_once_with(dummy_algorithm.iteration_num, - test_latest_iteration_state['image']) + writer.write_scalar.assert_called_once_with( + dummy_algorithm.iteration_num, test_latest_iteration_state["scalar"] + ) + writer.write_histogram.assert_called_once_with( + dummy_algorithm.iteration_num, test_latest_iteration_state["histogram"] + ) + writer.write_image.assert_called_once_with(dummy_algorithm.iteration_num, test_latest_iteration_state["image"]) diff --git a/tests/model_trainers/test_model_trainer.py b/tests/model_trainers/test_model_trainer.py index ce13dcca..bca99394 100644 --- a/tests/model_trainers/test_model_trainer.py +++ b/tests/model_trainers/test_model_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,20 +31,20 @@ def __call__(self, s): class EmptyRnnModel(EmptyModel): def __init__(self, scope_name: str): super().__init__(scope_name) - self._internal_state_shape = (10, ) + self._internal_state_shape = (10,) self._fake_internal_state = None def is_recurrent(self) -> bool: return True def internal_state_shapes(self): - return {'fake': self._internal_state_shape} + return {"fake": self._internal_state_shape} def set_internal_states(self, states): - self._fake_internal_state = states['fake'] + self._fake_internal_state = states["fake"] def get_internal_states(self): - return {'fake': self._fake_internal_state} + return {"fake": self._fake_internal_state} def __call__(self, s): self._fake_internal_state = self._fake_internal_state * 2 @@ -56,16 +56,16 @@ def setup_method(self, method): nn.clear_parameters() def test_assert_no_duplicate_model_without_duplicates(self): - models = [EmptyModel('model1'), EmptyModel('model2'), EmptyModel('model3')] + models = [EmptyModel("model1"), EmptyModel("model2"), EmptyModel("model3")] ModelTrainer._assert_no_duplicate_model(models) def test_assert_no_duplicate_model_with_duplicates(self): - duplicate_models = [EmptyModel('model1'), EmptyModel('model2'), EmptyModel('model1'), EmptyModel('model3')] + duplicate_models = [EmptyModel("model1"), EmptyModel("model2"), EmptyModel("model1"), EmptyModel("model3")] with pytest.raises(AssertionError): ModelTrainer._assert_no_duplicate_model(duplicate_models) def test_rnn_support_with_reset_on_terminal(self): - scope_name = 'test' + scope_name = "test" test_model = EmptyRnnModel(scope_name) batch_size = 3 num_steps = 3 @@ -74,10 +74,12 @@ def test_rnn_support_with_reset_on_terminal(self): for _ in range(num_steps): train_rnn_states = self._create_fake_internal_states(batch_size, test_model) non_terminal = self._create_fake_non_terminals(batch_size) - training_variables = TrainingVariables(batch_size, - non_terminal=non_terminal, - rnn_states=train_rnn_states, - next_step_variables=training_variables) + training_variables = TrainingVariables( + batch_size, + non_terminal=non_terminal, + rnn_states=train_rnn_states, + next_step_variables=training_variables, + ) prev_rnn_states = {} config = TrainerConfig(unroll_steps=num_steps, reset_on_terminal=True) @@ -95,8 +97,9 @@ def test_rnn_support_with_reset_on_terminal(self): prev_non_terminal = variables.prev_step_variables.non_terminal expected_states = {} for key in prev_states.keys(): - expected_state = prev_states[key] * prev_non_terminal + \ - (1 - prev_non_terminal) * train_states[key] + expected_state = ( + prev_states[key] * prev_non_terminal + (1 - prev_non_terminal) * train_states[key] + ) expected_states[key] = expected_state self._assert_have_same_states(actual_states, expected_states) @@ -108,7 +111,7 @@ def test_rnn_support_with_reset_on_terminal(self): self._assert_have_same_states(actual_states, expected_states) def test_rnn_support_without_reset_on_terminal(self): - scope_name = 'test' + scope_name = "test" test_model = EmptyRnnModel(scope_name) batch_size = 3 num_steps = 3 @@ -117,10 +120,12 @@ def test_rnn_support_without_reset_on_terminal(self): for _ in range(num_steps): train_rnn_states = self._create_fake_internal_states(batch_size, test_model) non_terminal = self._create_fake_non_terminals(batch_size) - training_variables = TrainingVariables(batch_size, - non_terminal=non_terminal, - rnn_states=train_rnn_states, - next_step_variables=training_variables) + training_variables = TrainingVariables( + batch_size, + non_terminal=non_terminal, + rnn_states=train_rnn_states, + next_step_variables=training_variables, + ) prev_rnn_states = {} config = TrainerConfig(unroll_steps=num_steps, reset_on_terminal=False) @@ -148,7 +153,7 @@ def test_rnn_support_without_reset_on_terminal(self): self._assert_have_same_states(actual_states, expected_states) def test_rnn_support_with_non_rnn_model(self): - scope_name = 'test' + scope_name = "test" test_model = EmptyModel(scope_name) batch_size = 3 num_steps = 3 @@ -156,9 +161,9 @@ def test_rnn_support_with_non_rnn_model(self): training_variables = None for _ in range(num_steps): non_terminal = self._create_fake_non_terminals(batch_size) - training_variables = TrainingVariables(batch_size, - non_terminal=non_terminal, - next_step_variables=training_variables) + training_variables = TrainingVariables( + batch_size, non_terminal=non_terminal, next_step_variables=training_variables + ) train_rnn_states = {} prev_rnn_states = {} config = TrainerConfig(num_steps) diff --git a/tests/model_trainers/test_policy_trainers.py b/tests/model_trainers/test_policy_trainers.py index 7cbb3b08..a05dc4e4 100644 --- a/tests/model_trainers/test_policy_trainers.py +++ b/tests/model_trainers/test_policy_trainers.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,9 +29,11 @@ from nnabla_rl.model_trainers.model_trainer import LossIntegration from nnabla_rl.model_trainers.policy.dpg_policy_trainer import DPGPolicyTrainer from nnabla_rl.model_trainers.policy.soft_policy_trainer import AdjustableTemperature, SoftPolicyTrainer -from nnabla_rl.model_trainers.policy.trpo_policy_trainer import (_concat_network_params_in_ndarray, - _hessian_vector_product, - _update_network_params_by_flat_params) +from nnabla_rl.model_trainers.policy.trpo_policy_trainer import ( + _concat_network_params_in_ndarray, + _hessian_vector_product, + _update_network_params_by_flat_params, +) from nnabla_rl.models import TD3Policy, TD3QFunction from nnabla_rl.models.mujoco.policies import SACPolicy from nnabla_rl.models.mujoco.q_functions import SACQFunction @@ -43,20 +45,20 @@ class DeterministicRnnPolicy(DeterministicPolicy): def __init__(self, scope_name: str): super().__init__(scope_name) - self._internal_state_shape = (10, ) + self._internal_state_shape = (10,) self._fake_internal_state = None def is_recurrent(self) -> bool: return True def internal_state_shapes(self): - return {'fake': self._internal_state_shape} + return {"fake": self._internal_state_shape} def set_internal_states(self, states): - self._fake_internal_state = states['fake'] + self._fake_internal_state = states["fake"] def get_internal_states(self): - return {'fake': self._fake_internal_state} + return {"fake": self._fake_internal_state} def pi(self, s): self._fake_internal_state = self._fake_internal_state * 2 @@ -66,25 +68,26 @@ def pi(self, s): class StochasticRnnPolicy(StochasticPolicy): def __init__(self, scope_name: str): super().__init__(scope_name) - self._internal_state_shape = (10, ) + self._internal_state_shape = (10,) self._fake_internal_state = None def is_recurrent(self) -> bool: return True def internal_state_shapes(self): - return {'fake': self._internal_state_shape} + return {"fake": self._internal_state_shape} def set_internal_states(self, states): - self._fake_internal_state = states['fake'] + self._fake_internal_state = states["fake"] def get_internal_states(self): - return {'fake': self._fake_internal_state} + return {"fake": self._fake_internal_state} def pi(self, s): self._fake_internal_state = self._fake_internal_state * 2 - return Gaussian(mean=nn.Variable.from_numpy_array(np.zeros(s.shape)), - ln_var=nn.Variable.from_numpy_array(np.ones(s.shape))) + return Gaussian( + mean=nn.Variable.from_numpy_array(np.zeros(s.shape)), ln_var=nn.Variable.from_numpy_array(np.ones(s.shape)) + ) class TrainerTest(metaclass=ABCMeta): @@ -97,20 +100,19 @@ class TestBEARPolicyTrainer(TrainerTest): def test_compute_gaussian_mmd(self): def gaussian_kernel(x): return x**2 - samples1 = np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], - [[2, 2, 2], [2, 2, 2], [3, 3, 3]]], dtype=np.float32) - samples2 = np.array([[[0, 0, 0], [1, 1, 1]], - [[1, 2, 3], [1, 1, 1]]], dtype=np.float32) + + samples1 = np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[2, 2, 2], [2, 2, 2], [3, 3, 3]]], dtype=np.float32) + samples2 = np.array([[[0, 0, 0], [1, 1, 1]], [[1, 2, 3], [1, 1, 1]]], dtype=np.float32) samples1_var = nn.Variable(samples1.shape) samples1_var.d = samples1 samples2_var = nn.Variable(samples2.shape) samples2_var.d = samples2 actual_mmd = MT.policy_trainers.bear_policy_trainer._compute_gaussian_mmd( - samples1=samples1_var, samples2=samples2_var, sigma=20.0) + samples1=samples1_var, samples2=samples2_var, sigma=20.0 + ) actual_mmd.forward() - expected_mmd = self._compute_mmd( - samples1, samples2, sigma=20.0, kernel=gaussian_kernel) + expected_mmd = self._compute_mmd(samples1, samples2, sigma=20.0, kernel=gaussian_kernel) assert actual_mmd.shape == (samples1.shape[0], 1) assert np.all(np.isclose(actual_mmd.d, expected_mmd)) @@ -118,20 +120,19 @@ def gaussian_kernel(x): def test_compute_laplacian_mmd(self): def laplacian_kernel(x): return np.abs(x) - samples1 = np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], - [[2, 2, 2], [2, 2, 2], [3, 3, 3]]], dtype=np.float32) - samples2 = np.array([[[0, 0, 0], [1, 1, 1]], - [[1, 2, 3], [1, 1, 1]]], dtype=np.float32) + + samples1 = np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[2, 2, 2], [2, 2, 2], [3, 3, 3]]], dtype=np.float32) + samples2 = np.array([[[0, 0, 0], [1, 1, 1]], [[1, 2, 3], [1, 1, 1]]], dtype=np.float32) samples1_var = nn.Variable(samples1.shape) samples1_var.d = samples1 samples2_var = nn.Variable(samples2.shape) samples2_var.d = samples2 actual_mmd = MT.policy_trainers.bear_policy_trainer._compute_laplacian_mmd( - samples1=samples1_var, samples2=samples2_var, sigma=20.0) + samples1=samples1_var, samples2=samples2_var, sigma=20.0 + ) actual_mmd.forward() - expected_mmd = self._compute_mmd( - samples1, samples2, sigma=20.0, kernel=laplacian_kernel) + expected_mmd = self._compute_mmd(samples1, samples2, sigma=20.0, kernel=laplacian_kernel) assert actual_mmd.shape == (samples1.shape[0], 1) assert np.all(np.isclose(actual_mmd.d, expected_mmd)) @@ -142,9 +143,7 @@ def _compute_mmd(self, samples1, samples2, sigma, kernel): diff_yy = self._compute_kernel_sum(samples2, samples2, sigma, kernel) n = samples1.shape[1] m = samples2.shape[1] - mmd = (diff_xx / (n*n) - - 2.0 * diff_xy / (n*m) + - diff_yy / (m*m)) + mmd = diff_xx / (n * n) - 2.0 * diff_xy / (n * m) + diff_yy / (m * m) mmd = np.sqrt(mmd + 1e-6) return mmd @@ -157,8 +156,8 @@ def _compute_kernel_sum(self, a, b, sigma, kernel): diff = 0.0 for k in range(a.shape[2]): # print(f'samples[{i}] - samples[{j}]={samples1[i]-samples1[j]}') - diff += kernel(a[index][i][k]-b[index][j][k]) - kernel_sum += np.exp(-diff/(2.0*sigma)) + diff += kernel(a[index][i][k] - b[index][j][k]) + kernel_sum += np.exp(-diff / (2.0 * sigma)) sums.append(kernel_sum) return np.reshape(np.array(sums), newshape=(len(sums), 1)) @@ -166,11 +165,11 @@ def _compute_kernel_sum(self, a, b, sigma, kernel): class TestComputeHessianVectorProduct(TrainerTest): def test_compute_hessian_vector_product_by_hand(self): state = nn.Variable((1, 2)) - output = NPF.affine(state, 1, w_init=NI.ConstantInitializer(value=1.), with_bias=False) + output = NPF.affine(state, 1, w_init=NI.ConstantInitializer(value=1.0), with_bias=False) loss = NF.sum(output**2) grads = nn.grad([loss], nn.get_parameters().values()) - flat_grads = grads[0].reshape((-1, )) + flat_grads = grads[0].reshape((-1,)) flat_grads.need_grad = True def compute_Ax(vec): @@ -180,14 +179,13 @@ def compute_Ax(vec): state.d = state_array flat_grads.forward() - actual = conjugate_gradient( - compute_Ax, flat_grads.d, max_iterations=1000) + actual = conjugate_gradient(compute_Ax, flat_grads.d, max_iterations=1000) H = np.array( - [[2*state_array[0, 0]**2, - 2*state_array[0, 0]*state_array[0, 1]], - [2*state_array[0, 0]*state_array[0, 1], - 2*state_array[0, 1]**2]] + [ + [2 * state_array[0, 0] ** 2, 2 * state_array[0, 0] * state_array[0, 1]], + [2 * state_array[0, 0] * state_array[0, 1], 2 * state_array[0, 1] ** 2], + ] ) expected = np.matmul(np.linalg.pinv(H), flat_grads.d.reshape(-1, 1)) @@ -195,8 +193,9 @@ def compute_Ax(vec): def test_compute_hessian_vector_product_by_hessian(self): state = nn.Variable((1, 2)) - output = NPF.affine(state, 1, w_init=NI.ConstantInitializer( - value=1.), b_init=NI.ConstantInitializer(value=1.)) + output = NPF.affine( + state, 1, w_init=NI.ConstantInitializer(value=1.0), b_init=NI.ConstantInitializer(value=1.0) + ) loss = NF.sum(output**2) grads = nn.grad([loss], nn.get_parameters().values()) @@ -210,13 +209,11 @@ def compute_Ax(vec): state.d = state_array flat_grads.forward() - actual = conjugate_gradient( - compute_Ax, flat_grads.d, max_iterations=1000) + actual = conjugate_gradient(compute_Ax, flat_grads.d, max_iterations=1000) hessian = compute_hessian(loss, nn.get_parameters().values()) - expected = np.matmul(np.linalg.pinv(hessian), - flat_grads.d.reshape(-1, 1)) + expected = np.matmul(np.linalg.pinv(hessian), flat_grads.d.reshape(-1, 1)) assert expected == pytest.approx(actual.reshape(-1, 1), abs=1e-5) @@ -259,34 +256,34 @@ class TestDPGPolicyTrainer(TrainerTest): def setup_method(self, method): nn.clear_parameters() - @pytest.mark.parametrize('unroll_steps', [1, 2]) - @pytest.mark.parametrize('burn_in_steps', [0, 1, 2]) - @pytest.mark.parametrize('loss_integration', [LossIntegration.LAST_TIMESTEP_ONLY, LossIntegration.ALL_TIMESTEPS]) + @pytest.mark.parametrize("unroll_steps", [1, 2]) + @pytest.mark.parametrize("burn_in_steps", [0, 1, 2]) + @pytest.mark.parametrize("loss_integration", [LossIntegration.LAST_TIMESTEP_ONLY, LossIntegration.ALL_TIMESTEPS]) def test_with_non_rnn_model(self, unroll_steps, burn_in_steps, loss_integration): env_info = EnvironmentInfo.from_env(DummyContinuous()) - policy = TD3Policy('stub_pi', action_dim=env_info.action_dim, max_action_value=1) - train_q = TD3QFunction('stub_q', None) + policy = TD3Policy("stub_pi", action_dim=env_info.action_dim, max_action_value=1) + train_q = TD3QFunction("stub_q", None) # Using DQN Q trainer as representative trainer - config = MT.policy_trainers.DPGPolicyTrainerConfig(unroll_steps=unroll_steps, - burn_in_steps=burn_in_steps, - loss_integration=loss_integration) + config = MT.policy_trainers.DPGPolicyTrainerConfig( + unroll_steps=unroll_steps, burn_in_steps=burn_in_steps, loss_integration=loss_integration + ) DPGPolicyTrainer(models=policy, q_function=train_q, solvers={}, env_info=env_info, config=config) # pass: If no ecror occurs - @pytest.mark.parametrize('unroll_steps', [1, 2]) - @pytest.mark.parametrize('burn_in_steps', [0, 1, 2]) - @pytest.mark.parametrize('loss_integration', [LossIntegration.LAST_TIMESTEP_ONLY, LossIntegration.ALL_TIMESTEPS]) + @pytest.mark.parametrize("unroll_steps", [1, 2]) + @pytest.mark.parametrize("burn_in_steps", [0, 1, 2]) + @pytest.mark.parametrize("loss_integration", [LossIntegration.LAST_TIMESTEP_ONLY, LossIntegration.ALL_TIMESTEPS]) def test_with_rnn_model(self, unroll_steps, burn_in_steps, loss_integration): env_info = EnvironmentInfo.from_env(DummyContinuous()) - policy = DeterministicRnnPolicy('stub_pi') - train_q = TD3QFunction('stub_q', None) + policy = DeterministicRnnPolicy("stub_pi") + train_q = TD3QFunction("stub_q", None) # Using DQN Q trainer as representative trainer - config = MT.policy_trainers.DPGPolicyTrainerConfig(unroll_steps=unroll_steps, - burn_in_steps=burn_in_steps, - loss_integration=loss_integration) + config = MT.policy_trainers.DPGPolicyTrainerConfig( + unroll_steps=unroll_steps, burn_in_steps=burn_in_steps, loss_integration=loss_integration + ) DPGPolicyTrainer(models=policy, q_function=train_q, solvers={}, env_info=env_info, config=config) # pass: If no ecror occurs @@ -296,48 +293,58 @@ class TestSoftPolicyTrainer(TrainerTest): def setup_method(self, method): nn.clear_parameters() - @pytest.mark.parametrize('unroll_steps', [1, 2]) - @pytest.mark.parametrize('burn_in_steps', [0, 1, 2]) - @pytest.mark.parametrize('loss_integration', [LossIntegration.LAST_TIMESTEP_ONLY, LossIntegration.ALL_TIMESTEPS]) + @pytest.mark.parametrize("unroll_steps", [1, 2]) + @pytest.mark.parametrize("burn_in_steps", [0, 1, 2]) + @pytest.mark.parametrize("loss_integration", [LossIntegration.LAST_TIMESTEP_ONLY, LossIntegration.ALL_TIMESTEPS]) def test_with_non_rnn_model(self, unroll_steps, burn_in_steps, loss_integration): env_info = EnvironmentInfo.from_env(DummyContinuous()) - policy = SACPolicy('stub_pi', action_dim=env_info.action_dim) - train_q1 = SACQFunction('stub_q1', None) - train_q2 = SACQFunction('stub_q2', None) + policy = SACPolicy("stub_pi", action_dim=env_info.action_dim) + train_q1 = SACQFunction("stub_q1", None) + train_q2 = SACQFunction("stub_q2", None) # Using DQN Q trainer as representative trainer - config = MT.policy_trainers.SoftPolicyTrainerConfig(unroll_steps=unroll_steps, - burn_in_steps=burn_in_steps, - loss_integration=loss_integration, - fixed_temperature=True) - SoftPolicyTrainer(policy, - solvers={}, - q_functions=[train_q1, train_q2], - temperature=AdjustableTemperature('stub_t'), - temperature_solver=None, - env_info=env_info, config=config) + config = MT.policy_trainers.SoftPolicyTrainerConfig( + unroll_steps=unroll_steps, + burn_in_steps=burn_in_steps, + loss_integration=loss_integration, + fixed_temperature=True, + ) + SoftPolicyTrainer( + policy, + solvers={}, + q_functions=[train_q1, train_q2], + temperature=AdjustableTemperature("stub_t"), + temperature_solver=None, + env_info=env_info, + config=config, + ) # pass: If no ecror occurs - @pytest.mark.parametrize('unroll_steps', [1, 2]) - @pytest.mark.parametrize('burn_in_steps', [0, 1, 2]) - @pytest.mark.parametrize('loss_integration', [LossIntegration.LAST_TIMESTEP_ONLY, LossIntegration.ALL_TIMESTEPS]) + @pytest.mark.parametrize("unroll_steps", [1, 2]) + @pytest.mark.parametrize("burn_in_steps", [0, 1, 2]) + @pytest.mark.parametrize("loss_integration", [LossIntegration.LAST_TIMESTEP_ONLY, LossIntegration.ALL_TIMESTEPS]) def test_with_rnn_model(self, unroll_steps, burn_in_steps, loss_integration): env_info = EnvironmentInfo.from_env(DummyContinuous()) - policy = StochasticRnnPolicy('stub_pi') - train_q1 = SACQFunction('stub_q1', None) - train_q2 = SACQFunction('stub_q2', None) - config = MT.policy_trainers.SoftPolicyTrainerConfig(unroll_steps=unroll_steps, - burn_in_steps=burn_in_steps, - loss_integration=loss_integration, - fixed_temperature=True) - SoftPolicyTrainer(policy, - solvers={}, - q_functions=[train_q1, train_q2], - temperature=AdjustableTemperature('stub_t'), - temperature_solver=None, - env_info=env_info, config=config) + policy = StochasticRnnPolicy("stub_pi") + train_q1 = SACQFunction("stub_q1", None) + train_q2 = SACQFunction("stub_q2", None) + config = MT.policy_trainers.SoftPolicyTrainerConfig( + unroll_steps=unroll_steps, + burn_in_steps=burn_in_steps, + loss_integration=loss_integration, + fixed_temperature=True, + ) + SoftPolicyTrainer( + policy, + solvers={}, + q_functions=[train_q1, train_q2], + temperature=AdjustableTemperature("stub_t"), + temperature_solver=None, + env_info=env_info, + config=config, + ) # pass: If no ecror occurs @@ -345,8 +352,7 @@ def test_with_rnn_model(self, unroll_steps, burn_in_steps, loss_integration): class TestAdjustableTemperature(TrainerTest): def test_initial_temperature(self): initial_value = 5.0 - temperature = AdjustableTemperature( - scope_name='test', initial_value=initial_value) + temperature = AdjustableTemperature(scope_name="test", initial_value=initial_value) actual_value = temperature() actual_value.forward(clear_no_need_grad=True) @@ -354,7 +360,7 @@ def test_initial_temperature(self): # Create tempearture with random initial value nn.clear_parameters() - temperature = AdjustableTemperature(scope_name='test') + temperature = AdjustableTemperature(scope_name="test") actual_value = temperature() actual_value.forward(clear_no_need_grad=True) @@ -362,14 +368,13 @@ def test_initial_temperature(self): def test_temperature_is_adjustable(self): initial_value = 5.0 - temperature = AdjustableTemperature( - scope_name='test', initial_value=initial_value) + temperature = AdjustableTemperature(scope_name="test", initial_value=initial_value) solver = nn.solvers.Adam(alpha=1.0) solver.set_parameters(temperature.get_parameters()) value = temperature() - loss = 0.5 * NF.mean(value ** 2) + loss = 0.5 * NF.mean(value**2) loss.forward() solver.zero_grad() diff --git a/tests/model_trainers/test_q_value_trainers.py b/tests/model_trainers/test_q_value_trainers.py index 94f45ea0..e482d3f6 100644 --- a/tests/model_trainers/test_q_value_trainers.py +++ b/tests/model_trainers/test_q_value_trainers.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -57,42 +57,46 @@ class TestSquaredTDQFunctionTrainer(object): def setup_method(self, method): nn.clear_parameters() - @pytest.mark.parametrize('num_steps', [1, 2]) - @pytest.mark.parametrize('unroll_steps', [1, 2, 3]) - @pytest.mark.parametrize('burn_in_steps', [0, 1, 2]) - @pytest.mark.parametrize('loss_integration', [LossIntegration.LAST_TIMESTEP_ONLY, LossIntegration.ALL_TIMESTEPS]) + @pytest.mark.parametrize("num_steps", [1, 2]) + @pytest.mark.parametrize("unroll_steps", [1, 2, 3]) + @pytest.mark.parametrize("burn_in_steps", [0, 1, 2]) + @pytest.mark.parametrize("loss_integration", [LossIntegration.LAST_TIMESTEP_ONLY, LossIntegration.ALL_TIMESTEPS]) def test_with_non_rnn_model(self, num_steps, unroll_steps, burn_in_steps, loss_integration): env_info = EnvironmentInfo.from_env(DummyDiscreteImg()) env_info = EnvironmentInfo.from_env(DummyDiscreteImg()) - train_q = DQNQFunction('stub', n_action=env_info.action_dim) - target_q = train_q.deepcopy('stub2') + train_q = DQNQFunction("stub", n_action=env_info.action_dim) + target_q = train_q.deepcopy("stub2") # Using DQN Q trainer as representative trainer - config = MT.q_value_trainers.DQNQTrainerConfig(unroll_steps=unroll_steps, - burn_in_steps=burn_in_steps, - num_steps=num_steps, - loss_integration=loss_integration) + config = MT.q_value_trainers.DQNQTrainerConfig( + unroll_steps=unroll_steps, + burn_in_steps=burn_in_steps, + num_steps=num_steps, + loss_integration=loss_integration, + ) DQNQTrainer(train_functions=train_q, solvers={}, target_function=target_q, env_info=env_info, config=config) # pass: If no ecror occurs - @pytest.mark.parametrize('num_steps', [1, 2]) - @pytest.mark.parametrize('unroll_steps', [1, 2, 3]) - @pytest.mark.parametrize('burn_in_steps', [0, 1, 2]) - @pytest.mark.parametrize('loss_integration', [LossIntegration.LAST_TIMESTEP_ONLY, LossIntegration.ALL_TIMESTEPS]) + @pytest.mark.parametrize("num_steps", [1, 2]) + @pytest.mark.parametrize("unroll_steps", [1, 2, 3]) + @pytest.mark.parametrize("burn_in_steps", [0, 1, 2]) + @pytest.mark.parametrize("loss_integration", [LossIntegration.LAST_TIMESTEP_ONLY, LossIntegration.ALL_TIMESTEPS]) def test_with_rnn_model(self, num_steps, unroll_steps, burn_in_steps, loss_integration): env_info = EnvironmentInfo.from_env(DummyDiscreteImg()) env_info = EnvironmentInfo.from_env(DummyDiscreteImg()) - train_q = DRQNQFunction('stub', n_action=env_info.action_dim) - target_q = train_q.deepcopy('stub2') + train_q = DRQNQFunction("stub", n_action=env_info.action_dim) + target_q = train_q.deepcopy("stub2") # Using DQN Q trainer as representative trainer - config = MT.q_value_trainers.DQNQTrainerConfig(unroll_steps=unroll_steps, - burn_in_steps=burn_in_steps, - num_steps=num_steps, - loss_integration=loss_integration) + config = MT.q_value_trainers.DQNQTrainerConfig( + unroll_steps=unroll_steps, + burn_in_steps=burn_in_steps, + num_steps=num_steps, + loss_integration=loss_integration, + ) DQNQTrainer(train_functions=train_q, solvers={}, target_function=target_q, env_info=env_info, config=config) # pass: If no ecror occurs @@ -126,7 +130,7 @@ def test_n_step_setup_batch(self): num_steps = 5 env_info = EnvironmentInfo.from_env(DummyDiscreteImg()) - train_q = DQNQFunction('stub', n_action=env_info.action_dim) + train_q = DQNQFunction("stub", n_action=env_info.action_dim) config = MT.q_value_trainers.multi_step_trainer.MultiStepTrainerConfig(num_steps=num_steps) trainer = MultiStepTrainerForTest(models=train_q, solvers={}, env_info=env_info, config=config) @@ -157,9 +161,10 @@ def test_rnn_n_step_setup_batch(self): unroll_steps = 3 env_info = EnvironmentInfo.from_env(DummyDiscreteImg()) - train_q = DRQNQFunction('stub', n_action=env_info.action_dim) - config = MT.q_value_trainers.multi_step_trainer.MultiStepTrainerConfig(num_steps=num_steps, - unroll_steps=unroll_steps) + train_q = DRQNQFunction("stub", n_action=env_info.action_dim) + config = MT.q_value_trainers.multi_step_trainer.MultiStepTrainerConfig( + num_steps=num_steps, unroll_steps=unroll_steps + ) trainer = MultiStepTrainerForTest(models=train_q, solvers={}, env_info=env_info, config=config) batch = _generate_batch(batch_size, num_steps + unroll_steps - 1, env_info) @@ -186,10 +191,10 @@ def test_rnn_with_burnin_n_step_setup_batch(self): burn_in_steps = 2 env_info = EnvironmentInfo.from_env(DummyDiscreteImg()) - train_q = DRQNQFunction('stub', n_action=env_info.action_dim) - config = MT.q_value_trainers.multi_step_trainer.MultiStepTrainerConfig(num_steps=num_steps, - unroll_steps=unroll_steps, - burn_in_steps=burn_in_steps) + train_q = DRQNQFunction("stub", n_action=env_info.action_dim) + config = MT.q_value_trainers.multi_step_trainer.MultiStepTrainerConfig( + num_steps=num_steps, unroll_steps=unroll_steps, burn_in_steps=burn_in_steps + ) trainer = MultiStepTrainerForTest(models=train_q, solvers={}, env_info=env_info, config=config) batch = _generate_batch(batch_size, num_steps + unroll_steps + burn_in_steps - 1, env_info) @@ -227,16 +232,18 @@ def _expected_batch(self, training_batch: TrainingBatch, num_steps) -> TrainingB next_batch = next_batch.next_step_batch - return TrainingBatch(batch_size=training_batch.batch_size, - s_current=training_batch.s_current, - a_current=training_batch.a_current, - reward=n_step_reward, - gamma=n_step_gamma, - non_terminal=n_step_non_terminal, - s_next=n_step_state, - weight=training_batch.weight, - extra=training_batch.extra, - next_step_batch=None) + return TrainingBatch( + batch_size=training_batch.batch_size, + s_current=training_batch.s_current, + a_current=training_batch.a_current, + reward=n_step_reward, + gamma=n_step_gamma, + non_terminal=n_step_non_terminal, + s_next=n_step_state, + weight=training_batch.weight, + extra=training_batch.extra, + next_step_batch=None, + ) def _generate_batch(batch_size, num_steps, env_info) -> TrainingBatch: @@ -247,21 +254,23 @@ def _generate_batch(batch_size, num_steps, env_info) -> TrainingBatch: tail_batch: Optional[TrainingBatch] = None s_current = np.random.normal(size=(batch_size, state_dim)) for _ in range(num_steps): - a_current = np.random.randint(action_num, size=(batch_size, 1)).astype('float32') + a_current = np.random.randint(action_num, size=(batch_size, 1)).astype("float32") reward = np.random.normal(size=(batch_size, 1)) gamma = 0.99 - non_terminal = np.random.randint(2, size=(batch_size, 1)).astype('float32') + non_terminal = np.random.randint(2, size=(batch_size, 1)).astype("float32") s_next = np.random.normal(size=(batch_size, state_dim)) weight = np.random.normal(size=(batch_size, 1)) - batch = TrainingBatch(batch_size=batch_size, - s_current=s_current, - a_current=a_current, - reward=reward, - gamma=gamma, - non_terminal=non_terminal, - s_next=s_next, - weight=weight) + batch = TrainingBatch( + batch_size=batch_size, + s_current=s_current, + a_current=a_current, + reward=reward, + gamma=gamma, + non_terminal=non_terminal, + s_next=s_next, + weight=weight, + ) if head_batch is None: head_batch = batch if tail_batch is None: diff --git a/tests/models/atari/test_distributional_functions.py b/tests/models/atari/test_distributional_functions.py index 8668f044..627dc63a 100644 --- a/tests/models/atari/test_distributional_functions.py +++ b/tests/models/atari/test_distributional_functions.py @@ -1,4 +1,4 @@ -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,8 +16,11 @@ import pytest import nnabla as nn -from nnabla_rl.models.atari.distributional_functions import (C51ValueDistributionFunction, IQNQuantileFunction, - QRDQNQuantileDistributionFunction) +from nnabla_rl.models.atari.distributional_functions import ( + C51ValueDistributionFunction, + IQNQuantileFunction, + QRDQNQuantileDistributionFunction, +) def risk_measure_function(tau): @@ -33,11 +36,9 @@ def test_scope_name(self): scope_name = "test" v_min = 0 v_max = 10 - model = C51ValueDistributionFunction(scope_name=scope_name, - n_action=n_action, - n_atom=n_atom, - v_max=v_max, - v_min=v_min) + model = C51ValueDistributionFunction( + scope_name=scope_name, n_action=n_action, n_atom=n_atom, v_max=v_max, v_min=v_min + ) assert scope_name == model.scope_name @@ -50,11 +51,9 @@ def test_get_parameters(self): scope_name = "test" v_min = 0 v_max = 10 - model = C51ValueDistributionFunction(scope_name=scope_name, - n_action=n_action, - n_atom=n_atom, - v_max=v_max, - v_min=v_min) + model = C51ValueDistributionFunction( + scope_name=scope_name, n_action=n_action, n_atom=n_atom, v_max=v_max, v_min=v_min + ) # Fake input to initialize parameters input_state = nn.Variable.from_numpy_array(np.random.rand(1, *state_shape)) @@ -70,11 +69,9 @@ def test_probabilities(self): scope_name = "test" v_min = 0 v_max = 10 - model = C51ValueDistributionFunction(scope_name=scope_name, - n_action=n_action, - n_atom=n_atom, - v_max=v_max, - v_min=v_min) + model = C51ValueDistributionFunction( + scope_name=scope_name, n_action=n_action, n_atom=n_atom, v_max=v_max, v_min=v_min + ) input_state = nn.Variable.from_numpy_array(np.random.rand(1, *state_shape)) val = model.all_probs(input_state) @@ -91,9 +88,7 @@ def test_scope_name(self): n_action = 4 n_quantile = 10 scope_name = "test" - model = QRDQNQuantileDistributionFunction(scope_name=scope_name, - n_action=n_action, - n_quantile=n_quantile) + model = QRDQNQuantileDistributionFunction(scope_name=scope_name, n_action=n_action, n_quantile=n_quantile) assert scope_name == model.scope_name @@ -104,9 +99,7 @@ def test_get_parameters(self): n_action = 4 n_quantile = 10 scope_name = "test" - model = QRDQNQuantileDistributionFunction(scope_name=scope_name, - n_action=n_action, - n_quantile=n_quantile) + model = QRDQNQuantileDistributionFunction(scope_name=scope_name, n_action=n_action, n_quantile=n_quantile) # Fake input to initialize parameters input_state = nn.Variable.from_numpy_array(np.random.rand(1, *state_shape)) model.all_quantiles(input_state) @@ -119,9 +112,7 @@ def test_quantiles(self): n_action = 4 n_quantile = 10 scope_name = "test" - model = QRDQNQuantileDistributionFunction(scope_name=scope_name, - n_action=n_action, - n_quantile=n_quantile) + model = QRDQNQuantileDistributionFunction(scope_name=scope_name, n_action=n_action, n_quantile=n_quantile) input_state = nn.Variable.from_numpy_array(np.random.rand(1, *state_shape)) val = model.all_quantiles(input_state) @@ -138,11 +129,13 @@ def test_scope_name(self): embedding_dim = 10 scope_name = "test" K = 10 - model = IQNQuantileFunction(scope_name=scope_name, - n_action=n_action, - embedding_dim=embedding_dim, - K=K, - risk_measure_function=risk_measure_function) + model = IQNQuantileFunction( + scope_name=scope_name, + n_action=n_action, + embedding_dim=embedding_dim, + K=K, + risk_measure_function=risk_measure_function, + ) assert scope_name == model.scope_name @@ -154,11 +147,13 @@ def test_get_parameters(self): embedding_dim = 10 scope_name = "test" K = 10 - model = IQNQuantileFunction(scope_name=scope_name, - n_action=n_action, - embedding_dim=embedding_dim, - K=K, - risk_measure_function=risk_measure_function) + model = IQNQuantileFunction( + scope_name=scope_name, + n_action=n_action, + embedding_dim=embedding_dim, + K=K, + risk_measure_function=risk_measure_function, + ) # Fake input to initialize parameters input_state = nn.Variable.from_numpy_array(np.random.rand(1, *state_shape)) tau = nn.Variable.from_numpy_array(np.random.rand(1, 10)) @@ -173,11 +168,13 @@ def test_quantile_values(self): embedding_dim = 64 scope_name = "test" K = 10 - model = IQNQuantileFunction(scope_name=scope_name, - n_action=n_action, - embedding_dim=embedding_dim, - K=K, - risk_measure_function=risk_measure_function) + model = IQNQuantileFunction( + scope_name=scope_name, + n_action=n_action, + embedding_dim=embedding_dim, + K=K, + risk_measure_function=risk_measure_function, + ) # Initialize parameters n_sample = 5 @@ -196,11 +193,13 @@ def test_encode(self): embedding_dim = 64 scope_name = "test" K = 10 - model = IQNQuantileFunction(scope_name=scope_name, - n_action=n_action, - embedding_dim=embedding_dim, - K=K, - risk_measure_function=risk_measure_function) + model = IQNQuantileFunction( + scope_name=scope_name, + n_action=n_action, + embedding_dim=embedding_dim, + K=K, + risk_measure_function=risk_measure_function, + ) # Initialize parameters n_sample = 5 @@ -211,7 +210,7 @@ def test_encode(self): assert encoded.shape == (1, n_sample, 3136) assert np.alltrue(encoded[:, 1:, :] == encoded[:, 0, :]) - print('encoded: ', encoded) + print("encoded: ", encoded) def test_compute_embeddings(self): nn.clear_parameters() @@ -220,11 +219,13 @@ def test_compute_embeddings(self): embedding_dim = 64 scope_name = "test" K = 10 - model = IQNQuantileFunction(scope_name=scope_name, - n_action=n_action, - embedding_dim=embedding_dim, - K=K, - risk_measure_function=risk_measure_function) + model = IQNQuantileFunction( + scope_name=scope_name, + n_action=n_action, + embedding_dim=embedding_dim, + K=K, + risk_measure_function=risk_measure_function, + ) # Initialize parameters n_sample = 5 @@ -235,7 +236,7 @@ def test_compute_embeddings(self): params = model.get_parameters() for key, param in params.items(): - if 'embedding' not in key: + if "embedding" not in key: continue param.d = np.ones(param.shape) embedding.forward() @@ -244,8 +245,7 @@ def test_compute_embeddings(self): expected = [] for t in tau[0]: for _ in range(encode_dim): - embedding = np.sum( - [np.cos(np.pi * (i + 1) * t) for i in range(embedding_dim)]) + embedding = np.sum([np.cos(np.pi * (i + 1) * t) for i in range(embedding_dim)]) embedding += 1 # Add bias embedding = np.maximum(0.0, embedding) expected.append(embedding) diff --git a/tests/models/atari/test_q_functions.py b/tests/models/atari/test_q_functions.py index 2028b069..b01143e3 100644 --- a/tests/models/atari/test_q_functions.py +++ b/tests/models/atari/test_q_functions.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -54,8 +54,7 @@ def test_q(self): scope_name = "test" model = DQNQFunction(scope_name=scope_name, n_action=n_action) - input_state = nn.Variable.from_numpy_array( - np.random.rand(1, *state_shape)) + input_state = nn.Variable.from_numpy_array(np.random.rand(1, *state_shape)) input_action = nn.Variable.from_numpy_array(np.ones((1, 1))) outputs = nn.Variable.from_numpy_array(np.random.rand(1, n_action)) diff --git a/tests/models/mujoco/test_policies.py b/tests/models/mujoco/test_policies.py index c5fbcf55..a976249b 100644 --- a/tests/models/mujoco/test_policies.py +++ b/tests/models/mujoco/test_policies.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,9 +27,7 @@ def setup_method(self, method): def test_scope_name(self): action_dim = 5 scope_name = "test" - model = TD3Policy(scope_name=scope_name, - action_dim=action_dim, - max_action_value=1.0) + model = TD3Policy(scope_name=scope_name, action_dim=action_dim, max_action_value=1.0) assert scope_name == model.scope_name @@ -37,12 +35,9 @@ def test_get_parameters(self): state_dim = 5 action_dim = 5 scope_name = "test" - model = TD3Policy(scope_name=scope_name, - action_dim=action_dim, - max_action_value=1.0) + model = TD3Policy(scope_name=scope_name, action_dim=action_dim, max_action_value=1.0) # Fake input to initialize parameters - input_state = nn.Variable.from_numpy_array( - np.random.rand(1, state_dim)) + input_state = nn.Variable.from_numpy_array(np.random.rand(1, state_dim)) model.pi(input_state) assert len(model.get_parameters()) == 6 @@ -55,8 +50,7 @@ def setup_method(self, method): def test_scope_name(self): action_dim = 5 scope_name = "test" - model = SACPolicy(scope_name=scope_name, - action_dim=action_dim) + model = SACPolicy(scope_name=scope_name, action_dim=action_dim) assert scope_name == model.scope_name @@ -64,12 +58,10 @@ def test_get_parameters(self): state_dim = 5 action_dim = 5 scope_name = "test" - model = SACPolicy(scope_name=scope_name, - action_dim=action_dim) + model = SACPolicy(scope_name=scope_name, action_dim=action_dim) # Fake input to initialize parameters - input_state = nn.Variable.from_numpy_array( - np.random.rand(1, state_dim)) + input_state = nn.Variable.from_numpy_array(np.random.rand(1, state_dim)) model.pi(input_state) assert len(model.get_parameters()) == 6 diff --git a/tests/models/mujoco/test_q_functions.py b/tests/models/mujoco/test_q_functions.py index d510796e..1528f8e8 100644 --- a/tests/models/mujoco/test_q_functions.py +++ b/tests/models/mujoco/test_q_functions.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -38,8 +38,7 @@ def test_get_parameters(self): model = TD3QFunction(scope_name=scope_name) # Fake input to initialize parameters - input_state = nn.Variable.from_numpy_array( - np.random.rand(1, state_dim)) + input_state = nn.Variable.from_numpy_array(np.random.rand(1, state_dim)) input_action = nn.Variable.from_numpy_array(np.ones((1, action_dim))) model.q(input_state, input_action) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index e8e51186..c3233d7f 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -57,19 +57,19 @@ def test_get_parameters(self): assert len(model.get_parameters()) == 0 def test_deepcopy_without_model_initialization(self): - scope_name = 'src' + scope_name = "src" input_dim = 5 x = nn.Variable.from_numpy_array(np.empty(shape=(1, input_dim))) model = self._create_model_from_input(scope_name=scope_name, x=x) - new_scope_name = 'copied' + new_scope_name = "copied" copied = model.deepcopy(new_scope_name) assert type(copied) is type(model) assert len(model.get_parameters()) == 0 assert len(copied.get_parameters()) == 0 def test_deepcopy_model_is_same(self): - scope_name = 'src' + scope_name = "src" input_dim = 5 x = nn.Variable.from_numpy_array(np.ones(shape=(1, input_dim))) model = self._create_model_from_input(scope_name=scope_name, x=x) @@ -77,12 +77,12 @@ def test_deepcopy_model_is_same(self): # Call once to create params model(x) - new_scope_name = 'copied' + new_scope_name = "copied" copied = model.deepcopy(new_scope_name) assert type(copied) is type(model) def test_deepcopy_model_parameters_are_not_shared(self): - scope_name = 'src' + scope_name = "src" input_dim = 5 x = nn.Variable.from_numpy_array(np.ones(shape=(1, input_dim))) model = self._create_model_from_input(scope_name=scope_name, x=x) @@ -90,7 +90,7 @@ def test_deepcopy_model_parameters_are_not_shared(self): # Call once to create params model(x) - new_scope_name = 'copied' + new_scope_name = "copied" copied = model.deepcopy(new_scope_name) for src_value in model.get_parameters().values(): @@ -98,7 +98,7 @@ def test_deepcopy_model_parameters_are_not_shared(self): assert src_value is not dst_value def test_deepcopy_model_has_same_param_num(self): - scope_name = 'src' + scope_name = "src" input_dim = 5 x = nn.Variable.from_numpy_array(np.ones(shape=(1, input_dim))) model = self._create_model_from_input(scope_name=scope_name, x=x) @@ -106,7 +106,7 @@ def test_deepcopy_model_has_same_param_num(self): # Call once to create params expected = model(x) - new_scope_name = 'copied' + new_scope_name = "copied" copied = model.deepcopy(new_scope_name) assert len(copied.get_parameters()) == len(model.get_parameters()) @@ -117,7 +117,7 @@ def test_deepcopy_model_has_same_param_num(self): assert np.allclose(expected.d, actual.d) def test_deepcopy_same_scope_name_not_allowed(self): - scope_name = 'src' + scope_name = "src" input_dim = 5 x = nn.Variable.from_numpy_array(np.empty(shape=(1, input_dim))) model = self._create_model_from_input(scope_name=scope_name, x=x) @@ -129,7 +129,7 @@ def test_deepcopy_same_scope_name_not_allowed(self): model.deepcopy(scope_name) def test_deepcopy_cannot_create_with_existing_scope_name(self): - scope_name = 'src' + scope_name = "src" input_dim = 5 x = nn.Variable.from_numpy_array(np.empty(shape=(1, input_dim))) model = self._create_model_from_input(scope_name=scope_name, x=x) @@ -137,7 +137,7 @@ def test_deepcopy_cannot_create_with_existing_scope_name(self): # Call once to create params model(x) - new_scope_name = 'new' + new_scope_name = "new" model.deepcopy(new_scope_name) # Can not create with same scope twice @@ -145,7 +145,7 @@ def test_deepcopy_cannot_create_with_existing_scope_name(self): model.deepcopy(new_scope_name) def test_shallowcopy(self): - scope_name = 'src' + scope_name = "src" input_dim = 5 x = nn.Variable.from_numpy_array(np.empty(shape=(1, input_dim))) model = self._create_model_from_input(scope_name=scope_name, x=x) diff --git a/tests/numpy_model_trainers/distribution_parameters/test_gmm_parameter_trainer.py b/tests/numpy_model_trainers/distribution_parameters/test_gmm_parameter_trainer.py index 2b203a18..7fea1d2a 100644 --- a/tests/numpy_model_trainers/distribution_parameters/test_gmm_parameter_trainer.py +++ b/tests/numpy_model_trainers/distribution_parameters/test_gmm_parameter_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,8 +15,10 @@ import numpy as np import pytest -from nnabla_rl.numpy_model_trainers.distribution_parameters.gmm_parameter_trainer import (GMMParameterTrainer, - GMMParameterTrainerConfig) +from nnabla_rl.numpy_model_trainers.distribution_parameters.gmm_parameter_trainer import ( + GMMParameterTrainer, + GMMParameterTrainerConfig, +) from nnabla_rl.numpy_models.distribution_parameters.gmm_parameter import GMMParameter @@ -58,7 +60,8 @@ def test_update(self): data = _sample_gaussian(means, covariances, 5000) trainer = GMMParameterTrainer( parameter=GMMParameter.from_data(data, num_classes=3), - config=GMMParameterTrainerConfig(num_iterations_per_update=1000, threshold=1e-12)) + config=GMMParameterTrainerConfig(num_iterations_per_update=1000, threshold=1e-12), + ) trainer.update(data) order = np.argsort(trainer._parameter._means[:, 0]) @@ -69,8 +72,12 @@ def test_update(self): expected_covariances = covariances[order] matched = True - for actual_mean, actual_covariance, expected_mean, expected_covariance, in zip( - actual_means, actual_covariances, expected_means, expected_covariances): + for ( + actual_mean, + actual_covariance, + expected_mean, + expected_covariance, + ) in zip(actual_means, actual_covariances, expected_means, expected_covariances): matched &= np.allclose(actual_mean, expected_mean, atol=1e-1) matched &= np.allclose(actual_covariance, expected_covariance, atol=1e-1) if matched: diff --git a/tests/preprocessors/test_running_mean_normalizer.py b/tests/preprocessors/test_running_mean_normalizer.py index 53b322f8..2de444f8 100644 --- a/tests/preprocessors/test_running_mean_normalizer.py +++ b/tests/preprocessors/test_running_mean_normalizer.py @@ -21,18 +21,21 @@ from nnabla_rl.preprocessors.running_mean_normalizer import RunningMeanNormalizer -class TestRunningMeanNormalizer(): +class TestRunningMeanNormalizer: def setup_method(self, method): nn.clear_parameters() np.random.seed(0) - @pytest.mark.parametrize("x1, x2, x3", - [(np.random.randn(1, 3), np.random.randn(1, 3), np.random.randn(1, 3)), - (np.random.randn(1, 2), np.random.randn(2, 2), np.random.randn(3, 2))]) + @pytest.mark.parametrize( + "x1, x2, x3", + [ + (np.random.randn(1, 3), np.random.randn(1, 3), np.random.randn(1, 3)), + (np.random.randn(1, 2), np.random.randn(2, 2), np.random.randn(3, 2)), + ], + ) def test_update(self, x1, x2, x3): state_dim = x1.shape[1] - normalizer = RunningMeanNormalizer( - scope_name="test", shape=(state_dim, ), epsilon=0.0) + normalizer = RunningMeanNormalizer(scope_name="test", shape=(state_dim,), epsilon=0.0) normalizer.update(x1) normalizer.update(x2) @@ -45,13 +48,16 @@ def test_update(self, x1, x2, x3): assert np.allclose(expected_mean, normalizer._mean.d, atol=1e-4) assert np.allclose(expected_var, normalizer._var.d, atol=1e-4) - @pytest.mark.parametrize("mean, var, s_batch", - [(np.ones((1, 3)), np.ones((1, 3))*0.2, np.random.randn(1, 3)), - (np.ones((1, 2))*0.5, np.ones((1, 2))*0.1, np.random.randn(3, 2))]) + @pytest.mark.parametrize( + "mean, var, s_batch", + [ + (np.ones((1, 3)), np.ones((1, 3)) * 0.2, np.random.randn(1, 3)), + (np.ones((1, 2)) * 0.5, np.ones((1, 2)) * 0.1, np.random.randn(3, 2)), + ], + ) def test_filter(self, mean, var, s_batch): state_dim = s_batch.shape[1] - normalizer = RunningMeanNormalizer( - scope_name="test", shape=(state_dim, ), epsilon=0.0) + normalizer = RunningMeanNormalizer(scope_name="test", shape=(state_dim,), epsilon=0.0) normalizer._mean.d = mean normalizer._var.d = var @@ -69,14 +75,19 @@ def test_filter(self, mean, var, s_batch): def test_invalid_value_clip(self): with pytest.raises(ValueError): - RunningMeanNormalizer("test", (1, 1), value_clip=[5., -5.]) + RunningMeanNormalizer("test", (1, 1), value_clip=[5.0, -5.0]) def test_numpy_initializer(self): - shape = (6, ) + shape = (6,) mean_initializer = np.random.rand(6) var_initializer = np.random.rand(6) - normalizer = RunningMeanNormalizer(scope_name="test", shape=shape, epsilon=0.0, - mean_initializer=mean_initializer, var_initializer=var_initializer) + normalizer = RunningMeanNormalizer( + scope_name="test", + shape=shape, + epsilon=0.0, + mean_initializer=mean_initializer, + var_initializer=var_initializer, + ) # dummy process output = normalizer.process(nn.Variable.from_numpy_array(np.random.rand(1, 6))) @@ -89,11 +100,16 @@ def test_numpy_initializer(self): assert np.allclose(actual_params["count"].d, np.ones((1, 1)) * 1e-4) def test_nnabla_initializer(self): - shape = (6, ) + shape = (6,) mean_initializer = NI.ConstantInitializer(5.0) var_initializer = NI.ConstantInitializer(6.0) - normalizer = RunningMeanNormalizer(scope_name="test", shape=shape, epsilon=0.0, - mean_initializer=mean_initializer, var_initializer=var_initializer) + normalizer = RunningMeanNormalizer( + scope_name="test", + shape=shape, + epsilon=0.0, + mean_initializer=mean_initializer, + var_initializer=var_initializer, + ) # dummy process output = normalizer.process(nn.Variable.from_numpy_array(np.random.rand(1, 6))) @@ -106,17 +122,27 @@ def test_nnabla_initializer(self): assert np.allclose(actual_params["count"].d, np.ones((1, 1)) * 1e-4) def test_numpy_initializer_with_invalid_mean_initializer_shape(self): - shape = (6, ) + shape = (6,) mean_initializer = np.random.rand(4) var_initializer = np.random.rand(6) with pytest.raises(AssertionError): - RunningMeanNormalizer(scope_name="test", shape=shape, epsilon=0.0, - mean_initializer=mean_initializer, var_initializer=var_initializer) + RunningMeanNormalizer( + scope_name="test", + shape=shape, + epsilon=0.0, + mean_initializer=mean_initializer, + var_initializer=var_initializer, + ) def test_numpy_initializer_with_invalid_var_initializer_shape(self): - shape = (6, ) + shape = (6,) mean_initializer = np.random.rand(6) var_initializer = np.random.rand(4) with pytest.raises(AssertionError): - RunningMeanNormalizer(scope_name="test", shape=shape, epsilon=0.0, - mean_initializer=mean_initializer, var_initializer=var_initializer) + RunningMeanNormalizer( + scope_name="test", + shape=shape, + epsilon=0.0, + mean_initializer=mean_initializer, + var_initializer=var_initializer, + ) diff --git a/tests/replay_buffers/test_buffer_iterator.py b/tests/replay_buffers/test_buffer_iterator.py index d229f769..8e3d6beb 100644 --- a/tests/replay_buffers/test_buffer_iterator.py +++ b/tests/replay_buffers/test_buffer_iterator.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -28,8 +28,7 @@ def test_buffer_iterator_shuffle_no_repeat(self): buffer.append_all(dummy_examples) batch_size = 30 - iterator = BufferIterator( - buffer=buffer, batch_size=batch_size, shuffle=True, repeat=False) + iterator = BufferIterator(buffer=buffer, batch_size=batch_size, shuffle=True, repeat=False) for _ in range(buffer_size // batch_size): batch, *_ = iterator.next() @@ -49,8 +48,7 @@ def test_buffer_iterator_shuffle_with_repeat(self): buffer.append_all(dummy_examples) batch_size = 30 - iterator = BufferIterator( - buffer=buffer, batch_size=batch_size, shuffle=True, repeat=True) + iterator = BufferIterator(buffer=buffer, batch_size=batch_size, shuffle=True, repeat=True) for _ in range(buffer_size // batch_size): batch, *_ = iterator.next() @@ -74,8 +72,7 @@ def test_buffer_iterator_is_iterable(self): buffer.append_all(dummy_examples) batch_size = 30 - iterator = BufferIterator( - buffer=buffer, batch_size=batch_size, shuffle=True, repeat=True) + iterator = BufferIterator(buffer=buffer, batch_size=batch_size, shuffle=True, repeat=True) for experience, *_ in iterator: assert len(experience) == batch_size diff --git a/tests/replay_buffers/test_decorable_replay_buffer.py b/tests/replay_buffers/test_decorable_replay_buffer.py index c23b7ebc..846b91fd 100644 --- a/tests/replay_buffers/test_decorable_replay_buffer.py +++ b/tests/replay_buffers/test_decorable_replay_buffer.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,12 +25,10 @@ def decor_fun(self, experience): pass def test_getitem(self): - mock_decor_fun = create_autospec( - self.decor_fun, return_value=(1, 2, 3, 4, 5)) + mock_decor_fun = create_autospec(self.decor_fun, return_value=(1, 2, 3, 4, 5)) capacity = 10 - buffer = DecorableReplayBuffer(capacity=capacity, - decor_fun=mock_decor_fun) + buffer = DecorableReplayBuffer(capacity=capacity, decor_fun=mock_decor_fun) append_num = 10 for i in range(append_num): diff --git a/tests/replay_buffers/test_hindsight_replay_buffer.py b/tests/replay_buffers/test_hindsight_replay_buffer.py index 47e0225c..e740e46f 100644 --- a/tests/replay_buffers/test_hindsight_replay_buffer.py +++ b/tests/replay_buffers/test_hindsight_replay_buffer.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,8 +30,7 @@ def setup_method(self, method): def test_unsupported_env(self): dummy_env = DummyContinuousActionGoalEnv(max_episode_steps=max_episode_steps) - buffer = HindsightReplayBuffer(reward_function=dummy_env.compute_reward, - capacity=100) + buffer = HindsightReplayBuffer(reward_function=dummy_env.compute_reward, capacity=100) experiences = self._generate_experiences(dummy_env) with pytest.raises(RuntimeError): @@ -40,8 +39,7 @@ def test_unsupported_env(self): def test_extract_end_index_of_episode(self): dummy_env = DummyContinuousActionGoalEnv(max_episode_steps=max_episode_steps) dummy_env = GoalConditionedTupleObservationEnv(dummy_env) - buffer = HindsightReplayBuffer(reward_function=dummy_env.compute_reward, - capacity=100) + buffer = HindsightReplayBuffer(reward_function=dummy_env.compute_reward, capacity=100) for _ in range(num_episode): experiences = self._generate_experiences(dummy_env) buffer.append_all(experiences) @@ -52,26 +50,24 @@ def test_extract_end_index_of_episode(self): gt_end_index_of_episode = (i + 1) * max_episode_steps - 1 assert end_index_of_episode == gt_end_index_of_episode - @pytest.mark.parametrize('index', [np.random.randint(num_experiences) for _ in range(10)]) + @pytest.mark.parametrize("index", [np.random.randint(num_experiences) for _ in range(10)]) def test_select_future_index(self, index): dummy_env = DummyContinuousActionGoalEnv(max_episode_steps=max_episode_steps) dummy_env = GoalConditionedTupleObservationEnv(dummy_env) - buffer = HindsightReplayBuffer(reward_function=dummy_env.compute_reward, - capacity=100) + buffer = HindsightReplayBuffer(reward_function=dummy_env.compute_reward, capacity=100) for _ in range(num_episode): experiences = self._generate_experiences(dummy_env) buffer.append_all(experiences) end_index_of_episode = self._extract_end_index_of_episode(buffer, index) future_index = buffer._select_future_index(index, end_index_of_episode) - assert ((index <= future_index) and (future_index <= end_index_of_episode)) + assert (index <= future_index) and (future_index <= end_index_of_episode) - @pytest.mark.parametrize('index', [np.random.randint(num_experiences) for _ in range(10)]) + @pytest.mark.parametrize("index", [np.random.randint(num_experiences) for _ in range(10)]) def test_replace_goal(self, index): dummy_env = DummyContinuousActionGoalEnv(max_episode_steps=max_episode_steps) dummy_env = GoalConditionedTupleObservationEnv(dummy_env) - buffer = HindsightReplayBuffer(reward_function=dummy_env.compute_reward, - capacity=100) + buffer = HindsightReplayBuffer(reward_function=dummy_env.compute_reward, capacity=100) for _ in range(num_episode): experiences = self._generate_experiences(dummy_env) buffer.append_all(experiences) @@ -85,12 +81,11 @@ def test_replace_goal(self, index): assert np.allclose(new_experience[0][1], future_experience[4][2]) assert np.allclose(new_experience[4][1], future_experience[4][2]) - @pytest.mark.parametrize('index', [np.random.randint(num_experiences) for _ in range(10)]) + @pytest.mark.parametrize("index", [np.random.randint(num_experiences) for _ in range(10)]) def test_replace_goal_with_same_index(self, index): dummy_env = DummyContinuousActionGoalEnv(max_episode_steps=max_episode_steps) dummy_env = GoalConditionedTupleObservationEnv(dummy_env) - buffer = HindsightReplayBuffer(reward_function=dummy_env.compute_reward, - capacity=100) + buffer = HindsightReplayBuffer(reward_function=dummy_env.compute_reward, capacity=100) for _ in range(num_episode): experiences = self._generate_experiences(dummy_env) buffer.append_all(experiences) @@ -118,8 +113,8 @@ def _generate_experiences(self, env): def _extract_end_index_of_episode(self, buffer, item_index): _, _, _, _, _, info = buffer[item_index] - index_in_episode = info['index_in_episode'] - episode_end_index = int(info['episode_end_index']) + index_in_episode = info["index_in_episode"] + episode_end_index = int(info["episode_end_index"]) distance_to_end = episode_end_index - index_in_episode return distance_to_end + item_index diff --git a/tests/replay_buffers/test_memory_efficient_atari_buffer.py b/tests/replay_buffers/test_memory_efficient_atari_buffer.py index 0909b817..2c412f57 100644 --- a/tests/replay_buffers/test_memory_efficient_atari_buffer.py +++ b/tests/replay_buffers/test_memory_efficient_atari_buffer.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,10 +19,12 @@ from nnabla_rl.environments.dummy import DummyAtariEnv from nnabla_rl.environments.wrappers.atari import MaxAndSkipEnv, NoopResetEnv -from nnabla_rl.replay_buffers.memory_efficient_atari_buffer import (MemoryEfficientAtariBuffer, - MemoryEfficientAtariTrajectoryBuffer, - ProportionalPrioritizedAtariBuffer, - RankBasedPrioritizedAtariBuffer) +from nnabla_rl.replay_buffers.memory_efficient_atari_buffer import ( + MemoryEfficientAtariBuffer, + MemoryEfficientAtariTrajectoryBuffer, + ProportionalPrioritizedAtariBuffer, + RankBasedPrioritizedAtariBuffer, +) from nnabla_rl.utils.reproductions import build_atari_env @@ -34,13 +36,11 @@ def test_append_float(self): buffer = MemoryEfficientAtariBuffer(capacity=capacity) buffer.append(experience) - s, _, _, _, s_next, *_ = buffer._buffer[len(buffer._buffer)-1] + s, _, _, _, s_next, *_ = buffer._buffer[len(buffer._buffer) - 1] assert s.dtype == np.uint8 assert s_next.dtype == np.uint8 - assert np.alltrue( - (experience[0][-1] * 255.0).astype(np.uint8) == s) - assert np.alltrue( - (experience[4][-1] * 255.0).astype(np.uint8) == s_next) + assert np.alltrue((experience[0][-1] * 255.0).astype(np.uint8) == s) + assert np.alltrue((experience[4][-1] * 255.0).astype(np.uint8) == s_next) def test_unstacked_frame(self): experiences = _generate_atari_experience_mock(num_mocks=10, frame_stack=False) @@ -78,8 +78,7 @@ def test_getitem(self): def test_full_buffer_getitem(self): capacity = 10 - experiences = _generate_atari_experience_mock( - num_mocks=(capacity + 5)) + experiences = _generate_atari_experience_mock(num_mocks=(capacity + 5)) buffer = MemoryEfficientAtariBuffer(capacity=capacity) for i in range(capacity): @@ -222,13 +221,11 @@ def test_append_float(self): buffer = ProportionalPrioritizedAtariBuffer(capacity=capacity) buffer.append(experience) - s, _, _, _, s_next, *_ = buffer._buffer[len(buffer._buffer)-1] + s, _, _, _, s_next, *_ = buffer._buffer[len(buffer._buffer) - 1] assert s.dtype == np.uint8 assert s_next.dtype == np.uint8 - assert np.alltrue( - (experience[0][-1] * 255.0).astype(np.uint8) == s) - assert np.alltrue( - (experience[4][-1] * 255.0).astype(np.uint8) == s_next) + assert np.alltrue((experience[0][-1] * 255.0).astype(np.uint8) == s) + assert np.alltrue((experience[4][-1] * 255.0).astype(np.uint8) == s_next) def test_unstacked_frame(self): experiences = _generate_atari_experience_mock(num_mocks=10, frame_stack=False) @@ -266,8 +263,7 @@ def test_getitem(self): def test_full_buffer_getitem(self): capacity = 10 - experiences = _generate_atari_experience_mock( - num_mocks=(capacity + 5)) + experiences = _generate_atari_experience_mock(num_mocks=(capacity + 5)) buffer = ProportionalPrioritizedAtariBuffer(capacity=capacity) for i in range(capacity): @@ -333,13 +329,11 @@ def test_append_float(self): buffer = RankBasedPrioritizedAtariBuffer(capacity=capacity) buffer.append(experience) - s, _, _, _, s_next, *_ = buffer._buffer[len(buffer._buffer)-1] + s, _, _, _, s_next, *_ = buffer._buffer[len(buffer._buffer) - 1] assert s.dtype == np.uint8 assert s_next.dtype == np.uint8 - assert np.alltrue( - (experience[0][-1] * 255.0).astype(np.uint8) == s) - assert np.alltrue( - (experience[4][-1] * 255.0).astype(np.uint8) == s_next) + assert np.alltrue((experience[0][-1] * 255.0).astype(np.uint8) == s) + assert np.alltrue((experience[4][-1] * 255.0).astype(np.uint8) == s_next) def test_unstacked_frame(self): experiences = _generate_atari_experience_mock(num_mocks=10, frame_stack=False) @@ -377,8 +371,7 @@ def test_getitem(self): def test_full_buffer_getitem(self): capacity = 10 - experiences = _generate_atari_experience_mock( - num_mocks=(capacity + 5)) + experiences = _generate_atari_experience_mock(num_mocks=(capacity + 5)) buffer = RankBasedPrioritizedAtariBuffer(capacity=capacity) for i in range(capacity): diff --git a/tests/replay_buffers/test_prioritized_replay_buffer.py b/tests/replay_buffers/test_prioritized_replay_buffer.py index a0558570..832ae263 100644 --- a/tests/replay_buffers/test_prioritized_replay_buffer.py +++ b/tests/replay_buffers/test_prioritized_replay_buffer.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,10 +19,15 @@ import numpy as np import pytest -from nnabla_rl.replay_buffers.prioritized_replay_buffer import (MaxHeap, MaxHeapDataHolder, MinTree, - ProportionalPrioritizedReplayBuffer, - RankBasedPrioritizedReplayBuffer, SumTree, - SumTreeDataHolder) +from nnabla_rl.replay_buffers.prioritized_replay_buffer import ( + MaxHeap, + MaxHeapDataHolder, + MinTree, + ProportionalPrioritizedReplayBuffer, + RankBasedPrioritizedReplayBuffer, + SumTree, + SumTreeDataHolder, +) class TestMinTree(object): @@ -297,7 +302,8 @@ def test_betasteps(self, beta, betasteps): @pytest.mark.parametrize("normalization_method", ["batch_max", "buffer_max"]) def test_sample_one_experience(self, beta, normalization_method): buffer = self._generate_buffer_with_experiences( - experience_num=100, beta=beta, normalization_method=normalization_method) + experience_num=100, beta=beta, normalization_method=normalization_method + ) experiences, info = buffer.sample() indices = buffer._last_sampled_indices assert len(experiences) == 1 @@ -315,7 +321,8 @@ def test_sample_one_experience(self, beta, normalization_method): @pytest.mark.parametrize("normalization_method", ["batch_max", "buffer_max"]) def test_sample_multiple_experiences(self, beta, normalization_method): buffer = self._generate_buffer_with_experiences( - experience_num=100, beta=beta, normalization_method=normalization_method) + experience_num=100, beta=beta, normalization_method=normalization_method + ) num_samples = 10 experiences, info = buffer.sample(num_samples=num_samples) indices = buffer._last_sampled_indices @@ -335,10 +342,15 @@ def test_sample_multiple_experiences(self, beta, normalization_method): @pytest.mark.parametrize("normalization_method", ["batch_max", "buffer_max"]) def test_sample_multiple_step_experience(self, beta, num_steps, normalization_method): buffer = self._generate_buffer_with_experiences( - experience_num=100, beta=beta, normalization_method=normalization_method) + experience_num=100, beta=beta, normalization_method=normalization_method + ) experiences_tuple, info = buffer.sample(num_steps=num_steps) if num_steps == 1: - experiences_tuple = tuple([experiences_tuple, ]) + experiences_tuple = tuple( + [ + experiences_tuple, + ] + ) indices = buffer._last_sampled_indices assert len(experiences_tuple) == num_steps assert "weights" in info @@ -361,7 +373,8 @@ def test_sample_from_insufficient_size_buffer(self): @pytest.mark.parametrize("normalization_method", ["batch_max", "buffer_max"]) def test_sample_indices(self, beta, normalization_method): buffer = self._generate_buffer_with_experiences( - experience_num=100, beta=beta, normalization_method=normalization_method) + experience_num=100, beta=beta, normalization_method=normalization_method + ) indices = [1, 67, 50, 4, 99] experiences, info = buffer.sample_indices(indices) @@ -484,8 +497,8 @@ def test_error_preprocessing(self, error_clip): assert all(max_error >= processed) def _generate_experience_mock(self): - state_shape = (5, ) - action_shape = (10, ) + state_shape = (5,) + action_shape = (10,) state = np.empty(shape=state_shape) action = np.empty(shape=action_shape) @@ -496,12 +509,12 @@ def _generate_experience_mock(self): return (state, action, reward, non_terminal, next_state, next_action) - def _generate_buffer_with_experiences(self, - experience_num, beta=1.0, - betasteps=1, - normalization_method="batch_max"): + def _generate_buffer_with_experiences( + self, experience_num, beta=1.0, betasteps=1, normalization_method="batch_max" + ): buffer = ProportionalPrioritizedReplayBuffer( - capacity=experience_num, beta=beta, betasteps=betasteps, normalization_method=normalization_method) + capacity=experience_num, beta=beta, betasteps=betasteps, normalization_method=normalization_method + ) for _ in range(experience_num): experience = _generate_experience_mock() buffer.append(experience) @@ -510,8 +523,9 @@ def _generate_buffer_with_experiences(self, def _compute_weight(self, buffer, index, alpha, beta): priority = buffer._buffer.get_priority(index) if buffer._normalization_method == "batch_max": - min_priority = np.min(np.array([buffer._buffer.get_priority(index) - for index in buffer._last_sampled_indices])) + min_priority = np.min( + np.array([buffer._buffer.get_priority(index) for index in buffer._last_sampled_indices]) + ) elif buffer._normalization_method == "buffer_max": min_priority = buffer._buffer.min_priority() else: @@ -596,11 +610,14 @@ def test_sample_multiple_experiences(self, beta): @pytest.mark.parametrize("beta", [np.random.uniform(low=0.0, high=1.0) for _ in range(1, 10)]) @pytest.mark.parametrize("num_steps", range(1, 5)) def test_sample_multiple_step_experience(self, beta, num_steps): - buffer = self._generate_buffer_with_experiences(experience_num=100, - beta=beta) + buffer = self._generate_buffer_with_experiences(experience_num=100, beta=beta) experiences_tuple, info = buffer.sample(num_steps=num_steps) if num_steps == 1: - experiences_tuple = tuple([experiences_tuple, ]) + experiences_tuple = tuple( + [ + experiences_tuple, + ] + ) indices = buffer._last_sampled_indices assert len(experiences_tuple) == num_steps assert "weights" in info @@ -708,9 +725,11 @@ def test_sort_interval(self): buffer._last_sampled_indices = [i] buffer.update_priorities(errors=[np.random.randint(100)]) if (i + 1) % sort_interval == 0: - sorted_heap = sorted(buffer._buffer._max_heap._heap, - key=lambda item: -math.inf if item is None else item[1], - reverse=True) + sorted_heap = sorted( + buffer._buffer._max_heap._heap, + key=lambda item: -math.inf if item is None else item[1], + reverse=True, + ) assert np.alltrue(buffer._buffer._max_heap._heap == sorted_heap) @pytest.mark.parametrize("error_clip", [(-np.random.uniform(), np.random.uniform()) for _ in range(1, 10)]) @@ -741,8 +760,8 @@ def _compute_weight(self, buffer, index, alpha, beta): def _generate_experience_mock(): - state_shape = (5, ) - action_shape = (10, ) + state_shape = (5,) + action_shape = (10,) state = np.empty(shape=state_shape) action = np.empty(shape=action_shape) diff --git a/tests/replay_buffers/test_replacement_sampling_replay_buffer.py b/tests/replay_buffers/test_replacement_sampling_replay_buffer.py index 27cba60a..d0b022d2 100644 --- a/tests/replay_buffers/test_replacement_sampling_replay_buffer.py +++ b/tests/replay_buffers/test_replacement_sampling_replay_buffer.py @@ -1,5 +1,5 @@ # Copyright 2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,15 +18,15 @@ from nnabla_rl.replay_buffers.replacement_sampling_replay_buffer import ReplacementSamplingReplayBuffer -class TestReplacementSamplingReplayBuffer(): +class TestReplacementSamplingReplayBuffer: def test_sample_from_insufficient_size_buffer(self): buffer = self._generate_buffer_with_experiences(experience_num=10) samples, _ = buffer.sample(num_samples=100) assert len(samples) == 100 def _generate_experience_mock(self): - state_shape = (5, ) - action_shape = (10, ) + state_shape = (5,) + action_shape = (10,) state = np.empty(shape=state_shape) action = np.empty(shape=action_shape) diff --git a/tests/replay_buffers/test_trajectory_replay_buffer.py b/tests/replay_buffers/test_trajectory_replay_buffer.py index a51ca826..bfc69b86 100644 --- a/tests/replay_buffers/test_trajectory_replay_buffer.py +++ b/tests/replay_buffers/test_trajectory_replay_buffer.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,19 +19,21 @@ from nnabla_rl.replay_buffers.trajectory_replay_buffer import TrajectoryReplayBuffer -class TestTrajectoryReplayBuffer(): +class TestTrajectoryReplayBuffer: def test_len(self): trajectory_num = 10 trajectory_length = 5 - buffer = self._generate_buffer_with_trajectories(trajectory_num=trajectory_num, - trajectory_length=trajectory_length) + buffer = self._generate_buffer_with_trajectories( + trajectory_num=trajectory_num, trajectory_length=trajectory_length + ) assert len(buffer) == trajectory_num * trajectory_length def test_trajectory_num(self): trajectory_num = 10 trajectory_length = 5 - buffer = self._generate_buffer_with_trajectories(trajectory_num=trajectory_num, - trajectory_length=trajectory_length) + buffer = self._generate_buffer_with_trajectories( + trajectory_num=trajectory_num, trajectory_length=trajectory_length + ) assert buffer.trajectory_num == trajectory_num def test_sample_from_insufficient_size_buffer(self): @@ -71,7 +73,7 @@ def test_sample_indices_portion(self, portion_length): index = index - 10 * trajectory_index start_index = min(index, len(trajectory) - portion_length) - expected = trajectory[start_index:start_index+portion_length] + expected = trajectory[start_index : start_index + portion_length] actual = trajectories[i] assert len(expected) == len(actual) for expected_element, actual_element in zip(expected, actual): @@ -104,13 +106,13 @@ def test_sample_indices(self): samples_from_trajectory_buffer, _ = trajectory_buffer.sample_indices(indices) samples_from_conventional_buffer, _ = conventional_buffer.sample_indices(indices) - for (actual_sample, expected_sample) in zip(samples_from_trajectory_buffer, samples_from_conventional_buffer): + for actual_sample, expected_sample in zip(samples_from_trajectory_buffer, samples_from_conventional_buffer): for actual_item, expected_item in zip(actual_sample, expected_sample): np.testing.assert_almost_equal(actual_item, expected_item) def _generate_experience_mock(self): - state_shape = (5, ) - action_shape = (10, ) + state_shape = (5,) + action_shape = (10,) state = np.empty(shape=state_shape) action = np.empty(shape=action_shape) diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index 2c45b9d0..e50b2085 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,8 +26,8 @@ class TestAlgorithm(object): @patch.multiple(Algorithm, __abstractmethods__=set()) - @patch('nnabla_rl.algorithm.Algorithm.is_supported_env', lambda self, env_info: True) - @patch('nnabla_rl.algorithm.Algorithm._has_rnn_models', lambda self: False) + @patch("nnabla_rl.algorithm.Algorithm.is_supported_env", lambda self, env_info: True) + @patch("nnabla_rl.algorithm.Algorithm._has_rnn_models", lambda self: False) def test_resume_online_training(self): env = E.DummyContinuous() algorithm = Algorithm(env) @@ -41,8 +41,8 @@ def test_resume_online_training(self): assert algorithm._run_online_training_iteration.call_count == total_iterations * 2 @patch.multiple(Algorithm, __abstractmethods__=set()) - @patch('nnabla_rl.algorithm.Algorithm.is_supported_env', lambda self, env_info: True) - @patch('nnabla_rl.algorithm.Algorithm._has_rnn_models', lambda self: False) + @patch("nnabla_rl.algorithm.Algorithm.is_supported_env", lambda self, env_info: True) + @patch("nnabla_rl.algorithm.Algorithm._has_rnn_models", lambda self: False) def test_resume_offline_training(self): env = E.DummyContinuous() algorithm = Algorithm(env) @@ -58,9 +58,9 @@ def test_resume_offline_training(self): assert algorithm._run_offline_training_iteration.call_count == total_iterations * 2 @patch.multiple(Algorithm, __abstractmethods__=set()) - @patch('nnabla_rl.algorithm.Algorithm.is_supported_env', lambda self, env_info: True) - @patch('nnabla_rl.algorithm.Algorithm.is_rnn_supported', lambda self: False) - @patch('nnabla_rl.algorithm.Algorithm._has_rnn_models', lambda self: True) + @patch("nnabla_rl.algorithm.Algorithm.is_supported_env", lambda self, env_info: True) + @patch("nnabla_rl.algorithm.Algorithm.is_rnn_supported", lambda self: False) + @patch("nnabla_rl.algorithm.Algorithm._has_rnn_models", lambda self: True) def test_rnn_unsupported_algorithm(self): env = E.DummyContinuous() algorithm = Algorithm(env) @@ -73,9 +73,9 @@ def test_rnn_unsupported_algorithm(self): algorithm.train(buffer, total_iterations=total_iterations) @patch.multiple(Algorithm, __abstractmethods__=set()) - @patch('nnabla_rl.algorithm.Algorithm.is_supported_env', lambda self, env_info: True) - @patch('nnabla_rl.algorithm.Algorithm.is_rnn_supported', lambda self: True) - @patch('nnabla_rl.algorithm.Algorithm._has_rnn_models', lambda self: True) + @patch("nnabla_rl.algorithm.Algorithm.is_supported_env", lambda self, env_info: True) + @patch("nnabla_rl.algorithm.Algorithm.is_rnn_supported", lambda self: True) + @patch("nnabla_rl.algorithm.Algorithm._has_rnn_models", lambda self: True) def test_rnn_supported_algorithm(self): env = E.DummyContinuous() algorithm = Algorithm(env) diff --git a/tests/test_environment_explorer.py b/tests/test_environment_explorer.py index b69c53c6..f10e00f2 100644 --- a/tests/test_environment_explorer.py +++ b/tests/test_environment_explorer.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,8 +17,13 @@ import pytest from nnabla_rl.environment_explorer import _is_end_of_episode, _sample_action -from nnabla_rl.environments.dummy import (DummyContinuous, DummyDiscrete, DummyTupleActionContinuous, - DummyTupleActionDiscrete, DummyTupleMixed) +from nnabla_rl.environments.dummy import ( + DummyContinuous, + DummyDiscrete, + DummyTupleActionContinuous, + DummyTupleActionDiscrete, + DummyTupleMixed, +) from nnabla_rl.environments.environment_info import EnvironmentInfo @@ -41,8 +46,16 @@ def test_is_end_of_episode(self, done, timelimit, timelimit_as_terminal): else: raise RuntimeError - @pytest.mark.parametrize("env", [DummyContinuous(), DummyDiscrete(), - DummyTupleActionContinuous(), DummyTupleActionDiscrete(), DummyTupleMixed()]) + @pytest.mark.parametrize( + "env", + [ + DummyContinuous(), + DummyDiscrete(), + DummyTupleActionContinuous(), + DummyTupleActionDiscrete(), + DummyTupleMixed(), + ], + ) def test_sample_action(self, env): env_info = EnvironmentInfo.from_env(env) action, *_ = _sample_action(env, env_info) @@ -50,15 +63,15 @@ def test_sample_action(self, env): if env_info.is_tuple_action_env(): for a, space in zip(action, env_info.action_space): if isinstance(space, gym.spaces.Discrete): - assert a.shape == (1, ) + assert a.shape == (1,) else: assert a.shape == space.shape else: if isinstance(env_info.action_space, gym.spaces.Discrete): - assert action.shape == (1, ) + assert action.shape == (1,) else: - assert action.shape == env_info.action_space.shape or action.shape == (1, ) + assert action.shape == env_info.action_space.shape or action.shape == (1,) -if __name__ == '__main__': +if __name__ == "__main__": pytest.main() diff --git a/tests/test_functions.py b/tests/test_functions.py index bb8e3786..0ac16f35 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -37,8 +37,7 @@ def test_sample_gaussian(self): ln_sigma_var = nn.Variable(input_shape) ln_sigma_var.d = ln_sigma - sampled_value = RF.sample_gaussian( - mean=mean_var, ln_var=(ln_sigma_var * 2.0)) + sampled_value = RF.sample_gaussian(mean=mean_var, ln_var=(ln_sigma_var * 2.0)) assert sampled_value.shape == (batch_size, output_dim) def test_sample_gaussian_wrong_parameter_shape(self): @@ -56,8 +55,7 @@ def test_sample_gaussian_wrong_parameter_shape(self): ln_sigma_var.d = ln_sigma with pytest.raises(ValueError): - RF.sample_gaussian( - mean=mean_var, ln_var=(ln_sigma_var * 2.0)) + RF.sample_gaussian(mean=mean_var, ln_var=(ln_sigma_var * 2.0)) def test_expand_dims(self): batch_size = 4 @@ -121,7 +119,7 @@ def test_sqrt(self): num_samples = 100 batch_num = 100 # exp to enforce positive value - data = np.exp(np.random.normal(size=(num_samples, batch_num, 1))) + data = np.exp(np.random.normal(size=(num_samples, batch_num, 1))) data = np.float32(data) data_var = nn.Variable(data.shape) data_var.d = data @@ -136,7 +134,7 @@ def test_std(self): # stddev computation num_samples = 100 batch_num = 100 - data = np.random.normal(size=(num_samples, batch_num, 1)) + data = np.random.normal(size=(num_samples, batch_num, 1)) data = np.float32(data) data_var = nn.Variable(data.shape) data_var.d = data @@ -157,7 +155,7 @@ def test_std(self): def test_argmax(self): num_samples = 100 batch_num = 100 - data = np.random.normal(size=(num_samples, batch_num, 1)) + data = np.random.normal(size=(num_samples, batch_num, 1)) data = np.float32(data) data_var = nn.Variable(data.shape) data_var.d = data @@ -172,7 +170,7 @@ def test_argmax(self): def test_argmin(self): num_samples = 100 batch_num = 100 - data = np.random.normal(size=(num_samples, batch_num, 1)) + data = np.random.normal(size=(num_samples, batch_num, 1)) data = np.float32(data) data_var = nn.Variable(data.shape) data_var.d = data @@ -192,7 +190,7 @@ def test_quantile_huber_loss(self): def huber_loss(x0, x1, kappa): diff = x0 - x1 flag = (np.abs(diff) < kappa).astype(np.float32) - return (flag) * 0.5 * (diff ** 2.0) + (1.0 - flag) * kappa * (np.abs(diff) - 0.5 * kappa) + return (flag) * 0.5 * (diff**2.0) + (1.0 - flag) * kappa * (np.abs(diff) - 0.5 * kappa) def quantile_huber_loss(x0, x1, kappa, tau): u = x0 - x1 @@ -234,94 +232,99 @@ def test_mean_squared_error(self): actual = RF.mean_squared_error(x0_var, x1_var) actual.forward(clear_buffer=True) - expected = np.mean((x0 - x1)**2) + expected = np.mean((x0 - x1) ** 2) assert actual.shape == expected.shape assert np.all(np.isclose(actual.d, expected)) def test_gaussian_cross_entropy_method(self): def objective_function(x): - return -((x - 3.)**2) + return -((x - 3.0) ** 2) batch_size = 1 var_size = 1 init_mean = nn.Variable.from_numpy_array(np.zeros((batch_size, var_size))) - init_var = nn.Variable.from_numpy_array(np.ones((batch_size, var_size))*4.) + init_var = nn.Variable.from_numpy_array(np.ones((batch_size, var_size)) * 4.0) optimal_mean, optimal_top = RF.gaussian_cross_entropy_method( - objective_function, init_mean, init_var, alpha=0., num_iterations=10) + objective_function, init_mean, init_var, alpha=0.0, num_iterations=10 + ) nn.forward_all([optimal_mean, optimal_top]) - assert np.allclose(optimal_mean.d, np.array([[3.]]), atol=1e-2) - assert np.allclose(optimal_top.d, np.array([[3.]]), atol=1e-2) + assert np.allclose(optimal_mean.d, np.array([[3.0]]), atol=1e-2) + assert np.allclose(optimal_top.d, np.array([[3.0]]), atol=1e-2) def test_gaussian_cross_entropy_method_with_complicated_objective_function(self): def dummy_q_function(s, a): - return -((a - s)**2) + return -((a - s) ** 2) batch_size = 5 pop_size = 500 state_size = 1 action_size = 1 - s = np.arange(batch_size*state_size).reshape(batch_size, state_size) + s = np.arange(batch_size * state_size).reshape(batch_size, state_size) s = np.tile(s, (pop_size, 1, 1)) s = np.transpose(s, (1, 0, 2)) s_var = nn.Variable.from_numpy_array(s.reshape(batch_size, pop_size, state_size)) - def objective_function(x): return dummy_q_function(s_var, x) + + def objective_function(x): + return dummy_q_function(s_var, x) init_mean = nn.Variable.from_numpy_array(np.zeros((batch_size, action_size))) - init_var = nn.Variable.from_numpy_array(np.ones((batch_size, action_size))*4) + init_var = nn.Variable.from_numpy_array(np.ones((batch_size, action_size)) * 4) optimal_mean, optimal_top = RF.gaussian_cross_entropy_method( - objective_function, init_mean, init_var, pop_size, alpha=0., num_iterations=10) + objective_function, init_mean, init_var, pop_size, alpha=0.0, num_iterations=10 + ) nn.forward_all([optimal_mean, optimal_top]) - assert np.allclose(optimal_mean.d, np.array([[0.], [1.], [2.], [3.], [4.]]), atol=1e-2) - assert np.allclose(optimal_top.d, np.array([[0.], [1.], [2.], [3.], [4.]]), atol=1e-2) + assert np.allclose(optimal_mean.d, np.array([[0.0], [1.0], [2.0], [3.0], [4.0]]), atol=1e-2) + assert np.allclose(optimal_top.d, np.array([[0.0], [1.0], [2.0], [3.0], [4.0]]), atol=1e-2) def test_random_shooting_method(self): def objective_function(x): - return -((x - 3.)**2) + return -((x - 3.0) ** 2) batch_size = 1 var_size = 1 upper_bound = np.ones((batch_size, var_size)) * 3.5 lower_bound = np.ones((batch_size, var_size)) * 2.5 - optimal_top = RF.random_shooting_method( - objective_function, upper_bound, lower_bound) + optimal_top = RF.random_shooting_method(objective_function, upper_bound, lower_bound) nn.forward_all([optimal_top]) - assert np.allclose(optimal_top.d, np.array([[3.]]), atol=1e-1) + assert np.allclose(optimal_top.d, np.array([[3.0]]), atol=1e-1) def test_random_shooting_method_with_complicated_objective_function(self): def dummy_q_function(s, a): - return -((a - s)**2) + return -((a - s) ** 2) batch_size = 5 sample_size = 500 state_size = 1 action_size = 1 - s = np.arange(batch_size*state_size).reshape(batch_size, state_size) + s = np.arange(batch_size * state_size).reshape(batch_size, state_size) s = np.tile(s, (sample_size, 1, 1)) s = np.transpose(s, (1, 0, 2)) s_var = nn.Variable.from_numpy_array(s.reshape(batch_size, sample_size, state_size)) - def objective_function(x): return dummy_q_function(s_var, x) - upper_bound = np.ones((batch_size, action_size))*5 + def objective_function(x): + return dummy_q_function(s_var, x) + + upper_bound = np.ones((batch_size, action_size)) * 5 lower_bound = np.zeros((batch_size, action_size)) optimal_top = RF.random_shooting_method(objective_function, upper_bound, lower_bound) nn.forward_all([optimal_top]) - assert np.allclose(optimal_top.d, np.array([[0.], [1.], [2.], [3.], [4.]]), atol=1e-1) + assert np.allclose(optimal_top.d, np.array([[0.0], [1.0], [2.0], [3.0], [4.0]]), atol=1e-1) def test_random_shooting_method_with_invalid_bounds(self): def objective_function(x): - return -((x - 3.)**2) + return -((x - 3.0) ** 2) batch_size = 1 var_size = 1 @@ -338,8 +341,9 @@ def objective_function(x): def test_triangular_matrix(self, batch_size, diag_size, upper): non_diag_size = diag_size * (diag_size - 1) // 2 diagonal = nn.Variable.from_numpy_array(np.random.normal(size=(batch_size, diag_size)).astype(np.float32)) - non_diagonal = nn.Variable.from_numpy_array(np.random.normal( - size=(batch_size, non_diag_size)).astype(np.float32)) + non_diagonal = nn.Variable.from_numpy_array( + np.random.normal(size=(batch_size, non_diag_size)).astype(np.float32) + ) triangular_matrix = RF.triangular_matrix(diagonal, non_diagonal, upper) triangular_matrix.forward() @@ -385,7 +389,7 @@ def test_triangular_matrix_create_diagonal_matrix(self, batch_size, diag_size, u assert value == 0 @pytest.mark.parametrize("batch_size", [i for i in range(1, 4)]) - @pytest.mark.parametrize("shape", [(1, ), (1, 2), (2, 3), (4, 5, 6)]) + @pytest.mark.parametrize("shape", [(1,), (1, 2), (2, 3), (4, 5, 6)]) def test_batch_flatten(self, batch_size, shape): x = nn.Variable.from_numpy_array(np.random.normal(size=(batch_size, *shape)).astype(np.float32)) flattened_x = RF.batch_flatten(x) @@ -406,12 +410,12 @@ def pi(self, s: nn.Variable): z = NPF.affine(s, n_outmaps=5) return D.Gaussian(z, nn.Variable.from_numpy_array(np.zeros(z.shape))) - test_model = TestModel('test') + test_model = TestModel("test") for parameter in test_model.get_parameters().values(): parameter.grad.zero() batch_size = 3 - shape = (10, ) + shape = (10,) s = nn.Variable.from_numpy_array(np.random.normal(size=(batch_size, *shape)).astype(np.float32)) distribution = test_model.pi(s) @@ -513,14 +517,15 @@ def test_swapaxes(self, axis1, axis2): assert np.allclose(original.d, reswapped.d) - @pytest.mark.parametrize("x, mean, std, value_clip, expected", - [ - (np.array([2.0]), np.array([1.0]), np.array([0.2]), None, np.array([5.0])), - (np.array([2.0]), np.array([1.0]), np.array([0.2]), (-1.5, 1.5), np.array([1.5])), - (np.array([-2.0]), np.array([1.0]), np.array([0.2]), (-1.5, 1.5), np.array([-1.5])), - (np.array([[2.0], [1.0]]), np.array([[1.0]]), np.array([[0.2]]), None, - np.array([[5.0], [0.0]])), - ]) + @pytest.mark.parametrize( + "x, mean, std, value_clip, expected", + [ + (np.array([2.0]), np.array([1.0]), np.array([0.2]), None, np.array([5.0])), + (np.array([2.0]), np.array([1.0]), np.array([0.2]), (-1.5, 1.5), np.array([1.5])), + (np.array([-2.0]), np.array([1.0]), np.array([0.2]), (-1.5, 1.5), np.array([-1.5])), + (np.array([[2.0], [1.0]]), np.array([[1.0]]), np.array([[0.2]]), None, np.array([[5.0], [0.0]])), + ], + ) def test_normalize(self, x, expected, mean, std, value_clip): x_var = nn.Variable.from_numpy_array(x) mean_var = nn.Variable.from_numpy_array(mean) @@ -531,14 +536,15 @@ def test_normalize(self, x, expected, mean, std, value_clip): assert np.allclose(actual_var.d, expected) - @pytest.mark.parametrize("x, mean, std, value_clip, expected", - [ - (np.array([2.0]), np.array([1.0]), np.array([0.2]), None, np.array([1.4])), - (np.array([2.0]), np.array([1.0]), np.array([0.2]), (-1.0, 1.0), np.array([1.0])), - (np.array([-2.0]), np.array([-1.0]), np.array([0.2]), (-1.0, 1.0), np.array([-1.0])), - (np.array([[2.0], [1.0]]), np.array([[1.0]]), np.array([[0.2]]), None, - np.array([[1.4], [1.2]])), - ]) + @pytest.mark.parametrize( + "x, mean, std, value_clip, expected", + [ + (np.array([2.0]), np.array([1.0]), np.array([0.2]), None, np.array([1.4])), + (np.array([2.0]), np.array([1.0]), np.array([0.2]), (-1.0, 1.0), np.array([1.0])), + (np.array([-2.0]), np.array([-1.0]), np.array([0.2]), (-1.0, 1.0), np.array([-1.0])), + (np.array([[2.0], [1.0]]), np.array([[1.0]]), np.array([[0.2]]), None, np.array([[1.4], [1.2]])), + ], + ) def test_unnormalize(self, x, expected, mean, std, value_clip): x_var = nn.Variable.from_numpy_array(x) mean_var = nn.Variable.from_numpy_array(mean) @@ -549,15 +555,17 @@ def test_unnormalize(self, x, expected, mean, std, value_clip): assert np.allclose(actual_var.d, expected) - @pytest.mark.parametrize("var, epsilon, mode_for_floating_point_error, expected", - [ - (np.array([3.0]), 1.0, "add", np.array([2.0])), - (np.array([4.0]), 0.01, "max", np.array([2.0])), - (np.array([0.4]), 1.0, "max", np.array([1.0])), - (np.array([[3.0], [8.0]]), 1.0, "add", np.array([[2.0], [3.0]])), - (np.array([[4.0], [9.0]]), 0.01, "max", np.array([[2.0], [3.0]])), - (np.array([[0.4], [0.9]]), 1.0, "max", np.array([[1.0], [1.0]])), - ]) + @pytest.mark.parametrize( + "var, epsilon, mode_for_floating_point_error, expected", + [ + (np.array([3.0]), 1.0, "add", np.array([2.0])), + (np.array([4.0]), 0.01, "max", np.array([2.0])), + (np.array([0.4]), 1.0, "max", np.array([1.0])), + (np.array([[3.0], [8.0]]), 1.0, "add", np.array([[2.0], [3.0]])), + (np.array([[4.0], [9.0]]), 0.01, "max", np.array([[2.0], [3.0]])), + (np.array([[0.4], [0.9]]), 1.0, "max", np.array([[1.0], [1.0]])), + ], + ) def test_compute_std(self, var, epsilon, mode_for_floating_point_error, expected): variance_variable = nn.Variable.from_numpy_array(var) actual_var = RF.compute_std(variance_variable, epsilon, mode_for_floating_point_error) diff --git a/tests/test_initializers.py b/tests/test_initializers.py index e3decabc..4108d059 100644 --- a/tests/test_initializers.py +++ b/tests/test_initializers.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -28,19 +28,17 @@ def test_he_normal(self): kernel = (10, 10) factor = 10.0 - with mock.patch('nnabla_rl.initializers.calc_normal_std_he_forward', return_value=True) as mock_calc: + with mock.patch("nnabla_rl.initializers.calc_normal_std_he_forward", return_value=True) as mock_calc: mock_calc.return_value = 10 - initializer = RI.HeNormal( - inmaps, outmaps, kernel, factor, mode='fan_in') + initializer = RI.HeNormal(inmaps, outmaps, kernel, factor, mode="fan_in") initializer(shape=(5, 5)) mock_calc.assert_called_once_with(inmaps, outmaps, kernel, factor) - with mock.patch('nnabla_rl.initializers.calc_normal_std_he_backward', return_value=True) as mock_calc: + with mock.patch("nnabla_rl.initializers.calc_normal_std_he_backward", return_value=True) as mock_calc: mock_calc.return_value = 10 - initializer = RI.HeNormal( - inmaps, outmaps, kernel, factor, mode='fan_out') + initializer = RI.HeNormal(inmaps, outmaps, kernel, factor, mode="fan_out") initializer(shape=(5, 5)) mock_calc.assert_called_once_with(inmaps, outmaps, kernel, factor) @@ -50,11 +48,10 @@ def test_lecun_normal(self): kernel = (10, 10) factor = 10.0 - with mock.patch('nnabla_rl.initializers.calc_normal_std_he_forward', return_value=True) as mock_calc: + with mock.patch("nnabla_rl.initializers.calc_normal_std_he_forward", return_value=True) as mock_calc: mock_calc.return_value = 10 - initializer = RI.LeCunNormal( - inmaps, outmaps, kernel, factor, mode='fan_in') + initializer = RI.LeCunNormal(inmaps, outmaps, kernel, factor, mode="fan_in") initializer(shape=(5, 5)) mock_calc.assert_called_once_with(inmaps, outmaps, kernel, factor) @@ -64,19 +61,17 @@ def test_he_uniform(self): kernel = (10, 10) factor = 10.0 - with mock.patch('nnabla_rl.initializers.calc_uniform_lim_he_forward', return_value=True) as mock_calc: + with mock.patch("nnabla_rl.initializers.calc_uniform_lim_he_forward", return_value=True) as mock_calc: mock_calc.return_value = 10 - initializer = RI.HeUniform( - inmaps, outmaps, kernel, factor, mode='fan_in') + initializer = RI.HeUniform(inmaps, outmaps, kernel, factor, mode="fan_in") initializer(shape=(5, 5)) mock_calc.assert_called_once_with(inmaps, outmaps, kernel, factor) - with mock.patch('nnabla_rl.initializers.calc_uniform_lim_he_backward', return_value=True) as mock_calc: + with mock.patch("nnabla_rl.initializers.calc_uniform_lim_he_backward", return_value=True) as mock_calc: mock_calc.return_value = 10 - initializer = RI.HeUniform( - inmaps, outmaps, kernel, factor, mode='fan_out') + initializer = RI.HeUniform(inmaps, outmaps, kernel, factor, mode="fan_out") initializer(shape=(5, 5)) mock_calc.assert_called_once_with(inmaps, outmaps, kernel, factor) @@ -87,7 +82,7 @@ def test_he_normal_unknown_mode(self): factor = 10.0 with pytest.raises(NotImplementedError): - RI.HeNormal(inmaps, outmaps, kernel, factor, mode='fan_unknown') + RI.HeNormal(inmaps, outmaps, kernel, factor, mode="fan_unknown") def test_he_uniform_unknown_mode(self): inmaps = 10 @@ -96,7 +91,7 @@ def test_he_uniform_unknown_mode(self): factor = 10.0 with pytest.raises(NotImplementedError): - RI.HeUniform(inmaps, outmaps, kernel, factor, mode='fan_unknown') + RI.HeUniform(inmaps, outmaps, kernel, factor, mode="fan_unknown") def test_he_uniform_with_rng(self): inmaps = 10 @@ -128,41 +123,37 @@ def test_lecun_normal_unknown_mode(self): factor = 10.0 with pytest.raises(NotImplementedError): - RI.LeCunNormal(inmaps, outmaps, kernel, factor, mode='fan_out') + RI.LeCunNormal(inmaps, outmaps, kernel, factor, mode="fan_out") with pytest.raises(NotImplementedError): - RI.LeCunNormal(inmaps, outmaps, kernel, factor, mode='fan_unknown') + RI.LeCunNormal(inmaps, outmaps, kernel, factor, mode="fan_unknown") - @pytest.mark.parametrize("inmap, outmap, kernel, factor", [(3*i, 5*i, (i, i), 0.5 * i) for i in range(1, 10)]) + @pytest.mark.parametrize("inmap, outmap, kernel, factor", [(3 * i, 5 * i, (i, i), 0.5 * i) for i in range(1, 10)]) def test_calc_normal_std_he_forward(self, inmap, outmap, kernel, factor): n = inmap * kernel[0] * kernel[1] expected = np.sqrt(factor / n) - actual = RI.calc_normal_std_he_forward( - inmaps=inmap, outmaps=outmap, kernel=kernel, factor=factor) + actual = RI.calc_normal_std_he_forward(inmaps=inmap, outmaps=outmap, kernel=kernel, factor=factor) np.testing.assert_almost_equal(actual, expected) - @pytest.mark.parametrize("inmap, outmap, kernel, factor", [(3*i, 5*i, (i, i), 0.5 * i) for i in range(1, 10)]) + @pytest.mark.parametrize("inmap, outmap, kernel, factor", [(3 * i, 5 * i, (i, i), 0.5 * i) for i in range(1, 10)]) def test_calc_normal_std_he_backward(self, inmap, outmap, kernel, factor): n = outmap * kernel[0] * kernel[1] expected = np.sqrt(factor / n) - actual = RI.calc_normal_std_he_backward( - inmaps=inmap, outmaps=outmap, kernel=kernel, factor=factor) + actual = RI.calc_normal_std_he_backward(inmaps=inmap, outmaps=outmap, kernel=kernel, factor=factor) np.testing.assert_almost_equal(actual, expected) - @pytest.mark.parametrize("inmap, outmap, kernel, factor", [(3*i, 5*i, (i, i), 0.5 * i) for i in range(1, 10)]) + @pytest.mark.parametrize("inmap, outmap, kernel, factor", [(3 * i, 5 * i, (i, i), 0.5 * i) for i in range(1, 10)]) def test_calc_uniform_lim_he_forward(self, inmap, outmap, kernel, factor): n = inmap * kernel[0] * kernel[1] expected = np.sqrt((3.0 * factor) / n) - actual = RI.calc_uniform_lim_he_forward( - inmaps=inmap, outmaps=outmap, kernel=kernel, factor=factor) + actual = RI.calc_uniform_lim_he_forward(inmaps=inmap, outmaps=outmap, kernel=kernel, factor=factor) np.testing.assert_almost_equal(actual, expected) - @pytest.mark.parametrize("inmap, outmap, kernel, factor", [(3*i, 5*i, (i, i), 0.5 * i) for i in range(1, 10)]) + @pytest.mark.parametrize("inmap, outmap, kernel, factor", [(3 * i, 5 * i, (i, i), 0.5 * i) for i in range(1, 10)]) def test_calc_uniform_lim_he_backward(self, inmap, outmap, kernel, factor): n = outmap * kernel[0] * kernel[1] expected = np.sqrt((3.0 * factor) / n) - actual = RI.calc_uniform_lim_he_backward( - inmaps=inmap, outmaps=outmap, kernel=kernel, factor=factor) + actual = RI.calc_uniform_lim_he_backward(inmaps=inmap, outmaps=outmap, kernel=kernel, factor=factor) np.testing.assert_almost_equal(actual, expected) @pytest.mark.parametrize("std", [0.5 * i for i in range(1, 10)]) diff --git a/tests/test_parametric_functions.py b/tests/test_parametric_functions.py index 60605ad9..fe0a3fb2 100644 --- a/tests/test_parametric_functions.py +++ b/tests/test_parametric_functions.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -28,19 +28,19 @@ def setup_method(self, method): def test_noisy_net_forward(self): nn.seed(0) x = nn.Variable.from_numpy_array(np.random.normal(size=(5, 5))) - with nn.parameter_scope('noisy1'): + with nn.parameter_scope("noisy1"): y1 = RPF.noisy_net(x, n_outmap=10, seed=1) y1_params = nn.get_parameters() assert y1.shape == (5, 10) nn.seed(0) - with nn.parameter_scope('noisy2'): + with nn.parameter_scope("noisy2"): y2 = RPF.noisy_net(x, n_outmap=10, seed=1) y2_params = nn.get_parameters() assert y1.shape == y2.shape assert y1_params.keys() == y2_params.keys() - for param_name in ['W', 'noisy_W', 'b', 'noisy_b']: + for param_name in ["W", "noisy_W", "b", "noisy_b"]: assert param_name in y1_params.keys() nn.forward_all([y1, y2]) @@ -51,19 +51,19 @@ def test_noisy_net_forward(self): def test_noisy_net_rng(self): rng = np.random.RandomState(seed=0) x = nn.Variable.from_numpy_array(np.random.normal(size=(5, 5))) - with nn.parameter_scope('noisy1'): + with nn.parameter_scope("noisy1"): y1 = RPF.noisy_net(x, n_outmap=10, seed=1, rng=rng) y1_params = nn.get_parameters() assert y1.shape == (5, 10) rng = np.random.RandomState(seed=0) - with nn.parameter_scope('noisy2'): + with nn.parameter_scope("noisy2"): y2 = RPF.noisy_net(x, n_outmap=10, seed=1, rng=rng) y2_params = nn.get_parameters() assert y1.shape == y2.shape assert y1_params.keys() == y2_params.keys() - for param_name in ['W', 'noisy_W', 'b', 'noisy_b']: + for param_name in ["W", "noisy_W", "b", "noisy_b"]: assert param_name in y1_params.keys() nn.forward_all([y1, y2]) @@ -74,13 +74,13 @@ def test_noisy_net_rng(self): def test_noisy_net_backward(self): nn.seed(0) x = nn.Variable.from_numpy_array(np.random.normal(size=(5, 5))) - with nn.parameter_scope('noisy1'): + with nn.parameter_scope("noisy1"): y1 = RPF.noisy_net(x, n_outmap=10, seed=1) y1_params = nn.get_parameters() assert y1.shape == (5, 10) nn.seed(0) - with nn.parameter_scope('noisy2'): + with nn.parameter_scope("noisy2"): y2 = RPF.noisy_net(x, n_outmap=10, seed=1) y2_params = nn.get_parameters() assert y1.shape == y2.shape @@ -102,65 +102,74 @@ def test_noisy_net_backward(self): def test_noisy_net_base_axis(self): x = nn.Variable.from_numpy_array(np.random.normal(size=(5, 5, 5))) - with nn.parameter_scope('noisy1'): + with nn.parameter_scope("noisy1"): y1 = RPF.noisy_net(x, n_outmap=10, seed=1, base_axis=2) assert y1.shape == (5, 5, 10) def test_noisy_net_without_deterministic_bias(self): x = nn.Variable.from_numpy_array(np.random.normal(size=(5, 5))) - with nn.parameter_scope('noisy1'): + with nn.parameter_scope("noisy1"): y1 = RPF.noisy_net(x, n_outmap=10, seed=1, with_bias=False) params = nn.get_parameters() assert y1.shape == (5, 10) - assert 'b' not in params.keys() - assert 'noisy_b' in params.keys() + assert "b" not in params.keys() + assert "noisy_b" in params.keys() def test_noisy_net_without_noisy_bias(self): x = nn.Variable.from_numpy_array(np.random.normal(size=(5, 5))) - with nn.parameter_scope('noisy1'): + with nn.parameter_scope("noisy1"): y1 = RPF.noisy_net(x, n_outmap=10, seed=1, with_noisy_bias=False) params = nn.get_parameters() assert y1.shape == (5, 10) - assert 'b' in params.keys() - assert 'noisy_b' not in params.keys() + assert "b" in params.keys() + assert "noisy_b" not in params.keys() def test_noisy_net_without_bias(self): x = nn.Variable.from_numpy_array(np.random.normal(size=(5, 5))) - with nn.parameter_scope('noisy1'): + with nn.parameter_scope("noisy1"): y1 = RPF.noisy_net(x, n_outmap=10, seed=1, with_bias=False, with_noisy_bias=False) params = nn.get_parameters() assert y1.shape == (5, 10) - assert 'b' not in params.keys() - assert 'noisy_b' not in params.keys() - - @pytest.mark.parametrize("x_c1, x_c2, x_c3, expected", - [(np.array([[np.log(2), 0., 0.], [0., 0., 0.], [0., 0., np.log(2)]]), - np.array([[0., np.log(2), 0.], [0., 0., 0.], [0., np.log(2), 0.]]), - np.array([[0., 0., 0.], [np.log(2), 0., np.log(2)], [0., 0., 0.]]), - np.array([[0., 0.], [0., 0.], [0., 0.]])), - (np.array([[0., np.log(2), 0.], [0., 0., 0.], [0., 0., 0.]]), - np.array([[0., 0., 0.], [0., 0., 0.], [0., 0., np.log(2)]]), - np.array([[0., 0., 0.], [np.log(2), 0., 0.], [0., 0., 0.]]), - np.array([[0., -0.1], [0.1, 0.1], [-0.1, 0.]])), - (np.array([[0., np.log(4), 0.], [0., np.log(4), 0.], [0., 0., 0.]]), - np.array([[0., 0., np.log(4)], [0., 0., 0.], [0., 0., np.log(4)]]), - np.array([[0., 0., 0.], [0., 0., np.log(4)], [0., 0., 0.]]), - np.array([[0., -0.2], [0.4, 0.], [0.25, 0.]])) - ]) + assert "b" not in params.keys() + assert "noisy_b" not in params.keys() + + @pytest.mark.parametrize( + "x_c1, x_c2, x_c3, expected", + [ + ( + np.array([[np.log(2), 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, np.log(2)]]), + np.array([[0.0, np.log(2), 0.0], [0.0, 0.0, 0.0], [0.0, np.log(2), 0.0]]), + np.array([[0.0, 0.0, 0.0], [np.log(2), 0.0, np.log(2)], [0.0, 0.0, 0.0]]), + np.array([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]), + ), + ( + np.array([[0.0, np.log(2), 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]), + np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, np.log(2)]]), + np.array([[0.0, 0.0, 0.0], [np.log(2), 0.0, 0.0], [0.0, 0.0, 0.0]]), + np.array([[0.0, -0.1], [0.1, 0.1], [-0.1, 0.0]]), + ), + ( + np.array([[0.0, np.log(4), 0.0], [0.0, np.log(4), 0.0], [0.0, 0.0, 0.0]]), + np.array([[0.0, 0.0, np.log(4)], [0.0, 0.0, 0.0], [0.0, 0.0, np.log(4)]]), + np.array([[0.0, 0.0, 0.0], [0.0, 0.0, np.log(4)], [0.0, 0.0, 0.0]]), + np.array([[0.0, -0.2], [0.4, 0.0], [0.25, 0.0]]), + ), + ], + ) def test_spatial_softmax_forward(self, x_c1, x_c2, x_c3, expected): batch_size = 1 channel = 3 x = np.stack([x_c1, x_c2, x_c3], axis=0)[np.newaxis, :, :, :] - expected = expected.reshape((-1, channel*2)) + expected = expected.reshape((-1, channel * 2)) - with nn.parameter_scope('spatial_softmax1'): + with nn.parameter_scope("spatial_softmax1"): x = nn.Variable.from_numpy_array(x) - y1 = RPF.spatial_softmax(x, alpha_init=1.) + y1 = RPF.spatial_softmax(x, alpha_init=1.0) y1_params = nn.get_parameters() - assert y1.shape == (batch_size, channel*2) + assert y1.shape == (batch_size, channel * 2) - for param_name in ['alpha']: + for param_name in ["alpha"]: assert param_name in y1_params.keys() nn.forward_all([y1]) @@ -172,14 +181,14 @@ def test_spatial_softmax_backward(self): height = 5 width = 5 - with nn.parameter_scope('spatial_softmax1'): + with nn.parameter_scope("spatial_softmax1"): x = nn.Variable.from_numpy_array(np.random.normal(size=(batch_size, channel, height, width))) - y1 = RPF.spatial_softmax(x, alpha_init=1.) + y1 = RPF.spatial_softmax(x, alpha_init=1.0) y1_params = nn.get_parameters() - assert y1.shape == (batch_size, channel*2) + assert y1.shape == (batch_size, channel * 2) - with nn.parameter_scope('spatial_softmax2'): - y2 = RPF.spatial_softmax(x, alpha_init=1.) + with nn.parameter_scope("spatial_softmax2"): + y2 = RPF.spatial_softmax(x, alpha_init=1.0) y2_params = nn.get_parameters() assert y1.shape == y2.shape @@ -203,7 +212,7 @@ def test_spatial_softmax_with_fixed_alpha(self): height = 5 width = 5 - with nn.parameter_scope('spatial_softmax1'): + with nn.parameter_scope("spatial_softmax1"): x = nn.Variable.from_numpy_array(np.random.normal(size=(batch_size, channel, height, width))) y1 = RPF.spatial_softmax(x, alpha_init=1.5, fix_alpha=True) y1_params = nn.get_parameters() @@ -219,16 +228,16 @@ def test_spatial_softmax_with_fixed_alpha(self): assert np.allclose(y1_params["alpha"].d, np.array([[1.5]])) @pytest.mark.parametrize("batch_size", [i for i in range(1, 3)]) - @pytest.mark.parametrize("shape", [(1, ), (5, ), (1, 5), (5, 5)]) + @pytest.mark.parametrize("shape", [(1,), (5,), (1, 5), (5, 5)]) def test_lstm_cell_in_out_shape(self, batch_size, shape): out_size = 10 input_shape = (batch_size, *shape) x = nn.Variable.from_numpy_array(np.random.normal(size=input_shape)) - hidden_shape = (batch_size, *(shape[0:-1] + (out_size, ))) + hidden_shape = (batch_size, *(shape[0:-1] + (out_size,))) h = nn.Variable.from_numpy_array(np.random.normal(size=hidden_shape)) - cell_shape = (batch_size, *(shape[0:-1] + (out_size, ))) + cell_shape = (batch_size, *(shape[0:-1] + (out_size,))) c = nn.Variable.from_numpy_array(np.random.normal(size=cell_shape)) - with nn.parameter_scope('test'): + with nn.parameter_scope("test"): y, c = RPF.lstm_cell(x, h, c, state_size=out_size, base_axis=(len(shape))) assert y.shape == hidden_shape assert c.shape == cell_shape diff --git a/tests/test_replay_buffer.py b/tests/test_replay_buffer.py index 4c23dcf0..ccd6a971 100644 --- a/tests/test_replay_buffer.py +++ b/tests/test_replay_buffer.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -144,8 +144,8 @@ def test_buffer_len(self): assert len(buffer) == 10 def _generate_experience_mock(self): - state_shape = (5, ) - action_shape = (10, ) + state_shape = (5,) + action_shape = (10,) state = np.empty(shape=state_shape) action = np.empty(shape=action_shape) diff --git a/tests/test_typing.py b/tests/test_typing.py index 559c65bc..916f231f 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -1,4 +1,4 @@ -# Copyright 2023 Sony Group Corporation. +# Copyright 2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ from nnabla_rl.typing import accepted_shapes -class TestTyping(): +class TestTyping: def test_accepted_shapes_call_with_args(self): @accepted_shapes(x=(3, 5), u=(2, 4)) def dummy_function(x, u): @@ -105,6 +105,7 @@ def dummy_function(x, u=np.ones((2, 1)), batched=True): def test_accepted_shapes_decorator_has_invalid_args(self): with pytest.raises(TypeError): + @accepted_shapes((4, 3), u=(2, 1)) def dummy_function(x, u=np.ones((2, 1)), batched=True): pass @@ -153,5 +154,5 @@ def dummy_function(x=np.zeros((3, 5)), u=np.zeros((2, 4))): dummy_function(u=np.zeros((2, 4)), x=np.zeros((3, 5))) -if __name__ == '__main__': +if __name__ == "__main__": pytest.main() diff --git a/tests/testing_utils.py b/tests/testing_utils.py index a81ae670..51efe6b3 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -41,7 +41,7 @@ def generate_dummy_experiences(env, experience_num): state = env.reset() if isinstance(env.action_space, gym.spaces.Discrete): action = env.action_space.sample() - action = np.reshape(action, newshape=(1, )) + action = np.reshape(action, newshape=(1,)) else: action = env.action_space.sample() next_state, reward, done, info = env.step(action) @@ -58,7 +58,7 @@ def generate_dummy_trajectory(env, trajectory_length): state = env.reset() if isinstance(env.action_space, gym.spaces.Discrete): action = env.action_space.sample() - action = np.reshape(action, newshape=(1, )) + action = np.reshape(action, newshape=(1,)) else: action = env.action_space.sample() next_state, reward, done, info = env.step(action) diff --git a/tests/utils/test_copy.py b/tests/utils/test_copy.py index e8634c6e..15b9643c 100644 --- a/tests/utils/test_copy.py +++ b/tests/utils/test_copy.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,8 +24,7 @@ class DummyNetwork(Model): - def __init__(self, scope_name, - weight_initializer=None, bias_initialzier=None): + def __init__(self, scope_name, weight_initializer=None, bias_initialzier=None): self._scope_name = scope_name self._weight_initializer = weight_initializer self._bias_initialzier = bias_initialzier @@ -36,9 +35,7 @@ def __init__(self, scope_name, def __call__(self, dummy_variable): with nn.parameter_scope(self.scope_name): - h = NPF.affine(dummy_variable, 1, - w_init=self._weight_initializer, - b_init=self._bias_initialzier) + h = NPF.affine(dummy_variable, 1, w_init=self._weight_initializer, b_init=self._bias_initialzier) return h @@ -46,55 +43,37 @@ class TestCopy(object): def test_copy_network_parameters(self): nn.clear_parameters() - base = DummyNetwork('base', - NI.ConstantInitializer(1), - NI.ConstantInitializer(1)) - target = DummyNetwork('target', - NI.ConstantInitializer(2), - NI.ConstantInitializer(2)) + base = DummyNetwork("base", NI.ConstantInitializer(1), NI.ConstantInitializer(1)) + target = DummyNetwork("target", NI.ConstantInitializer(2), NI.ConstantInitializer(2)) - copy_network_parameters(base.get_parameters(), - target.get_parameters(), - tau=1.0) + copy_network_parameters(base.get_parameters(), target.get_parameters(), tau=1.0) - assert self._has_same_parameters(base.get_parameters(), - target.get_parameters()) + assert self._has_same_parameters(base.get_parameters(), target.get_parameters()) def test_softcopy_network_parameters(self): nn.clear_parameters() - base = DummyNetwork('base', - NI.ConstantInitializer(1), - NI.ConstantInitializer(1)) + base = DummyNetwork("base", NI.ConstantInitializer(1), NI.ConstantInitializer(1)) weight_initializer = NI.ConstantInitializer(2) bias_initializer = NI.ConstantInitializer(2) - target_original = DummyNetwork('target_original', - weight_initializer, - bias_initializer) - target = DummyNetwork('target', - weight_initializer, - bias_initializer) - - copy_network_parameters(base.get_parameters(), - target.get_parameters(), - tau=0.75) - - assert self._has_soft_same_parameters(target.get_parameters(), - base.get_parameters(), - target_original.get_parameters(), - tau=0.75) + target_original = DummyNetwork("target_original", weight_initializer, bias_initializer) + target = DummyNetwork("target", weight_initializer, bias_initializer) + + copy_network_parameters(base.get_parameters(), target.get_parameters(), tau=0.75) + + assert self._has_soft_same_parameters( + target.get_parameters(), base.get_parameters(), target_original.get_parameters(), tau=0.75 + ) def test_softcopy_network_parameters_wrong_tau(self): nn.clear_parameters() - base = DummyNetwork('base') - target = DummyNetwork('target') + base = DummyNetwork("base") + target = DummyNetwork("target") with pytest.raises(ValueError): - copy_network_parameters(base.get_parameters(), - target.get_parameters(), - tau=-0.75) + copy_network_parameters(base.get_parameters(), target.get_parameters(), tau=-0.75) def _has_same_parameters(self, params1, params2): for key in params1.keys(): @@ -104,7 +83,6 @@ def _has_same_parameters(self, params1, params2): def _has_soft_same_parameters(self, merged_params, base_params1, base_params2, tau): for key in merged_params.keys(): - if not np.allclose(merged_params[key].d, - base_params1[key].d * tau + base_params2[key].d * (1 - tau)): + if not np.allclose(merged_params[key].d, base_params1[key].d * tau + base_params2[key].d * (1 - tau)): return False return True diff --git a/tests/utils/test_data.py b/tests/utils/test_data.py index 494bb19b..181b7ed8 100644 --- a/tests/utils/test_data.py +++ b/tests/utils/test_data.py @@ -19,12 +19,20 @@ import nnabla as nn import nnabla_rl.environments as E -from nnabla_rl.utils.data import (RingBuffer, add_batch_dimension, compute_std_ndarray, list_of_dict_to_dict_of_list, - marshal_dict_experiences, marshal_experiences, normalize_ndarray, - set_data_to_variable, unnormalize_ndarray) - - -class TestData(): +from nnabla_rl.utils.data import ( + RingBuffer, + add_batch_dimension, + compute_std_ndarray, + list_of_dict_to_dict_of_list, + marshal_dict_experiences, + marshal_experiences, + normalize_ndarray, + set_data_to_variable, + unnormalize_ndarray, +) + + +class TestData: def test_set_data_to_variable(self): variable = nn.Variable((3,)) array = np.random.rand(3) @@ -54,9 +62,9 @@ def test_marshal_experiences(self): dummy_env = E.DummyContinuous() experiences = generate_dummy_experiences(dummy_env, batch_size) state, action, reward, done, next_state, info = marshal_experiences(experiences) - rnn_states = info['rnn_states'] - rnn_dummy_state1 = rnn_states['dummy_scope']['dummy_state1'] - rnn_dummy_state2 = rnn_states['dummy_scope']['dummy_state2'] + rnn_states = info["rnn_states"] + rnn_dummy_state1 = rnn_states["dummy_scope"]["dummy_state1"] + rnn_dummy_state2 = rnn_states["dummy_scope"]["dummy_state2"] assert state.shape == (batch_size, dummy_env.observation_space.shape[0]) assert action.shape == (batch_size, dummy_env.action_space.shape[0]) @@ -71,9 +79,9 @@ def test_marshal_experiences_tuple_continous(self): dummy_env = E.DummyTupleContinuous() experiences = generate_dummy_experiences(dummy_env, batch_size) state, action, reward, done, next_state, info = marshal_experiences(experiences) - rnn_states = info['rnn_states'] - rnn_dummy_state1 = rnn_states['dummy_scope']['dummy_state1'] - rnn_dummy_state2 = rnn_states['dummy_scope']['dummy_state2'] + rnn_states = info["rnn_states"] + rnn_dummy_state1 = rnn_states["dummy_scope"]["dummy_state1"] + rnn_dummy_state2 = rnn_states["dummy_scope"]["dummy_state2"] assert state[0].shape == (batch_size, dummy_env.observation_space[0].shape[0]) assert state[1].shape == (batch_size, dummy_env.observation_space[1].shape[0]) @@ -91,9 +99,9 @@ def test_marshal_experiences_tuple_discrete(self): dummy_env = E.DummyTupleDiscrete() experiences = generate_dummy_experiences(dummy_env, batch_size) state, action, reward, done, next_state, info = marshal_experiences(experiences) - rnn_states = info['rnn_states'] - rnn_dummy_state1 = rnn_states['dummy_scope']['dummy_state1'] - rnn_dummy_state2 = rnn_states['dummy_scope']['dummy_state2'] + rnn_states = info["rnn_states"] + rnn_dummy_state1 = rnn_states["dummy_scope"]["dummy_state1"] + rnn_dummy_state2 = rnn_states["dummy_scope"]["dummy_state2"] assert state[0].shape == (batch_size, 1) assert state[1].shape == (batch_size, 1) @@ -107,12 +115,12 @@ def test_marshal_experiences_tuple_discrete(self): assert rnn_dummy_state2.shape == (batch_size, 1) def test_marshal_dict_experiences(self): - experiences = {'key1': 1, 'key2': 2} - dict_experiences = [{'key_parent': experiences}, {'key_parent': experiences}] + experiences = {"key1": 1, "key2": 2} + dict_experiences = [{"key_parent": experiences}, {"key_parent": experiences}] marshaled_experience = marshal_dict_experiences(dict_experiences) - key1_experiences = marshaled_experience['key_parent']['key1'] - key2_experiences = marshaled_experience['key_parent']['key2'] + key1_experiences = marshaled_experience["key_parent"]["key1"] + key2_experiences = marshaled_experience["key_parent"]["key2"] assert key1_experiences.shape == (2, 1) assert key2_experiences.shape == (2, 1) @@ -121,13 +129,13 @@ def test_marshal_dict_experiences(self): np.testing.assert_allclose(np.asarray(key2_experiences), 2) def test_marshal_triple_nested_dict_experiences(self): - experiences = {'key1': 1, 'key2': 2} - nested_experiences = {'nest1': experiences, 'nest2': experiences} - dict_experiences = [{'key_parent': nested_experiences}, {'key_parent': nested_experiences}] + experiences = {"key1": 1, "key2": 2} + nested_experiences = {"nest1": experiences, "nest2": experiences} + dict_experiences = [{"key_parent": nested_experiences}, {"key_parent": nested_experiences}] marshaled_experience = marshal_dict_experiences(dict_experiences) - key1_experiences = marshaled_experience['key_parent']['nest1']['key1'] - key2_experiences = marshaled_experience['key_parent']['nest2']['key2'] + key1_experiences = marshaled_experience["key_parent"]["nest1"]["key1"] + key2_experiences = marshaled_experience["key_parent"]["nest2"]["key2"] assert len(key1_experiences) == 2 assert len(key2_experiences) == 2 @@ -137,31 +145,31 @@ def test_marshal_triple_nested_dict_experiences(self): def test_marashal_dict_experiences_with_inhomogeneous_part(self): installed_numpy_version = parse(np.__version__) - numpy_version1_24 = parse('1.24.0') + numpy_version1_24 = parse("1.24.0") if installed_numpy_version < numpy_version1_24: # no need to test return - experiences = {'key1': 1, 'key2': 2} - inhomgeneous_experiences = {'key1': np.empty(shape=(6, )), 'key2': 2} - dict_experiences = [{'key_parent': experiences}, {'key_parent': inhomgeneous_experiences}] + experiences = {"key1": 1, "key2": 2} + inhomgeneous_experiences = {"key1": np.empty(shape=(6,)), "key2": 2} + dict_experiences = [{"key_parent": experiences}, {"key_parent": inhomgeneous_experiences}] marshaled_experience = marshal_dict_experiences(dict_experiences) - assert 'key1' not in marshaled_experience['key_parent'] + assert "key1" not in marshaled_experience["key_parent"] - key2_experiences = marshaled_experience['key_parent']['key2'] + key2_experiences = marshaled_experience["key_parent"]["key2"] assert key2_experiences.shape == (2, 1) np.testing.assert_allclose(np.asarray(key2_experiences), 2) def test_list_of_dict_to_dict_of_list(self): - list_of_dict = [{'key1': 1, 'key2': 2}, {'key1': 1, 'key2': 2}] + list_of_dict = [{"key1": 1, "key2": 2}, {"key1": 1, "key2": 2}] dict_of_list = list_of_dict_to_dict_of_list(list_of_dict) - key1_list = dict_of_list['key1'] - key2_list = dict_of_list['key2'] + key1_list = dict_of_list["key1"] + key2_list = dict_of_list["key2"] assert len(key1_list) == 2 assert len(key2_list) == 2 @@ -184,39 +192,43 @@ def test_add_batch_dimension_tuple(self): assert actual_array[0].shape == (1, *array1.shape) assert actual_array[1].shape == (1, *array2.shape) - @pytest.mark.parametrize("x, mean, std, value_clip, expected", - [ - (np.array([2.0]), np.array([1.0]), np.array([0.2]), None, np.array([5.0])), - (np.array([2.0]), np.array([1.0]), np.array([0.2]), (-1.5, 1.5), np.array([1.5])), - (np.array([-2.0]), np.array([1.0]), np.array([0.2]), (-1.5, 1.5), np.array([-1.5])), - (np.array([[2.0], [1.0]]), np.array([[1.0]]), np.array([[0.2]]), None, - np.array([[5.0], [0.0]])), - ]) + @pytest.mark.parametrize( + "x, mean, std, value_clip, expected", + [ + (np.array([2.0]), np.array([1.0]), np.array([0.2]), None, np.array([5.0])), + (np.array([2.0]), np.array([1.0]), np.array([0.2]), (-1.5, 1.5), np.array([1.5])), + (np.array([-2.0]), np.array([1.0]), np.array([0.2]), (-1.5, 1.5), np.array([-1.5])), + (np.array([[2.0], [1.0]]), np.array([[1.0]]), np.array([[0.2]]), None, np.array([[5.0], [0.0]])), + ], + ) def test_normalize_ndarray(self, x, expected, mean, std, value_clip): actual_var = normalize_ndarray(x, mean, std, value_clip=value_clip) assert np.allclose(actual_var, expected) - @pytest.mark.parametrize("x, mean, std, value_clip, expected", - [ - (np.array([2.0]), np.array([1.0]), np.array([0.2]), None, np.array([1.4])), - (np.array([2.0]), np.array([1.0]), np.array([0.2]), (-1.0, 1.0), np.array([1.0])), - (np.array([-2.0]), np.array([-1.0]), np.array([0.2]), (-1.0, 1.0), np.array([-1.0])), - (np.array([[2.0], [1.0]]), np.array([[1.0]]), np.array([[0.2]]), None, - np.array([[1.4], [1.2]])), - ]) + @pytest.mark.parametrize( + "x, mean, std, value_clip, expected", + [ + (np.array([2.0]), np.array([1.0]), np.array([0.2]), None, np.array([1.4])), + (np.array([2.0]), np.array([1.0]), np.array([0.2]), (-1.0, 1.0), np.array([1.0])), + (np.array([-2.0]), np.array([-1.0]), np.array([0.2]), (-1.0, 1.0), np.array([-1.0])), + (np.array([[2.0], [1.0]]), np.array([[1.0]]), np.array([[0.2]]), None, np.array([[1.4], [1.2]])), + ], + ) def test_unnormalize_ndarray(self, x, expected, mean, std, value_clip): actual_var = unnormalize_ndarray(x, mean, std, value_clip=value_clip) assert np.allclose(actual_var, expected) - @pytest.mark.parametrize("var, epsilon, mode_for_floating_point_error, expected", - [ - (np.array([3.0]), 1.0, "add", np.array([2.0])), - (np.array([4.0]), 0.01, "max", np.array([2.0])), - (np.array([0.4]), 1.0, "max", np.array([1.0])), - (np.array([[3.0], [8.0]]), 1.0, "add", np.array([[2.0], [3.0]])), - (np.array([[4.0], [9.0]]), 0.01, "max", np.array([[2.0], [3.0]])), - (np.array([[0.4], [0.9]]), 1.0, "max", np.array([[1.0], [1.0]])), - ]) + @pytest.mark.parametrize( + "var, epsilon, mode_for_floating_point_error, expected", + [ + (np.array([3.0]), 1.0, "add", np.array([2.0])), + (np.array([4.0]), 0.01, "max", np.array([2.0])), + (np.array([0.4]), 1.0, "max", np.array([1.0])), + (np.array([[3.0], [8.0]]), 1.0, "add", np.array([[2.0], [3.0]])), + (np.array([[4.0], [9.0]]), 0.01, "max", np.array([[2.0], [3.0]])), + (np.array([[0.4], [0.9]]), 1.0, "max", np.array([[1.0], [1.0]])), + ], + ) def test_compute_std_ndarray(self, var, epsilon, mode_for_floating_point_error, expected): actual_var = compute_std_ndarray(var, epsilon, mode_for_floating_point_error) assert np.allclose(actual_var, expected) @@ -268,6 +280,7 @@ def test_buffer_len(self): if __name__ == "__main__": from testing_utils import generate_dummy_experiences + pytest.main() else: from ..testing_utils import generate_dummy_experiences diff --git a/tests/utils/test_debugging.py b/tests/utils/test_debugging.py index f8745804..36834119 100644 --- a/tests/utils/test_debugging.py +++ b/tests/utils/test_debugging.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,9 +21,8 @@ from nnabla_rl.utils.debugging import count_parameter_number -class TestCountParameterNumber(): - @pytest.mark.parametrize("batch_size, state_size, output_size", [ - (5, 3, 2)]) +class TestCountParameterNumber: + @pytest.mark.parametrize("batch_size, state_size, output_size", [(5, 3, 2)]) def test_affine_count(self, batch_size, state_size, output_size): nn.clear_parameters() dummy_input = nn.Variable((batch_size, state_size)) @@ -33,4 +32,4 @@ def test_affine_count(self, batch_size, state_size, output_size): parameter_number = count_parameter_number(nn.get_parameters()) - assert parameter_number == state_size*output_size + output_size + assert parameter_number == state_size * output_size + output_size diff --git a/tests/utils/test_evaluator.py b/tests/utils/test_evaluator.py index 92333a81..56c73acc 100644 --- a/tests/utils/test_evaluator.py +++ b/tests/utils/test_evaluator.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -49,8 +49,7 @@ class TestTimestepEvaluator(object): def test_evaluation(self): num_timesteps = 100 max_episode_length = 10 - dummy_env = E.DummyAtariEnv( - done_at_random=False, max_episode_length=max_episode_length) + dummy_env = E.DummyAtariEnv(done_at_random=False, max_episode_length=max_episode_length) evaluator = TimestepEvaluator(num_timesteps=num_timesteps) dummy_algorithm = A.Dummy(dummy_env) @@ -64,8 +63,7 @@ def test_evaluation(self): def test_timestep_limit(self): num_timesteps = 113 max_episode_length = 10 - dummy_env = E.DummyAtariEnv( - done_at_random=False, max_episode_length=max_episode_length) + dummy_env = E.DummyAtariEnv(done_at_random=False, max_episode_length=max_episode_length) evaluator = TimestepEvaluator(num_timesteps=num_timesteps) dummy_algorithm = A.Dummy(dummy_env) @@ -77,5 +75,5 @@ def test_timestep_limit(self): assert dummy_algorithm.compute_eval_action.call_count == num_timesteps + 1 -if __name__ == '__main__': +if __name__ == "__main__": pytest.main() diff --git a/tests/utils/test_files.py b/tests/utils/test_files.py index d9f47ca1..ddb92661 100644 --- a/tests/utils/test_files.py +++ b/tests/utils/test_files.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,18 +25,17 @@ class TestFiles(object): def test_file_exists(self): - with mock.patch('os.path.exists', return_value=True) as _: + with mock.patch("os.path.exists", return_value=True) as _: test_file = "test" assert files.file_exists(test_file) is True def test_file_does_not_exist(self): - with mock.patch('os.path.exists', return_value=False) as _: + with mock.patch("os.path.exists", return_value=False) as _: test_file = "test" assert files.file_exists(test_file) is False def test_create_dir_if_not_exist(self): - with mock.patch('os.path.exists', return_value=False) as mock_exists, \ - mock.patch('os.makedirs') as mock_mkdirs: + with mock.patch("os.path.exists", return_value=False) as mock_exists, mock.patch("os.makedirs") as mock_mkdirs: test_file = "test" files.create_dir_if_not_exist(test_file) @@ -44,9 +43,8 @@ def test_create_dir_if_not_exist(self): mock_mkdirs.assert_called_once() def test_create_dir_when_exists(self): - with mock.patch('os.path.exists', return_value=True) as mock_exists, \ - mock.patch('os.makedirs') as mock_mkdirs: - with mock.patch('os.path.isdir', return_value=True): + with mock.patch("os.path.exists", return_value=True) as mock_exists, mock.patch("os.makedirs") as mock_mkdirs: + with mock.patch("os.path.isdir", return_value=True): test_file = "test" files.create_dir_if_not_exist(test_file) @@ -54,9 +52,8 @@ def test_create_dir_when_exists(self): mock_mkdirs.assert_not_called() def test_create_dir_when_target_is_not_directory(self): - with mock.patch('os.path.exists', return_value=True) as mock_exists, \ - mock.patch('os.makedirs') as mock_mkdirs: - with mock.patch('os.path.isdir', return_value=False): + with mock.patch("os.path.exists", return_value=True) as mock_exists, mock.patch("os.makedirs") as mock_mkdirs: + with mock.patch("os.path.isdir", return_value=False): with pytest.raises(RuntimeError): test_file = "test" files.create_dir_if_not_exist(test_file) @@ -66,8 +63,8 @@ def test_create_dir_when_target_is_not_directory(self): def test_read_write_text_to_file(self): with tempfile.TemporaryDirectory() as tempdir: - target_path = os.path.join(tempdir, 'test.txt') - time_format = '%Y-%m-%d-%H%M%S.%f' + target_path = os.path.join(tempdir, "test.txt") + time_format = "%Y-%m-%d-%H%M%S.%f" test_text = datetime.datetime.now().strftime(time_format) files.write_text_to_file(target_path, test_text) @@ -77,5 +74,5 @@ def test_read_write_text_to_file(self): assert read_text == test_text -if __name__ == '__main__': +if __name__ == "__main__": pytest.main() diff --git a/tests/utils/test_matrices.py b/tests/utils/test_matrices.py index 19c392b3..7550e877 100644 --- a/tests/utils/test_matrices.py +++ b/tests/utils/test_matrices.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,25 +24,26 @@ from nnabla_rl.utils.matrices import compute_hessian -class TestComputeHessian(): +class TestComputeHessian: def setup_method(self, method): nn.clear_parameters() def test_compute_hessian(self): - x = get_parameter_or_create("x", shape=(1, )) - y = get_parameter_or_create("y", shape=(1, )) - loss = x**3 + 2.*x*y + y**2 - x + x = get_parameter_or_create("x", shape=(1,)) + y = get_parameter_or_create("y", shape=(1,)) + loss = x**3 + 2.0 * x * y + y**2 - x - x.d = 2. - y.d = 3. + x.d = 2.0 + y.d = 3.0 actual = compute_hessian(loss, nn.get_parameters().values()) - assert np.array([[12., 2.], [2., 2.]]) == pytest.approx(actual) + assert np.array([[12.0, 2.0], [2.0, 2.0]]) == pytest.approx(actual) def test_compute_network_parameters(self): state = nn.Variable((1, 2)) - output = NPF.affine(state, 1, w_init=NI.ConstantInitializer( - value=1.), b_init=NI.ConstantInitializer(value=1.)) + output = NPF.affine( + state, 1, w_init=NI.ConstantInitializer(value=1.0), b_init=NI.ConstantInitializer(value=1.0) + ) loss = NF.sum(output**2) state_array = np.array([[1.0, 0.5]]) @@ -51,15 +52,11 @@ def test_compute_network_parameters(self): actual = compute_hessian(loss, nn.get_parameters().values()) expected = np.array( - [[2*state_array[0, 0]**2, - 2*state_array[0, 0]*state_array[0, 1], - 2*state_array[0, 0]], - [2*state_array[0, 0]*state_array[0, 1], - 2*state_array[0, 1]**2, - 2*state_array[0, 1]], - [2*state_array[0, 0], - 2*state_array[0, 1], - 2.]] + [ + [2 * state_array[0, 0] ** 2, 2 * state_array[0, 0] * state_array[0, 1], 2 * state_array[0, 0]], + [2 * state_array[0, 0] * state_array[0, 1], 2 * state_array[0, 1] ** 2, 2 * state_array[0, 1]], + [2 * state_array[0, 0], 2 * state_array[0, 1], 2.0], + ] ) assert expected == pytest.approx(actual) diff --git a/tests/utils/test_misc.py b/tests/utils/test_misc.py index 07417a0a..d760f537 100644 --- a/tests/utils/test_misc.py +++ b/tests/utils/test_misc.py @@ -1,4 +1,4 @@ -# Copyright 2021,2022,2023 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ from nnabla_rl.utils.misc import create_attention_mask, create_variable -class TestMisc(): +class TestMisc: def test_create_variable_int(self): batch_size = 3 shape = 5 @@ -36,7 +36,7 @@ def test_create_variable_tuple(self): def test_create_variable_tuples(self): batch_size = 3 - shape = ((6, ), (3, )) + shape = ((6,), (3,)) actual_var = create_variable(batch_size, shape) assert isinstance(actual_var, tuple) diff --git a/tests/utils/test_multiprocess.py b/tests/utils/test_multiprocess.py index 8b82fc84..41bcbfdc 100644 --- a/tests/utils/test_multiprocess.py +++ b/tests/utils/test_multiprocess.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,14 +20,25 @@ import nnabla.initializer as NI import nnabla.parametric_functions as NPF from nnabla_rl.models import Model -from nnabla_rl.utils.multiprocess import (copy_mp_arrays_to_params, copy_params_to_mp_arrays, mp_array_from_np_array, - mp_to_np_array, new_mp_arrays_from_params, np_to_mp_array) +from nnabla_rl.utils.multiprocess import ( + copy_mp_arrays_to_params, + copy_params_to_mp_arrays, + mp_array_from_np_array, + mp_to_np_array, + new_mp_arrays_from_params, + np_to_mp_array, +) class TestModel(Model): - def __init__(self, scope_name, input_dim, output_dim, - w_init=NI.ConstantInitializer(100.0), - b_init=NI.ConstantInitializer(100.0)): + def __init__( + self, + scope_name, + input_dim, + output_dim, + w_init=NI.ConstantInitializer(100.0), + b_init=NI.ConstantInitializer(100.0), + ): super(TestModel, self).__init__(scope_name) self._input_dim = input_dim self._output_dim = output_dim @@ -38,12 +49,9 @@ def __call__(self, x): assert x.shape[1] == self._input_dim with nn.parameter_scope(self.scope_name): - h = NPF.affine(x, n_outmaps=256, name="linear1", - w_init=self._w_init, b_init=self._b_init) - h = NPF.affine(h, n_outmaps=256, name="linear2", - w_init=self._w_init, b_init=self._b_init) - h = NPF.affine(h, n_outmaps=self._output_dim, - name="linear3", w_init=self._w_init, b_init=self._b_init) + h = NPF.affine(x, n_outmaps=256, name="linear1", w_init=self._w_init, b_init=self._b_init) + h = NPF.affine(h, n_outmaps=256, name="linear2", w_init=self._w_init, b_init=self._b_init) + h = NPF.affine(h, n_outmaps=self._output_dim, name="linear3", w_init=self._w_init, b_init=self._b_init) return h @@ -60,8 +68,7 @@ def test_mp_to_np_array(self): np_array = np.empty(shape=(10, 9, 8, 7), dtype=np.int64) mp_array = mp_array_from_np_array(np_array) - converted = mp_to_np_array( - mp_array, np_array.shape, dtype=np_array.dtype) + converted = mp_to_np_array(mp_array, np_array.shape, dtype=np_array.dtype) assert converted.shape == np_array.shape assert np.allclose(converted, np_array) @@ -71,13 +78,11 @@ def test_np_to_mp_array(self): mp_array = mp_array_from_np_array(np_array) test_array = np.random.uniform(size=(10, 9, 8, 7)) - before_copying = mp_to_np_array( - mp_array, test_array.shape, dtype=test_array.dtype) + before_copying = mp_to_np_array(mp_array, test_array.shape, dtype=test_array.dtype) assert not np.allclose(before_copying, test_array) mp_array = np_to_mp_array(test_array, mp_array, dtype=test_array.dtype) - after_copying = mp_to_np_array( - mp_array, test_array.shape, dtype=test_array.dtype) + after_copying = mp_to_np_array(mp_array, test_array.shape, dtype=test_array.dtype) assert np.allclose(after_copying, test_array) def test_new_mp_arrays_from_params(self): @@ -90,7 +95,7 @@ def test_new_mp_arrays_from_params(self): for key, value in params.items(): assert key in mp_arrays mp_array = mp_arrays[key] - print('key: ', key) + print("key: ", key) assert len(mp_array) == len(value.d.flatten()) def test_copy_params_to_mp_arrays(self): @@ -128,5 +133,5 @@ def test_copy_mp_arrays_to_params(self): assert np.allclose(value.d, 50.0) -if __name__ == '__main__': +if __name__ == "__main__": pytest.main() diff --git a/tests/utils/test_optimization.py b/tests/utils/test_optimization.py index 7ab97b51..1f44923d 100644 --- a/tests/utils/test_optimization.py +++ b/tests/utils/test_optimization.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,16 +19,18 @@ from nnabla_rl.utils.optimization import conjugate_gradient -class TestOptimization(): +class TestOptimization: def test_conjugate_gradient(self): x_dim = 2 A = np.random.uniform(-3, 3, (x_dim, x_dim)) symmetric_positive_A = np.dot(A, A.T) - def compute_Ax(x): return np.dot(symmetric_positive_A, x) - b = np.random.uniform(-3, 3, (x_dim, )) - optimized_x = conjugate_gradient( - compute_Ax, b, max_iterations=1000) + def compute_Ax(x): + return np.dot(symmetric_positive_A, x) + + b = np.random.uniform(-3, 3, (x_dim,)) + + optimized_x = conjugate_gradient(compute_Ax, b, max_iterations=1000) expected_x = np.dot(np.linalg.inv(symmetric_positive_A), b) diff --git a/tests/utils/test_reproductions.py b/tests/utils/test_reproductions.py index 3c56bbc1..df966417 100644 --- a/tests/utils/test_reproductions.py +++ b/tests/utils/test_reproductions.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ import nnabla_rl.utils.reproductions as reproductions -class TestReproductions(): +class TestReproductions: def test_set_global_seed(self): seed = 0 @@ -58,5 +58,5 @@ def test_set_global_seed(self): assert not np.allclose(random_variable1.d, random_variable3.d) -if __name__ == '__main__': +if __name__ == "__main__": pytest.main() diff --git a/tests/utils/test_serializers.py b/tests/utils/test_serializers.py index c7906475..91a51761 100644 --- a/tests/utils/test_serializers.py +++ b/tests/utils/test_serializers.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,8 +25,8 @@ class TestLoadSnapshot(object): def test_load_snapshot(self): - snapshot_path = pathlib.Path('test_resources/utils/ddpg-snapshot') - env = DummyContinuous(observation_shape=(3, ), action_shape=(1, )) + snapshot_path = pathlib.Path("test_resources/utils/ddpg-snapshot") + env = DummyContinuous(observation_shape=(3,), action_shape=(1,)) ddpg = load_snapshot(snapshot_path, env) assert isinstance(ddpg, A.DDPG) @@ -39,10 +39,10 @@ def test_load_snapshot(self): assert ddpg._config.replay_buffer_size == 1000000 def test_load_snapshot_no_env(self): - snapshot_path = pathlib.Path('test_resources/utils/ddpg-snapshot') + snapshot_path = pathlib.Path("test_resources/utils/ddpg-snapshot") with pytest.raises(RuntimeError): load_snapshot(snapshot_path, {}) -if __name__ == '__main__': +if __name__ == "__main__": pytest.main() diff --git a/tests/utils/test_solver_wrappers.py b/tests/utils/test_solver_wrappers.py index e68314ac..d87a81d5 100644 --- a/tests/utils/test_solver_wrappers.py +++ b/tests/utils/test_solver_wrappers.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ from nnabla_rl.utils.solver_wrappers import AutoClipGradByNorm, AutoWeightDecay, SolverWrapper -class TestSolverWrappers(): +class TestSolverWrappers: def test_auto_clip_grad_by_norm(self): norm = 10.0 solver_mock = mock.MagicMock() @@ -51,5 +51,5 @@ def test_auto_weight_decay(self): weight_decay_mock.assert_called_once() -if __name__ == '__main__': +if __name__ == "__main__": pytest.main() diff --git a/tests/writers/test_file_writer.py b/tests/writers/test_file_writer.py index eab5e44b..bfd73288 100644 --- a/tests/writers/test_file_writer.py +++ b/tests/writers/test_file_writer.py @@ -22,45 +22,39 @@ from nnabla_rl.writers.file_writer import FileWriter -class TestFileWriter(): +class TestFileWriter: def test_write_scalar(self): with tempfile.TemporaryDirectory() as tmpdir: test_returns = np.arange(5) test_results = {} - test_results['mean'] = np.mean(test_returns) - test_results['std_dev'] = np.std(test_returns) - test_results['min'] = np.min(test_returns) - test_results['max'] = np.max(test_returns) - test_results['median'] = np.median(test_returns) - - writer = FileWriter( - outdir=tmpdir, file_prefix='evaluation_results') + test_results["mean"] = np.mean(test_returns) + test_results["std_dev"] = np.std(test_returns) + test_results["min"] = np.min(test_returns) + test_results["max"] = np.max(test_returns) + test_results["median"] = np.median(test_returns) + + writer = FileWriter(outdir=tmpdir, file_prefix="evaluation_results") writer.write_scalar(1, test_results) - file_path = \ - os.path.join(tmpdir, 'evaluation_results_scalar.tsv') + file_path = os.path.join(tmpdir, "evaluation_results_scalar.tsv") this_file_dir = os.path.dirname(__file__) - test_file_dir = this_file_dir.replace('tests', 'test_resources') - test_file_path = \ - os.path.join(test_file_dir, 'evaluation_results_scalar.tsv') + test_file_dir = this_file_dir.replace("tests", "test_resources") + test_file_path = os.path.join(test_file_dir, "evaluation_results_scalar.tsv") self._check_same_tsv_file(file_path, test_file_path) def test_write_histogram(self): with tempfile.TemporaryDirectory() as tmpdir: test_returns = np.arange(5) test_results = {} - test_results['returns'] = test_returns + test_results["returns"] = test_returns - writer = FileWriter( - outdir=tmpdir, file_prefix='evaluation_results') + writer = FileWriter(outdir=tmpdir, file_prefix="evaluation_results") writer.write_histogram(1, test_results) - file_path = \ - os.path.join(tmpdir, 'evaluation_results_histogram.tsv') + file_path = os.path.join(tmpdir, "evaluation_results_histogram.tsv") this_file_dir = os.path.dirname(__file__) - test_file_dir = this_file_dir.replace('tests', 'test_resources') - test_file_path = \ - os.path.join(test_file_dir, 'evaluation_results_histogram.tsv') + test_file_dir = this_file_dir.replace("tests", "test_resources") + test_file_path = os.path.join(test_file_dir, "evaluation_results_histogram.tsv") self._check_same_tsv_file(file_path, test_file_path) @pytest.mark.parametrize("format", ["%f", "%.3f", "%.5f"]) @@ -68,26 +62,25 @@ def test_data_formatting(self, format): with tempfile.TemporaryDirectory() as tmpdir: test_returns = np.arange(5) test_results = {} - test_results['mean'] = np.mean(test_returns) - test_results['std_dev'] = np.std(test_returns) - test_results['min'] = np.min(test_returns) - test_results['max'] = np.max(test_returns) - test_results['median'] = np.median(test_returns) + test_results["mean"] = np.mean(test_returns) + test_results["std_dev"] = np.std(test_returns) + test_results["min"] = np.min(test_returns) + test_results["max"] = np.max(test_returns) + test_results["median"] = np.median(test_returns) - writer = FileWriter(outdir=tmpdir, file_prefix='actual_results', fmt=format) + writer = FileWriter(outdir=tmpdir, file_prefix="actual_results", fmt=format) writer.write_scalar(1, test_results) - actual_file_path = os.path.join(tmpdir, 'actual_results_scalar.tsv') + actual_file_path = os.path.join(tmpdir, "actual_results_scalar.tsv") this_file_dir = os.path.dirname(__file__) - expected_file_dir = this_file_dir.replace('tests', 'test_resources') - expected_file_path = os.path.join(expected_file_dir, f'evaluation_results_scalar{format}.tsv') + expected_file_dir = this_file_dir.replace("tests", "test_resources") + expected_file_path = os.path.join(expected_file_dir, f"evaluation_results_scalar{format}.tsv") self._check_same_tsv_file(actual_file_path, expected_file_path) def _check_same_tsv_file(self, file_path1, file_path2): # check each line - with open(file_path1, mode='rt') as data_1, \ - open(file_path2, mode='rt') as data_2: + with open(file_path1, mode="rt") as data_1, open(file_path2, mode="rt") as data_2: for d_1, d_2 in zip(data_1, data_2): assert d_1 == d_2 diff --git a/tests/writers/test_monitor_writer.py b/tests/writers/test_monitor_writer.py index 92e3e3e8..5beae77d 100644 --- a/tests/writers/test_monitor_writer.py +++ b/tests/writers/test_monitor_writer.py @@ -1,4 +1,4 @@ -# Copyright 2022 Sony Group Corporation. +# Copyright 2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,37 +20,36 @@ from nnabla_rl.writers.monitor_writer import MonitorWriter -class TestMonitorWriter(): +class TestMonitorWriter: def test_write_scalar(self): with tempfile.TemporaryDirectory() as tmpdir: test_returns = np.arange(5) test_results = {} - test_results['mean'] = np.mean(test_returns) - test_results['std_dev'] = np.std(test_returns) - test_results['min'] = np.min(test_returns) - test_results['max'] = np.max(test_returns) - test_results['median'] = np.median(test_returns) - - writer = MonitorWriter( - outdir=tmpdir, file_prefix='evaluation_results') + test_results["mean"] = np.mean(test_returns) + test_results["std_dev"] = np.std(test_returns) + test_results["min"] = np.min(test_returns) + test_results["max"] = np.max(test_returns) + test_results["median"] = np.median(test_returns) + + writer = MonitorWriter(outdir=tmpdir, file_prefix="evaluation_results") writer.write_scalar(1, test_results) for name in test_results.keys(): - file_name = f'evaluation_results_scalar_{name}.series.txt' + file_name = f"evaluation_results_scalar_{name}.series.txt" file_path = os.path.join(tmpdir, file_name) this_file_dir = os.path.dirname(__file__) - test_file_dir = this_file_dir.replace('tests', 'test_resources') + test_file_dir = this_file_dir.replace("tests", "test_resources") test_file_path = os.path.join(test_file_dir, file_name) self._check_same_txt_file(file_path, test_file_path) def _check_same_txt_file(self, file_path1, file_path2): # check each line - with open(file_path1, mode='rt') as data_1, \ - open(file_path2, mode='rt') as data_2: + with open(file_path1, mode="rt") as data_1, open(file_path2, mode="rt") as data_2: for d_1, d_2 in zip(data_1, data_2): assert d_1 == d_2 if __name__ == "__main__": import pytest + pytest.main() diff --git a/tests/writers/test_writing_distributor.py b/tests/writers/test_writing_distributor.py index 98b12bc6..971add9a 100644 --- a/tests/writers/test_writing_distributor.py +++ b/tests/writers/test_writing_distributor.py @@ -1,5 +1,5 @@ # Copyright 2020,2021 Sony Corporation. -# Copyright 2021 Sony Group Corporation. +# Copyright 2021,2022,2023,2024 Sony Group Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,12 +20,12 @@ from nnabla_rl.writers.writing_distributor import WritingDistributor -class TestWritingDistributor(): +class TestWritingDistributor: def test_write_scalar(self): writers = [self._new_mock_writer() for _ in range(10)] distributor = WritingDistributor(writers) - distributor.write_scalar(1, {'test': 100}) + distributor.write_scalar(1, {"test": 100}) for writer in writers: writer.write_scalar.assert_called_once() @@ -34,7 +34,7 @@ def test_write_histogram(self): writers = [self._new_mock_writer() for _ in range(10)] distributor = WritingDistributor(writers) - distributor.write_histogram(1, {'test': [100, 100]}) + distributor.write_histogram(1, {"test": [100, 100]}) for writer in writers: writer.write_histogram.assert_called_once() @@ -44,7 +44,7 @@ def test_write_image(self): distributor = WritingDistributor(writers) image = np.empty(shape=(3, 10, 10)) - distributor.write_image(1, {'test': image}) + distributor.write_image(1, {"test": image}) for writer in writers: writer.write_image.assert_called_once()