diff --git a/.github/unittest/linux/scripts/environment.yml b/.github/unittest/linux/scripts/environment.yml index 3ae16869835..30e01cfc4b5 100644 --- a/.github/unittest/linux/scripts/environment.yml +++ b/.github/unittest/linux/scripts/environment.yml @@ -28,6 +28,7 @@ dependencies: - mlflow - av - coverage - - ray<2.8.0 + - ray - transformers - ninja + - timm diff --git a/.github/unittest/linux/scripts/run_all.sh b/.github/unittest/linux/scripts/run_all.sh index b5066472907..38235043d3f 100755 --- a/.github/unittest/linux/scripts/run_all.sh +++ b/.github/unittest/linux/scripts/run_all.sh @@ -88,7 +88,8 @@ conda deactivate conda activate "${env_dir}" echo "installing gymnasium" -pip3 install "gymnasium[atari,ale-py,accept-rom-license]" +pip3 install "gymnasium" +pip3 install ale_py pip3 install mo-gymnasium[mujoco] # requires here bc needs mujoco-py pip3 install mujoco -U diff --git a/.github/unittest/linux_distributed/scripts/environment.yml b/.github/unittest/linux_distributed/scripts/environment.yml index 2f5210135fe..6d27071791b 100644 --- a/.github/unittest/linux_distributed/scripts/environment.yml +++ b/.github/unittest/linux_distributed/scripts/environment.yml @@ -27,5 +27,5 @@ dependencies: - mlflow - av - coverage - - ray<2.8.0 + - ray - virtualenv diff --git a/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml b/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml index 9efcbbfa640..d34011e7bdc 100644 --- a/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml +++ b/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml @@ -24,5 +24,5 @@ dependencies: - dm_control -e git+https://github.com/deepmind/dm_control.git@c053360edea6170acfd9c8f65446703307d9d352#egg={dm_control} - patchelf - pyopengl==3.1.4 - - ray<2.8.0 + - ray - av diff --git a/.github/unittest/linux_optdeps/scripts/environment.yml b/.github/unittest/linux_optdeps/scripts/environment.yml index 7263c14192f..fcc3c3481d0 100644 --- a/.github/unittest/linux_optdeps/scripts/environment.yml +++ b/.github/unittest/linux_optdeps/scripts/environment.yml @@ -17,4 +17,4 @@ dependencies: - pyyaml - scipy - coverage - - ray<2.8.0 + - ray diff --git a/.github/unittest/windows_optdepts/scripts/install.sh b/.github/unittest/windows_optdepts/scripts/install.sh index 5c425d18a95..f13b83a0be0 100644 --- a/.github/unittest/windows_optdepts/scripts/install.sh +++ b/.github/unittest/windows_optdepts/scripts/install.sh @@ -37,6 +37,7 @@ fi # submodules git submodule sync && git submodule update --init --recursive +python -m pip install "numpy<2.0" printf "Installing PyTorch with %s\n" "${cudatoolkit}" if [[ "$TORCH_VERSION" == "nightly" ]]; then diff --git a/.github/workflows/build-wheels-windows.yml b/.github/workflows/build-wheels-windows.yml index 1beef7318f4..17991ae2a65 100644 --- a/.github/workflows/build-wheels-windows.yml +++ b/.github/workflows/build-wheels-windows.yml @@ -46,4 +46,4 @@ jobs: package-name: ${{ matrix.package-name }} smoke-test-script: ${{ matrix.smoke-test-script }} trigger-event: ${{ github.event_name }} - env-var-script: .github/scripts/td_script.sh + pre-script: .github/scripts/td_script.sh diff --git a/.github/workflows/test-linux.yml b/.github/workflows/test-linux.yml index d2e13eddd63..e8728180c67 100644 --- a/.github/workflows/test-linux.yml +++ b/.github/workflows/test-linux.yml @@ -22,7 +22,7 @@ jobs: tests-cpu: strategy: matrix: - python_version: ["3.8", "3.9", "3.10", "3.11"] + python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"] fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: @@ -51,7 +51,7 @@ jobs: tests-gpu: strategy: matrix: - python_version: ["3.10"] + python_version: ["3.11"] cuda_arch_version: ["12.1"] fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job.yml@main diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index c7b0eba35c0..11a5bb041a6 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -335,6 +335,19 @@ etc.), but one can not use an arbitrary TorchRL environment, as it is possible w ParallelEnv EnvCreator + +Custom native TorchRL environments +---------------------------------- + +TorchRL offers a series of custom built-in environments. + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + PendulumEnv + TicTacToeEnv + Multi-agent environments ------------------------ @@ -780,6 +793,7 @@ to be able to create this other composition: CenterCrop ClipTransform Compose + Crop DTypeCastTransform DeviceCastTransform DiscreteActionProjection diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index b46d789ed15..c73ed5083fd 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -72,9 +72,11 @@ other cases, the action written in the tensordict is simply the network output. :toctree: generated/ :template: rl_template_noinherit.rst + AdditiveGaussianModule AdditiveGaussianWrapper EGreedyModule EGreedyWrapper + OrnsteinUhlenbeckProcessModule OrnsteinUhlenbeckProcessWrapper Probabilistic actors @@ -353,13 +355,18 @@ algorithms, such as DQN, DDPG or Dreamer. Multi-agent-specific modules ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -These networks implement models that can be used in -multi-agent contexts. +These networks implement models that can be used in multi-agent contexts. +They use :func:`~torch.vmap` to execute multiple networks all at once on the +network inputs. Because the parameters are batched, initialization may differ +from what is usually done with other PyTorch modules, see +:meth:`~torchrl.modules.MultiAgentNetBase.get_stateful_net` +for more information. .. autosummary:: :toctree: generated/ :template: rl_template_noinherit.rst + MultiAgentNetBase MultiAgentMLP MultiAgentConvNet QMixer diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index b475885daf0..db0c58409e2 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -29,6 +29,35 @@ The main characteristics of TorchRL losses are: >>> loss_val = sum(loss for key, loss in loss_vals.items() if key.startswith("loss_")) +.. note:: + Initializing parameters in losses can be done via a query to :meth:`~torchrl.objectives.LossModule.get_stateful_net` + which will return a stateful version of the network that can be initialized like any other module. + If the modification is done in-place, it will be downstreamed to any other module that uses the same parameter + set (within and outside of the loss): for instance, modifying the ``actor_network`` parameters from the loss + will also modify the actor in the collector. + If the parameters are modified out-of-place, :meth:`~torchrl.objectives.LossModule.from_stateful_net` can be + used to reset the parameters in the loss to the new value. + +torch.vmap and randomness +------------------------- + +TorchRL loss modules have plenty of calls to :func:`~torch.vmap` to amortize the cost of calling multiple similar models +in a loop, and instead vectorize these operations. `vmap` needs to be told explicitly what to do when random numbers +need to be generated within the call. To do this, a randomness mode need to be set and must be one of `"error"` (default, +errors when dealing with pseudo-random functions), `"same"` (replicates the results across the batch) or `"different"` +(each element of the batch is treated separately). +Relying on the default will typically result in an error such as this one: + + >>> RuntimeError: vmap: called random operation while in randomness error mode. + +Since the calls to `vmap` are buried down the loss modules, TorchRL +provides an interface to set that vmap mode from the outside through `loss.vmap_randomness = str_value`, see +:meth:`~torchrl.objectives.LossModule.vmap_randomness` for more information. + +``LossModule.vmap_randomness`` defaults to `"error"` if no random module is detected, and to `"different"` in +other cases. By default, only a limited number of modules are listed as random, but the list can be extended +using the :func:`~torchrl.objectives.common.add_random_module` function. + Training value functions ------------------------ diff --git a/setup.py b/setup.py index 95dc0802a4f..73541790e8f 100644 --- a/setup.py +++ b/setup.py @@ -274,6 +274,7 @@ def _main(argv): "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", "Development Status :: 4 - Beta", diff --git a/sota-implementations/ddpg/ddpg.py b/sota-implementations/ddpg/ddpg.py index a92ee6185c3..1b038d69d15 100644 --- a/sota-implementations/ddpg/ddpg.py +++ b/sota-implementations/ddpg/ddpg.py @@ -108,7 +108,7 @@ def main(cfg: "DictConfig"): # noqa: F821 for _, tensordict in enumerate(collector): sampling_time = time.time() - sampling_start # Update exploration policy - exploration_policy.step(tensordict.numel()) + exploration_policy[1].step(tensordict.numel()) # Update weights of the inference policy collector.update_policy_weights_() diff --git a/sota-implementations/ddpg/utils.py b/sota-implementations/ddpg/utils.py index 45c6da7a342..338081a7e8d 100644 --- a/sota-implementations/ddpg/utils.py +++ b/sota-implementations/ddpg/utils.py @@ -6,6 +6,8 @@ import torch +from tensordict.nn import TensorDictSequential + from torch import nn, optim from torchrl.collectors import SyncDataCollector from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer @@ -25,9 +27,9 @@ from torchrl.envs.libs.gym import GymEnv, set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( - AdditiveGaussianWrapper, + AdditiveGaussianModule, MLP, - OrnsteinUhlenbeckProcessWrapper, + OrnsteinUhlenbeckProcessModule, SafeModule, SafeSequential, TanhModule, @@ -227,18 +229,24 @@ def make_ddpg_agent(cfg, train_env, eval_env, device): # Exploration wrappers: if cfg.network.noise_type == "ou": - actor_model_explore = OrnsteinUhlenbeckProcessWrapper( + actor_model_explore = TensorDictSequential( model[0], - annealing_num_steps=1_000_000, - ).to(device) + OrnsteinUhlenbeckProcessModule( + spec=action_spec, + annealing_num_steps=1_000_000, + ).to(device), + ) elif cfg.network.noise_type == "gaussian": - actor_model_explore = AdditiveGaussianWrapper( + actor_model_explore = TensorDictSequential( model[0], - sigma_end=1.0, - sigma_init=1.0, - mean=0.0, - std=0.1, - ).to(device) + AdditiveGaussianModule( + spec=action_spec, + sigma_end=1.0, + sigma_init=1.0, + mean=0.0, + std=0.1, + ).to(device), + ) else: raise NotImplementedError diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index e521b9df386..f28fac8e675 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -279,7 +279,7 @@ def compile_rssms(module): if logger is not None: log_metrics(logger, metrics_to_log, collected_frames) - policy.step(current_frames) + policy[1].step(current_frames) collector.update_policy_weights_() # Evaluation if (i % eval_iter) == 0: diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 73baa310821..6745b1a079a 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -52,7 +52,7 @@ ) from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type from torchrl.modules import ( - AdditiveGaussianWrapper, + AdditiveGaussianModule, DreamerActor, IndependentNormal, MLP, @@ -266,13 +266,16 @@ def make_dreamer( test_env=test_env, ) # Exploration noise to be added to the actor_realworld - actor_realworld = AdditiveGaussianWrapper( + actor_realworld = TensorDictSequential( actor_realworld, - sigma_init=1.0, - sigma_end=1.0, - annealing_num_steps=1, - mean=0.0, - std=cfg.networks.exploration_noise, + AdditiveGaussianModule( + spec=test_env.action_spec, + sigma_init=1.0, + sigma_end=1.0, + annealing_num_steps=1, + mean=0.0, + std=cfg.networks.exploration_noise, + ), ) # Make Critic diff --git a/sota-implementations/multiagent/maddpg_iddpg.py b/sota-implementations/multiagent/maddpg_iddpg.py index bd44bb0a043..e9de2ac4e14 100644 --- a/sota-implementations/multiagent/maddpg_iddpg.py +++ b/sota-implementations/multiagent/maddpg_iddpg.py @@ -7,7 +7,7 @@ import hydra import torch -from tensordict.nn import TensorDictModule +from tensordict.nn import TensorDictModule, TensorDictSequential from torch import nn from torchrl._utils import logger as torchrl_logger from torchrl.collectors import SyncDataCollector @@ -18,7 +18,7 @@ from torchrl.envs.libs.vmas import VmasEnv from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( - AdditiveGaussianWrapper, + AdditiveGaussianModule, ProbabilisticActor, TanhDelta, ValueOperator, @@ -102,10 +102,13 @@ def train(cfg: "DictConfig"): # noqa: F821 return_log_prob=False, ) - policy_explore = AdditiveGaussianWrapper( + policy_explore = TensorDictSequential( policy, - annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), - action_key=env.action_key, + AdditiveGaussianModule( + spec=env.unbatched_action_spec, + annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)), + action_key=env.action_key, + ), ) # Critic @@ -200,7 +203,7 @@ def train(cfg: "DictConfig"): # noqa: F821 optim.zero_grad() target_net_updater.step() - policy_explore.step(frames=current_frames) # Update exploration annealing + policy_explore[1].step(frames=current_frames) # Update exploration annealing collector.update_policy_weights_() training_time = time.time() - training_start diff --git a/sota-implementations/redq/redq.py b/sota-implementations/redq/redq.py index c6b96db9292..eb802f6773d 100644 --- a/sota-implementations/redq/redq.py +++ b/sota-implementations/redq/redq.py @@ -8,10 +8,11 @@ import hydra import torch.cuda +from tensordict.nn import TensorDictSequential from torchrl.envs import EnvCreator, ParallelEnv from torchrl.envs.transforms import RewardScaling, TransformedEnv from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.modules import OrnsteinUhlenbeckProcessWrapper +from torchrl.modules import OrnsteinUhlenbeckProcessModule from torchrl.record import VideoRecorder from torchrl.record.loggers import get_logger from utils import ( @@ -111,12 +112,15 @@ def main(cfg: "DictConfig"): # noqa: F821 if cfg.exploration.ou_exploration: if cfg.exploration.gSDE: raise RuntimeError("gSDE and ou_exploration are incompatible") - actor_model_explore = OrnsteinUhlenbeckProcessWrapper( + actor_model_explore = TensorDictSequential( actor_model_explore, - annealing_num_steps=cfg.exploration.annealing_frames, - sigma=cfg.exploration.ou_sigma, - theta=cfg.exploration.ou_theta, - ).to(device) + OrnsteinUhlenbeckProcessModule( + spec=actor_model_explore.spec, + annealing_num_steps=cfg.exploration.annealing_frames, + sigma=cfg.exploration.ou_sigma, + theta=cfg.exploration.ou_theta, + ).to(device), + ) if device == torch.device("cpu"): # mostly for debugging actor_model_explore.share_memory() diff --git a/sota-implementations/redq/utils.py b/sota-implementations/redq/utils.py index 8a093c8f0ac..dd922372cbb 100644 --- a/sota-implementations/redq/utils.py +++ b/sota-implementations/redq/utils.py @@ -57,7 +57,7 @@ ActorCriticOperator, ActorValueOperator, NoisyLinear, - NormalParamWrapper, + NormalParamExtractor, SafeModule, SafeSequential, ) @@ -483,10 +483,12 @@ def make_redq_model( } if not gSDE: - actor_net = NormalParamWrapper( + actor_net = nn.Sequential( actor_net, - scale_mapping=f"biased_softplus_{default_policy_scale}", - scale_lb=cfg.network.scale_lb, + NormalParamExtractor( + scale_mapping=f"biased_softplus_{default_policy_scale}", + scale_lb=cfg.network.scale_lb, + ), ) actor_module = SafeModule( actor_net, diff --git a/sota-implementations/td3/td3.py b/sota-implementations/td3/td3.py index 5fbc9b032d7..632ee58503d 100644 --- a/sota-implementations/td3/td3.py +++ b/sota-implementations/td3/td3.py @@ -109,7 +109,7 @@ def main(cfg: "DictConfig"): # noqa: F821 sampling_start = time.time() for tensordict in collector: sampling_time = time.time() - sampling_start - exploration_policy.step(tensordict.numel()) + exploration_policy[1].step(tensordict.numel()) # Update weights of the inference policy collector.update_policy_weights_() diff --git a/sota-implementations/td3/utils.py b/sota-implementations/td3/utils.py index c597ae205a2..60a4d046355 100644 --- a/sota-implementations/td3/utils.py +++ b/sota-implementations/td3/utils.py @@ -7,6 +7,7 @@ from contextlib import nullcontext import torch +from tensordict.nn import TensorDictSequential from torch import nn, optim from torchrl.collectors import SyncDataCollector @@ -27,7 +28,7 @@ from torchrl.envs.libs.gym import GymEnv, set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( - AdditiveGaussianWrapper, + AdditiveGaussianModule, MLP, SafeModule, SafeSequential, @@ -233,14 +234,16 @@ def make_td3_agent(cfg, train_env, eval_env, device): eval_env.close() # Exploration wrappers: - actor_model_explore = AdditiveGaussianWrapper( + actor_model_explore = TensorDictSequential( model[0], - sigma_init=1, - sigma_end=1, - mean=0, - std=0.1, - spec=action_spec, - ).to(device) + AdditiveGaussianModule( + sigma_init=1, + sigma_end=1, + mean=0, + std=0.1, + spec=action_spec, + ).to(device), + ) return model, actor_model_explore diff --git a/sota-implementations/td3_bc/utils.py b/sota-implementations/td3_bc/utils.py index 3772eefccde..3dcbd45d30c 100644 --- a/sota-implementations/td3_bc/utils.py +++ b/sota-implementations/td3_bc/utils.py @@ -5,6 +5,7 @@ import functools import torch +from tensordict.nn import TensorDictSequential from torch import nn, optim from torchrl.data.datasets.d4rl import D4RLExperienceReplay @@ -24,7 +25,7 @@ from torchrl.envs.libs.gym import GymEnv, set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( - AdditiveGaussianWrapper, + AdditiveGaussianModule, MLP, SafeModule, SafeSequential, @@ -174,14 +175,16 @@ def make_td3_agent(cfg, train_env, device): del td # Exploration wrappers: - actor_model_explore = AdditiveGaussianWrapper( + actor_model_explore = TensorDictSequential( model[0], - sigma_init=1, - sigma_end=1, - mean=0, - std=0.1, - spec=action_spec, - ).to(device) + AdditiveGaussianModule( + sigma_init=1, + sigma_end=1, + mean=0, + std=0.1, + spec=action_spec, + ).to(device), + ) return model, actor_model_explore diff --git a/test/_utils_internal.py b/test/_utils_internal.py index e43c0ff2ecf..61b0c003f9d 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -56,11 +56,32 @@ def HALFCHEETAH_VERSIONED(): def PONG_VERSIONED(): # load gym + # Gymnasium says that the ale_py behaviour changes from 1.0 + # but with python 3.12 it is already the case with 0.29.1 + try: + import ale_py # noqa + except ImportError: + pass + if gym_backend() is not None: _set_gym_environments() return _PONG_VERSIONED +def BREAKOUT_VERSIONED(): + # load gym + # Gymnasium says that the ale_py behaviour changes from 1.0 + # but with python 3.12 it is already the case with 0.29.1 + try: + import ale_py # noqa + except ImportError: + pass + + if gym_backend() is not None: + _set_gym_environments() + return _BREAKOUT_VERSIONED + + def PENDULUM_VERSIONED(): # load gym if gym_backend() is not None: @@ -69,42 +90,46 @@ def PENDULUM_VERSIONED(): def _set_gym_environments(): - global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED + global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED _CARTPOLE_VERSIONED = None _HALFCHEETAH_VERSIONED = None _PENDULUM_VERSIONED = None _PONG_VERSIONED = None + _BREAKOUT_VERSIONED = None @implement_for("gym", None, "0.21.0") def _set_gym_environments(): # noqa: F811 - global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED + global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED _CARTPOLE_VERSIONED = "CartPole-v0" _HALFCHEETAH_VERSIONED = "HalfCheetah-v2" _PENDULUM_VERSIONED = "Pendulum-v0" _PONG_VERSIONED = "Pong-v4" + _BREAKOUT_VERSIONED = "Breakout-v4" @implement_for("gym", "0.21.0", None) def _set_gym_environments(): # noqa: F811 - global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED + global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED _CARTPOLE_VERSIONED = "CartPole-v1" _HALFCHEETAH_VERSIONED = "HalfCheetah-v4" _PENDULUM_VERSIONED = "Pendulum-v1" _PONG_VERSIONED = "ALE/Pong-v5" + _BREAKOUT_VERSIONED = "ALE/Breakout-v5" @implement_for("gymnasium") def _set_gym_environments(): # noqa: F811 - global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED + global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED _CARTPOLE_VERSIONED = "CartPole-v1" _HALFCHEETAH_VERSIONED = "HalfCheetah-v4" _PENDULUM_VERSIONED = "Pendulum-v1" _PONG_VERSIONED = "ALE/Pong-v5" + _BREAKOUT_VERSIONED = "ALE/Breakout-v5" if _has_gym: diff --git a/test/test_collector.py b/test/test_collector.py index 12ec490e7e2..7d7208aead0 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -92,7 +92,7 @@ PARTIAL_MISSING_ERR, RandomPolicy, ) -from torchrl.modules import Actor, OrnsteinUhlenbeckProcessWrapper, SafeModule +from torchrl.modules import Actor, OrnsteinUhlenbeckProcessModule, SafeModule # torch.set_default_dtype(torch.double) IS_WINDOWS = sys.platform == "win32" @@ -1291,8 +1291,13 @@ def make_env(): policy_module, in_keys=["observation"], out_keys=["action"] ) copier = TensorDictModule(lambda x: x, in_keys=["observation"], out_keys=[out_key]) - policy = TensorDictSequential(policy, copier) - policy_explore = OrnsteinUhlenbeckProcessWrapper(policy) + policy_explore = TensorDictSequential( + policy, + copier, + OrnsteinUhlenbeckProcessModule( + spec=CompositeSpec({key: None for key in policy.out_keys}) + ), + ) collector_kwargs = { "create_env_fn": make_env, @@ -2472,7 +2477,9 @@ def make_env(): obs_spec = dummy_env.observation_spec["observation"] policy_module = nn.Linear(obs_spec.shape[-1], dummy_env.action_spec.shape[-1]) policy = Actor(policy_module, spec=dummy_env.action_spec) - policy_explore = OrnsteinUhlenbeckProcessWrapper(policy) + policy_explore = TensorDictSequential( + policy, OrnsteinUhlenbeckProcessModule(spec=policy.spec) + ) collector_kwargs = { "create_env_fn": make_env, diff --git a/test/test_cost.py b/test/test_cost.py index 0d7b24f7779..6192e45c113 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -72,11 +72,7 @@ SafeSequential, WorldModelWrapper, ) -from torchrl.modules.distributions.continuous import ( - NormalParamWrapper, - TanhDelta, - TanhNormal, -) +from torchrl.modules.distributions.continuous import TanhDelta, TanhNormal from torchrl.modules.models.model_based import ( DreamerActor, ObsDecoder, @@ -261,6 +257,9 @@ def __init__(self): net = nn.Sequential(*layers).to(device) model = TensorDictModule(net, in_keys=["obs"], out_keys=["action"]) self.convert_to_functional(model, "model", expand_dim=4) + self._make_vmap() + + def _make_vmap(self): self.vmap_model = _vmap_func( self.model, (None, 0), @@ -3460,7 +3459,7 @@ def _create_mock_actor( action_spec = BoundedTensorSpec( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) module = TensorDictModule( net, in_keys=[observation_key], out_keys=["loc", "scale"] ) @@ -4370,7 +4369,7 @@ def _create_mock_actor( ): # Actor action_spec = OneHotDiscreteTensorSpec(action_dim) - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) module = TensorDictModule(net, in_keys=[observation_key], out_keys=["logits"]) actor = ProbabilisticActor( spec=action_spec, @@ -4958,7 +4957,7 @@ def _create_mock_actor( action_spec = BoundedTensorSpec( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) module = TensorDictModule( net, in_keys=[observation_key], out_keys=["loc", "scale"] ) @@ -5653,7 +5652,7 @@ def _create_mock_actor( action_spec = BoundedTensorSpec( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) module = TensorDictModule( net, in_keys=[observation_key], out_keys=["loc", "scale"] ) @@ -5761,7 +5760,9 @@ def forward(self, obs): class ActorClass(nn.Module): def __init__(self): super().__init__() - self.linear = NormalParamWrapper(nn.Linear(hidden_dim, 2 * action_dim)) + self.linear = nn.Sequential( + nn.Linear(hidden_dim, 2 * action_dim), NormalParamExtractor() + ) def forward(self, hidden): return self.linear(hidden) @@ -6596,7 +6597,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): action_spec = BoundedTensorSpec( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) module = TensorDictModule( net, in_keys=["observation"], out_keys=["loc", "scale"] ) @@ -6855,6 +6856,71 @@ def test_cql( p.grad is None or p.grad.norm() == 0.0 ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" + @pytest.mark.parametrize("delay_actor", (True,)) + @pytest.mark.parametrize("delay_qvalue", (True,)) + @pytest.mark.parametrize( + "max_q_backup", + [ + True, + ], + ) + @pytest.mark.parametrize( + "deterministic_backup", + [ + True, + ], + ) + @pytest.mark.parametrize( + "with_lagrange", + [ + True, + ], + ) + @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize("td_est", [None]) + def test_cql_qvalfromlist( + self, + delay_actor, + delay_qvalue, + max_q_backup, + deterministic_backup, + with_lagrange, + device, + td_est, + ): + torch.manual_seed(self.seed) + td = self._create_mock_data_cql(device=device) + + actor = self._create_mock_actor(device=device) + qvalue0 = self._create_mock_qvalue(device=device) + qvalue1 = self._create_mock_qvalue(device=device) + + loss_fn_single = CQLLoss( + actor_network=actor, + qvalue_network=qvalue0, + loss_function="l2", + max_q_backup=max_q_backup, + deterministic_backup=deterministic_backup, + with_lagrange=with_lagrange, + delay_actor=delay_actor, + delay_qvalue=delay_qvalue, + ) + loss_fn_mult = CQLLoss( + actor_network=actor, + qvalue_network=[qvalue0, qvalue1], + loss_function="l2", + max_q_backup=max_q_backup, + deterministic_backup=deterministic_backup, + with_lagrange=with_lagrange, + delay_actor=delay_actor, + delay_qvalue=delay_qvalue, + ) + # Check that all params have the same shape + p2 = dict(loss_fn_mult.named_parameters()) + for key, val in loss_fn_single.named_parameters(): + assert val.shape == p2[key].shape + assert len(dict(loss_fn_single.named_parameters())) == len(p2) + @pytest.mark.parametrize("delay_actor", (True, False)) @pytest.mark.parametrize("delay_qvalue", (True, False)) @pytest.mark.parametrize("max_q_backup", [True]) @@ -7489,7 +7555,7 @@ def _create_mock_actor( action_spec = BoundedTensorSpec( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) module = TensorDictModule( net, in_keys=[observation_key], out_keys=["loc", "scale"] ) @@ -7526,8 +7592,8 @@ def _create_mock_actor_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) base_layer = nn.Linear(obs_dim, 5) - net = NormalParamWrapper( - nn.Sequential(base_layer, nn.Linear(5, 2 * action_dim)) + net = nn.Sequential( + base_layer, nn.Linear(5, 2 * action_dim), NormalParamExtractor() ) module = TensorDictModule( net, in_keys=["observation"], out_keys=["loc", "scale"] @@ -8380,7 +8446,7 @@ def _create_mock_actor( action_spec = BoundedTensorSpec( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) module = TensorDictModule( net, in_keys=[observation_key], out_keys=["loc", "scale"] ) @@ -9077,7 +9143,7 @@ def test_reinforce_value_net( batch = 4 gamma = 0.9 value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=["observation"]) - net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) module = TensorDictModule( net, in_keys=["observation"], out_keys=["loc", "scale"] ) @@ -9187,7 +9253,7 @@ def test_reinforce_tensordict_keys(self, td_est): n_obs = 3 n_act = 5 value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=["observation"]) - net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) module = TensorDictModule( net, in_keys=["observation"], out_keys=["loc", "scale"] ) @@ -9381,7 +9447,7 @@ def test_reinforce_notensordict( n_act = 5 batch = 4 value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=[observation_key]) - net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) module = TensorDictModule( net, in_keys=[observation_key], out_keys=["loc", "scale"] ) @@ -9987,7 +10053,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): action_spec = BoundedTensorSpec( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) module = TensorDictModule( net, in_keys=["observation"], out_keys=["loc", "scale"] ) @@ -10219,7 +10285,7 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): action_spec = BoundedTensorSpec( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) module = TensorDictModule(net, in_keys=["observation"], out_keys=["param"]) actor = ProbabilisticActor( module=module, @@ -10633,7 +10699,7 @@ def _create_mock_actor( action_spec = BoundedTensorSpec( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) module = TensorDictModule( net, in_keys=[observation_key], out_keys=["loc", "scale"] ) @@ -11442,7 +11508,7 @@ def _create_mock_actor( ): # Actor action_spec = OneHotDiscreteTensorSpec(action_dim) - net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim)) + net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) module = TensorDictModule(net, in_keys=[observation_key], out_keys=["logits"]) actor = ProbabilisticActor( spec=action_spec, @@ -14143,7 +14209,7 @@ def test_shared_params(dest, expected_dtype, expected_device): out_keys=["hidden"], ) module_action = TensorDictModule( - NormalParamWrapper(torch.nn.Linear(4, 8)), + nn.Sequential(nn.Linear(4, 8), NormalParamExtractor()), in_keys=["hidden"], out_keys=["loc", "scale"], ) @@ -14771,6 +14837,118 @@ def __init__(self, compare_against, expand_dim): for key in ["module.1.bias", "module.1.weight"]: loss_module.module_b_params.flatten_keys()[key].requires_grad + def test_init_params(self): + class MyLoss(LossModule): + module_a: TensorDictModule + module_b: TensorDictModule + module_a_params: TensorDict + module_b_params: TensorDict + target_module_a_params: TensorDict + target_module_b_params: TensorDict + + def __init__(self, expand_dim=2): + super().__init__() + module1 = nn.Linear(3, 4) + module2 = nn.Linear(3, 4) + module3 = nn.Linear(3, 4) + module_a = TensorDictModule( + nn.Sequential(module1, module2), in_keys=["a"], out_keys=["c"] + ) + module_b = TensorDictModule( + nn.Sequential(module1, module3), in_keys=["b"], out_keys=["c"] + ) + self.convert_to_functional(module_a, "module_a") + self.convert_to_functional( + module_b, + "module_b", + compare_against=module_a.parameters(), + expand_dim=expand_dim, + ) + + loss = MyLoss() + + module_a = loss.get_stateful_net("module_a", copy=False) + assert module_a is loss.module_a + + module_a = loss.get_stateful_net("module_a") + assert module_a is not loss.module_a + + def init(mod): + if hasattr(mod, "weight"): + mod.weight.data.zero_() + if hasattr(mod, "bias"): + mod.bias.data.zero_() + + module_a.apply(init) + assert (loss.module_a_params == 0).all() + + def init(mod): + if hasattr(mod, "weight"): + mod.weight = torch.nn.Parameter(mod.weight.data + 1) + if hasattr(mod, "bias"): + mod.bias = torch.nn.Parameter(mod.bias.data + 1) + + module_a.apply(init) + assert (loss.module_a_params == 0).all() + loss.from_stateful_net("module_a", module_a) + assert (loss.module_a_params == 1).all() + + def test_from_module_list(self): + class MyLoss(LossModule): + module_a: TensorDictModule + module_b: TensorDictModule + + module_a_params: TensorDict + module_b_params: TensorDict + + target_module_a_params: TensorDict + target_module_b_params: TensorDict + + def __init__(self, module_a, module_b0, module_b1, expand_dim=2): + super().__init__() + self.convert_to_functional(module_a, "module_a") + self.convert_to_functional( + [module_b0, module_b1], + "module_b", + # This will be ignored + compare_against=module_a.parameters(), + expand_dim=expand_dim, + ) + + module1 = nn.Linear(3, 4) + module2 = nn.Linear(3, 4) + module3a = nn.Linear(3, 4) + module3b = nn.Linear(3, 4) + + module_a = TensorDictModule( + nn.Sequential(module1, module2), in_keys=["a"], out_keys=["c"] + ) + + module_b0 = TensorDictModule( + nn.Sequential(module1, module3a), in_keys=["b"], out_keys=["c"] + ) + module_b1 = TensorDictModule( + nn.Sequential(module1, module3b), in_keys=["b"], out_keys=["c"] + ) + + loss = MyLoss(module_a, module_b0, module_b1) + + # This should be extended + assert not isinstance( + loss.module_b_params["module", "0", "weight"], nn.Parameter + ) + assert loss.module_b_params["module", "0", "weight"].shape[0] == 2 + assert ( + loss.module_b_params["module", "0", "weight"].data.data_ptr() + == loss.module_a_params["module", "0", "weight"].data.data_ptr() + ) + assert isinstance(loss.module_b_params["module", "1", "weight"], nn.Parameter) + assert loss.module_b_params["module", "1", "weight"].shape[0] == 2 + assert ( + loss.module_b_params["module", "1", "weight"].data.data_ptr() + != loss.module_a_params["module", "1", "weight"].data.data_ptr() + ) + def test_tensordict_keys(self): """Test configurable tensordict key behavior with derived classes.""" @@ -15128,10 +15306,10 @@ def __init__(self): assert v_p1 == v_p2 assert v_params1 == v_params2 assert v_buffers1 == v_buffers2 - for p in mod.parameters(): - assert isinstance(p, nn.Parameter) - for p in mod.buffers(): - assert isinstance(p, Buffer) + for k, p in mod.named_parameters(): + assert isinstance(p, nn.Parameter), k + for k, p in mod.named_buffers(): + assert isinstance(p, Buffer), k for p in mod.actor_params.values(True, True): assert isinstance(p, (nn.Parameter, Buffer)) for p in mod.value_params.values(True, True): diff --git a/test/test_env.py b/test/test_env.py index adce3e48326..dee03c06e7d 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -79,7 +79,9 @@ EnvBase, EnvCreator, ParallelEnv, + PendulumEnv, SerialEnv, + TicTacToeEnv, ) from torchrl.envs.batched_envs import _stackable from torchrl.envs.gym_like import default_info_dict_reader @@ -311,6 +313,62 @@ def test_rollout_predictability(device): ).all() +# Check that the "terminated" key is filled in automatically if only the "done" +# key is provided in `_step`. +def test_done_key_completion_done(): + class DoneEnv(CountingEnv): + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + self.count += 1 + tensordict = TensorDict( + source={ + "observation": self.count.clone(), + "done": self.count > self.max_steps, + "reward": torch.zeros_like(self.count, dtype=torch.float), + }, + batch_size=self.batch_size, + device=self.device, + ) + return tensordict + + env = DoneEnv(max_steps=torch.tensor([[0], [1]]), batch_size=(2,)) + td = env.reset() + env.rand_action(td) + td = env.step(td) + assert torch.equal(td[("next", "done")], torch.tensor([[True], [False]])) + assert torch.equal(td[("next", "terminated")], torch.tensor([[True], [False]])) + + +# Check that the "done" key is filled in automatically if only the "terminated" +# key is provided in `_step`. +def test_done_key_completion_terminated(): + class TerminatedEnv(CountingEnv): + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + self.count += 1 + tensordict = TensorDict( + source={ + "observation": self.count.clone(), + "terminated": self.count > self.max_steps, + "reward": torch.zeros_like(self.count, dtype=torch.float), + }, + batch_size=self.batch_size, + device=self.device, + ) + return tensordict + + env = TerminatedEnv(max_steps=torch.tensor([[0], [1]]), batch_size=(2,)) + td = env.reset() + env.rand_action(td) + td = env.step(td) + assert torch.equal(td[("next", "done")], torch.tensor([[True], [False]])) + assert torch.equal(td[("next", "terminated")], torch.tensor([[True], [False]])) + + @pytest.mark.skipif(not _has_gym, reason="no gym") @pytest.mark.parametrize("env_name", [PENDULUM_VERSIONED]) @pytest.mark.parametrize("frame_skip", [1]) @@ -3251,6 +3309,48 @@ def test_partial_rest(self, batched): assert s["next", "string"] == ["6", "6"] +class TestCustomEnvs: + def test_tictactoe_env(self): + torch.manual_seed(0) + env = TicTacToeEnv() + check_env_specs(env) + for _ in range(10): + r = env.rollout(10) + assert r.shape[-1] < 10 + r = env.rollout(10, tensordict=TensorDict(batch_size=[5])) + assert r.shape[-1] < 10 + r = env.rollout( + 100, tensordict=TensorDict(batch_size=[5]), break_when_any_done=False + ) + assert r.shape == (5, 100) + + def test_tictactoe_env_single(self): + torch.manual_seed(0) + env = TicTacToeEnv(single_player=True) + check_env_specs(env) + for _ in range(10): + r = env.rollout(10) + assert r.shape[-1] < 6 + r = env.rollout(10, tensordict=TensorDict(batch_size=[5])) + assert r.shape[-1] < 6 + r = env.rollout( + 100, tensordict=TensorDict(batch_size=[5]), break_when_any_done=False + ) + assert r.shape == (5, 100) + + def test_pendulum_env(self): + env = PendulumEnv(device=None) + assert env.device is None + env = PendulumEnv(device="cpu") + assert env.device == torch.device("cpu") + check_env_specs(env) + for _ in range(10): + r = env.rollout(10) + assert r.shape == torch.Size((10,)) + r = env.rollout(10, tensordict=TensorDict(batch_size=[5])) + assert r.shape == torch.Size((5, 10)) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_exploration.py b/test/test_exploration.py index f65ea655de2..83ee4bc4220 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -31,10 +31,10 @@ from torchrl.envs.transforms.transforms import gSDENoise, InitTracker, TransformedEnv from torchrl.envs.utils import set_exploration_type from torchrl.modules import SafeModule, SafeSequential -from torchrl.modules.distributions import TanhNormal -from torchrl.modules.distributions.continuous import ( +from torchrl.modules.distributions import ( IndependentNormal, - NormalParamWrapper, + NormalParamExtractor, + TanhNormal, ) from torchrl.modules.models.exploration import LazygSDEModule from torchrl.modules.tensordict_module.actors import ( @@ -44,9 +44,11 @@ ) from torchrl.modules.tensordict_module.exploration import ( _OrnsteinUhlenbeckProcess, + AdditiveGaussianModule, AdditiveGaussianWrapper, EGreedyModule, EGreedyWrapper, + OrnsteinUhlenbeckProcessModule, OrnsteinUhlenbeckProcessWrapper, ) @@ -203,8 +205,8 @@ def test_wrong_action_shape(self, module): @pytest.mark.parametrize("device", get_default_devices()) -class TestOrnsteinUhlenbeckProcessWrapper: - def test_ou(self, device, seed=0): +class TestOrnsteinUhlenbeckProcess: + def test_ou_process(self, device, seed=0): torch.manual_seed(seed) td = TensorDict({"action": torch.randn(3) / 10}, batch_size=[], device=device) ou = _OrnsteinUhlenbeckProcess(10.0, mu=2.0, x0=-4, sigma=0.1, sigma_min=0.01) @@ -229,9 +231,14 @@ def test_ou(self, device, seed=0): assert pval_acc > 0.05 assert pval_reg < 0.1 - def test_ou_wrapper(self, device, d_obs=4, d_act=6, batch=32, n_steps=100, seed=0): + @pytest.mark.parametrize("interface", ["module", "wrapper"]) + def test_ou( + self, device, interface, d_obs=4, d_act=6, batch=32, n_steps=100, seed=0 + ): torch.manual_seed(seed) - net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device) + net = nn.Sequential(nn.Linear(d_obs, 2 * d_act), NormalParamExtractor()).to( + device + ) module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) action_spec = BoundedTensorSpec(-torch.ones(d_act), torch.ones(d_act), (d_act,)) policy = ProbabilisticActor( @@ -241,7 +248,13 @@ def test_ou_wrapper(self, device, d_obs=4, d_act=6, batch=32, n_steps=100, seed= distribution_class=TanhNormal, default_interaction_type=InteractionType.RANDOM, ).to(device) - exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy) + + if interface == "module": + ou = OrnsteinUhlenbeckProcessModule(spec=action_spec).to(device) + exploratory_policy = TensorDictSequential(policy, ou) + else: + exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy) + ou = exploratory_policy tensordict = TensorDict( batch_size=[batch], @@ -261,13 +274,11 @@ def test_ou_wrapper(self, device, d_obs=4, d_act=6, batch=32, n_steps=100, seed= ) tensordict = exploratory_policy(tensordict.clone()) if i == 0: - assert (tensordict[exploratory_policy.ou.steps_key] == 1).all() + assert (tensordict[ou.ou.steps_key] == 1).all() elif i == n_steps // 2 + 1: - assert ( - tensordict[exploratory_policy.ou.steps_key][: batch // 2] == 1 - ).all() + assert (tensordict[ou.ou.steps_key][: batch // 2] == 1).all() else: - assert not (tensordict[exploratory_policy.ou.steps_key] == 1).any() + assert not (tensordict[ou.ou.steps_key] == 1).any() out.append(tensordict.clone()) out_noexp.append(tensordict_noexp.clone()) @@ -284,7 +295,8 @@ def test_ou_wrapper(self, device, d_obs=4, d_act=6, batch=32, n_steps=100, seed= @pytest.mark.parametrize("parallel_spec", [True, False]) @pytest.mark.parametrize("probabilistic", [True, False]) - def test_collector(self, device, parallel_spec, probabilistic, seed=0): + @pytest.mark.parametrize("interface", ["module", "wrapper"]) + def test_collector(self, device, parallel_spec, probabilistic, interface, seed=0): torch.manual_seed(seed) env = SerialEnv( 2, @@ -298,7 +310,9 @@ def test_collector(self, device, parallel_spec, probabilistic, seed=0): action_spec = ContinuousActionVecMockEnv(device=device).action_spec d_act = action_spec.shape[-1] if probabilistic: - net = NormalParamWrapper(nn.LazyLinear(2 * d_act)).to(device) + net = nn.Sequential(nn.LazyLinear(2 * d_act), NormalParamExtractor()).to( + device + ) module = SafeModule( net, in_keys=["observation"], @@ -317,7 +331,12 @@ def test_collector(self, device, parallel_spec, probabilistic, seed=0): net, in_keys=["observation"], out_keys=["action"], spec=action_spec ) - exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy) + if interface == "module": + exploratory_policy = TensorDictSequential( + policy, OrnsteinUhlenbeckProcessModule(spec=action_spec).to(device) + ) + else: + exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy) exploratory_policy(env.reset()) collector = SyncDataCollector( create_env_fn=env, @@ -334,12 +353,14 @@ def test_collector(self, device, parallel_spec, probabilistic, seed=0): @pytest.mark.parametrize("nested_obs_action", [True, False]) @pytest.mark.parametrize("nested_done", [True, False]) @pytest.mark.parametrize("is_init_key", ["some"]) + @pytest.mark.parametrize("interface", ["module", "wrapper"]) def test_nested( self, device, nested_obs_action, nested_done, is_init_key, + interface, seed=0, n_envs=2, nested_dim=5, @@ -368,9 +389,20 @@ def test_nested( in_keys=[("data", "states") if nested_obs_action else "observation"], out_keys=[env.action_key], ) - exploratory_policy = OrnsteinUhlenbeckProcessWrapper( - policy, spec=action_spec, action_key=env.action_key, is_init_key=is_init_key - ) + if interface == "module": + exploratory_policy = TensorDictSequential( + policy, + OrnsteinUhlenbeckProcessModule( + spec=action_spec, action_key=env.action_key, is_init_key=is_init_key + ).to(device), + ) + else: + exploratory_policy = OrnsteinUhlenbeckProcessWrapper( + policy, + spec=action_spec, + action_key=env.action_key, + is_init_key=is_init_key, + ) collector = SyncDataCollector( create_env_fn=env, policy=exploratory_policy, @@ -388,43 +420,61 @@ def test_nested( return + def test_no_spec_error(self, device): + with pytest.raises(RuntimeError, match="spec cannot be None."): + OrnsteinUhlenbeckProcessModule(spec=None).to(device) + @pytest.mark.parametrize("device", get_default_devices()) class TestAdditiveGaussian: @pytest.mark.parametrize("spec_origin", ["spec", "policy", None]) + @pytest.mark.parametrize("interface", ["module", "wrapper"]) def test_additivegaussian_sd( self, device, spec_origin, + interface, d_obs=4, d_act=6, batch=32, n_steps=100, seed=0, ): + if interface == "module" and spec_origin != "spec": + pytest.skip("module raises an error if given spec=None") + torch.manual_seed(seed) - net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device) action_spec = BoundedTensorSpec( -torch.ones(d_act, device=device), torch.ones(d_act, device=device), (d_act,), device=device, ) - module = SafeModule( - net, - in_keys=["observation"], - out_keys=["loc", "scale"], - spec=None, - ) - policy = ProbabilisticActor( - spec=CompositeSpec(action=action_spec) if spec_origin is not None else None, - module=module, - in_keys=["loc", "scale"], - distribution_class=TanhNormal, - default_interaction_type=InteractionType.RANDOM, - ).to(device) - given_spec = action_spec if spec_origin == "spec" else None - exploratory_policy = AdditiveGaussianWrapper(policy, spec=given_spec).to(device) + if interface == "module": + exploratory_policy = AdditiveGaussianModule(action_spec).to(device) + else: + net = nn.Sequential(nn.Linear(d_obs, 2 * d_act), NormalParamExtractor()).to( + device + ) + module = SafeModule( + net, + in_keys=["observation"], + out_keys=["loc", "scale"], + spec=None, + ) + policy = ProbabilisticActor( + spec=CompositeSpec(action=action_spec) + if spec_origin is not None + else None, + module=module, + in_keys=["loc", "scale"], + distribution_class=TanhNormal, + default_interaction_type=InteractionType.RANDOM, + ).to(device) + given_spec = action_spec if spec_origin == "spec" else None + exploratory_policy = AdditiveGaussianWrapper(policy, spec=given_spec).to( + device + ) if spec_origin is not None: sigma_init = ( action_spec.project( @@ -442,9 +492,14 @@ def test_additivegaussian_sd( sigma_init = exploratory_policy.sigma_init sigma_end = exploratory_policy.sigma_end if spec_origin is None: + class_name = ( + "AdditiveGaussianModule" + if interface == "module" + else "AdditiveGaussianWrapper" + ) with pytest.raises( RuntimeError, - match="the action spec must be provided to AdditiveGaussianWrapper", + match=f"the action spec must be provided to {class_name}", ): exploratory_policy._add_noise(action_spec.rand((100000,)).zero_()) return @@ -466,11 +521,25 @@ def test_additivegaussian_sd( assert abs(noisy_action.std() - sigma_end) < 1e-1 @pytest.mark.parametrize("spec_origin", ["spec", "policy", None]) - def test_additivegaussian_wrapper( - self, device, spec_origin, d_obs=4, d_act=6, batch=32, n_steps=100, seed=0 + @pytest.mark.parametrize("interface", ["module", "wrapper"]) + def test_additivegaussian( + self, + device, + spec_origin, + interface, + d_obs=4, + d_act=6, + batch=32, + n_steps=100, + seed=0, ): + if interface == "module" and spec_origin != "spec": + pytest.skip("module raises an error if given spec=None") + torch.manual_seed(seed) - net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device) + net = nn.Sequential(nn.Linear(d_obs, 2 * d_act), NormalParamExtractor()).to( + device + ) module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) action_spec = BoundedTensorSpec( -torch.ones(d_act, device=device), @@ -486,9 +555,14 @@ def test_additivegaussian_wrapper( default_interaction_type=InteractionType.RANDOM, ).to(device) given_spec = action_spec if spec_origin == "spec" else None - exploratory_policy = AdditiveGaussianWrapper( - policy, spec=given_spec, safe=False - ).to(device) + if interface == "module": + exploratory_policy = TensorDictSequential( + policy, AdditiveGaussianModule(spec=given_spec).to(device) + ) + else: + exploratory_policy = AdditiveGaussianWrapper( + policy, spec=given_spec, safe=False + ).to(device) tensordict = TensorDict( batch_size=[batch], @@ -513,7 +587,8 @@ def test_additivegaussian_wrapper( assert action_spec.is_in(out.get("action")) @pytest.mark.parametrize("parallel_spec", [True, False]) - def test_collector(self, device, parallel_spec, seed=0): + @pytest.mark.parametrize("interface", ["module", "wrapper"]) + def test_collector(self, device, parallel_spec, interface, seed=0): torch.manual_seed(seed) env = SerialEnv( 2, @@ -526,7 +601,7 @@ def test_collector(self, device, parallel_spec, seed=0): else: action_spec = ContinuousActionVecMockEnv(device=device).action_spec d_act = action_spec.shape[-1] - net = NormalParamWrapper(nn.LazyLinear(2 * d_act)).to(device) + net = nn.Sequential(nn.LazyLinear(2 * d_act), NormalParamExtractor()).to(device) module = SafeModule( net, in_keys=["observation"], @@ -539,7 +614,12 @@ def test_collector(self, device, parallel_spec, seed=0): default_interaction_type=InteractionType.RANDOM, spec=action_spec, ).to(device) - exploratory_policy = AdditiveGaussianWrapper(policy, safe=False) + if interface == "module": + exploratory_policy = TensorDictSequential( + policy, AdditiveGaussianModule(spec=action_spec).to(device) + ) + else: + exploratory_policy = AdditiveGaussianWrapper(policy, safe=False) exploratory_policy(env.reset()) collector = SyncDataCollector( create_env_fn=env, @@ -553,6 +633,10 @@ def test_collector(self, device, parallel_spec, seed=0): pass return + def test_no_spec_error(self, device): + with pytest.raises(RuntimeError, match="spec cannot be None."): + AdditiveGaussianModule(spec=None).to(device) + @pytest.mark.parametrize("state_dim", [7]) @pytest.mark.parametrize("action_dim", [5, 11]) @@ -582,7 +666,7 @@ def test_gsde( else: in_keys = ["observation"] model = torch.nn.LazyLinear(action_dim * 2, device=device) - wrapper = NormalParamWrapper(model) + wrapper = nn.Sequential(model, NormalParamExtractor()) module = SafeModule(wrapper, in_keys=in_keys, out_keys=["loc", "scale"]) distribution_class = TanhNormal distribution_kwargs = {"low": -bound, "high": bound} diff --git a/test/test_modules.py b/test/test_modules.py index 592464f0a96..11cf11f41e6 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -839,6 +839,65 @@ def test_multiagent_mlp( agent_dim={-2}\)""" assert re.match(pattern, str(mlp), re.DOTALL) + @retry(AssertionError, 5) + @pytest.mark.parametrize("n_agents", [1, 3]) + @pytest.mark.parametrize("share_params", [True, False]) + @pytest.mark.parametrize("centralized", [True, False]) + @pytest.mark.parametrize("n_agent_inputs", [6, None]) + @pytest.mark.parametrize("batch", [(4,), (4, 3), ()]) + def test_multiagent_mlp_init( + self, + n_agents, + centralized, + share_params, + batch, + n_agent_inputs, + n_agent_outputs=2, + ): + torch.manual_seed(1) + mlp = MultiAgentMLP( + n_agent_inputs=n_agent_inputs, + n_agent_outputs=n_agent_outputs, + n_agents=n_agents, + centralized=centralized, + share_params=share_params, + depth=2, + ) + for m in mlp.modules(): + if isinstance(m, nn.Linear): + assert not isinstance(m.weight, nn.Parameter) + assert m.weight.device == torch.device("meta") + break + else: + raise RuntimeError("could not find a Linear module") + if n_agent_inputs is None: + n_agent_inputs = 6 + td = self._get_mock_input_td(n_agents, n_agent_inputs, batch=batch) + obs = td.get(("agents", "observation")) + mlp(obs) + snet = mlp.get_stateful_net() + assert snet is not mlp._empty_net + + def zero_inplace(mod): + if hasattr(mod, "weight"): + mod.weight.data *= 0 + if hasattr(mod, "bias"): + mod.bias.data *= 0 + + snet.apply(zero_inplace) + assert (mlp.params == 0).all() + + def one_outofplace(mod): + if hasattr(mod, "weight"): + mod.weight = nn.Parameter(torch.ones_like(mod.weight.data)) + if hasattr(mod, "bias"): + mod.bias = nn.Parameter(torch.ones_like(mod.bias.data)) + + snet.apply(one_outofplace) + assert (mlp.params == 0).all() + mlp.from_stateful_net(snet) + assert (mlp.params == 1).all() + def test_multiagent_mlp_lazy(self): mlp = MultiAgentMLP( n_agent_inputs=None, diff --git a/test/test_specs.py b/test/test_specs.py index 6b779811f1d..2d597d770f0 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -3013,7 +3013,9 @@ def test_repr(self): space=None, device=cpu, dtype=torch.float32, - domain=continuous), device=cpu, shape=torch.Size([3])), + domain=continuous), + device=cpu, + shape=torch.Size([3])), 1 -> lidar: BoundedTensorSpec( shape=torch.Size([20]), @@ -3031,7 +3033,9 @@ def test_repr(self): high=Tensor(shape=torch.Size([3, 1, 2]), device=cpu, dtype=torch.float32, contiguous=True)), device=cpu, dtype=torch.float32, - domain=continuous), device=cpu, shape=torch.Size([3])), + domain=continuous), + device=cpu, + shape=torch.Size([3])), 2 -> individual_2_obs: CompositeSpec( individual_1_obs_0: UnboundedContinuousTensorSpec( @@ -3039,7 +3043,9 @@ def test_repr(self): space=None, device=cpu, dtype=torch.float32, - domain=continuous), device=cpu, shape=torch.Size([3]))}}, + domain=continuous), + device=cpu, + shape=torch.Size([3]))}}, device=cpu, shape={torch.Size((3,))}, stack_dim={c.stack_dim})""" diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 6f81a9748bc..38360a464e0 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -27,14 +27,14 @@ ) from torchrl.envs.utils import set_exploration_type, step_mdp from torchrl.modules import ( - AdditiveGaussianWrapper, + AdditiveGaussianModule, DecisionTransformerInferenceWrapper, DTActor, GRUModule, LSTMModule, MLP, MultiStepActorWrapper, - NormalParamWrapper, + NormalParamExtractor, OnlineDTActor, ProbabilisticActor, SafeModule, @@ -201,7 +201,7 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys) in_keys = ["in"] net = SafeModule( - module=NormalParamWrapper(net), + module=nn.Sequential(net, NormalParamExtractor()), spec=None, in_keys=in_keys, out_keys=out_keys, @@ -363,7 +363,7 @@ def test_stateful_probabilistic(self, safe, spec_type, lazy): net1 = nn.Linear(3, 4) dummy_net = nn.Linear(4, 4) net2 = nn.Linear(4, 4 * param_multiplier) - net2 = NormalParamWrapper(net2) + net2 = nn.Sequential(net2, NormalParamExtractor()) if spec_type is None: spec = None @@ -474,11 +474,11 @@ def test_sequential_partial(self, stack): net1 = nn.Linear(3, 4) net2 = nn.Linear(4, 4 * param_multiplier) - net2 = NormalParamWrapper(net2) + net2 = nn.Sequential(net2, NormalParamExtractor()) net2 = SafeModule(net2, in_keys=["b"], out_keys=["loc", "scale"]) net3 = nn.Linear(4, 4 * param_multiplier) - net3 = NormalParamWrapper(net3) + net3 = nn.Sequential(net3, NormalParamExtractor()) net3 = SafeModule(net3, in_keys=["c"], out_keys=["loc", "scale"]) spec = BoundedTensorSpec(-0.1, 0.1, 4) @@ -1363,17 +1363,19 @@ def test_actor_critic_specs(): out_keys=[action_key], ) original_spec = spec.clone() - module = AdditiveGaussianWrapper(policy_module, spec=spec, action_key=action_key) + module = TensorDictSequential( + policy_module, AdditiveGaussianModule(spec=spec, action_key=action_key) + ) value_module = ValueOperator( module=module, in_keys=[("agents", "observation"), action_key], out_keys=[("agents", "state_action_value")], ) assert original_spec == spec - assert module.spec == spec + assert module[1].spec == spec DDPGLoss(actor_network=module, value_network=value_module) assert original_spec == spec - assert module.spec == spec + assert module[1].spec == spec def test_vmapmodule(): diff --git a/test/test_transforms.py b/test/test_transforms.py index fcfd6f08aff..94ec8b2716c 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -22,6 +22,7 @@ import torch from _utils_internal import ( # noqa + BREAKOUT_VERSIONED, dtype_fixture, get_default_devices, HALFCHEETAH_VERSIONED, @@ -69,6 +70,7 @@ CenterCrop, ClipTransform, Compose, + Crop, DeviceCastTransform, DiscreteActionProjection, DMControlEnv, @@ -248,7 +250,10 @@ def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv( @@ -257,7 +262,10 @@ def test_trans_serial_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( @@ -267,7 +275,10 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("batch", [[], [4], [6, 4]]) @@ -572,7 +583,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_serial_trans_env_check(self): def make_env(): @@ -604,7 +618,10 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = ContinuousActionVecMockEnv() @@ -650,7 +667,10 @@ def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv( @@ -674,7 +694,10 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.skipif(not _has_gym, reason="Test executed on gym") @pytest.mark.parametrize("batched_class", [ParallelEnv, SerialEnv]) @@ -1613,15 +1636,18 @@ def test_single_trans_env_check(self, update_done, max_steps): assert "truncated" not in r.keys() assert ("next", "truncated") not in r.keys(True) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv(ContinuousActionVecMockEnv(), StepCounter(10)) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_serial_trans_env_check(self): def make_env(): @@ -1630,14 +1656,17 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, ContinuousActionVecMockEnv), StepCounter(10) + maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), StepCounter(10) ) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv(SerialEnv(2, ContinuousActionVecMockEnv), StepCounter(10)) @@ -1907,7 +1936,7 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): ct = CatTensors( in_keys=["observation", "observation_orig"], @@ -1917,11 +1946,14 @@ def make_env(): ) return TransformedEnv(ContinuousActionVecMockEnv(), ct) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): ct = CatTensors( @@ -1934,7 +1966,7 @@ def test_trans_serial_env_check(self): env = TransformedEnv(SerialEnv(2, ContinuousActionVecMockEnv), ct) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): ct = CatTensors( in_keys=["observation", "observation_orig"], out_key="observation_out", @@ -1942,11 +1974,14 @@ def test_trans_parallel_env_check(self): del_keys=False, ) - env = TransformedEnv(ParallelEnv(2, ContinuousActionVecMockEnv), ct) + env = TransformedEnv(maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), ct) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize( @@ -2101,6 +2136,213 @@ def test_transform_inverse(self): raise pytest.skip("No inverse for CatTensors") +@pytest.mark.skipif(not _has_tv, reason="no torchvision") +class TestCrop(TransformBase): + @pytest.mark.parametrize("nchannels", [1, 3]) + @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) + @pytest.mark.parametrize("h", [None, 21]) + @pytest.mark.parametrize( + "keys", [["observation", ("some_other", "nested_key")], ["observation_pixels"]] + ) + @pytest.mark.parametrize("device", get_default_devices()) + def test_transform_no_env(self, keys, h, nchannels, batch, device): + torch.manual_seed(0) + dont_touch = torch.randn(*batch, nchannels, 16, 16, device=device) + crop = Crop(w=20, h=h, in_keys=keys) + if h is None: + h = 20 + td = TensorDict( + { + key: torch.randn(*batch, nchannels, 16, 16, device=device) + for key in keys + }, + batch, + device=device, + ) + td.set("dont touch", dont_touch.clone()) + crop(td) + for key in keys: + assert td.get(key).shape[-2:] == torch.Size([20, h]) + assert (td.get("dont touch") == dont_touch).all() + + if len(keys) == 1: + observation_spec = BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) + observation_spec = crop.transform_observation_spec(observation_spec) + assert observation_spec.shape == torch.Size([nchannels, 20, h]) + else: + observation_spec = CompositeSpec( + {key: BoundedTensorSpec(-1, 1, (nchannels, 16, 16)) for key in keys} + ) + observation_spec = crop.transform_observation_spec(observation_spec) + for key in keys: + assert observation_spec[key].shape == torch.Size([nchannels, 20, h]) + + @pytest.mark.parametrize("nchannels", [3]) + @pytest.mark.parametrize("batch", [[2]]) + @pytest.mark.parametrize("h", [None]) + @pytest.mark.parametrize("keys", [["observation_pixels"]]) + @pytest.mark.parametrize("device", get_default_devices()) + def test_transform_model(self, keys, h, nchannels, batch, device): + torch.manual_seed(0) + dont_touch = torch.randn(*batch, nchannels, 16, 16, device=device) + crop = Crop(w=20, h=h, in_keys=keys) + if h is None: + h = 20 + td = TensorDict( + { + key: torch.randn(*batch, nchannels, 16, 16, device=device) + for key in keys + }, + batch, + device=device, + ) + td.set("dont touch", dont_touch.clone()) + model = nn.Sequential(crop, nn.Identity()) + model(td) + for key in keys: + assert td.get(key).shape[-2:] == torch.Size([20, h]) + assert (td.get("dont touch") == dont_touch).all() + + @pytest.mark.parametrize("nchannels", [3]) + @pytest.mark.parametrize("batch", [[2]]) + @pytest.mark.parametrize("h", [None]) + @pytest.mark.parametrize("keys", [["observation_pixels"]]) + @pytest.mark.parametrize("device", get_default_devices()) + def test_transform_compose(self, keys, h, nchannels, batch, device): + torch.manual_seed(0) + dont_touch = torch.randn(*batch, nchannels, 16, 16, device=device) + crop = Crop(w=20, h=h, in_keys=keys) + if h is None: + h = 20 + td = TensorDict( + { + key: torch.randn(*batch, nchannels, 16, 16, device=device) + for key in keys + }, + batch, + device=device, + ) + td.set("dont touch", dont_touch.clone()) + model = Compose(crop) + tdc = model(td.clone()) + for key in keys: + assert tdc.get(key).shape[-2:] == torch.Size([20, h]) + assert (tdc.get("dont touch") == dont_touch).all() + tdc = model._call(td.clone()) + for key in keys: + assert tdc.get(key).shape[-2:] == torch.Size([20, h]) + assert (tdc.get("dont touch") == dont_touch).all() + + @pytest.mark.parametrize("nchannels", [3]) + @pytest.mark.parametrize("batch", [[2]]) + @pytest.mark.parametrize("h", [None]) + @pytest.mark.parametrize("keys", [["observation_pixels"]]) + @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) + def test_transform_rb( + self, + rbclass, + keys, + h, + nchannels, + batch, + ): + torch.manual_seed(0) + dont_touch = torch.randn( + *batch, + nchannels, + 16, + 16, + ) + crop = Crop(w=20, h=h, in_keys=keys) + if h is None: + h = 20 + td = TensorDict( + { + key: torch.randn( + *batch, + nchannels, + 16, + 16, + ) + for key in keys + }, + batch, + ) + td.set("dont touch", dont_touch.clone()) + rb = rbclass(storage=LazyTensorStorage(10)) + rb.append_transform(crop) + rb.extend(td) + td = rb.sample(10) + for key in keys: + assert td.get(key).shape[-2:] == torch.Size([20, h]) + + def test_single_trans_env_check(self): + keys = ["pixels"] + ct = Compose(ToTensorImage(), Crop(w=20, h=20, in_keys=keys)) + env = TransformedEnv(DiscreteActionConvMockEnvNumpy(), ct) + check_env_specs(env) + + def test_serial_trans_env_check(self): + keys = ["pixels"] + + def make_env(): + ct = Compose(ToTensorImage(), Crop(w=20, h=20, in_keys=keys)) + return TransformedEnv(DiscreteActionConvMockEnvNumpy(), ct) + + env = SerialEnv(2, make_env) + check_env_specs(env) + + def test_parallel_trans_env_check(self): + keys = ["pixels"] + + def make_env(): + ct = Compose(ToTensorImage(), Crop(w=20, h=20, in_keys=keys)) + return TransformedEnv(DiscreteActionConvMockEnvNumpy(), ct) + + env = ParallelEnv(2, make_env) + try: + check_env_specs(env) + finally: + try: + env.close() + except RuntimeError: + pass + + def test_trans_serial_env_check(self): + keys = ["pixels"] + ct = Compose(ToTensorImage(), Crop(w=20, h=20, in_keys=keys)) + env = TransformedEnv(SerialEnv(2, DiscreteActionConvMockEnvNumpy), ct) + check_env_specs(env) + + def test_trans_parallel_env_check(self): + keys = ["pixels"] + ct = Compose(ToTensorImage(), Crop(w=20, h=20, in_keys=keys)) + env = TransformedEnv(ParallelEnv(2, DiscreteActionConvMockEnvNumpy), ct) + try: + check_env_specs(env) + finally: + try: + env.close() + except RuntimeError: + pass + + @pytest.mark.skipif(not _has_gym, reason="No Gym detected") + @pytest.mark.parametrize("out_key", [None, ["outkey"], [("out", "key")]]) + def test_transform_env(self, out_key): + keys = ["pixels"] + ct = Compose(ToTensorImage(), Crop(out_keys=out_key, w=20, h=20, in_keys=keys)) + env = TransformedEnv(GymEnv(PONG_VERSIONED()), ct) + td = env.reset() + if out_key is None: + assert td["pixels"].shape == torch.Size([3, 20, 20]) + else: + assert td[out_key[0]].shape == torch.Size([3, 20, 20]) + check_env_specs(env) + + def test_transform_inverse(self): + raise pytest.skip("Crop does not have an inverse method.") + + @pytest.mark.skipif(not _has_tv, reason="no torchvision") class TestCenterCrop(TransformBase): @pytest.mark.parametrize("nchannels", [1, 3]) @@ -2143,18 +2385,8 @@ def test_transform_no_env(self, keys, h, nchannels, batch, device): assert observation_spec[key].shape == torch.Size([nchannels, 20, h]) @pytest.mark.parametrize("nchannels", [3]) - @pytest.mark.parametrize( - "batch", - [ - [2], - ], - ) - @pytest.mark.parametrize( - "h", - [ - None, - ], - ) + @pytest.mark.parametrize("batch", [[2]]) + @pytest.mark.parametrize("h", [None]) @pytest.mark.parametrize("keys", [["observation_pixels"]]) @pytest.mark.parametrize("device", get_default_devices()) def test_transform_model(self, keys, h, nchannels, batch, device): @@ -2179,18 +2411,8 @@ def test_transform_model(self, keys, h, nchannels, batch, device): assert (td.get("dont touch") == dont_touch).all() @pytest.mark.parametrize("nchannels", [3]) - @pytest.mark.parametrize( - "batch", - [ - [2], - ], - ) - @pytest.mark.parametrize( - "h", - [ - None, - ], - ) + @pytest.mark.parametrize("batch", [[2]]) + @pytest.mark.parametrize("h", [None]) @pytest.mark.parametrize("keys", [["observation_pixels"]]) @pytest.mark.parametrize("device", get_default_devices()) def test_transform_compose(self, keys, h, nchannels, batch, device): @@ -2219,18 +2441,8 @@ def test_transform_compose(self, keys, h, nchannels, batch, device): assert (tdc.get("dont touch") == dont_touch).all() @pytest.mark.parametrize("nchannels", [3]) - @pytest.mark.parametrize( - "batch", - [ - [2], - ], - ) - @pytest.mark.parametrize( - "h", - [ - None, - ], - ) + @pytest.mark.parametrize("batch", [[2]]) + @pytest.mark.parametrize("h", [None]) @pytest.mark.parametrize("keys", [["observation_pixels"]]) @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) def test_transform_rb( @@ -2298,7 +2510,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): keys = ["pixels"] @@ -2306,14 +2521,19 @@ def test_trans_serial_env_check(self): env = TransformedEnv(SerialEnv(2, DiscreteActionConvMockEnvNumpy), ct) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): keys = ["pixels"] ct = Compose(ToTensorImage(), CenterCrop(w=20, h=20, in_keys=keys)) - env = TransformedEnv(ParallelEnv(2, DiscreteActionConvMockEnvNumpy), ct) + env = TransformedEnv( + maybe_fork_ParallelEnv(2, DiscreteActionConvMockEnvNumpy), ct + ) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.skipif(not _has_gym, reason="No Gym detected") @pytest.mark.parametrize("out_key", [None, ["outkey"], [("out", "key")]]) @@ -2350,17 +2570,20 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv( DiscreteActionConvMockEnvNumpy(), DiscreteActionProjection(7, 10) ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv( @@ -2369,15 +2592,18 @@ def test_trans_serial_env_check(self): ) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, DiscreteActionConvMockEnvNumpy), + maybe_fork_ParallelEnv(2, DiscreteActionConvMockEnvNumpy), DiscreteActionProjection(7, 10), ) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("action_key", ["action", ("nested", "stuff")]) def test_transform_no_env(self, action_key): @@ -2594,7 +2820,9 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self, dtype_fixture): # noqa: F811 + def test_parallel_trans_env_check( + self, dtype_fixture, maybe_fork_ParallelEnv # noqa: F811 + ): def make_env(): return TransformedEnv( ContinuousActionVecMockEnv(dtype=torch.float64), @@ -2602,10 +2830,13 @@ def make_env(): ) try: - env = ParallelEnv(1, make_env) + env = maybe_fork_ParallelEnv(1, make_env) check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass del env def test_trans_serial_env_check(self, dtype_fixture): # noqa: F811 @@ -2615,15 +2846,22 @@ def test_trans_serial_env_check(self, dtype_fixture): # noqa: F811 ) check_env_specs(env) - def test_trans_parallel_env_check(self, dtype_fixture): # noqa: F811 + def test_trans_parallel_env_check( + self, dtype_fixture, maybe_fork_ParallelEnv # noqa: F811 + ): env = TransformedEnv( - ParallelEnv(2, lambda: ContinuousActionVecMockEnv(dtype=torch.float64)), + maybe_fork_ParallelEnv( + 2, lambda: ContinuousActionVecMockEnv(dtype=torch.float64) + ), DoubleToFloat(in_keys=["observation"], in_keys_inv=["action"]), ) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_transform_no_env(self, dtype_fixture): # noqa: F811 t = DoubleToFloat(in_keys=["observation"], in_keys_inv=["action"]) @@ -2763,7 +3001,7 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): t = Compose( CatTensors( @@ -2774,11 +3012,14 @@ def make_env(): env = TransformedEnv(ContinuousActionVecMockEnv(), t) return env - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): t = Compose( @@ -2790,18 +3031,21 @@ def test_trans_serial_env_check(self): env = TransformedEnv(SerialEnv(2, ContinuousActionVecMockEnv), t) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): t = Compose( CatTensors( in_keys=["observation"], out_key="observation_copy", del_keys=False ), ExcludeTransform("observation_copy"), ) - env = TransformedEnv(ParallelEnv(2, ContinuousActionVecMockEnv), t) + env = TransformedEnv(maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), t) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_transform_env(self): base_env = TestExcludeTransform.EnvWithManyKeys() @@ -2995,7 +3239,7 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): t = Compose( CatTensors( @@ -3006,11 +3250,14 @@ def make_env(): env = TransformedEnv(ContinuousActionVecMockEnv(), t) return env - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): t = Compose( @@ -3022,18 +3269,21 @@ def test_trans_serial_env_check(self): env = TransformedEnv(SerialEnv(2, ContinuousActionVecMockEnv), t) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): t = Compose( CatTensors( in_keys=["observation"], out_key="observation_copy", del_keys=False ), SelectTransform("observation", "observation_orig"), ) - env = TransformedEnv(ParallelEnv(2, ContinuousActionVecMockEnv), t) + env = TransformedEnv(maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), t) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_transform_env(self): base_env = TestExcludeTransform.EnvWithManyKeys() @@ -3192,18 +3442,21 @@ def make_env(): env = SerialEnv(2, make_env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): env = TransformedEnv( DiscreteActionConvMockEnvNumpy(), FlattenObservation(-3, -1) ) return env - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv( @@ -3215,9 +3468,9 @@ def test_trans_serial_env_check(self): ) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, DiscreteActionConvMockEnvNumpy), + maybe_fork_ParallelEnv(2, DiscreteActionConvMockEnvNumpy), FlattenObservation( -3, -1, @@ -3226,7 +3479,10 @@ def test_trans_parallel_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.skipif(not _has_tv, reason="no torchvision") @pytest.mark.parametrize("nchannels", [1, 3]) @@ -3372,16 +3628,19 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): env = TransformedEnv(ContinuousActionVecMockEnv(), FrameSkipTransform(2)) return env - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv( @@ -3389,14 +3648,17 @@ def test_trans_serial_env_check(self): ) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, ContinuousActionVecMockEnv), FrameSkipTransform(2) + maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), FrameSkipTransform(2) ) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_transform_no_env(self): t = FrameSkipTransform(2) @@ -3609,7 +3871,7 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): out_keys = None def make_env(): @@ -3618,11 +3880,14 @@ def make_env(): Compose(ToTensorImage(), GrayScale(out_keys=out_keys)), ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): out_keys = None @@ -3632,16 +3897,19 @@ def test_trans_serial_env_check(self): ) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): out_keys = None env = TransformedEnv( - ParallelEnv(2, DiscreteActionConvMockEnvNumpy), + maybe_fork_ParallelEnv(2, DiscreteActionConvMockEnvNumpy), Compose(ToTensorImage(), GrayScale(out_keys=out_keys)), ) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("out_keys", [None, ["stuff"]]) def test_transform_env(self, out_keys): @@ -3709,15 +3977,18 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv(ContinuousActionVecMockEnv(), NoopResetEnv()) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv(SerialEnv(2, ContinuousActionVecMockEnv), NoopResetEnv()) @@ -3873,9 +4144,7 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check( - self, - ): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv( ContinuousActionVecMockEnv(), @@ -3886,11 +4155,14 @@ def make_env(): ), ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check( self, @@ -3905,11 +4177,9 @@ def test_trans_serial_env_check( ) check_env_specs(env) - def test_trans_parallel_env_check( - self, - ): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, ContinuousActionVecMockEnv), + maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), ObservationNorm( loc=torch.zeros(7), in_keys=["observation"], @@ -3919,7 +4189,10 @@ def test_trans_parallel_env_check( try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("standard_normal", [True, False]) @pytest.mark.parametrize("in_key", ["observation", ("some_other", "observation")]) @@ -4462,18 +4735,21 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv( DiscreteActionConvMockEnvNumpy(), Compose(ToTensorImage(), Resize(20, 21, in_keys=["pixels"])), ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv( @@ -4482,15 +4758,18 @@ def test_trans_serial_env_check(self): ) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, DiscreteActionConvMockEnvNumpy), + maybe_fork_ParallelEnv(2, DiscreteActionConvMockEnvNumpy), Compose(ToTensorImage(), Resize(20, 21, in_keys=["pixels"])), ) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.skipif(not _has_gym, reason="No gym") @pytest.mark.parametrize("out_key", ["pixels", ("agents", "pixels")]) @@ -4539,17 +4818,20 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv( ContinuousActionVecMockEnv(), RewardClipping(-0.1, 0.1) ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv( @@ -4557,14 +4839,18 @@ def test_trans_serial_env_check(self): ) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, ContinuousActionVecMockEnv), RewardClipping(-0.1, 0.1) + maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), + RewardClipping(-0.1, 0.1), ) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("reward_key", ["reward", ("agents", "reward")]) def test_transform_no_env(self, reward_key): @@ -4676,15 +4962,18 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv(ContinuousActionVecMockEnv(), RewardScaling(0.5, 1.5)) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv( @@ -4692,14 +4981,18 @@ def test_trans_serial_env_check(self): ) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, ContinuousActionVecMockEnv), RewardScaling(0.5, 1.5) + maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), + RewardScaling(0.5, 1.5), ) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("standard_normal", [True, False]) def test_transform_no_env(self, standard_normal): @@ -4804,20 +5097,23 @@ def make_env(): r = env.rollout(4) assert r["next", "episode_reward"].unique().numel() > 1 - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv( ContinuousActionVecMockEnv(), Compose(RewardScaling(loc=-1, scale=1), RewardSum()), ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) r = env.rollout(4) assert r["next", "episode_reward"].unique().numel() > 1 finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv( @@ -4828,9 +5124,9 @@ def test_trans_serial_env_check(self): r = env.rollout(4) assert r["next", "episode_reward"].unique().numel() > 1 - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, ContinuousActionVecMockEnv), + maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), Compose(RewardScaling(loc=-1, scale=1), RewardSum()), ) try: @@ -4838,7 +5134,10 @@ def test_trans_parallel_env_check(self): r = env.rollout(4) assert r["next", "episode_reward"].unique().numel() > 1 finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("has_in_keys,", [True, False]) @pytest.mark.parametrize("reset_keys,", [None, ["_reset"] * 3]) @@ -5480,18 +5779,21 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv( ContinuousActionVecMockEnv(), UnsqueezeTransform(-1, in_keys=["observation"]), ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv( @@ -5500,15 +5802,18 @@ def test_trans_serial_env_check(self): ) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, ContinuousActionVecMockEnv), + maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), UnsqueezeTransform(-1, in_keys=["observation"]), ) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("unsqueeze_dim", [1, -2]) @pytest.mark.parametrize("nchannels", [1, 3]) @@ -5786,17 +6091,20 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv( ContinuousActionVecMockEnv(), self._circular_transform ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv( @@ -5805,16 +6113,23 @@ def test_trans_serial_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, ContinuousActionVecMockEnv), self._circular_transform + maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), + self._circular_transform, ) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("squeeze_dim", [1, -2]) @pytest.mark.parametrize("nchannels", [1, 3]) @@ -5995,7 +6310,7 @@ def make_env(): @pytest.mark.parametrize("mode", ["reduce", "constant"]) @pytest.mark.parametrize("device", get_default_devices()) - def test_parallel_trans_env_check(self, mode, device): + def test_parallel_trans_env_check(self, mode, device, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv( ContinuousActionVecMockEnv(), @@ -6003,11 +6318,14 @@ def make_env(): device=device, ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("mode", ["reduce", "constant"]) @pytest.mark.parametrize("device", get_default_devices()) @@ -6020,20 +6338,26 @@ def test_trans_serial_env_check(self, mode, device): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("mode", ["reduce", "constant"]) @pytest.mark.parametrize("device", get_default_devices()) - def test_trans_parallel_env_check(self, mode, device): + def test_trans_parallel_env_check(self, mode, device, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, DiscreteActionConvMockEnvNumpy).to(device), + maybe_fork_ParallelEnv(2, DiscreteActionConvMockEnvNumpy).to(device), TargetReturn(target_return=10.0, mode=mode), device=device, ) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.skipif(not _has_gym, reason="Test executed on gym") @pytest.mark.parametrize("batched_class", [SerialEnv, ParallelEnv]) @@ -6231,18 +6555,21 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv( DiscreteActionConvMockEnvNumpy(), ToTensorImage(in_keys=["pixels"], out_keys=None), ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv( @@ -6251,15 +6578,18 @@ def test_trans_serial_env_check(self): ) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, DiscreteActionConvMockEnvNumpy), + maybe_fork_ParallelEnv(2, DiscreteActionConvMockEnvNumpy), ToTensorImage(in_keys=["pixels"], out_keys=None), ) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("out_keys", [None, ["stuff"], [("nested", "stuff")]]) @pytest.mark.parametrize("default_dtype", [torch.float32, torch.float64]) @@ -6379,20 +6709,23 @@ def test_transform_compose(self): t(td) assert "mykey" in td.keys() - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv( ContinuousActionVecMockEnv(), TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([3])), ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) assert "mykey" in env.reset().keys() assert ("next", "mykey") in env.rollout(3).keys(True) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_serial_trans_env_check(self): def make_env(): @@ -6407,11 +6740,14 @@ def make_env(): assert "mykey" in env.reset().keys() assert ("next", "mykey") in env.rollout(3).keys(True) finally: - env.close() + try: + env.close() + except RuntimeError: + pass - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, ContinuousActionVecMockEnv), + maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), TensorDictPrimer(mykey=UnboundedContinuousTensorSpec([2, 4])), ) try: @@ -6421,7 +6757,10 @@ def test_trans_parallel_env_check(self): assert ("next", "mykey") in r.keys(True) assert r["next", "mykey"].shape == torch.Size([2, 3, 4]) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("spec_shape", [[4], [2, 4]]) def test_trans_serial_env_check(self, spec_shape): @@ -6670,8 +7009,8 @@ def test_serial_trans_env_check(self): ) check_env_specs(env) - def test_parallel_trans_env_check(self): - env = ParallelEnv( + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): + env = maybe_fork_ParallelEnv( 2, lambda: TransformedEnv( ContinuousActionVecMockEnv(), @@ -6684,7 +7023,10 @@ def test_parallel_trans_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv( @@ -6696,9 +7038,9 @@ def test_trans_serial_env_check(self): ) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, lambda: ContinuousActionVecMockEnv()), + maybe_fork_ParallelEnv(2, lambda: ContinuousActionVecMockEnv()), TimeMaxPool( in_keys=["observation"], T=3, @@ -6707,7 +7049,10 @@ def test_trans_parallel_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.skipif(not _has_gym, reason="Test executed on gym") @pytest.mark.parametrize("batched_class", [ParallelEnv, SerialEnv]) @@ -6858,7 +7203,7 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): state_dim = 7 action_dim = 7 @@ -6867,11 +7212,14 @@ def make_env(): gSDENoise(state_dim=state_dim, action_dim=action_dim), ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("shape", [(), (2,)]) def test_trans_serial_env_check(self, shape): @@ -6884,19 +7232,25 @@ def test_trans_serial_env_check(self, shape): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): state_dim = 7 action_dim = 7 env = TransformedEnv( - ParallelEnv(2, ContinuousActionVecMockEnv), + maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), gSDENoise(state_dim=state_dim, action_dim=action_dim, shape=(2,)), ) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_transform_no_env(self): state_dim = 7 @@ -8733,7 +9087,7 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self, create_copy): + def test_parallel_trans_env_check(self, create_copy, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv( ContinuousActionVecMockEnv(), @@ -8744,11 +9098,14 @@ def make_env(): ), ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def make_env(): return TransformedEnv( @@ -8762,11 +9119,14 @@ def make_env(): ), ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self, create_copy): def make_env(): @@ -8793,12 +9153,12 @@ def make_env(): ) check_env_specs(env) - def test_trans_parallel_env_check(self, create_copy): + def test_trans_parallel_env_check(self, create_copy, maybe_fork_ParallelEnv): def make_env(): return ContinuousActionVecMockEnv() env = TransformedEnv( - ParallelEnv(2, make_env), + maybe_fork_ParallelEnv(2, make_env), RenameTransform( ["observation"], ["stuff"], @@ -8808,9 +9168,12 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass env = TransformedEnv( - ParallelEnv(2, make_env), + maybe_fork_ParallelEnv(2, make_env), RenameTransform( ["observation_orig"], ["stuff"], @@ -8822,7 +9185,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("mode", ["forward", "_call"]) @pytest.mark.parametrize( @@ -9011,17 +9377,20 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): env = CountingBatchedEnv(max_steps=torch.tensor([4, 5]), batch_size=[2]) env = TransformedEnv(env, InitTracker()) return env - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): def make_env(): @@ -9033,19 +9402,25 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): def make_env(): env = CountingBatchedEnv(max_steps=torch.tensor([4, 5]), batch_size=[2]) return env - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) env = TransformedEnv(env, InitTracker()) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_transform_no_env(self): with pytest.raises(ValueError, match="init_key can only be of type str"): @@ -9304,18 +9679,21 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): out_key = "reward" def make_env(): base_env = self.envclass() return TransformedEnv(base_env, self._make_transform_env(out_key, base_env)) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): out_key = "reward" @@ -9324,16 +9702,22 @@ def test_trans_serial_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): out_key = "reward" - base_env = ParallelEnv(2, self.envclass) + base_env = maybe_fork_ParallelEnv(2, self.envclass) env = TransformedEnv(base_env, self._make_transform_env(out_key, base_env)) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_transform_model(self): actor = self._make_actor() @@ -9467,26 +9851,37 @@ def test_serial_trans_env_check(self): env = SerialEnv(2, lambda: TransformedEnv(self._env_class(), ActionMask())) check_env_specs(env) - def test_parallel_trans_env_check(self): - env = ParallelEnv(2, lambda: TransformedEnv(self._env_class(), ActionMask())) + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): + env = maybe_fork_ParallelEnv( + 2, lambda: TransformedEnv(self._env_class(), ActionMask()) + ) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv(SerialEnv(2, self._env_class), ActionMask()) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass - def test_trans_parallel_env_check(self): - env = TransformedEnv(ParallelEnv(2, self._env_class), ActionMask()) + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): + env = TransformedEnv(maybe_fork_ParallelEnv(2, self._env_class), ActionMask()) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_transform_no_env(self): t = ActionMask() @@ -9603,7 +9998,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("in_keys", ["observation"]) @pytest.mark.parametrize("out_keys", [None, ["obs_device"]]) @@ -9654,7 +10052,10 @@ def make_env(): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_transform_no_env(self): t = DeviceCastTransform("cpu:1", "cpu:0", in_keys=["a"], out_keys=["b"]) @@ -9783,18 +10184,21 @@ def make_env(): assert env.device == torch.device("cpu:1") check_env_specs(env) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): return TransformedEnv( ContinuousActionVecMockEnv(device="cpu:0"), DeviceCastTransform("cpu:1") ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) assert env.device == torch.device("cpu:1") try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): def make_env(): @@ -9804,16 +10208,21 @@ def make_env(): assert env.device == torch.device("cpu:1") check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): def make_env(): return ContinuousActionVecMockEnv(device="cpu:0") - env = TransformedEnv(ParallelEnv(2, make_env), DeviceCastTransform("cpu:1")) + env = TransformedEnv( + maybe_fork_ParallelEnv(2, make_env), DeviceCastTransform("cpu:1") + ) assert env.device == torch.device("cpu:1") try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_transform_no_env(self): t = DeviceCastTransform("cpu:1", "cpu:0") @@ -9895,8 +10304,8 @@ def test_serial_trans_env_check(self): ) check_env_specs(env) - def test_parallel_trans_env_check(self): - env = ParallelEnv( + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): + env = maybe_fork_ParallelEnv( 2, lambda: TransformedEnv( TestPermuteTransform.envclass(), TestPermuteTransform._get_permute() @@ -9905,7 +10314,10 @@ def test_parallel_trans_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv( @@ -9915,17 +10327,23 @@ def test_trans_serial_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, TestPermuteTransform.envclass), + maybe_fork_ParallelEnv(2, TestPermuteTransform.envclass), TestPermuteTransform._get_permute(), ) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) def test_transform_compose(self, batch): @@ -10019,22 +10437,22 @@ def test_transform_no_env(self, batch): reason="EndOfLifeTransform can only be tested when Gym is present.", ) class TestEndOfLife(TransformBase): - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): def make(): with set_gym_backend("gymnasium"): - return GymEnv("ALE/Breakout-v5") + return GymEnv(BREAKOUT_VERSIONED()) with pytest.warns(UserWarning, match="The base_env is not a gym env"): with pytest.raises(AttributeError): env = TransformedEnv( - ParallelEnv(2, make), transform=EndOfLifeTransform() + maybe_fork_ParallelEnv(2, make), transform=EndOfLifeTransform() ) check_env_specs(env) def test_trans_serial_env_check(self): def make(): with set_gym_backend("gymnasium"): - return GymEnv("ALE/Breakout-v5") + return GymEnv(BREAKOUT_VERSIONED()) with pytest.warns(UserWarning, match="The base_env is not a gym env"): env = TransformedEnv(SerialEnv(2, make), transform=EndOfLifeTransform()) @@ -10045,7 +10463,7 @@ def make(): def test_single_trans_env_check(self, eol_key, lives_key): with set_gym_backend("gymnasium"): env = TransformedEnv( - GymEnv("ALE/Breakout-v5"), + GymEnv(BREAKOUT_VERSIONED()), transform=EndOfLifeTransform(eol_key=eol_key, lives_key=lives_key), ) check_env_specs(env) @@ -10056,7 +10474,7 @@ def test_serial_trans_env_check(self, eol_key, lives_key): def make(): with set_gym_backend("gymnasium"): return TransformedEnv( - GymEnv("ALE/Breakout-v5"), + GymEnv(BREAKOUT_VERSIONED()), transform=EndOfLifeTransform(eol_key=eol_key, lives_key=lives_key), ) @@ -10065,19 +10483,22 @@ def make(): @pytest.mark.parametrize("eol_key", ["eol_key", ("nested", "eol")]) @pytest.mark.parametrize("lives_key", ["lives_key", ("nested", "lives")]) - def test_parallel_trans_env_check(self, eol_key, lives_key): + def test_parallel_trans_env_check(self, eol_key, lives_key, maybe_fork_ParallelEnv): def make(): with set_gym_backend("gymnasium"): return TransformedEnv( - GymEnv("ALE/Breakout-v5"), + GymEnv(BREAKOUT_VERSIONED()), transform=EndOfLifeTransform(eol_key=eol_key, lives_key=lives_key), ) - env = ParallelEnv(2, make) + env = maybe_fork_ParallelEnv(2, make) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_transform_no_env(self): t = EndOfLifeTransform() @@ -10098,7 +10519,7 @@ def test_transform_env(self, eol_key, lives_key): with set_gym_backend("gymnasium"): env = TransformedEnv( - GymEnv("ALE/Breakout-v5"), + GymEnv(BREAKOUT_VERSIONED()), transform=EndOfLifeTransform(eol_key=eol_key, lives_key=lives_key), ) check_env_specs(env) @@ -10465,7 +10886,7 @@ def test_transform_no_env(self): assert data["reward"] == 2 assert self.check_sign_applied(data["reward_sign"]) - def test_parallel_trans_env_check(self): + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): def make_env(): env = ContinuousActionVecMockEnv() return TransformedEnv( @@ -10476,11 +10897,14 @@ def make_env(): ), ) - env = ParallelEnv(2, make_env) + env = maybe_fork_ParallelEnv(2, make_env) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_serial_trans_env_check(self): def make_env(): @@ -10496,9 +10920,9 @@ def make_env(): env = SerialEnv(2, make_env) check_env_specs(env) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - ParallelEnv(2, ContinuousActionVecMockEnv), + maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), SignTransform( in_keys=["observation", "reward"], in_keys_inv=["observation_orig"], @@ -10507,7 +10931,10 @@ def test_trans_parallel_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): env = TransformedEnv( @@ -10520,7 +10947,10 @@ def test_trans_serial_env_check(self): try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass class TestRemoveEmptySpecs(TransformBase): @@ -10574,14 +11004,17 @@ def test_serial_trans_env_check(self): env = SerialEnv(2, lambda: TransformedEnv(self.DummyEnv(), RemoveEmptySpecs())) check_env_specs(env) - def test_parallel_trans_env_check(self): - env = ParallelEnv( + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): + env = maybe_fork_ParallelEnv( 2, lambda: TransformedEnv(self.DummyEnv(), RemoveEmptySpecs()) ) try: check_env_specs(env) finally: - env.close() + try: + env.close() + except RuntimeError: + pass def test_trans_serial_env_check(self): with pytest.raises( @@ -10589,11 +11022,13 @@ def test_trans_serial_env_check(self): ): env = TransformedEnv(SerialEnv(2, self.DummyEnv), RemoveEmptySpecs()) - def test_trans_parallel_env_check(self): + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): with pytest.raises( RuntimeError, match="The environment passed to ParallelEnv has empty specs" ): - env = TransformedEnv(ParallelEnv(2, self.DummyEnv), RemoveEmptySpecs()) + env = TransformedEnv( + maybe_fork_ParallelEnv(2, self.DummyEnv), RemoveEmptySpecs() + ) def test_transform_no_env(self): td = TensorDict({"a": {"b": {"c": {}}}}, []) @@ -11045,7 +11480,7 @@ def make_env(): ) return env - env = ParallelEnv(2, make_env, mp_start_method="fork") + env = ParallelEnv(2, make_env, mp_start_method=mp_ctx) check_env_specs(env) @pytest.mark.parametrize("categorical", [True, False]) @@ -11058,7 +11493,7 @@ def test_trans_serial_env_check(self, categorical): @pytest.mark.parametrize("categorical", [True, False]) def test_trans_parallel_env_check(self, categorical): env = ParallelEnv( - 2, ContinuousActionVecMockEnv, mp_start_method="fork" + 2, ContinuousActionVecMockEnv, mp_start_method=mp_ctx ).append_transform(ActionDiscretizer(num_intervals=5, categorical=categorical)) check_env_specs(env) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 32294a25edd..be24a06e39c 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -1121,7 +1121,10 @@ def _maybe_set_truncated(self, final_rollout): truncated = final_rollout["next", truncated_key] truncated[last_step] = True final_rollout["next", truncated_key] = truncated - final_rollout["next", _replace_last(truncated_key, "done")] = truncated + done = final_rollout["next", _replace_last(truncated_key, "done")] + final_rollout["next", _replace_last(truncated_key, "done")] = ( + done | truncated + ) return final_rollout @torch.no_grad() diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 0006213cd27..7c787b3ccfc 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1326,6 +1326,8 @@ class OneHotDiscreteTensorSpec(TensorSpec): discrete outcomes are sampled from an arbitrary set, whose elements will be mapped in a register to a series of unique one-hot binary vectors). + mask (torch.Tensor or None): mask some of the possible outcomes when a + sample is taken. See :meth:`~.update_mask` for more information. """ @@ -1368,6 +1370,25 @@ def n(self): return self.space.n def update_mask(self, mask): + """Sets a mask to prevent some of the possible outcomes when a sample is taken. + + The mask can also be set during initialization of the spec. + + Args: + mask (torch.Tensor or None): boolean mask. If None, the mask is + disabled. Otherwise, the shape of the mask must be expandable to + the shape of the spec. ``False`` masks an outcome and ``True`` + leaves the outcome unmasked. If all of the possible outcomes are + masked, then an error is raised when a sample is taken. + + Examples: + >>> mask = torch.tensor([True, False, False]) + >>> ts = OneHotDiscreteTensorSpec(3, (2, 3,), dtype=torch.int64, mask=mask) + >>> # All but one of the three possible outcomes are masked + >>> ts.rand() + tensor([[1, 0, 0], + [1, 0, 0]]) + """ if mask is not None: try: mask = mask.expand(self._safe_shape) @@ -2516,6 +2537,8 @@ class MultiOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec): device (str, int or torch.device, optional): device of the tensors. dtype (str or torch.dtype, optional): dtype of the tensors. + mask (torch.Tensor or None): mask some of the possible outcomes when a + sample is taken. See :meth:`~.update_mask` for more information. Examples: >>> ts = MultiOneHotDiscreteTensorSpec((3,2,3)) @@ -2564,6 +2587,28 @@ def __init__( self.update_mask(mask) def update_mask(self, mask): + """Sets a mask to prevent some of the possible outcomes when a sample is taken. + + The mask can also be set during initialization of the spec. + + Args: + mask (torch.Tensor or None): boolean mask. If None, the mask is + disabled. Otherwise, the shape of the mask must be expandable to + the shape of the spec. ``False`` masks an outcome and ``True`` + leaves the outcome unmasked. If all of the possible outcomes are + masked, then an error is raised when a sample is taken. + + Examples: + >>> mask = torch.tensor([True, False, False, + ... True, True]) + >>> ts = MultiOneHotDiscreteTensorSpec((3, 2), (2, 5), dtype=torch.int64, mask=mask) + >>> # All but one of the three possible outcomes for the first + >>> # one-hot group are masked, but neither of the two possible + >>> # outcomes for the second one-hot group are masked. + >>> ts.rand() + tensor([[1, 0, 0, 0, 1], + [1, 0, 0, 1, 0]]) + """ if mask is not None: try: mask = mask.expand(*self._safe_shape) @@ -2900,6 +2945,8 @@ class DiscreteTensorSpec(TensorSpec): shape: (torch.Size, optional): shape of the variable, default is "torch.Size([])". device (str, int or torch.device, optional): device of the tensors. dtype (str or torch.dtype, optional): dtype of the tensors. + mask (torch.Tensor or None): mask some of the possible outcomes when a + sample is taken. See :meth:`~.update_mask` for more information. """ @@ -2933,6 +2980,25 @@ def n(self): return self.space.n def update_mask(self, mask): + """Sets a mask to prevent some of the possible outcomes when a sample is taken. + + The mask can also be set during initialization of the spec. + + Args: + mask (torch.Tensor or None): boolean mask. If None, the mask is + disabled. Otherwise, the shape of the mask must be expandable to + the shape of the equivalent one-hot spec. ``False`` masks an + outcome and ``True`` leaves the outcome unmasked. If all of the + possible outcomes are masked, then an error is raised when a + sample is taken. + + Examples: + >>> mask = torch.tensor([True, False, True]) + >>> ts = DiscreteTensorSpec(3, (10,), dtype=torch.int64, mask=mask) + >>> # One of the three possible outcomes is masked + >>> ts.rand() + tensor([0, 2, 2, 0, 2, 0, 2, 2, 0, 2]) + """ if mask is not None: try: mask = mask.expand(_remove_neg_shapes(*self.shape, self.space.n)) @@ -3315,6 +3381,8 @@ class MultiDiscreteTensorSpec(DiscreteTensorSpec): dtype (str or torch.dtype, optional): dtype of the tensors. remove_singleton (bool, optional): if ``True``, singleton samples (of size [1]) will be squeezed. Defaults to ``True``. + mask (torch.Tensor or None): mask some of the possible outcomes when a + sample is taken. See :meth:`~.update_mask` for more information. Examples: >>> ts = MultiDiscreteTensorSpec((3, 2, 3)) @@ -3361,6 +3429,32 @@ def __init__( self.remove_singleton = remove_singleton def update_mask(self, mask): + """Sets a mask to prevent some of the possible outcomes when a sample is taken. + + The mask can also be set during initialization of the spec. + + Args: + mask (torch.Tensor or None): boolean mask. If None, the mask is + disabled. Otherwise, the shape of the mask must be expandable to + the shape of the equivalent one-hot spec. ``False`` masks an + outcome and ``True`` leaves the outcome unmasked. If all of the + possible outcomes are masked, then an error is raised when a + sample is taken. + + Examples: + >>> mask = torch.tensor([False, False, True, + ... True, True]) + >>> ts = MultiDiscreteTensorSpec((3, 2), (5, 2,), dtype=torch.int64, mask=mask) + >>> # All but one of the three possible outcomes for the first + >>> # group are masked, but neither of the two possible + >>> # outcomes for the second group are masked. + >>> ts.rand() + tensor([[2, 1], + [2, 0], + [2, 1], + [2, 1], + [2, 0]]) + """ if mask is not None: try: mask = mask.expand(_remove_neg_shapes(*self.shape[:-1], mask.shape[-1])) @@ -4006,7 +4100,7 @@ def __repr__(self) -> str: indent(f"{k}: {str(item)}", 4 * " ") for k, item in self._specs.items() ] sub_str = ",\n".join(sub_str) - return f"CompositeSpec(\n{sub_str}, device={self._device}, shape={self.shape})" + return f"CompositeSpec(\n{sub_str},\n device={self._device},\n shape={self.shape})" def type_check( self, diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 8475979a3ba..ced185d7e00 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -5,6 +5,7 @@ from .batched_envs import ParallelEnv, SerialEnv from .common import EnvBase, EnvMetaData, make_tensordict +from .custom import PendulumEnv, TicTacToeEnv from .env_creator import EnvCreator, get_env_metadata from .gym_like import default_info_dict_reader, GymLikeEnv from .libs import ( @@ -50,6 +51,7 @@ CenterCrop, ClipTransform, Compose, + Crop, DeviceCastTransform, DiscreteActionProjection, DoubleToFloat, diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index e30de3534d9..b9216b58e86 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -1515,7 +1515,8 @@ def _complete_done( shape = (*leading_dim, *item.shape) if val is not None: if val.shape != shape: - data.set(key, val.reshape(shape)) + val = val.reshape(shape) + data.set(key, val) vals[key] = val if len(vals) < i + 1: @@ -1535,6 +1536,7 @@ def _complete_done( "Cannot infer the value of terminated when only done and truncated are present." ) data.set("terminated", val) + data_keys.add("terminated") elif ( key == "terminated" and val is not None @@ -1542,11 +1544,10 @@ def _complete_done( and "done" not in data_keys ): if "truncated" in data_keys: - done = val | data.get("truncated") - data.set("done", done) - else: - data.set("done", val) - elif val is None: + val = val | data.get("truncated") + data.set("done", val) + data_keys.add("done") + elif val is None and key not in data_keys: # we must keep this here: we only want to fill with 0s if we're sure # done should not be copied to terminated or terminated to done # in this case, just fill with 0s @@ -2354,10 +2355,13 @@ def rollout( break_when_any_done (bool): breaks if any of the done state is True. If False, a reset() is called on the sub-envs that are done. Default is True. return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is True. - tensordict (TensorDict, optional): if auto_reset is False, an initial + tensordict (TensorDict, optional): if ``auto_reset`` is False, an initial tensordict must be provided. Rollout will check if this tensordict has done flags and reset the - environment in those dimensions (if needed). This normally should not occur if ``tensordict`` is the - output of a reset, but can occur if ``tensordict`` is the last step of a previous rollout. + environment in those dimensions (if needed). + This normally should not occur if ``tensordict`` is the output of a reset, but can occur + if ``tensordict`` is the last step of a previous rollout. + A ``tensordict`` can also be provided when ``auto_reset=True`` if metadata need to be passed + to the ``reset`` method, such as a batch-size or a device for stateless environments. set_truncated (bool, optional): if ``True``, ``"truncated"`` and ``"done"`` keys will be set to ``True`` after completion of the rollout. If no ``"truncated"`` is found within the ``done_spec``, an exception is raised. @@ -2564,11 +2568,7 @@ def rollout( env_device = self.device if auto_reset: - if tensordict is not None: - raise RuntimeError( - "tensordict cannot be provided when auto_reset is True" - ) - tensordict = self.reset() + tensordict = self.reset(tensordict) elif tensordict is None: raise RuntimeError("tensordict must be provided when auto_reset is False") else: diff --git a/torchrl/envs/custom/__init__.py b/torchrl/envs/custom/__init__.py new file mode 100644 index 00000000000..8649d3d3e97 --- /dev/null +++ b/torchrl/envs/custom/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .pendulum import PendulumEnv +from .tictactoeenv import TicTacToeEnv diff --git a/torchrl/envs/custom/pendulum.py b/torchrl/envs/custom/pendulum.py new file mode 100644 index 00000000000..8253e3df9b7 --- /dev/null +++ b/torchrl/envs/custom/pendulum.py @@ -0,0 +1,323 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import numpy as np + +import torch +from tensordict import TensorDict, TensorDictBase +from torchrl.data.tensor_specs import ( + BoundedTensorSpec, + CompositeSpec, + UnboundedContinuousTensorSpec, +) +from torchrl.envs.common import EnvBase +from torchrl.envs.utils import make_composite_from_td + + +class PendulumEnv(EnvBase): + """A stateless Pendulum environment. + + See the Pendulum tutorial for more details: :ref:`tutorial `. + + Specs: + CompositeSpec( + output_spec: CompositeSpec( + full_observation_spec: CompositeSpec( + th: BoundedTensorSpec( + shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True)), + dtype=torch.float32, + domain=continuous), + thdot: BoundedTensorSpec( + shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True)), + dtype=torch.float32, + domain=continuous), + params: CompositeSpec( + max_speed: UnboundedContinuousTensorSpec( + shape=torch.Size([]), + dtype=torch.int64, + domain=discrete), + max_torque: UnboundedContinuousTensorSpec( + shape=torch.Size([]), + dtype=torch.float32, + domain=continuous), + dt: UnboundedContinuousTensorSpec( + shape=torch.Size([]), + dtype=torch.float32, + domain=continuous), + g: UnboundedContinuousTensorSpec( + shape=torch.Size([]), + dtype=torch.float32, + domain=continuous), + m: UnboundedContinuousTensorSpec( + shape=torch.Size([]), + dtype=torch.float32, + domain=continuous), + l: UnboundedContinuousTensorSpec( + shape=torch.Size([]), + dtype=torch.float32, + domain=continuous), + shape=torch.Size([])), + shape=torch.Size([])), + full_reward_spec: CompositeSpec( + reward: UnboundedContinuousTensorSpec( + shape=torch.Size([1]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([1]), dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([1]), dtype=torch.float32, contiguous=True)), + dtype=torch.float32, + domain=continuous), + shape=torch.Size([])), + full_done_spec: CompositeSpec( + done: DiscreteTensorSpec( + shape=torch.Size([1]), + space=DiscreteBox(n=2), + dtype=torch.bool, + domain=discrete), + terminated: DiscreteTensorSpec( + shape=torch.Size([1]), + space=DiscreteBox(n=2), + dtype=torch.bool, + domain=discrete), + shape=torch.Size([])), + shape=torch.Size([])), + input_spec: CompositeSpec( + full_state_spec: CompositeSpec( + th: BoundedTensorSpec( + shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True)), + dtype=torch.float32, + domain=continuous), + thdot: BoundedTensorSpec( + shape=torch.Size([]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([]), dtype=torch.float32, contiguous=True)), + dtype=torch.float32, + domain=continuous), + params: CompositeSpec( + max_speed: UnboundedContinuousTensorSpec( + shape=torch.Size([]), + dtype=torch.int64, + domain=discrete), + max_torque: UnboundedContinuousTensorSpec( + shape=torch.Size([]), + dtype=torch.float32, + domain=continuous), + dt: UnboundedContinuousTensorSpec( + shape=torch.Size([]), + dtype=torch.float32, + domain=continuous), + g: UnboundedContinuousTensorSpec( + shape=torch.Size([]), + dtype=torch.float32, + domain=continuous), + m: UnboundedContinuousTensorSpec( + shape=torch.Size([]), + dtype=torch.float32, + domain=continuous), + l: UnboundedContinuousTensorSpec( + shape=torch.Size([]), + dtype=torch.float32, + domain=continuous), + shape=torch.Size([])), + shape=torch.Size([])), + full_action_spec: CompositeSpec( + action: BoundedTensorSpec( + shape=torch.Size([1]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([1]), dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([1]), dtype=torch.float32, contiguous=True)), + dtype=torch.float32, + domain=continuous), + shape=torch.Size([])), + shape=torch.Size([])), + shape=torch.Size([])) + + """ + + DEFAULT_X = np.pi + DEFAULT_Y = 1.0 + + metadata = { + "render_modes": ["human", "rgb_array"], + "render_fps": 30, + } + batch_locked = False + + def __init__(self, td_params=None, seed=None, device=None): + if td_params is None: + td_params = self.gen_params() + + super().__init__(device=device) + self._make_spec(td_params) + if seed is None: + seed = torch.empty((), dtype=torch.int64).random_().item() + self.set_seed(seed) + + @classmethod + def _step(cls, tensordict): + th, thdot = tensordict["th"], tensordict["thdot"] # th := theta + + g_force = tensordict["params", "g"] + mass = tensordict["params", "m"] + length = tensordict["params", "l"] + dt = tensordict["params", "dt"] + u = tensordict["action"].squeeze(-1) + u = u.clamp( + -tensordict["params", "max_torque"], tensordict["params", "max_torque"] + ) + costs = cls.angle_normalize(th) ** 2 + 0.1 * thdot**2 + 0.001 * (u**2) + + new_thdot = ( + thdot + + (3 * g_force / (2 * length) * th.sin() + 3.0 / (mass * length**2) * u) + * dt + ) + new_thdot = new_thdot.clamp( + -tensordict["params", "max_speed"], tensordict["params", "max_speed"] + ) + new_th = th + new_thdot * dt + reward = -costs.view(*tensordict.shape, 1) + done = torch.zeros_like(reward, dtype=torch.bool) + out = TensorDict( + { + "th": new_th, + "thdot": new_thdot, + "params": tensordict["params"], + "reward": reward, + "done": done, + }, + tensordict.shape, + ) + return out + + def _reset(self, tensordict): + batch_size = ( + tensordict.batch_size if tensordict is not None else self.batch_size + ) + if tensordict is None or tensordict.is_empty(): + # if no ``tensordict`` is passed, we generate a single set of hyperparameters + # Otherwise, we assume that the input ``tensordict`` contains all the relevant + # parameters to get started. + tensordict = self.gen_params(batch_size=batch_size) + + high_th = torch.tensor(self.DEFAULT_X, device=self.device) + high_thdot = torch.tensor(self.DEFAULT_Y, device=self.device) + low_th = -high_th + low_thdot = -high_thdot + + # for non batch-locked environments, the input ``tensordict`` shape dictates the number + # of simulators run simultaneously. In other contexts, the initial + # random state's shape will depend upon the environment batch-size instead. + th = ( + torch.rand(tensordict.shape, generator=self.rng, device=self.device) + * (high_th - low_th) + + low_th + ) + thdot = ( + torch.rand(tensordict.shape, generator=self.rng, device=self.device) + * (high_thdot - low_thdot) + + low_thdot + ) + out = TensorDict( + { + "th": th, + "thdot": thdot, + "params": tensordict["params"], + }, + batch_size=batch_size, + ) + return out + + def _make_spec(self, td_params): + # Under the hood, this will populate self.output_spec["observation"] + self.observation_spec = CompositeSpec( + th=BoundedTensorSpec( + low=-torch.pi, + high=torch.pi, + shape=(), + dtype=torch.float32, + ), + thdot=BoundedTensorSpec( + low=-td_params["params", "max_speed"], + high=td_params["params", "max_speed"], + shape=(), + dtype=torch.float32, + ), + # we need to add the ``params`` to the observation specs, as we want + # to pass it at each step during a rollout + params=make_composite_from_td( + td_params["params"], unsqueeze_null_shapes=False + ), + shape=(), + ) + # since the environment is stateless, we expect the previous output as input. + # For this, ``EnvBase`` expects some state_spec to be available + self.state_spec = self.observation_spec.clone() + # action-spec will be automatically wrapped in input_spec when + # `self.action_spec = spec` will be called supported + self.action_spec = BoundedTensorSpec( + low=-td_params["params", "max_torque"], + high=td_params["params", "max_torque"], + shape=(1,), + dtype=torch.float32, + ) + self.reward_spec = UnboundedContinuousTensorSpec(shape=(*td_params.shape, 1)) + + def make_composite_from_td(td): + # custom function to convert a ``tensordict`` in a similar spec structure + # of unbounded values. + composite = CompositeSpec( + { + key: make_composite_from_td(tensor) + if isinstance(tensor, TensorDictBase) + else UnboundedContinuousTensorSpec( + dtype=tensor.dtype, device=tensor.device, shape=tensor.shape + ) + for key, tensor in td.items() + }, + shape=td.shape, + ) + return composite + + def _set_seed(self, seed: int): + rng = torch.manual_seed(seed) + self.rng = rng + + @staticmethod + def gen_params(g=10.0, batch_size=None) -> TensorDictBase: + """Returns a ``tensordict`` containing the physical parameters such as gravitational force and torque or speed limits.""" + if batch_size is None: + batch_size = [] + td = TensorDict( + { + "params": TensorDict( + { + "max_speed": 8, + "max_torque": 2.0, + "dt": 0.05, + "g": g, + "m": 1.0, + "l": 1.0, + }, + [], + ) + }, + [], + ) + if batch_size: + td = td.expand(batch_size).contiguous() + return td + + @staticmethod + def angle_normalize(x): + return ((x + torch.pi) % (2 * torch.pi)) - torch.pi diff --git a/torchrl/envs/custom/tictactoeenv.py b/torchrl/envs/custom/tictactoeenv.py new file mode 100644 index 00000000000..79ea3b2dfb6 --- /dev/null +++ b/torchrl/envs/custom/tictactoeenv.py @@ -0,0 +1,299 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from typing import Optional + +import torch +from tensordict import TensorDict, TensorDictBase + +from torchrl.data.tensor_specs import ( + CompositeSpec, + DiscreteTensorSpec, + UnboundedContinuousTensorSpec, + UnboundedDiscreteTensorSpec, +) +from torchrl.envs.common import EnvBase + + +class TicTacToeEnv(EnvBase): + """A Tic-Tac-Toe implementation. + + Keyword Args: + single_player (bool, optional): whether one or two players have to be + accounted for. ``single_player=True`` means that ``"player1"`` is + playing randomly. If ``False`` (default), at each turn, + one of the two players has to play. + device (torch.device, optional): the device where to put the tensors. + Defaults to ``None`` (default device). + + The environment is stateless. To run it across multiple batches, call + + >>> env.reset(TensorDict(batch_size=desired_batch_size)) + + If the ``"mask"`` entry is present, ``rand_action`` takes it into account to + generate the next action. Any policy executed on this env should take this + mask into account, as well as the turn of the player (stored in the ``"turn"`` + output entry). + + Specs: + CompositeSpec( + output_spec: CompositeSpec( + full_observation_spec: CompositeSpec( + board: DiscreteTensorSpec( + shape=torch.Size([3, 3]), + space=DiscreteBox(n=2), + dtype=torch.int32, + domain=discrete), + turn: DiscreteTensorSpec( + shape=torch.Size([1]), + space=DiscreteBox(n=2), + dtype=torch.int32, + domain=discrete), + mask: DiscreteTensorSpec( + shape=torch.Size([9]), + space=DiscreteBox(n=2), + dtype=torch.bool, + domain=discrete), + shape=torch.Size([])), + full_reward_spec: CompositeSpec( + player0: CompositeSpec( + reward: UnboundedContinuousTensorSpec( + shape=torch.Size([1]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)), + dtype=torch.float32, + domain=continuous), + shape=torch.Size([])), + player1: CompositeSpec( + reward: UnboundedContinuousTensorSpec( + shape=torch.Size([1]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)), + dtype=torch.float32, + domain=continuous), + shape=torch.Size([])), + shape=torch.Size([])), + full_done_spec: CompositeSpec( + done: DiscreteTensorSpec( + shape=torch.Size([1]), + space=DiscreteBox(n=2), + dtype=torch.bool, + domain=discrete), + terminated: DiscreteTensorSpec( + shape=torch.Size([1]), + space=DiscreteBox(n=2), + dtype=torch.bool, + domain=discrete), + truncated: DiscreteTensorSpec( + shape=torch.Size([1]), + space=DiscreteBox(n=2), + dtype=torch.bool, + domain=discrete), + shape=torch.Size([])), + shape=torch.Size([])), + input_spec: CompositeSpec( + full_state_spec: CompositeSpec( + board: DiscreteTensorSpec( + shape=torch.Size([3, 3]), + space=DiscreteBox(n=2), + dtype=torch.int32, + domain=discrete), + turn: DiscreteTensorSpec( + shape=torch.Size([1]), + space=DiscreteBox(n=2), + dtype=torch.int32, + domain=discrete), + mask: DiscreteTensorSpec( + shape=torch.Size([9]), + space=DiscreteBox(n=2), + dtype=torch.bool, + domain=discrete), shape=torch.Size([])), + full_action_spec: CompositeSpec( + action: DiscreteTensorSpec( + shape=torch.Size([1]), + space=DiscreteBox(n=9), + dtype=torch.int64, + domain=discrete), + shape=torch.Size([])), + shape=torch.Size([])), + shape=torch.Size([])) + + To run a dummy rollout, execute the following command: + + Examples: + >>> env = TicTacToeEnv() + >>> env.rollout(10) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.int64, is_shared=False), + board: Tensor(shape=torch.Size([9, 3, 3]), device=cpu, dtype=torch.int32, is_shared=False), + done: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.bool, is_shared=False), + mask: Tensor(shape=torch.Size([9, 9]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + board: Tensor(shape=torch.Size([9, 3, 3]), device=cpu, dtype=torch.int32, is_shared=False), + done: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.bool, is_shared=False), + mask: Tensor(shape=torch.Size([9, 9]), device=cpu, dtype=torch.bool, is_shared=False), + player0: TensorDict( + fields={ + reward: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([9]), + device=None, + is_shared=False), + player1: TensorDict( + fields={ + reward: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([9]), + device=None, + is_shared=False), + terminated: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.bool, is_shared=False), + turn: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.int32, is_shared=False)}, + batch_size=torch.Size([9]), + device=None, + is_shared=False), + terminated: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.bool, is_shared=False), + turn: Tensor(shape=torch.Size([9, 1]), device=cpu, dtype=torch.int32, is_shared=False)}, + batch_size=torch.Size([9]), + device=None, + is_shared=False) + + """ + + # batch_locked is set to False since various batch sizes can be provided to the env + batch_locked: bool = False + + def __init__(self, *, single_player: bool = False, device=None): + super().__init__(device=device) + self.single_player = single_player + self.action_spec: UnboundedDiscreteTensorSpec = DiscreteTensorSpec( + n=9, + shape=(), + device=device, + ) + + self.full_observation_spec: CompositeSpec = CompositeSpec( + board=UnboundedContinuousTensorSpec( + shape=(3, 3), dtype=torch.int, device=device + ), + turn=DiscreteTensorSpec( + 2, + shape=(1,), + dtype=torch.int, + device=device, + ), + mask=DiscreteTensorSpec( + 2, + shape=(9,), + dtype=torch.bool, + device=device, + ), + device=device, + ) + self.state_spec: CompositeSpec = self.observation_spec.clone() + + self.reward_spec: UnboundedContinuousTensorSpec = CompositeSpec( + { + ("player0", "reward"): UnboundedContinuousTensorSpec( + shape=(1,), device=device + ), + ("player1", "reward"): UnboundedContinuousTensorSpec( + shape=(1,), device=device + ), + }, + device=device, + ) + + self.full_done_spec: DiscreteTensorSpec = CompositeSpec( + done=DiscreteTensorSpec(2, shape=(1,), dtype=torch.bool, device=device), + device=device, + ) + self.full_done_spec["terminated"] = self.full_done_spec["done"].clone() + self.full_done_spec["truncated"] = self.full_done_spec["done"].clone() + + def _reset(self, reset_td: TensorDict) -> TensorDict: + shape = reset_td.shape if reset_td is not None else () + state = self.state_spec.zero(shape) + state["board"] -= 1 + state["mask"].fill_(True) + return state.update(self.full_done_spec.zero(shape)) + + def _step(self, state: TensorDict) -> TensorDict: + board = state["board"].clone() + turn = state["turn"].clone() + action = state["action"] + board.flatten(-2, -1).scatter_(index=action.unsqueeze(-1), dim=-1, value=1) + wins = self.win(state["board"], action) + + mask = board.flatten(-2, -1) == -1 + done = wins | ~mask.any(-1, keepdim=True) + terminated = done.clone() + + reward_0 = wins & (turn == 0) + reward_1 = wins & (turn == 1) + + state = TensorDict( + { + "done": done, + "terminated": terminated, + ("player0", "reward"): reward_0.float(), + ("player1", "reward"): reward_1.float(), + "board": torch.where(board == -1, board, 1 - board), + "turn": 1 - state["turn"], + "mask": mask, + }, + batch_size=state.batch_size, + ) + if self.single_player: + select = (~done & (turn == 0)).squeeze(-1) + if select.all(): + state_select = state + elif select.any(): + state_select = state[select] + else: + return state + state_select = self._step(self.rand_action(state_select)) + if select.all(): + return state_select + return torch.where(done, state, state_select) + return state + + def _set_seed(self, seed: int | None): + ... + + @staticmethod + def win(board: torch.Tensor, action: torch.Tensor): + row = action // 3 # type: ignore + col = action % 3 # type: ignore + return ( + board[..., row, :].sum() + == 3 | board[..., col].sum() + == 3 | board.diagonal(0, -2, -1).sum() + == 3 | board.flip(-1).diagonal(0, -2, -1).sum() + == 3 + ) + + @staticmethod + def full(board: torch.Tensor) -> bool: + return torch.sym_int(board.abs().sum()) == 9 + + @staticmethod + def get_action_mask(): + pass + + def rand_action(self, tensordict: Optional[TensorDictBase] = None): + mask = tensordict.get("mask") + action_spec = self.action_spec + if tensordict.ndim: + action_spec = action_spec.expand(tensordict.shape) + else: + action_spec = action_spec.clone() + action_spec.update_mask(mask) + tensordict.set(self.action_key, action_spec.rand()) + return tensordict diff --git a/torchrl/envs/libs/brax.py b/torchrl/envs/libs/brax.py index c86ba9a543c..ac4cd71ddad 100644 --- a/torchrl/envs/libs/brax.py +++ b/torchrl/envs/libs/brax.py @@ -80,8 +80,10 @@ class BraxWrapper(_EnvWrapper): Examples: >>> import brax.envs >>> from torchrl.envs import BraxWrapper + >>> import torch + >>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> base_env = brax.envs.get_environment("ant") - >>> env = BraxWrapper(base_env) + >>> env = BraxWrapper(base_env, device=device) >>> env.set_seed(0) >>> td = env.reset() >>> td["action"] = env.action_spec.rand() @@ -111,7 +113,9 @@ class BraxWrapper(_EnvWrapper): and report the execution time for a short rollout: Examples: + >>> import torch >>> from torch.utils.benchmark import Timer + >>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> for batch_size in [4, 16, 128]: ... timer = Timer(''' ... env.rollout(100) @@ -119,7 +123,7 @@ class BraxWrapper(_EnvWrapper): ... setup=f''' ... import brax.envs ... from torchrl.envs import BraxWrapper - ... env = BraxWrapper(brax.envs.get_environment("ant"), batch_size=[{batch_size}]) + ... env = BraxWrapper(brax.envs.get_environment("ant"), batch_size=[{batch_size}], device="{device}") ... env.set_seed(0) ... env.rollout(2) ... ''') @@ -459,7 +463,9 @@ class BraxEnv(BraxWrapper): Examples: >>> from torchrl.envs import BraxEnv - >>> env = BraxEnv("ant") + >>> import torch + >>> device = "cuda" if torch.cuda.is_available() else "cpu" + >>> env = BraxEnv("ant", device=device) >>> env.set_seed(0) >>> td = env.reset() >>> td["action"] = env.action_spec.rand() @@ -489,13 +495,16 @@ class BraxEnv(BraxWrapper): and report the execution time for a short rollout: Examples: + >>> import torch + >>> from torch.utils.benchmark import Timer + >>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> for batch_size in [4, 16, 128]: ... timer = Timer(''' ... env.rollout(100) ... ''', ... setup=f''' ... from torchrl.envs import BraxEnv - ... env = BraxEnv("ant", batch_size=[{batch_size}]) + ... env = BraxEnv("ant", batch_size=[{batch_size}], device="{device}") ... env.set_seed(0) ... env.rollout(2) ... ''') diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 5204b8a19d8..64a25b94e37 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -20,6 +20,7 @@ CenterCrop, ClipTransform, Compose, + Crop, DeviceCastTransform, DiscreteActionProjection, DoubleToFloat, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 70aef03e041..1a66ee489a6 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1913,6 +1913,73 @@ def _reset( return tensordict_reset +class Crop(ObservationTransform): + """Crops the input image at the specified location and output size. + + Args: + w (int): resulting width + h (int, optional): resulting height. If None, then w is used (square crop). + top (int, optional): top pixel coordinate to start cropping. Default is 0, i.e. top of the image. + left (int, optional): left pixel coordinate to start cropping. Default is 0, i.e. left of the image. + in_keys (sequence of NestedKey, optional): the entries to crop. If none is provided, + ``["pixels"]`` is assumed. + out_keys (sequence of NestedKey, optional): the cropped images keys. If none is + provided, ``in_keys`` is assumed. + + """ + + def __init__( + self, + w: int, + h: int = None, + top: int = 0, + left: int = 0, + in_keys: Sequence[NestedKey] | None = None, + out_keys: Sequence[NestedKey] | None = None, + ): + if in_keys is None: + in_keys = IMAGE_KEYS # default + if out_keys is None: + out_keys = copy(in_keys) + super().__init__(in_keys=in_keys, out_keys=out_keys) + self.w = w + self.h = h if h else w + self.top = top + self.left = left + + def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: + from torchvision.transforms.functional import crop + + observation = crop(observation, self.top, self.left, self.w, self.h) + return observation + + def _reset( + self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase + ) -> TensorDictBase: + with _set_missing_tolerance(self, True): + tensordict_reset = self._call(tensordict_reset) + return tensordict_reset + + @_apply_to_composite + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + space = observation_spec.space + if isinstance(space, ContinuousBox): + space.low = self._apply_transform(space.low) + space.high = self._apply_transform(space.high) + observation_spec.shape = space.low.shape + else: + observation_spec.shape = self._apply_transform( + torch.zeros(observation_spec.shape) + ).shape + return observation_spec + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"w={float(self.w):4.4f}, h={float(self.h):4.4f}, top={float(self.top):4.4f}, left={float(self.left):4.4f}, " + ) + + class CenterCrop(ObservationTransform): """Crops the center of an image. @@ -6420,6 +6487,7 @@ class InitTracker(Transform): Args: init_key (NestedKey, optional): the key to be used for the tracker entry. + In case of multiple _reset flags, this key is used as the leaf replacement for each. Examples: >>> from torchrl.envs.libs.gym import GymEnv @@ -6433,11 +6501,12 @@ class InitTracker(Transform): """ - def __init__(self, init_key: NestedKey = "is_init"): + def __init__(self, init_key: str = "is_init"): if not isinstance(init_key, str): - raise ValueError("init_key can only be of type str.") + raise ValueError( + "init_key can only be of type str as it will be the leaf key associated to each reset flag." + ) self.init_key = init_key - self.reset_key = "_reset" super().__init__() def set_container(self, container: Union[Transform, EnvBase]) -> None: diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 61c210acffa..ee7649fabe4 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -869,11 +869,13 @@ def _sort_keys(element): return element -def make_composite_from_td(data): +def make_composite_from_td(data, unsqueeze_null_shapes: bool = True): """Creates a CompositeSpec instance from a tensordict, assuming all values are unbounded. Args: data (tensordict.TensorDict): a tensordict to be mapped onto a CompositeSpec. + unsqueeze_null_shapes (bool, optional): if ``True``, every empty shape will be + unsqueezed to (1,). Defaults to ``True``. Examples: >>> from tensordict import TensorDict @@ -905,7 +907,9 @@ def make_composite_from_td(data): else UnboundedContinuousTensorSpec( dtype=tensor.dtype, device=tensor.device, - shape=tensor.shape if tensor.shape else [1], + shape=tensor.shape + if tensor.shape or not unsqueeze_null_shapes + else [1], ) for key, tensor in data.items() }, diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 4a3c5e716e8..0a06e5844a0 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -11,6 +11,7 @@ IndependentNormal, MaskedCategorical, MaskedOneHotCategorical, + NormalParamExtractor, NormalParamWrapper, OneHotCategorical, ReparamGradientStrategy, @@ -32,6 +33,7 @@ MLP, MultiAgentConvNet, MultiAgentMLP, + MultiAgentNetBase, NoisyLazyLinear, NoisyLinear, ObsDecoder, @@ -51,6 +53,7 @@ ActorCriticOperator, ActorCriticWrapper, ActorValueOperator, + AdditiveGaussianModule, AdditiveGaussianWrapper, DecisionTransformerInferenceWrapper, DistributionalQValueActor, @@ -66,6 +69,7 @@ LSTMCell, LSTMModule, MultiStepActorWrapper, + OrnsteinUhlenbeckProcessModule, OrnsteinUhlenbeckProcessWrapper, ProbabilisticActor, QValueActor, diff --git a/torchrl/modules/distributions/__init__.py b/torchrl/modules/distributions/__init__.py index a3c5d0d4774..367765812bb 100644 --- a/torchrl/modules/distributions/__init__.py +++ b/torchrl/modules/distributions/__init__.py @@ -3,8 +3,9 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from tensordict.nn import NormalParamExtractor + from .continuous import ( - __all__ as _all_continuous, Delta, IndependentNormal, NormalParamWrapper, @@ -13,7 +14,6 @@ TruncatedNormal, ) from .discrete import ( - __all__ as _all_discrete, MaskedCategorical, MaskedOneHotCategorical, OneHotCategorical, @@ -21,6 +21,15 @@ ) distributions_maps = { - distribution_class.lower(): eval(distribution_class) - for distribution_class in _all_continuous + _all_discrete + str(dist).lower(): dist + for dist in ( + Delta, + IndependentNormal, + TanhDelta, + TanhNormal, + TruncatedNormal, + MaskedCategorical, + MaskedOneHotCategorical, + OneHotCategorical, + ) } diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index 38d8d1dfd02..fddc2f3415d 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -25,15 +25,6 @@ ) from torchrl.modules.utils import mappings -__all__ = [ - "NormalParamWrapper", - "TanhNormal", - "Delta", - "TanhDelta", - "TruncatedNormal", - "IndependentNormal", -] - # speeds up distribution construction D.Distribution.set_default_validate_args(False) @@ -153,6 +144,10 @@ def __init__( scale_mapping: str = "biased_softplus_1.0", scale_lb: Number = 1e-4, ) -> None: + warnings.warn( + "The NormalParamWrapper class will be deprecated in v0.7 in favor of :class:`~tensordict.nn.NormalParamExtractor`.", + category=DeprecationWarning, + ) super().__init__() self.operator = operator self.scale_mapping = scale_mapping @@ -759,7 +754,10 @@ def mean(self) -> torch.Tensor: raise AttributeError("TanhDelta mean has not analytical form.") -def uniform_sample_delta(dist: Delta, size=None) -> torch.Tensor: +def _uniform_sample_delta(dist: Delta, size=None) -> torch.Tensor: if size is None: size = torch.Size([]) return torch.randn_like(dist.sample(size)) + + +uniform_sample_delta = _uniform_sample_delta diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index 62ccf53c30a..9a814e35477 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -32,5 +32,11 @@ MLP, OnlineDTActor, ) -from .multiagent import MultiAgentConvNet, MultiAgentMLP, QMixer, VDNMixer +from .multiagent import ( + MultiAgentConvNet, + MultiAgentMLP, + MultiAgentNetBase, + QMixer, + VDNMixer, +) from .utils import Squeeze2dLayer, SqueezeLayer diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py index c44042388a5..6ccc4721678 100644 --- a/torchrl/modules/models/multiagent.py +++ b/torchrl/modules/models/multiagent.py @@ -5,6 +5,7 @@ from __future__ import annotations import abc +from copy import deepcopy from textwrap import indent from typing import Optional, Sequence, Tuple, Type, Union @@ -21,7 +22,13 @@ class MultiAgentNetBase(nn.Module): - """A base class for multi-agent networks.""" + """A base class for multi-agent networks. + + .. note:: to initialize the MARL module parameters with the `torch.nn.init` + module, please refer to :meth:`~.get_stateful_net` and :meth:`~.from_stateful_net` + methods. + + """ _empty_net: nn.Module @@ -63,8 +70,23 @@ def __init__( break self.initialized = initialized self._make_params(agent_networks) + # We make sure all params and buffers are on 'meta' device + # To do this, we set the device keyword arg to 'meta', we also temporarily change + # the default device. Finally, we convert all params to 'meta' tensors that are not params. kwargs["device"] = "meta" - self.__dict__["_empty_net"] = self._build_single_net(**kwargs) + with torch.device("meta"): + try: + self._empty_net = self._build_single_net(**kwargs) + except NotImplementedError as err: + if "Cannot copy out of meta tensor" in str(err): + raise RuntimeError( + "The network was built using `factory().to(device), build the network directly " + "on device using `factory(device=device)` instead." + ) + # Remove all parameters + TensorDict.from_module(self._empty_net).data.to("meta").to_module( + self._empty_net + ) @property def vmap_randomness(self): @@ -142,8 +164,83 @@ def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: return output + def get_stateful_net(self, copy: bool = True): + """Returns a stateful version of the network. + + This can be used to initialize parameters. + + Such networks will often not be callable out-of-the-box and will require a `vmap` call + to be executable. + + Args: + copy (bool, optional): if ``True``, a deepcopy of the network is made. + Defaults to ``True``. + + If the parameters are modified in-place (recommended) there is no need to copy the + parameters back into the MARL module. + See :meth:`~.from_stateful_net` for details on how to re-populate the MARL model with + parameters that have been re-initialized out-of-place. + + Examples: + >>> from torchrl.modules import MultiAgentMLP + >>> import torch + >>> n_agents = 6 + >>> n_agent_inputs=3 + >>> n_agent_outputs=2 + >>> batch = 64 + >>> obs = torch.zeros(batch, n_agents, n_agent_inputs) + >>> mlp = MultiAgentMLP( + ... n_agent_inputs=n_agent_inputs, + ... n_agent_outputs=n_agent_outputs, + ... n_agents=n_agents, + ... centralized=False, + ... share_params=False, + ... depth=2, + ... ) + >>> snet = mlp.get_stateful_net() + >>> def init(module): + ... if hasattr(module, "weight"): + ... torch.nn.init.kaiming_normal_(module.weight) + >>> snet.apply(init) + >>> # If the module has been updated out-of-place (not the case here) we can reset the params + >>> mlp.from_stateful_net(snet) + + """ + if copy: + try: + net = deepcopy(self._empty_net) + except RuntimeError as err: + raise RuntimeError( + "Failed to deepcopy the module, consider using copy=False." + ) from err + else: + net = self._empty_net + self.params.to_module(net) + return net + + def from_stateful_net(self, stateful_net: nn.Module): + """Populates the parameters given a stateful version of the network. + + See :meth:`~.get_stateful_net` for details on how to gather a stateful version of the network. + + Args: + stateful_net (nn.Module): the stateful network from which the params should be + gathered. + + """ + params = TensorDict.from_module(stateful_net, as_module=True) + keyset0 = set(params.keys(True, True)) + keyset1 = set(self.params.keys(True, True)) + if keyset0 != keyset1: + raise RuntimeError( + f"The keys of params and provided module differ: " + f"{keyset1-keyset0} are in self.params and not in the module, " + f"{keyset0-keyset1} are in the module but not in self.params." + ) + self.params.data.update_(params.data) + def __repr__(self): - empty_net = self.__dict__["_empty_net"] + empty_net = self._empty_net with self.params.to_module(empty_net): module_repr = indent(str(empty_net), 4 * " ") n_agents = indent(f"n_agents={self.n_agents}", 4 * " ") @@ -212,6 +309,10 @@ class MultiAgentMLP(MultiAgentNetBase): default: nn.Tanh. **kwargs: for :class:`torchrl.modules.models.MLP` can be passed to customize the MLPs. + .. note:: to initialize the MARL module parameters with the `torch.nn.init` + module, please refer to :meth:`~.get_stateful_net` and :meth:`~.from_stateful_net` + methods. + Examples: >>> from torchrl.modules import MultiAgentMLP >>> import torch @@ -219,8 +320,8 @@ class MultiAgentMLP(MultiAgentNetBase): >>> n_agent_inputs=3 >>> n_agent_outputs=2 >>> batch = 64 - >>> obs = torch.zeros(batch, n_agents, n_agent_inputs - First let's instantiate a local network shared by all agents (e.g. a parameter-shared policy) + >>> obs = torch.zeros(batch, n_agents, n_agent_inputs) + >>> # instantiate a local network shared by all agents (e.g. a parameter-shared policy) >>> mlp = MultiAgentMLP( ... n_agent_inputs=n_agent_inputs, ... n_agent_outputs=n_agent_outputs, @@ -357,6 +458,10 @@ class MultiAgentConvNet(MultiAgentNetBase): It expects inputs with shape ``(*B, n_agents, channels, x, y)``. + .. note:: to initialize the MARL module parameters with the `torch.nn.init` + module, please refer to :meth:`~.get_stateful_net` and :meth:`~.from_stateful_net` + methods. + Args: n_agents (int): number of agents. centralized (bool): If ``True``, each agent will use the inputs of all agents to compute its output, resulting in input of shape ``(*B, n_agents * channels, x, y)``. Otherwise, each agent will only use its data as input. @@ -388,7 +493,7 @@ class MultiAgentConvNet(MultiAgentNetBase): >>> n_agents = 7 >>> channels, x, y = 3, 100, 100 >>> obs = torch.randn(*batch, n_agents, channels, x, y) - >>> # First lets consider a centralized network with shared parameters. + >>> # Let's consider a centralized network with shared parameters. >>> cnn = MultiAgentConvNet( ... n_agents, ... centralized = True, diff --git a/torchrl/modules/tensordict_module/__init__.py b/torchrl/modules/tensordict_module/__init__.py index 98dfcf80f3b..202f84fd173 100644 --- a/torchrl/modules/tensordict_module/__init__.py +++ b/torchrl/modules/tensordict_module/__init__.py @@ -23,9 +23,11 @@ ) from .common import SafeModule, VmapModule from .exploration import ( + AdditiveGaussianModule, AdditiveGaussianWrapper, EGreedyModule, EGreedyWrapper, + OrnsteinUhlenbeckProcessModule, OrnsteinUhlenbeckProcessWrapper, ) from .probabilistic import ( diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 83b6a8d1fb3..81b7ec1e605 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -206,11 +206,11 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential): >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule >>> from torchrl.data import BoundedTensorSpec - >>> from torchrl.modules import ProbabilisticActor, NormalParamWrapper, TanhNormal + >>> from torchrl.modules import ProbabilisticActor, NormalParamExtractor, TanhNormal >>> td = TensorDict({"observation": torch.randn(3, 4)}, [3,]) >>> action_spec = BoundedTensorSpec(shape=torch.Size([4]), ... low=-1, high=1) - >>> module = NormalParamWrapper(torch.nn.Linear(4, 8)) + >>> module = nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor()) >>> tensordict_module = TensorDictModule(module, in_keys=["observation"], out_keys=["loc", "scale"]) >>> td_module = ProbabilisticActor( ... module=tensordict_module, @@ -1379,7 +1379,7 @@ class ActorValueOperator(SafeSequential): >>> import torch >>> from tensordict import TensorDict >>> from torchrl.modules import ProbabilisticActor, SafeModule - >>> from torchrl.modules import ValueOperator, TanhNormal, ActorValueOperator, NormalParamWrapper + >>> from torchrl.modules import ValueOperator, TanhNormal, ActorValueOperator, NormalParamExtractor >>> module_hidden = torch.nn.Linear(4, 4) >>> td_module_hidden = SafeModule( ... module=module_hidden, @@ -1387,7 +1387,7 @@ class ActorValueOperator(SafeSequential): ... out_keys=["hidden"], ... ) >>> module_action = TensorDictModule( - ... NormalParamWrapper(torch.nn.Linear(4, 8)), + ... nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor()), ... in_keys=["hidden"], ... out_keys=["loc", "scale"], ... ) @@ -1531,14 +1531,14 @@ class ActorCriticOperator(ActorValueOperator): >>> import torch >>> from tensordict import TensorDict >>> from torchrl.modules import ProbabilisticActor - >>> from torchrl.modules import ValueOperator, TanhNormal, ActorCriticOperator, NormalParamWrapper, MLP + >>> from torchrl.modules import ValueOperator, TanhNormal, ActorCriticOperator, NormalParamExtractor, MLP >>> module_hidden = torch.nn.Linear(4, 4) >>> td_module_hidden = SafeModule( ... module=module_hidden, ... in_keys=["observation"], ... out_keys=["hidden"], ... ) - >>> module_action = NormalParamWrapper(torch.nn.Linear(4, 8)) + >>> module_action = nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor()) >>> module_action = TensorDictModule(module_action, in_keys=["hidden"], out_keys=["loc", "scale"]) >>> td_module_action = ProbabilisticActor( ... module=module_action, @@ -1677,12 +1677,12 @@ class ActorCriticWrapper(SafeSequential): >>> from torchrl.modules import ( ... ActorCriticWrapper, ... ProbabilisticActor, - ... NormalParamWrapper, + ... NormalParamExtractor, ... TanhNormal, ... ValueOperator, ... ) >>> action_module = TensorDictModule( - ... NormalParamWrapper(torch.nn.Linear(4, 8)), + ... nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor()), ... in_keys=["observation"], ... out_keys=["loc", "scale"], ... ) diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index e8a1e94698d..5a41f11bf76 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -23,7 +23,9 @@ __all__ = [ "EGreedyWrapper", "EGreedyModule", + "AdditiveGaussianModule", "AdditiveGaussianWrapper", + "OrnsteinUhlenbeckProcessModule", "OrnsteinUhlenbeckProcessWrapper", ] @@ -299,6 +301,12 @@ def __init__( spec: Optional[TensorSpec] = None, safe: Optional[bool] = True, ): + warnings.warn( + "AdditiveGaussianWrapper is deprecated and will be removed " + "in v0.7. Please use torchrl.modules.AdditiveGaussianModule " + "instead.", + category=DeprecationWarning, + ) super().__init__(policy) if sigma_end > sigma_init: raise RuntimeError("sigma should decrease over time or be constant") @@ -382,6 +390,117 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: return tensordict +class AdditiveGaussianModule(TensorDictModuleBase): + """Additive Gaussian PO module. + + Args: + spec (TensorSpec): the spec used for sampling actions. The sampled + action will be projected onto the valid action space once explored. + sigma_init (scalar, optional): initial epsilon value. + default: 1.0 + sigma_end (scalar, optional): final epsilon value. + default: 0.1 + annealing_num_steps (int, optional): number of steps it will take for + sigma to reach the :obj:`sigma_end` value. + default: 1000 + mean (float, optional): mean of each output element’s normal distribution. + default: 0.0 + std (float, optional): standard deviation of each output element’s normal distribution. + default: 1.0 + + Keyword Args: + action_key (NestedKey, optional): if the policy module has more than one output key, + its output spec will be of type CompositeSpec. One needs to know where to + find the action spec. + default: "action" + + .. note:: + It is + crucial to incorporate a call to :meth:`~.step` in the training loop + to update the exploration factor. + Since it is not easy to capture this omission no warning or exception + will be raised if this is ommitted! + + + """ + + def __init__( + self, + spec: TensorSpec, + sigma_init: float = 1.0, + sigma_end: float = 0.1, + annealing_num_steps: int = 1000, + mean: float = 0.0, + std: float = 1.0, + *, + action_key: Optional[NestedKey] = "action", + ): + if not isinstance(sigma_init, float): + warnings.warn("eps_init should be a float.") + if sigma_end > sigma_init: + raise RuntimeError("sigma should decrease over time or be constant") + self.action_key = action_key + self.in_keys = [self.action_key] + self.out_keys = [self.action_key] + + super().__init__() + + self.register_buffer("sigma_init", torch.tensor([sigma_init])) + self.register_buffer("sigma_end", torch.tensor([sigma_end])) + self.annealing_num_steps = annealing_num_steps + self.register_buffer("mean", torch.tensor([mean])) + self.register_buffer("std", torch.tensor([std])) + self.register_buffer("sigma", torch.tensor([sigma_init], dtype=torch.float32)) + + if spec is not None: + if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: + spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1]) + else: + raise RuntimeError("spec cannot be None.") + self._spec = spec + self.register_forward_hook(_forward_hook_safe_action) + + @property + def spec(self): + return self._spec + + def step(self, frames: int = 1) -> None: + """A step of sigma decay. + + After `self.annealing_num_steps` calls to this method, calls result in no-op. + + Args: + frames (int): number of frames since last step. Defaults to ``1``. + + """ + for _ in range(frames): + self.sigma.data[0] = max( + self.sigma_end.item(), + ( + self.sigma + - (self.sigma_init - self.sigma_end) / self.annealing_num_steps + ).item(), + ) + + def _add_noise(self, action: torch.Tensor) -> torch.Tensor: + sigma = self.sigma.item() + noise = torch.normal( + mean=torch.ones(action.shape) * self.mean.item(), + std=torch.ones(action.shape) * self.std.item(), + ).to(action.device) + action = action + noise * sigma + spec = self.spec[self.action_key] + action = spec.project(action) + return action + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + if exploration_type() is ExplorationType.RANDOM or exploration_type() is None: + out = tensordict.get(self.action_key) + out = self._add_noise(out) + tensordict.set(self.action_key, out) + return tensordict + + class OrnsteinUhlenbeckProcessWrapper(TensorDictModuleWrapper): r"""Ornstein-Uhlenbeck exploration policy wrapper. @@ -491,6 +610,12 @@ def __init__( safe: bool = True, key: Optional[NestedKey] = None, ): + warnings.warn( + "OrnsteinUhlenbeckProcessWrapper is deprecated and will be removed " + "in v0.7. Please use torchrl.modules.OrnsteinUhlenbeckProcessModule " + "instead.", + category=DeprecationWarning, + ) if key is not None: action_key = key warnings.warn( @@ -593,6 +718,199 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: return tensordict +class OrnsteinUhlenbeckProcessModule(TensorDictModuleBase): + r"""Ornstein-Uhlenbeck exploration policy module. + + Presented in "CONTINUOUS CONTROL WITH DEEP REINFORCEMENT LEARNING", https://arxiv.org/pdf/1509.02971.pdf. + + The OU exploration is to be used with continuous control policies and introduces a auto-correlated exploration + noise. This enables a sort of 'structured' exploration. + + Noise equation: + + .. math:: + noise_t = noise_{t-1} + \theta * (mu - noise_{t-1}) * dt + \sigma_t * \sqrt{dt} * W + + Sigma equation: + + .. math:: + \sigma_t = max(\sigma^{min, (-(\sigma_{t-1} - \sigma^{min}) / (n^{\text{steps annealing}}) * n^{\text{steps}} + \sigma)) + + To keep track of the steps and noise from sample to sample, an :obj:`"ou_prev_noise{id}"` and :obj:`"ou_steps{id}"` keys + will be written in the input/output tensordict. It is expected that the tensordict will be zeroed at reset, + indicating that a new trajectory is being collected. If not, and is the same tensordict is used for consecutive + trajectories, the step count will keep on increasing across rollouts. Note that the collector classes take care of + zeroing the tensordict at reset time. + + .. note:: + It is + crucial to incorporate a call to :meth:`~.step` in the training loop + to update the exploration factor. + Since it is not easy to capture this omission no warning or exception + will be raised if this is ommitted! + + Args: + spec (TensorSpec): the spec used for sampling actions. The sampled + action will be projected onto the valid action space once explored. + eps_init (scalar): initial epsilon value, determining the amount of noise to be added. + default: 1.0 + eps_end (scalar): final epsilon value, determining the amount of noise to be added. + default: 0.1 + annealing_num_steps (int): number of steps it will take for epsilon to reach the eps_end value. + default: 1000 + theta (scalar): theta factor in the noise equation + default: 0.15 + mu (scalar): OU average (mu in the noise equation). + default: 0.0 + sigma (scalar): sigma value in the sigma equation. + default: 0.2 + dt (scalar): dt in the noise equation. + default: 0.01 + x0 (Tensor, ndarray, optional): initial value of the process. + default: 0.0 + sigma_min (number, optional): sigma_min in the sigma equation. + default: None + n_steps_annealing (int): number of steps for the sigma annealing. + default: 1000 + + Keyword Args: + action_key (NestedKey, optional): key of the action to be modified. + default: "action" + is_init_key (NestedKey, optional): key where to find the is_init flag used to reset the noise steps. + default: "is_init" + + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> from tensordict.nn import TensorDictSequential + >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.modules import OrnsteinUhlenbeckProcessModule, Actor + >>> torch.manual_seed(0) + >>> spec = BoundedTensorSpec(-1, 1, torch.Size([4])) + >>> module = torch.nn.Linear(4, 4, bias=False) + >>> policy = Actor(module=module, spec=spec) + >>> ou = OrnsteinUhlenbeckProcessModule(spec=spec) + >>> explorative_policy = TensorDictSequential(policy, ou) + >>> td = TensorDict({"observation": torch.zeros(10, 4)}, batch_size=[10]) + >>> print(explorative_policy(td)) + TensorDict( + fields={ + _ou_prev_noise: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False), + _ou_steps: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False), + action: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False), + observation: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([10]), + device=None, + is_shared=False) + """ + + def __init__( + self, + spec: TensorSpec, + eps_init: float = 1.0, + eps_end: float = 0.1, + annealing_num_steps: int = 1000, + theta: float = 0.15, + mu: float = 0.0, + sigma: float = 0.2, + dt: float = 1e-2, + x0: Optional[Union[torch.Tensor, np.ndarray]] = None, + sigma_min: Optional[float] = None, + n_steps_annealing: int = 1000, + *, + action_key: Optional[NestedKey] = "action", + is_init_key: Optional[NestedKey] = "is_init", + ): + super().__init__() + + self.ou = _OrnsteinUhlenbeckProcess( + theta=theta, + mu=mu, + sigma=sigma, + dt=dt, + x0=x0, + sigma_min=sigma_min, + n_steps_annealing=n_steps_annealing, + key=action_key, + ) + + self.register_buffer("eps_init", torch.tensor([eps_init])) + self.register_buffer("eps_end", torch.tensor([eps_end])) + if self.eps_end > self.eps_init: + raise ValueError( + "eps should decrease over time or be constant, " + f"got eps_init={eps_init} and eps_end={eps_end}" + ) + self.annealing_num_steps = annealing_num_steps + self.register_buffer("eps", torch.tensor([eps_init], dtype=torch.float32)) + + self.in_keys = [self.ou.key] + self.out_keys = [self.ou.key] + self.ou.out_keys + self.is_init_key = is_init_key + noise_key = self.ou.noise_key + steps_key = self.ou.steps_key + + if spec is not None: + if not isinstance(spec, CompositeSpec) and len(self.out_keys) >= 1: + spec = CompositeSpec({action_key: spec}, shape=spec.shape[:-1]) + self._spec = spec + else: + raise RuntimeError("spec cannot be None.") + ou_specs = { + noise_key: None, + steps_key: None, + } + self._spec.update(ou_specs) + if len(set(self.out_keys)) != len(self.out_keys): + raise RuntimeError(f"Got multiple identical output keys: {self.out_keys}") + self.register_forward_hook(_forward_hook_safe_action) + + @property + def spec(self): + return self._spec + + def step(self, frames: int = 1) -> None: + """Updates the eps noise factor. + + Args: + frames (int): number of frames of the current batch (corresponding to the number of updates to be made). + + """ + for _ in range(frames): + if self.annealing_num_steps > 0: + self.eps.data[0] = max( + self.eps_end.item(), + ( + self.eps + - (self.eps_init - self.eps_end) / self.annealing_num_steps + ).item(), + ) + else: + raise ValueError( + f"{self.__class__.__name__}.step() called when " + f"self.annealing_num_steps={self.annealing_num_steps}. Expected a strictly positive " + f"number of frames." + ) + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + if exploration_type() == ExplorationType.RANDOM or exploration_type() is None: + is_init = tensordict.get(self.is_init_key, None) + if is_init is None: + warnings.warn( + f"The tensordict passed to {self.__class__.__name__} appears to be " + f"missing the '{self.is_init_key}' entry. This entry is used to " + f"reset the noise at the beginning of a trajectory, without it " + f"the behaviour of this exploration method is undefined. " + f"This is allowed for BC compatibility purposes but it will be deprecated soon! " + f"To create a '{self.is_init_key}' entry, simply append an torchrl.envs.InitTracker " + f"transform to your environment with `env = TransformedEnv(env, InitTracker())`." + ) + tensordict = self.ou.add_sample( + tensordict, self.eps.item(), is_init=is_init + ) + return tensordict + + # Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab class _OrnsteinUhlenbeckProcess: def __init__( diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index f0290d6a42f..878fb13ebb8 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -381,6 +381,8 @@ class LSTMModule(ModuleBase): Methods: set_recurrent_mode: controls whether the module should be executed in recurrent mode. + make_tensordict_primer: creates the TensorDictPrimer transforms for the environment to be aware of the + recurrent states of the RNN. .. note:: This module relies on specific ``recurrent_state`` keys being present in the input TensorDicts. To generate a :class:`~torchrl.envs.transforms.TensorDictPrimer` transform that will automatically @@ -521,6 +523,45 @@ def __init__( self._recurrent_mode = False def make_tensordict_primer(self): + """Makes a tensordict primer for the environment. + + A :class:`~torchrl.envs.TensorDictPrimer` object will ensure that the policy is aware of the supplementary + inputs and outputs (recurrent states) during rollout execution. That way, the data can be shared across + processes and dealt with properly. + + Not including a ``TensorDictPrimer`` in the environment may result in poorly defined behaviours, for instance + in parallel settings where a step involves copying the new recurrent state from ``"next"`` to the root + tensordict, which the meth:`~torchrl.EnvBase.step_mdp` method will not be able to do as the recurrent states + are not registered within the environment specs. + + Examples: + >>> from torchrl.collectors import SyncDataCollector + >>> from torchrl.envs import TransformedEnv, InitTracker + >>> from torchrl.envs import GymEnv + >>> from torchrl.modules import MLP, LSTMModule + >>> from torch import nn + >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod + >>> + >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) + >>> lstm_module = LSTMModule( + ... input_size=env.observation_spec["observation"].shape[-1], + ... hidden_size=64, + ... in_keys=["observation", "rs_h", "rs_c"], + ... out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")]) + >>> mlp = MLP(num_cells=[64], out_features=1) + >>> policy = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) + >>> policy(env.reset()) + >>> env = env.append_transform(lstm_module.make_tensordict_primer()) + >>> data_collector = SyncDataCollector( + ... env, + ... policy, + ... frames_per_batch=10 + ... ) + >>> for data in data_collector: + ... print(data) + ... break + + """ from torchrl.envs.transforms.transforms import TensorDictPrimer def make_tuple(key): @@ -1065,6 +1106,8 @@ class GRUModule(ModuleBase): Methods: set_recurrent_mode: controls whether the module should be executed in recurrent mode. + make_tensordict_primer: creates the TensorDictPrimer transforms for the environment to be aware of the + recurrent states of the RNN. .. note:: This module relies on specific ``recurrent_state`` keys being present in the input TensorDicts. To generate a :class:`~torchrl.envs.transforms.TensorDictPrimer` transform that will automatically @@ -1230,6 +1273,45 @@ def __init__( self._recurrent_mode = False def make_tensordict_primer(self): + """Makes a tensordict primer for the environment. + + A :class:`~torchrl.envs.TensorDictPrimer` object will ensure that the policy is aware of the supplementary + inputs and outputs (recurrent states) during rollout execution. That way, the data can be shared across + processes and dealt with properly. + + Not including a ``TensorDictPrimer`` in the environment may result in poorly defined behaviours, for instance + in parallel settings where a step involves copying the new recurrent state from ``"next"`` to the root + tensordict, which the meth:`~torchrl.EnvBase.step_mdp` method will not be able to do as the recurrent states + are not registered within the environment specs. + + Examples: + >>> from torchrl.collectors import SyncDataCollector + >>> from torchrl.envs import TransformedEnv, InitTracker + >>> from torchrl.envs import GymEnv + >>> from torchrl.modules import MLP, LSTMModule + >>> from torch import nn + >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod + >>> + >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) + >>> gru_module = GRUModule( + ... input_size=env.observation_spec["observation"].shape[-1], + ... hidden_size=64, + ... in_keys=["observation", "rs"], + ... out_keys=["intermediate", ("next", "rs")]) + >>> mlp = MLP(num_cells=[64], out_features=1) + >>> policy = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) + >>> policy(env.reset()) + >>> env = env.append_transform(gru_module.make_tensordict_primer()) + >>> data_collector = SyncDataCollector( + ... env, + ... policy, + ... frames_per_batch=10 + ... ) + >>> for data in data_collector: + ... print(data) + ... break + + """ from torchrl.envs import TensorDictPrimer def make_tuple(key): diff --git a/torchrl/modules/tensordict_module/sequence.py b/torchrl/modules/tensordict_module/sequence.py index 28f721ba6a1..41ddb55fb35 100644 --- a/torchrl/modules/tensordict_module/sequence.py +++ b/torchrl/modules/tensordict_module/sequence.py @@ -34,11 +34,11 @@ class SafeSequential(TensorDictSequential, SafeModule): >>> import torch >>> from tensordict import TensorDict >>> from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec - >>> from torchrl.modules import TanhNormal, SafeSequential, TensorDictModule, NormalParamWrapper + >>> from torchrl.modules import TanhNormal, SafeSequential, TensorDictModule, NormalParamExtractor >>> from torchrl.modules.tensordict_module import SafeProbabilisticModule >>> td = TensorDict({"input": torch.randn(3, 4)}, [3,]) >>> spec1 = CompositeSpec(hidden=UnboundedContinuousTensorSpec(4), loc=None, scale=None) - >>> net1 = NormalParamWrapper(torch.nn.Linear(4, 8)) + >>> net1 = nn.Sequential(torch.nn.Linear(4, 8), NormalParamExtractor()) >>> module1 = TensorDictModule(net1, in_keys=["input"], out_keys=["loc", "scale"]) >>> td_module1 = SafeProbabilisticModule( ... module=module1, diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 1471cde5141..4a0948e1bca 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -62,7 +62,7 @@ class A2CLoss(LossModule): Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. - Defaults to ``False``, ie. gradients are propagated to shared + Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. advantage_key (str): [Deprecated, use set_keys(advantage_key=advantage_key) instead] The input tensordict key where the advantage is expected to be written. default: "advantage" @@ -97,14 +97,14 @@ class A2CLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.a2c import A2CLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, @@ -148,14 +148,14 @@ class A2CLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.a2c import A2CLoss >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index cbfc218327d..5ceec84e36a 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -8,10 +8,10 @@ import abc import functools import warnings +from copy import deepcopy from dataclasses import dataclass from typing import Iterator, List, Optional, Tuple -import torch from tensordict import is_tensor_collection, TensorDict, TensorDictBase from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams @@ -138,6 +138,67 @@ def __init__(self): self._tensor_keys = self._AcceptedKeys() self.register_forward_pre_hook(_updater_check_forward_prehook) + @property + def functional(self): + """Whether the module is functional. + + Unless it has been specifically designed not to be functional, all losses are functional. + """ + return True + + def get_stateful_net(self, network_name: str, copy: bool | None = None): + """Returns a stateful version of the network. + + This can be used to initialize parameters. + + Such networks will often not be callable out-of-the-box and will require a `vmap` call + to be executable. + + Args: + network_name (str): the network name to gather. + copy (bool, optional): if ``True``, a deepcopy of the network is made. + Defaults to ``True``. + + .. note:: if the module is not functional, no copy is made. + """ + net = getattr(self, network_name) + if not self.functional: + if copy is not None and copy: + raise RuntimeError("Cannot copy module in non-functional mode.") + return net + copy = True if copy is None else copy + if copy: + net = deepcopy(net) + params = getattr(self, network_name + "_params") + params.to_module(net) + return net + + def from_stateful_net(self, network_name: str, stateful_net: nn.Module): + """Populates the parameters of a model given a stateful version of the network. + + See :meth:`~.get_stateful_net` for details on how to gather a stateful version of the network. + + Args: + network_name (str): the network name to reset. + stateful_net (nn.Module): the stateful network from which the params should be + gathered. + + """ + if not self.functional: + getattr(self, network_name).load_state_dict(stateful_net.state_dict()) + return + params = TensorDict.from_module(stateful_net, as_module=True) + keyset0 = set(params.keys(True, True)) + self_params = getattr(self, network_name + "_params") + keyset1 = set(self_params.keys(True, True)) + if keyset0 != keyset1: + raise RuntimeError( + f"The keys of params and provided module differ: " + f"{keyset1-keyset0} are in self.params and not in the module, " + f"{keyset0-keyset1} are in the module but not in self.params." + ) + self_params.data.update_(params.data) + def _set_deprecated_ctor_keys(self, **kwargs) -> None: for key, value in kwargs.items(): if value is not None: @@ -255,57 +316,67 @@ def convert_to_functional( # Otherwise, casting the module to a device will keep old references # to uncast tensors sep = self.SEP - params = TensorDict.from_module(module, as_module=True) - - for key in params.keys(True): - if sep in key: - raise KeyError( - f"The key {key} contains the '_sep_' pattern which is prohibited. Consider renaming the parameter / buffer." + if isinstance(module, (list, tuple)): + if len(module) != expand_dim: + raise RuntimeError( + "The ``expand_dim`` value must match the length of the module list/tuple " + "if a single module isn't provided." ) - if compare_against is not None: - compare_against = set(compare_against) + params = TensorDict.from_modules( + *module, as_module=True, expand_identical=True + ) else: - compare_against = set() - if expand_dim: - # Expands the dims of params and buffers. - # If the param already exist in the module, we return a simple expansion of the - # original one. Otherwise, we expand and resample it. - # For buffers, a cloned expansion (or equivalently a repeat) is returned. - - def _compare_and_expand(param): - if is_tensor_collection(param): - return param._apply_nest( + params = TensorDict.from_module(module, as_module=True) + + for key in params.keys(True): + if sep in key: + raise KeyError( + f"The key {key} contains the '_sep_' pattern which is prohibited. Consider renaming the parameter / buffer." + ) + if compare_against is not None: + compare_against = set(compare_against) + else: + compare_against = set() + if expand_dim: + # Expands the dims of params and buffers. + # If the param already exist in the module, we return a simple expansion of the + # original one. Otherwise, we expand and resample it. + # For buffers, a cloned expansion (or equivalently a repeat) is returned. + + def _compare_and_expand(param): + if is_tensor_collection(param): + return param._apply_nest( + _compare_and_expand, + batch_size=[expand_dim, *param.shape], + filter_empty=False, + call_on_nested=True, + ) + if not isinstance(param, nn.Parameter): + buffer = param.expand(expand_dim, *param.shape).clone() + return buffer + if param in compare_against: + expanded_param = param.data.expand(expand_dim, *param.shape) + # the expanded parameter must be sent to device when to() + # is called: + return expanded_param + else: + p_out = param.expand(expand_dim, *param.shape).clone() + p_out = nn.Parameter( + p_out.uniform_( + p_out.min().item(), p_out.max().item() + ).requires_grad_() + ) + return p_out + + params = TensorDictParams( + params.apply( _compare_and_expand, - batch_size=[expand_dim, *param.shape], + batch_size=[expand_dim, *params.shape], filter_empty=False, call_on_nested=True, - ) - if not isinstance(param, nn.Parameter): - buffer = param.expand(expand_dim, *param.shape).clone() - return buffer - if param in compare_against: - expanded_param = param.data.expand(expand_dim, *param.shape) - # the expanded parameter must be sent to device when to() - # is called: - return expanded_param - else: - p_out = param.expand(expand_dim, *param.shape).clone() - p_out = nn.Parameter( - p_out.uniform_( - p_out.min().item(), p_out.max().item() - ).requires_grad_() - ) - return p_out - - params = TensorDictParams( - params.apply( - _compare_and_expand, - batch_size=[expand_dim, *params.shape], - filter_empty=False, - call_on_nested=True, - ), - no_convert=True, - ) + ), + no_convert=True, + ) param_name = module_name + "_params" @@ -468,17 +539,34 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams @property def vmap_randomness(self): + """Vmap random mode. + + The vmap randomness mode controls what :func:`~torch.vmap` should do when dealing with + functions with a random outcome such as :func:`~torch.randn` and :func:`~torch.rand`. + If `"error"`, any random function will raise an exception indicating that `vmap` does not + know how to handle the random call. + + If `"different"`, every element of the batch along which vmap is being called will + behave differently. If `"same"`, vmaps will copy the same result across all elements. + + ``vmap_randomness`` defaults to `"error"` if no random module is detected, and to `"different"` in + other cases. By default, only a limited number of modules are listed as random, but the list can be extended + using the :func:`~torchrl.objectives.common.add_random_module` function. + + This property supports setting its value. + + """ if self._vmap_randomness is None: - do_break = False - for val in self.__dict__.values(): - if isinstance(val, torch.nn.Module): - for module in val.modules(): - if isinstance(module, RANDOM_MODULE_LIST): - self._vmap_randomness = "different" - do_break = True - break - if do_break: - # double break + main_modules = list(self.__dict__.values()) + list(self.children()) + modules = ( + module + for main_module in main_modules + if isinstance(main_module, nn.Module) + for module in main_module.modules() + ) + for val in modules: + if isinstance(val, RANDOM_MODULE_LIST): + self._vmap_randomness = "different" break else: self._vmap_randomness = "error" @@ -486,7 +574,12 @@ def vmap_randomness(self): return self._vmap_randomness def set_vmap_randomness(self, value): + if value not in ("error", "same", "different"): + raise ValueError( + "Wrong vmap randomness, should be one of 'error', 'same' or 'different'." + ) self._vmap_randomness = value + self._make_vmap() @staticmethod def _make_meta_params(param): @@ -498,6 +591,12 @@ def _make_meta_params(param): pd = nn.Parameter(pd, requires_grad=False) return pd + def _make_vmap(self): + """Caches the the vmap callers to reduce the overhead at runtime.""" + raise NotImplementedError( + f"_make_vmap has been called but is not implemented for loss of type {type(self).__name__}." + ) + class _make_target_param: def __init__(self, clone): @@ -509,3 +608,9 @@ def __call__(self, x): x.data.clone() if self.clone else x.data, requires_grad=False ) return x.data.clone() if self.clone else x.data + + +def add_ramdom_module(module): + """Adds a random module to the list of modules that will be detected by :meth:`~torchrl.objectives.LossModule.vmap_randomness` as random.""" + global RANDOM_MODULE_LIST + RANDOM_MODULE_LIST = RANDOM_MODULE_LIST + (module,) diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 98283b24ff7..f1e2aa9c532 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -9,7 +9,7 @@ from copy import deepcopy from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -46,8 +46,15 @@ class CQLLoss(LossModule): Args: actor_network (ProbabilisticActor): stochastic actor - qvalue_network (TensorDictModule): Q(s, a) parametric model. + qvalue_network (TensorDictModule or list of TensorDictModule): Q(s, a) parametric model. This module typically outputs a ``"state_action_value"`` entry. + If a single instance of `qvalue_network` is provided, it will be duplicated ``N`` + times (where ``N=2`` for this loss). If a list of modules is passed, their + parameters will be stacked unless they share the same identity (in which case + the original parameter will be expanded). + + .. warning:: When a list of parameters if passed, it will __not__ be compared against the policy parameters + and all the parameters will be considered as untied. Keyword args: loss_function (str, optional): loss function to be used with @@ -94,14 +101,14 @@ class CQLLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.cql import CQLLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, @@ -154,14 +161,14 @@ class CQLLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.cql import CQLLoss >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, @@ -266,7 +273,7 @@ class _AcceptedKeys: def __init__( self, actor_network: ProbabilisticActor, - qvalue_network: TensorDictModule, + qvalue_network: TensorDictModule | List[TensorDictModule], *, loss_function: str = "smooth_l1", alpha_init: float = 1.0, @@ -366,14 +373,16 @@ def __init__( "log_alpha_prime", torch.nn.Parameter(torch.tensor(math.log(1.0), device=device)), ) + self._make_vmap() + self.reduction = reduction + def _make_vmap(self): self._vmap_qvalue_networkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) self._vmap_qvalue_network00 = _vmap_func( self.qvalue_network, randomness=self.vmap_randomness ) - self.reduction = reduction @property def target_entropy(self): diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 22d35bd5799..e76e3438c09 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -7,7 +7,7 @@ import math from dataclasses import dataclass from functools import wraps -from typing import Dict, Tuple, Union +from typing import Dict, List, Tuple, Union import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams @@ -54,6 +54,13 @@ class CrossQLoss(LossModule): actor_network (ProbabilisticActor): stochastic actor qvalue_network (TensorDictModule): Q(s, a) parametric model. This module typically outputs a ``"state_action_value"`` entry. + If a single instance of `qvalue_network` is provided, it will be duplicated ``num_qvalue_nets`` + times. If a list of modules is passed, their + parameters will be stacked unless they share the same identity (in which case + the original parameter will be expanded). + + .. warning:: When a list of parameters if passed, it will __not__ be compared against the policy parameters + and all the parameters will be considered as untied. Keyword Args: num_qvalue_nets (integer, optional): number of Q-Value networks used. @@ -81,7 +88,7 @@ class CrossQLoss(LossModule): priority (for prioritized replay buffer usage). Defaults to ``"td_error"``. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. - Defaults to ``False``, ie. gradients are propagated to shared + Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. reduction (str, optional): Specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, @@ -92,14 +99,14 @@ class CrossQLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.crossq import CrossQLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, @@ -150,14 +157,14 @@ class CrossQLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives import CrossQLoss >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, @@ -248,7 +255,7 @@ class _AcceptedKeys: def __init__( self, actor_network: ProbabilisticActor, - qvalue_network: TensorDictModule, + qvalue_network: TensorDictModule | List[TensorDictModule], *, num_qvalue_nets: int = 2, loss_function: str = "smooth_l1", @@ -331,10 +338,13 @@ def __init__( self._target_entropy = target_entropy self._action_spec = action_spec + self._make_vmap() + self.reduction = reduction + + def _make_vmap(self): self._vmap_qnetworkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) - self.reduction = reduction @property def target_entropy_buffer(self): diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index 5ffbeaf029b..6e1cf0f5eb3 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -40,7 +40,7 @@ class DDPGLoss(LossModule): data collection. Default is ``True``. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. - Defaults to ``False``, ie. gradients are propagated to shared + Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. reduction (str, optional): Specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index dd2ac615b58..c1ed8b2cffe 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -7,7 +7,7 @@ import math from dataclasses import dataclass from numbers import Number -from typing import Tuple, Union +from typing import List, Tuple, Union import numpy as np import torch @@ -41,6 +41,13 @@ class REDQLoss_deprecated(LossModule): actor_network (TensorDictModule): the actor to be trained qvalue_network (TensorDictModule): a single Q-value network that will be multiplied as many times as needed. + If a single instance of `qvalue_network` is provided, it will be duplicated ``num_qvalue_nets`` + times. If a list of modules is passed, their + parameters will be stacked unless they share the same identity (in which case + the original parameter will be expanded). + + .. warning:: When a list of parameters if passed, it will __not__ be compared against the policy parameters + and all the parameters will be considered as untied. Keyword Args: num_qvalue_nets (int, optional): Number of Q-value networks to be trained. @@ -75,7 +82,7 @@ class REDQLoss_deprecated(LossModule): ``"td_error"``. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. - Defaults to ``False``, ie. gradients are propagated to shared + Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. reduction (str, optional): Specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, @@ -134,7 +141,7 @@ class _AcceptedKeys: def __init__( self, actor_network: TensorDictModule, - qvalue_network: TensorDictModule, + qvalue_network: TensorDictModule | List[TensorDictModule], *, num_qvalue_nets: int = 10, sub_sample_len: int = 2, @@ -214,13 +221,15 @@ def __init__( self._action_spec = action_spec self.target_entropy_buffer = None self.gSDE = gSDE + self._make_vmap() self.reduction = reduction - self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0)) - if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) + def _make_vmap(self): + self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0)) + @property def target_entropy(self): target_entropy = self.target_entropy_buffer diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index a60d010d480..74cfe504e78 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -6,7 +6,7 @@ import warnings from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams @@ -37,6 +37,14 @@ class IQLLoss(LossModule): Args: actor_network (ProbabilisticActor): stochastic actor qvalue_network (TensorDictModule): Q(s, a) parametric model + If a single instance of `qvalue_network` is provided, it will be duplicated ``num_qvalue_nets`` + times. If a list of modules is passed, their + parameters will be stacked unless they share the same identity (in which case + the original parameter will be expanded). + + .. warning:: When a list of parameters if passed, it will __not__ be compared against the policy parameters + and all the parameters will be considered as untied. + value_network (TensorDictModule, optional): V(s) parametric model. Keyword Args: @@ -55,7 +63,7 @@ class IQLLoss(LossModule): buffer usage). Default is `"td_error"`. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. - Defaults to ``False``, ie. gradients are propagated to shared + Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. reduction (str, optional): Specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, @@ -66,14 +74,14 @@ class IQLLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.iql import IQLLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, @@ -129,14 +137,14 @@ class IQLLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.iql import IQLLoss >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, @@ -247,7 +255,7 @@ class _AcceptedKeys: def __init__( self, actor_network: ProbabilisticActor, - qvalue_network: TensorDictModule, + qvalue_network: TensorDictModule | List[TensorDictModule], value_network: Optional[TensorDictModule], *, num_qvalue_nets: int = 2, @@ -310,10 +318,13 @@ def __init__( self.loss_function = loss_function if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) + self._make_vmap() + self.reduction = reduction + + def _make_vmap(self): self._vmap_qvalue_networkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) - self.reduction = reduction @property def device(self) -> torch.device: @@ -548,7 +559,7 @@ class DiscreteIQLLoss(IQLLoss): buffer usage). Default is `"td_error"`. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. - Defaults to ``False``, ie. gradients are propagated to shared + Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. reduction (str, optional): Specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 08afc2a13f4..eb7f14f43c4 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -82,7 +82,7 @@ class PPOLoss(LossModule): before being used. Defaults to ``False``. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. - Defaults to ``False``, ie. gradients are propagated to shared + Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. advantage_key (str, optional): [Deprecated, use set_keys(advantage_key=advantage_key) instead] The input tensordict key where the advantage is @@ -152,7 +152,7 @@ class PPOLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data.tensor_specs import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.ppo import PPOLoss @@ -160,7 +160,7 @@ class PPOLoss(LossModule): >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> base_layer = nn.Linear(n_obs, 5) - >>> net = NormalParamWrapper(nn.Sequential(base_layer, nn.Linear(5, 2 * n_act))) + >>> net = nn.Sequential(base_layer, nn.Linear(5, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, @@ -205,14 +205,14 @@ class PPOLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data.tensor_specs import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.ppo import PPOLoss >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> base_layer = nn.Linear(n_obs, 5) - >>> net = NormalParamWrapper(nn.Sequential(base_layer, nn.Linear(5, 2 * n_act))) + >>> net = nn.Sequential(base_layer, nn.Linear(5, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, @@ -657,7 +657,7 @@ class ClipPPOLoss(PPOLoss): before being used. Defaults to ``False``. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. - Defaults to ``False``, ie. gradients are propagated to shared + Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. advantage_key (str, optional): [Deprecated, use set_keys(advantage_key=advantage_key) instead] The input tensordict key where the advantage is @@ -896,7 +896,7 @@ class KLPENPPOLoss(PPOLoss): before being used. Defaults to ``False``. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. - Defaults to ``False``, ie. gradients are propagated to shared + Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. advantage_key (str, optional): [Deprecated, use set_keys(advantage_key=advantage_key) instead] The input tensordict key where the advantage is diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index 00e5c24f08c..1522fd7749e 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -7,7 +7,7 @@ import math from dataclasses import dataclass from numbers import Number -from typing import Union +from typing import List, Union import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams @@ -41,8 +41,14 @@ class REDQLoss(LossModule): Args: actor_network (TensorDictModule): the actor to be trained - qvalue_network (TensorDictModule): a single Q-value network that will - be multiplicated as many times as needed. + qvalue_network (TensorDictModule): a single Q-value network or a list of Q-value networks. + If a single instance of `qvalue_network` is provided, it will be duplicated ``num_qvalue_nets`` + times. If a list of modules is passed, their + parameters will be stacked unless they share the same identity (in which case + the original parameter will be expanded). + + .. warning:: When a list of parameters if passed, it will __not__ be compared against the policy parameters + and all the parameters will be considered as untied. Keyword Args: num_qvalue_nets (int, optional): Number of Q-value networks to be trained. @@ -77,7 +83,7 @@ class REDQLoss(LossModule): ``"td_error"``. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. - Defaults to ``False``, ie. gradients are propagated to shared + Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. reduction (str, optional): Specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, @@ -88,14 +94,14 @@ class REDQLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.redq import REDQLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, @@ -150,13 +156,13 @@ class REDQLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.redq import REDQLoss >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, @@ -250,7 +256,7 @@ class _AcceptedKeys: def __init__( self, actor_network: TensorDictModule, - qvalue_network: TensorDictModule, + qvalue_network: TensorDictModule | List[TensorDictModule], *, num_qvalue_nets: int = 10, sub_sample_len: int = 2, @@ -330,7 +336,9 @@ def __init__( self.gSDE = gSDE if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) + self._make_vmap() + def _make_vmap(self): self._vmap_qvalue_network00 = _vmap_func( self.qvalue_network, randomness=self.vmap_randomness ) diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index f32bea50d7e..3d867b8cb99 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -56,7 +56,7 @@ class ReinforceLoss(LossModule): value is expected to be written. Defaults to ``"value_target"``. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. - Defaults to ``False``, ie. gradients are propagated to shared + Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. functional (bool, optional): whether modules should be functionalized. Functionalizing permits features like meta-RL, but makes it @@ -101,14 +101,14 @@ class ReinforceLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.reinforce import ReinforceLoss >>> from tensordict import TensorDict >>> n_obs, n_act = 3, 5 >>> value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=["observation"]) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor_net = ProbabilisticActor( ... module, @@ -147,13 +147,13 @@ class ReinforceLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.reinforce import ReinforceLoss >>> n_obs, n_act = 3, 5 >>> value_net = ValueOperator(nn.Linear(n_obs, 1), in_keys=["observation"]) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor_net = ProbabilisticActor( ... module, diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 51017384dbe..df444eac053 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -9,7 +9,7 @@ from dataclasses import dataclass from functools import wraps from numbers import Number -from typing import Dict, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -57,6 +57,14 @@ class SACLoss(LossModule): actor_network (ProbabilisticActor): stochastic actor qvalue_network (TensorDictModule): Q(s, a) parametric model. This module typically outputs a ``"state_action_value"`` entry. + If a single instance of `qvalue_network` is provided, it will be duplicated ``num_qvalue_nets`` + times. If a list of modules is passed, their + parameters will be stacked unless they share the same identity (in which case + the original parameter will be expanded). + + .. warning:: When a list of parameters if passed, it will __not__ be compared against the policy parameters + and all the parameters will be considered as untied. + value_network (TensorDictModule, optional): V(s) parametric model. This module typically outputs a ``"state_value"`` entry. @@ -64,6 +72,7 @@ class SACLoss(LossModule): If not provided, the second version of SAC is assumed, where only the Q-Value network is needed. + Keyword Args: num_qvalue_nets (integer, optional): number of Q-Value networks used. Defaults to ``2``. loss_function (str, optional): loss function to be used with @@ -98,7 +107,7 @@ class SACLoss(LossModule): priority (for prioritized replay buffer usage). Defaults to ``"td_error"``. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. - Defaults to ``False``, ie. gradients are propagated to shared + Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. reduction (str, optional): Specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, @@ -109,14 +118,14 @@ class SACLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.sac import SACLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, @@ -172,14 +181,14 @@ class SACLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.sac import SACLoss >>> _ = torch.manual_seed(42) >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) >>> actor = ProbabilisticActor( ... module=module, @@ -280,7 +289,7 @@ class _AcceptedKeys: def __init__( self, actor_network: ProbabilisticActor, - qvalue_network: TensorDictModule, + qvalue_network: TensorDictModule | List[TensorDictModule], value_network: Optional[TensorDictModule] = None, *, num_qvalue_nets: int = 2, @@ -394,14 +403,17 @@ def __init__( ) if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) + self._make_vmap() + self.reduction = reduction + + def _make_vmap(self): self._vmap_qnetworkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) if self._version == 1: self._vmap_qnetwork00 = _vmap_func( - qvalue_network, randomness=self.vmap_randomness + self.qvalue_network, randomness=self.vmap_randomness ) - self.reduction = reduction @property def target_entropy_buffer(self): @@ -830,7 +842,7 @@ class DiscreteSACLoss(LossModule): Default is `"td_error"`. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. - Defaults to ``False``, ie. gradients are propagated to shared + Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. reduction (str, optional): Specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, @@ -841,8 +853,7 @@ class DiscreteSACLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper - >>> from torchrl.modules.distributions.discrete import OneHotCategorical + >>> from torchrl.modules.distributions import NormalParamExtractor, OneHotCategorical >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.sac import DiscreteSACLoss @@ -899,14 +910,13 @@ class DiscreteSACLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper - >>> from torchrl.modules.distributions.discrete import OneHotCategorical + >>> from torchrl.modules.distributions import NormalParamExtractor, OneHotCategorical >>> from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.sac import DiscreteSACLoss >>> n_act, n_obs = 4, 3 >>> spec = OneHotDiscreteTensorSpec(n_act) - >>> net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + >>> net = nn.Sequential(nn.Linear(n_obs, 2 * n_act), NormalParamExtractor()) >>> module = SafeModule(net, in_keys=["observation"], out_keys=["logits"]) >>> actor = ProbabilisticActor( ... module=module, @@ -1092,10 +1102,13 @@ def __init__( self.register_buffer( "target_entropy", torch.tensor(target_entropy, device=device) ) + self._make_vmap() + self.reduction = reduction + + def _make_vmap(self): self._vmap_qnetworkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) - self.reduction = reduction def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index b569eb01345..eb1027ad936 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -5,7 +5,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Optional, Tuple +from typing import List, Optional, Tuple import torch @@ -34,8 +34,15 @@ class TD3Loss(LossModule): Args: actor_network (TensorDictModule): the actor to be trained - qvalue_network (TensorDictModule): a single Q-value network that will - be multiplicated as many times as needed. + qvalue_network (TensorDictModule): a single Q-value network or a list of + Q-value networks. + If a single instance of `qvalue_network` is provided, it will be duplicated ``num_qvalue_nets`` + times. If a list of modules is passed, their + parameters will be stacked unless they share the same identity (in which case + the original parameter will be expanded). + + .. warning:: When a list of parameters if passed, it will __not__ be compared against the policy parameters + and all the parameters will be considered as untied. Keyword Args: bounds (tuple of float, optional): the bounds of the action space. @@ -66,7 +73,7 @@ class TD3Loss(LossModule): the actor. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. - Defaults to ``False``, ie. gradients are propagated to shared + Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. reduction (str, optional): Specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, @@ -77,7 +84,7 @@ class TD3Loss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.td3 import TD3Loss @@ -218,7 +225,7 @@ class _AcceptedKeys: def __init__( self, actor_network: TensorDictModule, - qvalue_network: TensorDictModule, + qvalue_network: TensorDictModule | List[TensorDictModule], *, action_spec: TensorSpec = None, bounds: Optional[Tuple[float]] = None, @@ -310,13 +317,16 @@ def __init__( self.register_buffer("min_action", low) if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) + self._make_vmap() + self.reduction = reduction + + def _make_vmap(self): self._vmap_qvalue_network00 = _vmap_func( self.qvalue_network, randomness=self.vmap_randomness ) self._vmap_actor_network00 = _vmap_func( self.actor_network, randomness=self.vmap_randomness ) - self.reduction = reduction def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: diff --git a/torchrl/objectives/td3_bc.py b/torchrl/objectives/td3_bc.py index 93845bb00bd..aa87ea9aa1a 100644 --- a/torchrl/objectives/td3_bc.py +++ b/torchrl/objectives/td3_bc.py @@ -5,7 +5,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Optional, Tuple +from typing import List, Optional, Tuple import torch @@ -43,8 +43,15 @@ class TD3BCLoss(LossModule): Args: actor_network (TensorDictModule): the actor to be trained - qvalue_network (TensorDictModule): a single Q-value network that will - be multiplicated as many times as needed. + qvalue_network (TensorDictModule): a single Q-value network or a list of + Q-value networks. + If a single instance of `qvalue_network` is provided, it will be duplicated ``num_qvalue_nets`` + times. If a list of modules is passed, their + parameters will be stacked unless they share the same identity (in which case + the original parameter will be expanded). + + .. warning:: When a list of parameters if passed, it will __not__ be compared against the policy parameters + and all the parameters will be considered as untied. Keyword Args: bounds (tuple of float, optional): the bounds of the action space. @@ -77,7 +84,7 @@ class TD3BCLoss(LossModule): the actor. separate_losses (bool, optional): if ``True``, shared parameters between policy and critic will only be trained on the policy loss. - Defaults to ``False``, ie. gradients are propagated to shared + Defaults to ``False``, i.e., gradients are propagated to shared parameters for both policy and critic losses. reduction (str, optional): Specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, @@ -88,7 +95,7 @@ class TD3BCLoss(LossModule): >>> import torch >>> from torch import nn >>> from torchrl.data import BoundedTensorSpec - >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal >>> from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule >>> from torchrl.objectives.td3_bc import TD3BCLoss @@ -233,7 +240,7 @@ class _AcceptedKeys: def __init__( self, actor_network: TensorDictModule, - qvalue_network: TensorDictModule, + qvalue_network: TensorDictModule | List[TensorDictModule], *, action_spec: TensorSpec = None, bounds: Optional[Tuple[float]] = None, @@ -324,13 +331,16 @@ def __init__( high = high.to(device) self.register_buffer("max_action", high) self.register_buffer("min_action", low) + self._make_vmap() + self.reduction = reduction + + def _make_vmap(self): self._vmap_qvalue_network00 = _vmap_func( self.qvalue_network, randomness=self.vmap_randomness ) self._vmap_actor_network00 = _vmap_func( self.actor_network, randomness=self.vmap_randomness ) - self.reduction = reduction def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index b977a3440dd..b7db2e8242e 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -502,7 +502,7 @@ class TD0Estimator(ValueEstimatorBase): skip_existing (bool, optional): if ``True``, the value network will skip modules which outputs are already present in the tensordict. - Defaults to ``None``, ie. the value of :func:`tensordict.nn.skip_existing()` + Defaults to ``None``, i.e., the value of :func:`tensordict.nn.skip_existing()` is not affected. advantage_key (str or tuple of str, optional): [Deprecated] the key of the advantage entry. Defaults to ``"advantage"``. @@ -701,7 +701,7 @@ class TD1Estimator(ValueEstimatorBase): skip_existing (bool, optional): if ``True``, the value network will skip modules which outputs are already present in the tensordict. - Defaults to ``None``, ie. the value of :func:`tensordict.nn.skip_existing()` + Defaults to ``None``, i.e., the value of :func:`tensordict.nn.skip_existing()` is not affected. advantage_key (str or tuple of str, optional): [Deprecated] the key of the advantage entry. Defaults to ``"advantage"``. @@ -922,7 +922,7 @@ class TDLambdaEstimator(ValueEstimatorBase): lambda return. Default is `True`. skip_existing (bool, optional): if ``True``, the value network will skip modules which outputs are already present in the tensordict. - Defaults to ``None``, ie. the value of :func:`tensordict.nn.skip_existing()` + Defaults to ``None``, i.e., the value of :func:`tensordict.nn.skip_existing()` is not affected. advantage_key (str or tuple of str, optional): [Deprecated] the key of the advantage entry. Defaults to ``"advantage"``. @@ -1164,7 +1164,7 @@ class GAE(ValueEstimatorBase): lambda return. Default is `True`. skip_existing (bool, optional): if ``True``, the value network will skip modules which outputs are already present in the tensordict. - Defaults to ``None``, ie. the value of :func:`tensordict.nn.skip_existing()` + Defaults to ``None``, i.e., the value of :func:`tensordict.nn.skip_existing()` is not affected. Defaults to "state_value". advantage_key (str or tuple of str, optional): [Deprecated] the key of @@ -1476,7 +1476,7 @@ class VTrace(ValueEstimatorBase): pass detached parameters for functional modules. skip_existing (bool, optional): if ``True``, the value network will skip modules which outputs are already present in the tensordict. - Defaults to ``None``, ie. the value of :func:`tensordict.nn.skip_existing()` + Defaults to ``None``, i.e., the value of :func:`tensordict.nn.skip_existing()` is not affected. Defaults to "state_value". advantage_key (str or tuple of str, optional): [Deprecated] the key of diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index c52cb3bd5b2..1bf7fd57e83 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -11,7 +11,7 @@ # Overview # -------- # -# TorchRL separates the training of RL sota-implementations in various pieces that will be +# TorchRL separates the training of RL algorithms in various pieces that will be # assembled in your training script: the environment, the data collection and # storage, the model and finally the loss function. # @@ -167,7 +167,7 @@ # the losses without it. However, we encourage its usage for the following # reason. # -# The reason TorchRL does this is that RL sota-implementations often execute the same +# The reason TorchRL does this is that RL algorithms often execute the same # model with different sets of parameters, called "trainable" and "target" # parameters. # The "trainable" parameters are those that the optimizer needs to fit. The @@ -188,7 +188,7 @@ # Later, we will see how the target parameters should be updated in TorchRL. # -from tensordict.nn import TensorDictModule +from tensordict.nn import TensorDictModule, TensorDictSequential def _init( @@ -272,7 +272,7 @@ def make_value_estimator(self, value_type: ValueEstimators, **hyperparams): ############################################################################### -# The ``make_value_estimator`` method can but does not need to be called: ifgg +# The ``make_value_estimator`` method can but does not need to be called: if # not, the :class:`~torchrl.objectives.LossModule` will query this method with # its default estimator. # @@ -406,7 +406,7 @@ class DDPGLoss(LossModule): # Environment # ----------- # -# In most sota-implementations, the first thing that needs to be taken care of is the +# In most algorithms, the first thing that needs to be taken care of is the # construction of the environment as it conditions the remainder of the # training script. # @@ -722,7 +722,7 @@ def get_env_stats(): ActorCriticWrapper, DdpgMlpActor, DdpgMlpQNet, - OrnsteinUhlenbeckProcessWrapper, + OrnsteinUhlenbeckProcessModule, ProbabilisticActor, TanhDelta, ValueOperator, @@ -781,15 +781,18 @@ def make_ddpg_actor( # Exploration # ~~~~~~~~~~~ # -# The policy is wrapped in a :class:`~torchrl.modules.OrnsteinUhlenbeckProcessWrapper` +# The policy is passed into a :class:`~torchrl.modules.OrnsteinUhlenbeckProcessModule` # exploration module, as suggested in the original paper. # Let's define the number of frames before OU noise reaches its minimum value annealing_frames = 1_000_000 -actor_model_explore = OrnsteinUhlenbeckProcessWrapper( +actor_model_explore = TensorDictSequential( actor, - annealing_num_steps=annealing_frames, -).to(device) + OrnsteinUhlenbeckProcessModule( + spec=actor.spec.clone(), + annealing_num_steps=annealing_frames, + ).to(device), +) if device == torch.device("cpu"): actor_model_explore.share_memory() @@ -1058,7 +1061,7 @@ def ceil_div(x, y): # Target network updater # ~~~~~~~~~~~~~~~~~~~~~~ # -# Target networks are a crucial part of off-policy RL sota-implementations. +# Target networks are a crucial part of off-policy RL algorithms. # Updating the target network parameters is made easy thanks to the # :class:`~torchrl.objectives.HardUpdate` and :class:`~torchrl.objectives.SoftUpdate` # classes. They're built with the loss module as argument, and the update is @@ -1173,7 +1176,7 @@ def ceil_div(x, y): ) # update the exploration strategy - actor_model_explore.step(current_frames) + actor_model_explore[1].step(current_frames) collector.shutdown() del collector diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index 3b9d712736a..e9f2085d3df 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -42,7 +42,7 @@ # estimated return; # - how to collect data from your environment efficiently and store them # in a replay buffer; -# - how to use multi-step, a simple preprocessing step for off-policy sota-implementations; +# - how to use multi-step, a simple preprocessing step for off-policy algorithms; # - and finally how to evaluate your model. # # **Prerequisites**: We encourage you to get familiar with torchrl through the @@ -365,7 +365,7 @@ def make_model(dummy_env): # Replay buffers # ~~~~~~~~~~~~~~ # -# Replay buffers play a central role in off-policy RL sota-implementations such as DQN. +# Replay buffers play a central role in off-policy RL algorithms such as DQN. # They constitute the dataset we will be sampling from during training. # # Here, we will use a regular sampling strategy, although a prioritized RB @@ -471,13 +471,13 @@ def get_collector( # Target parameters # ~~~~~~~~~~~~~~~~~ # -# Many off-policy RL sota-implementations use the concept of "target parameters" when it +# Many off-policy RL algorithms use the concept of "target parameters" when it # comes to estimate the value of the next state or state-action pair. # The target parameters are lagged copies of the model parameters. Because # their predictions mismatch those of the current model configuration, they # help learning by putting a pessimistic bound on the value being estimated. # This is a powerful trick (known as "Double Q-Learning") that is ubiquitous -# in similar sota-implementations. +# in similar algorithms. # diff --git a/tutorials/sphinx-tutorials/coding_ppo.py b/tutorials/sphinx-tutorials/coding_ppo.py index 40f71798a99..51229e1880d 100644 --- a/tutorials/sphinx-tutorials/coding_ppo.py +++ b/tutorials/sphinx-tutorials/coding_ppo.py @@ -518,7 +518,7 @@ # Replay buffer # ------------- # -# Replay buffers are a common building piece of off-policy RL sota-implementations. +# Replay buffers are a common building piece of off-policy RL algorithms. # In on-policy contexts, a replay buffer is refilled every time a batch of # data is collected, and its data is repeatedly consumed for a certain number # of epochs. diff --git a/tutorials/sphinx-tutorials/getting-started-1.py b/tutorials/sphinx-tutorials/getting-started-1.py index fb33d520860..437cae26c42 100644 --- a/tutorials/sphinx-tutorials/getting-started-1.py +++ b/tutorials/sphinx-tutorials/getting-started-1.py @@ -117,7 +117,7 @@ # Probabilistic policies # ---------------------- # -# Policy-optimization sota-implementations like +# Policy-optimization algorithms like # `PPO `_ require the policy to be # stochastic: unlike in the examples above, the module now encodes a map from # the observation space to a parameter space encoding a distribution over the @@ -161,7 +161,7 @@ # # - Since we asked for it during the construction of the actor, the # log-probability of the actions given the distribution at that time is -# also written. This is necessary for sota-implementations like PPO. +# also written. This is necessary for algorithms like PPO. # - The parameters of the distribution are returned within the output # tensordict too under the ``"loc"`` and ``"scale"`` entries. # @@ -191,8 +191,8 @@ # also palliate to this with its exploration modules. # We will take the example of the :class:`~torchrl.modules.EGreedyModule` # exploration module (check also -# :class:`~torchrl.modules.AdditiveGaussianWrapper` and -# :class:`~torchrl.modules.OrnsteinUhlenbeckProcessWrapper`). +# :class:`~torchrl.modules.AdditiveGaussianModule` and +# :class:`~torchrl.modules.OrnsteinUhlenbeckProcessModule`). # To see this module in action, let's revert to a deterministic policy: from tensordict.nn import TensorDictSequential diff --git a/tutorials/sphinx-tutorials/getting-started-2.py b/tutorials/sphinx-tutorials/getting-started-2.py index 22154cf4726..84fefc8197a 100644 --- a/tutorials/sphinx-tutorials/getting-started-2.py +++ b/tutorials/sphinx-tutorials/getting-started-2.py @@ -39,9 +39,9 @@ # ---------------------- # # In RL, innovation typically involves the exploration of novel methods -# for optimizing a policy (i.e., new sota-implementations), rather than focusing +# for optimizing a policy (i.e., new algorithms), rather than focusing # on new architectures, as seen in other domains. Within TorchRL, -# these sota-implementations are encapsulated within loss modules. A loss +# these algorithms are encapsulated within loss modules. A loss # module orchestrates the various components of your algorithm and # yields a set of loss values that can be backpropagated # through to train the corresponding components. @@ -145,7 +145,7 @@ # ----------------------------------------- # # Another important aspect to consider is the presence of target parameters -# in off-policy sota-implementations like DDPG. Target parameters typically represent +# in off-policy algorithms like DDPG. Target parameters typically represent # a delayed or smoothed version of the parameters over time, and they play # a crucial role in value estimation during policy training. Utilizing target # parameters for policy training often proves to be significantly more diff --git a/tutorials/sphinx-tutorials/getting-started-3.py b/tutorials/sphinx-tutorials/getting-started-3.py index ad6f6525a7c..3bd5c6ea5c3 100644 --- a/tutorials/sphinx-tutorials/getting-started-3.py +++ b/tutorials/sphinx-tutorials/getting-started-3.py @@ -29,7 +29,7 @@ # dataloaders are referred to as ``DataCollectors``. Most of the time, # data collection does not stop at the collection of raw data, # as the data needs to be stored temporarily in a buffer -# (or equivalent structure for on-policy sota-implementations) before being consumed +# (or equivalent structure for on-policy algorithms) before being consumed # by the :ref:`loss module `. This tutorial will explore # these two classes. # @@ -93,7 +93,7 @@ ################################# # Data collectors are very useful when it comes to coding state-of-the-art -# sota-implementations, as performance is usually measured by the capability of a +# algorithms, as performance is usually measured by the capability of a # specific technique to solve a problem in a given number of interactions with # the environment (the ``total_frames`` argument in the collector). # For this reason, most training loops in our examples look like this: diff --git a/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py b/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py index b4bc38eb7bf..77574b765e7 100644 --- a/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py +++ b/tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py @@ -125,7 +125,7 @@ ) from torchrl.modules import ( - AdditiveGaussianWrapper, + AdditiveGaussianModule, MultiAgentMLP, ProbabilisticActor, TanhDelta, @@ -499,7 +499,7 @@ # Since the DDPG policy is deterministic, we need a way to perform exploration during collection. # # For this purpose, we need to append an exploration layer to our policies before passing them to the collector. -# In this case we use a :class:`~torchrl.modules.AdditiveGaussianWrapper`, which adds gaussian noise to our action +# In this case we use a :class:`~torchrl.modules.AdditiveGaussianModule`, which adds gaussian noise to our action # (and clamps it if the noise makes the action out of bounds). # # This exploration wrapper uses a ``sigma`` parameter which is multiplied by the noise to determine its magnitude. @@ -510,13 +510,16 @@ exploration_policies = {} for group, _agents in env.group_map.items(): - exploration_policy = AdditiveGaussianWrapper( + exploration_policy = TensorDictSequential( policies[group], - annealing_num_steps=total_frames - // 2, # Number of frames after which sigma is sigma_end - action_key=(group, "action"), - sigma_init=0.9, # Initial value of the sigma - sigma_end=0.1, # Final value of the sigma + AdditiveGaussianModule( + spec=policies[group].spec, + annealing_num_steps=total_frames + // 2, # Number of frames after which sigma is sigma_end + action_key=(group, "action"), + sigma_init=0.9, # Initial value of the sigma + sigma_end=0.1, # Final value of the sigma + ), ) exploration_policies[group] = exploration_policy @@ -648,7 +651,7 @@ # Replay buffer # ------------- # -# Replay buffers are a common building piece of off-policy RL sota-implementations. +# Replay buffers are a common building piece of off-policy RL algorithms. # There are many types of buffers, in this tutorial we use a basic buffer to store and sample tensordict # data randomly. # @@ -922,7 +925,7 @@ def process_batch(batch: TensorDictBase) -> TensorDictBase: # # Now that you are proficient with multi-agent DDPG, you can check out all the TorchRL multi-agent implementations in the # GitHub repository. -# These are code-only scripts of many MARL sota-implementations such as the ones seen in this tutorial, +# These are code-only scripts of many MARL algorithms such as the ones seen in this tutorial, # QMIX, MADDPG, IQL, and many more! # # Also do remember to check out our tutorial: :doc:`/tutorials/multiagent_ppo`. diff --git a/tutorials/sphinx-tutorials/multiagent_ppo.py b/tutorials/sphinx-tutorials/multiagent_ppo.py index b163a5df64f..d7d906a4fb0 100644 --- a/tutorials/sphinx-tutorials/multiagent_ppo.py +++ b/tutorials/sphinx-tutorials/multiagent_ppo.py @@ -55,7 +55,7 @@ # the foundational policy-optimization algorithm. For more information, see the # `Proximal Policy Optimization Algorithms `_ paper. # -# This type of sota-implementations is usually trained *on-policy*. This means that, at every learning iteration, we have a +# This type of algorithms is usually trained *on-policy*. This means that, at every learning iteration, we have a # **sampling** and a **training** phase. In the **sampling** phase of iteration :math:`t`, rollouts are collected # form agents' interactions in the environment using the current policies :math:`\mathbf{\pi}_t`. # In the **training** phase, all the collected rollouts are immediately fed to the training process to perform @@ -551,7 +551,7 @@ # Replay buffer # ------------- # -# Replay buffers are a common building piece of off-policy RL sota-implementations. +# Replay buffers are a common building piece of off-policy RL algorithms. # In on-policy contexts, a replay buffer is refilled every time a batch of # data is collected, and its data is repeatedly consumed for a certain number # of epochs. @@ -780,7 +780,7 @@ # # Now that you are proficient with multi-agent DDPG, you can check out all the TorchRL multi-agent implementations in the # GitHub repository. -# These are code-only scripts of many popular MARL sota-implementations such as the ones seen in this tutorial, +# These are code-only scripts of many popular MARL algorithms such as the ones seen in this tutorial, # QMIX, MADDPG, IQL, and many more! # # You can also check out our other multi-agent tutorial on how to train competitive diff --git a/tutorials/sphinx-tutorials/pendulum.py b/tutorials/sphinx-tutorials/pendulum.py index 4eda4ea8e91..d25bc2cdd8a 100644 --- a/tutorials/sphinx-tutorials/pendulum.py +++ b/tutorials/sphinx-tutorials/pendulum.py @@ -73,6 +73,8 @@ # simulation graph. # * Finally, we will train a simple policy to solve the system we implemented. # +# A built-in version of this environment can be found in class:`~torchrl.envs.PendulumEnv`. +# # sphinx_gallery_start_ignore import warnings diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index 62419cbb3ef..9d25da0a4cd 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -162,13 +162,13 @@ # │ └── "trainers.py" # └── "version.py" # -# Unlike other domains, RL is less about media than *sota-implementations*. As such, it +# Unlike other domains, RL is less about media than *algorithms*. As such, it # is harder to make truly independent components. # # What TorchRL is not: # -# * a collection of sota-implementations: we do not intend to provide SOTA implementations of RL sota-implementations, -# but we provide these sota-implementations only as examples of how to use the library. +# * a collection of algorithms: we do not intend to provide SOTA implementations of RL algorithms, +# but we provide these algorithms only as examples of how to use the library. # # * a research framework: modularity in TorchRL comes in two flavours. First, we try # to build re-usable components, such that they can be easily swapped with each other. @@ -605,10 +605,12 @@ def exec_sequence(params, data): ############################################################################### # Probabilistic modules -from torchrl.modules import NormalParamWrapper, TanhNormal +from torchrl.modules import NormalParamExtractor, TanhNormal td = TensorDict({"input": torch.randn(3, 5)}, [3]) -net = NormalParamWrapper(nn.Linear(5, 4)) # splits the output in loc and scale +net = nn.Sequential( + nn.Linear(5, 4), NormalParamExtractor() +) # splits the output in loc and scale module = TensorDictModule(net, in_keys=["input"], out_keys=["loc", "scale"]) td_module = ProbabilisticTensorDictSequential( module,