diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index 62cf1dedf35..2d6a6344970 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -57,16 +57,21 @@ projected (in a L1-manner) into the desired domain. SafeSequential TanhModule -Exploration wrappers -~~~~~~~~~~~~~~~~~~~~ +Exploration wrappers and modules +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -To efficiently explore the environment, TorchRL proposes a series of wrappers +To efficiently explore the environment, TorchRL proposes a series of modules that will override the action sampled by the policy by a noisier version. Their behavior is controlled by :func:`~torchrl.envs.utils.exploration_mode`: if the exploration is set to ``"random"``, the exploration is active. In all other cases, the action written in the tensordict is simply the network output. -.. currentmodule:: torchrl.modules.tensordict_module +.. note:: Unlike other exploration modules, :class:`~torchrl.modules.ConsistentDropoutModule` + uses the ``train``/``eval`` mode to comply with the regular `Dropout` API in PyTorch. + The :func:`~torchrl.envs.utils.set_exploration_mode` context manager will have no effect on + this module. + +.. currentmodule:: torchrl.modules .. autosummary:: :toctree: generated/ @@ -74,6 +79,7 @@ other cases, the action written in the tensordict is simply the network output. AdditiveGaussianModule AdditiveGaussianWrapper + ConsistentDropoutModule EGreedyModule EGreedyWrapper OrnsteinUhlenbeckProcessModule @@ -438,12 +444,13 @@ Regular modules :toctree: generated/ :template: rl_template_noinherit.rst - MLP - ConvNet + BatchRenorm1d + ConsistentDropout Conv3dNet - SqueezeLayer + ConvNet + MLP Squeeze2dLayer - BatchRenorm1d + SqueezeLayer Algorithm-specific modules ~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/test/test_exploration.py b/test/test_exploration.py index 3bb05708d83..f6a3ab7041b 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -31,7 +31,7 @@ NormalParamExtractor, TanhNormal, ) -from torchrl.modules.models.exploration import LazygSDEModule +from torchrl.modules.models.exploration import ConsistentDropoutModule, LazygSDEModule from torchrl.modules.tensordict_module.actors import ( Actor, ProbabilisticActor, @@ -738,6 +738,156 @@ def test_gsde_init(sigma_init, state_dim, action_dim, mean, std, device, learn_s ), f"failed: mean={mean}, std={std}, sigma_init={sigma_init}, actual: {sigma.mean()}" +class TestConsistentDropout: + @pytest.mark.parametrize("dropout_p", [0.0, 0.1, 0.5]) + @pytest.mark.parametrize("parallel_spec", [False, True]) + @pytest.mark.parametrize("device", get_default_devices()) + def test_consistent_dropout(self, dropout_p, parallel_spec, device): + """ + + This preliminary test seeks to ensure two things for both + ConsistentDropout and ConsistentDropoutModule: + 1. Rollout transitions generate a dropout mask as desired. + - We can easily verify the existence of a mask + 2. The dropout mask is correctly applied. + - We will check with stochastic policies whether or not + the loc and scale are the same. + """ + torch.manual_seed(0) + + # NOTE: Please only put a module with one dropout layer. + # That's how this test is constructed anyways. + @torch.no_grad + def inner_verify_routine(module, env): + # Perform transitions. + collector = SyncDataCollector( + create_env_fn=env, + policy=module, + frames_per_batch=1, + total_frames=10, + device=device, + ) + for frames in collector: + masks = [ + (key, value) + for key, value in frames.items() + if key.startswith("mask_") + ] + # Assert rollouts do indeed correctly generate the masks. + assert len(masks) == 1, ( + "Expected exactly ONE mask since we only put " + f"one dropout module, got {len(masks)}." + ) + + # Verify that the result for this batch is the same. + # Kind of Monte Carlo, to be honest. + sentinel_mask = masks[0][1].clone() + sentinel_outputs = frames.select("loc", "scale").clone() + + desired_dropout_mask = torch.full_like( + sentinel_mask, 1 / (1 - dropout_p) + ) + desired_dropout_mask[sentinel_mask == 0.0] = 0.0 + # As of 15/08/24, :meth:`~torch.nn.functional.dropout` + # is being used. Never hurts to be safe. + assert torch.allclose( + sentinel_mask, desired_dropout_mask + ), "Dropout was not scaled properly." + + new_frames = module(frames.clone()) + infer_mask = new_frames[masks[0][0]] + infer_outputs = new_frames.select("loc", "scale") + assert (infer_mask == sentinel_mask).all(), "Mask does not match" + + assert all( + [ + torch.allclose(infer_outputs[key], sentinel_outputs[key]) + for key in ("loc", "scale") + ] + ), ( + "Outputs do not match:\n " + f"{infer_outputs['loc']}\n--- vs ---\n{sentinel_outputs['loc']}" + f"{infer_outputs['scale']}\n--- vs ---\n{sentinel_outputs['scale']}" + ) + + env = SerialEnv( + 2, + ContinuousActionVecMockEnv, + ) + env = TransformedEnv(env.to(device), InitTracker()) + env = env.to(device) + # the module must work with the action spec of a single env or a serial env + if parallel_spec: + action_spec = env.action_spec + else: + action_spec = ContinuousActionVecMockEnv(device=device).action_spec + d_act = action_spec.shape[-1] + + # NOTE: Please only put a module with one dropout layer. + # That's how this test is constructed anyways. + module_td_seq = TensorDictSequential( + TensorDictModule( + nn.LazyLinear(2 * d_act), in_keys=["observation"], out_keys=["out"] + ), + ConsistentDropoutModule(p=dropout_p, in_keys="out"), + TensorDictModule( + NormalParamExtractor(), in_keys=["out"], out_keys=["loc", "scale"] + ), + ) + + policy_td_seq = ProbabilisticActor( + module=module_td_seq, + in_keys=["loc", "scale"], + distribution_class=TanhNormal, + default_interaction_type=InteractionType.RANDOM, + spec=action_spec, + ).to(device) + + # Wake up the policies + policy_td_seq(env.reset()) + + # Test. + inner_verify_routine(policy_td_seq, env) + + def test_consistent_dropout_primer(self): + import torch + + from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq + from torchrl.envs import SerialEnv, StepCounter + from torchrl.modules import ConsistentDropoutModule, get_primers_from_module + + torch.manual_seed(0) + + m = Seq( + Mod( + torch.nn.Linear(7, 4), + in_keys=["observation"], + out_keys=["intermediate"], + ), + ConsistentDropoutModule( + p=0.5, + input_shape=( + 2, + 4, + ), + in_keys="intermediate", + ), + Mod(torch.nn.Linear(4, 7), in_keys=["intermediate"], out_keys=["action"]), + ) + primer = get_primers_from_module(m) + env0 = ContinuousActionVecMockEnv().append_transform(StepCounter(5)) + env1 = ContinuousActionVecMockEnv().append_transform(StepCounter(6)) + env = SerialEnv(2, [lambda env=env0: env, lambda env=env1: env]) + env = env.append_transform(primer) + r = env.rollout(10, m, break_when_any_done=False) + mask = [k for k in r.keys() if k.startswith("mask")][0] + assert (r[mask][0, :5] != r[mask][0, 5:6]).any() + assert (r[mask][0, :4] == r[mask][0, 4:5]).all() + + assert (r[mask][1, :6] != r[mask][1, 6:7]).any() + assert (r[mask][1, :5] == r[mask][1, 5:6]).all() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 7f8403c793e..34a1d61bfc5 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4597,7 +4597,7 @@ class TensorDictPrimer(Transform): .. note:: Some TorchRL modules rely on specific keys being present in the environment TensorDicts, like :class:`~torchrl.modules.models.LSTM` or :class:`~torchrl.modules.models.GRU`. - To facilitate this process, the method :func:`~torchrl.models.utils.get_primers_from_module` + To facilitate this process, the method :func:`~torchrl.modules.utils.get_primers_from_module` automatically checks for required primer transforms in a module and its submodules and generates them. """ @@ -4664,10 +4664,15 @@ def __init__( def reset_key(self): reset_key = self.__dict__.get("_reset_key", None) if reset_key is None: + if self.parent is None: + raise RuntimeError( + "Missing parent, cannot infer reset_key automatically." + ) reset_keys = self.parent.reset_keys if len(reset_keys) > 1: raise RuntimeError( - f"Got more than one reset key in env {self.container}, cannot infer which one to use. Consider providing the reset key in the {type(self)} constructor." + f"Got more than one reset key in env {self.container}, cannot infer which one to use. " + f"Consider providing the reset key in the {type(self)} constructor." ) reset_key = self._reset_key = reset_keys[0] return reset_key diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index c246b553e95..f65461842bb 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -21,6 +21,7 @@ ) from .models import ( BatchRenorm1d, + ConsistentDropoutModule, Conv3dNet, ConvNet, DdpgCnnActor, @@ -85,4 +86,5 @@ VmapModule, WorldModelWrapper, ) +from .utils import get_primers_from_module from .planners import CEMPlanner, MPCPlannerBase, MPPIPlanner # usort:skip diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index 9a814e35477..90b9fadd747 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -9,7 +9,12 @@ from .batchrenorm import BatchRenorm1d from .decision_transformer import DecisionTransformer -from .exploration import NoisyLazyLinear, NoisyLinear, reset_noise +from .exploration import ( + ConsistentDropoutModule, + NoisyLazyLinear, + NoisyLinear, + reset_noise, +) from .model_based import ( DreamerActor, ObsDecoder, diff --git a/torchrl/modules/models/exploration.py b/torchrl/modules/models/exploration.py index 16c6ac5ff30..720934a6809 100644 --- a/torchrl/modules/models/exploration.py +++ b/torchrl/modules/models/exploration.py @@ -2,16 +2,24 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import functools import math import warnings -from typing import Optional, Sequence, Union +from typing import List, Optional, Sequence, Union import torch + +from tensordict.nn import TensorDictModuleBase +from tensordict.utils import NestedKey from torch import distributions as d, nn +from torch.nn import functional as F +from torch.nn.modules.dropout import _DropoutNd from torch.nn.modules.lazy import LazyModuleMixin from torch.nn.parameter import UninitializedBuffer, UninitializedParameter - from torchrl._utils import prod +from torchrl.data.tensor_specs import Unbounded from torchrl.data.utils import DEVICE_TYPING, DEVICE_TYPING_ARGS from torchrl.envs.utils import exploration_type, ExplorationType from torchrl.modules.distributions.utils import _cast_transform_device @@ -520,3 +528,203 @@ def initialize_parameters( ) self._sigma.materialize((action_dim, state_dim)) self._sigma.data.copy_(self.sigma_init.expand_as(self._sigma)) + + +class ConsistentDropout(_DropoutNd): + """Implements a :class:`~torch.nn.Dropout` variant with consistent dropout. + + This method is proposed in `"Consistent Dropout for Policy Gradient Reinforcement Learning" (Hausknecht & Wagener, 2022) + `_. + + This :class:`~torch.nn.Dropout` variant attempts to increase training stability and + reduce update variance by caching the dropout masks used during rollout + and reusing them during the update phase. + + The class you are looking at is independent of the rest of TorchRL's API and does not require tensordict to be run. + :class:`~torchrl.modules.ConsistentDropoutModule` is a wrapper around ``ConsistentDropout`` that capitalizes on the extensibility + of ``TensorDict``s by storing generated dropout masks in the transition ``TensorDict`` themselves. + See this class for a detailed explanation as well as usage examples. + + There is otherwise little conceptual deviance from the PyTorch + :class:`~torch.nn.Dropout` implementation. + + ..note:: TorchRL's data collectors perform rollouts in :meth:`~torch.no_grad` mode but not in `eval` mode, + so the dropout masks will be applied unless the policy passed to the collector is in eval mode. + + .. note:: Unlike other exploration modules, :class:`~torchrl.modules.ConsistentDropoutModule` + uses the ``train``/``eval`` mode to comply with the regular `Dropout` API in PyTorch. + The :func:`~torchrl.envs.utils.set_exploration_mode` context manager will have no effect on + this module. + + Args: + p (float, optional): Dropout probability. Defaults to ``0.5``. + + .. seealso:: + + - :class:`~torchrl.collectors.SyncDataCollector`: + :meth:`~torchrl.collectors.SyncDataCollector.rollout()` and :meth:`~torchrl.collectors.SyncDataCollector.iterator()` + - :class:`~torchrl.collectors.MultiSyncDataCollector`: + Uses :meth:`~torchrl.collectors.collectors._main_async_collector` (:class:`~torchrl.collectors.SyncDataCollector`) + under the hood + - :class:`~torchrl.collectors.MultiaSyncDataCollector`, :class:`~torchrl.collectors.aSyncDataCollector`: Ditto. + + """ + + def __init__(self, p: float = 0.5): + super().__init__() + self.p = p + + def forward( + self, x: torch.Tensor, mask: torch.Tensor | None = None + ) -> torch.Tensor: + """During training (rollouts & updates), this call masks a tensor full of ones before multiplying with the input tensor. + + During evaluation, this call results in a no-op and only the input is returned. + + Args: + x (torch.Tensor): the input tensor. + mask (torch.Tensor, optional): the optional mask for the dropout. + + Returns: a tensor and a corresponding mask in train mode, and only a tensor in eval mode. + """ + if self.training: + if mask is None: + mask = self.make_mask(input=x) + return x * mask, mask + + return x + + def make_mask(self, *, input=None, shape=None): + if input is not None: + return F.dropout( + torch.ones_like(input), self.p, self.training, inplace=False + ) + elif shape is not None: + return F.dropout(torch.ones(shape), self.p, self.training, inplace=False) + else: + raise RuntimeError("input or shape must be passed to make_mask.") + + +class ConsistentDropoutModule(TensorDictModuleBase): + """A TensorDictModule wrapper for :class:`~ConsistentDropout`. + + Args: + p (float, optional): Dropout probability. Default: ``0.5``. + in_keys (NestedKey or list of NestedKeys): keys to be read + from input tensordict and passed to this module. + out_keys (NestedKey or iterable of NestedKeys): keys to be written to the input tensordict. + Defaults to ``in_keys`` values. + + Keyword Args: + input_shape (tuple, optional): the shape of the input (non-batchted), used to generate the + tensordict primers with :meth:`~.make_tensordict_primer`. + input_dtype (torch.dtype, optional): the dtype of the input for the primer. If none is pased, + ``torch.get_default_dtype`` is assumed. + + .. note:: To use this class within a policy, one needs the mask to be reset at reset time. + This can be achieved through a :class:`~torchrl.envs.TensorDictPrimer` transform that can be obtained + with :meth:`~.make_tensordict_primer`. See this method for more information. + + Examples: + >>> from tensordict import TensorDict + >>> module = ConsistentDropoutModule(p = 0.1) + >>> td = TensorDict({"x": torch.randn(3, 4)}, [3]) + >>> module(td) + TensorDict( + fields={ + mask_6127171760: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.bool, is_shared=False), + x: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([3]), + device=None, + is_shared=False) + """ + + def __init__( + self, + p: float, + in_keys: NestedKey | List[NestedKey], + out_keys: NestedKey | List[NestedKey] | None = None, + input_shape: torch.Size = None, + input_dtype: torch.dtype | None = None, + ): + if isinstance(in_keys, NestedKey): + in_keys = [in_keys, f"mask_{id(self)}"] + if out_keys is None: + out_keys = list(in_keys) + if isinstance(out_keys, NestedKey): + out_keys = [out_keys, f"mask_{id(self)}"] + if len(in_keys) != 2 or len(out_keys) != 2: + raise ValueError( + "in_keys and out_keys length must be 2 for consistent dropout." + ) + self.in_keys = in_keys + self.out_keys = out_keys + self.input_shape = input_shape + self.input_dtype = input_dtype + super().__init__() + + if not 0 <= p < 1: + raise ValueError(f"p must be in [0,1), got p={p: 4.4f}.") + + self.consistent_dropout = ConsistentDropout(p) + + def forward(self, tensordict): + x = tensordict.get(self.in_keys[0]) + mask = tensordict.get(self.in_keys[1], default=None) + if self.consistent_dropout.training: + x, mask = self.consistent_dropout(x, mask=mask) + tensordict.set(self.out_keys[0], x) + tensordict.set(self.out_keys[1], mask) + else: + x = self.consistent_dropout(x, mask=mask) + tensordict.set(self.out_keys[0], x) + + return tensordict + + def make_tensordict_primer(self): + """Makes a tensordict primer for the environment to generate random masks during reset calls. + + .. seealso:: :func:`torchrl.modules.utils.get_primers_from_module` for a method to generate all primers for a given + module. + + Examples: + >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod + >>> from torchrl.envs import GymEnv, StepCounter, SerialEnv + >>> m = Seq( + ... Mod(torch.nn.Linear(7, 4), in_keys=["observation"], out_keys=["intermediate"]), + ... ConsistentDropoutModule( + ... p=0.5, + ... input_shape=(2, 4), + ... in_keys="intermediate", + ... ), + ... Mod(torch.nn.Linear(4, 7), in_keys=["intermediate"], out_keys=["action"]), + ... ) + >>> primer = get_primers_from_module(m) + >>> env0 = GymEnv("Pendulum-v1").append_transform(StepCounter(5)) + >>> env1 = GymEnv("Pendulum-v1").append_transform(StepCounter(6)) + >>> env = SerialEnv(2, [lambda env=env0: env, lambda env=env1: env]) + >>> env = env.append_transform(primer) + >>> r = env.rollout(10, m, break_when_any_done=False) + >>> mask = [k for k in r.keys() if k.startswith("mask")][0] + >>> assert (r[mask][0, :5] != r[mask][0, 5:6]).any() + >>> assert (r[mask][0, :4] == r[mask][0, 4:5]).all() + + """ + from torchrl.envs.transforms.transforms import TensorDictPrimer + + shape = self.input_shape + dtype = self.input_dtype + if dtype is None: + dtype = torch.get_default_dtype() + if shape is None: + raise RuntimeError( + "Cannot infer the shape of the input automatically. " + "Please pass the shape of the tensor to `ConstistentDropoutModule` during construction " + "with the `input_shape` kwarg." + ) + return TensorDictPrimer( + primers={self.in_keys[1]: Unbounded(dtype=dtype, shape=shape)}, + default_value=functools.partial( + self.consistent_dropout.make_mask, shape=shape + ), + ) diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index 48756683c11..f538f8e95c5 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + from typing import Optional, Tuple import torch @@ -387,7 +389,7 @@ class LSTMModule(ModuleBase): .. note:: This module relies on specific ``recurrent_state`` keys being present in the input TensorDicts. To generate a :class:`~torchrl.envs.transforms.TensorDictPrimer` transform that will automatically add hidden states to the environment TensorDicts, use the method :func:`~torchrl.modules.rnn.LSTMModule.make_tensordict_primer`. - If this class is a submodule in a larger module, the method :func:`~torchrl.models.utils.get_primers_from_module` can be called + If this class is a submodule in a larger module, the method :func:`~torchrl.modules.utils.get_primers_from_module` can be called on the parent module to automatically generate the primer transforms required for all submodules, including this one. @@ -534,6 +536,9 @@ def make_tensordict_primer(self): tensordict, which the meth:`~torchrl.EnvBase.step_mdp` method will not be able to do as the recurrent states are not registered within the environment specs. + See :func:`torchrl.modules.utils.get_primers_from_module` for a method to generate all primers for a given + module. + Examples: >>> from torchrl.collectors import SyncDataCollector >>> from torchrl.envs import TransformedEnv, InitTracker @@ -1108,7 +1113,7 @@ class GRUModule(ModuleBase): .. note:: This module relies on specific ``recurrent_state`` keys being present in the input TensorDicts. To generate a :class:`~torchrl.envs.transforms.TensorDictPrimer` transform that will automatically add hidden states to the environment TensorDicts, use the method :func:`~torchrl.modules.rnn.GRUModule.make_tensordict_primer`. - If this class is a submodule in a larger module, the method :func:`~torchrl.models.utils.get_primers_from_module` can be called + If this class is a submodule in a larger module, the method :func:`~torchrl.modules.utils.get_primers_from_module` can be called on the parent module to automatically generate the primer transforms required for all submodules, including this one. Examples: @@ -1280,6 +1285,9 @@ def make_tensordict_primer(self): tensordict, which the meth:`~torchrl.EnvBase.step_mdp` method will not be able to do as the recurrent states are not registered within the environment specs. + See :func:`torchrl.modules.utils.get_primers_from_module` for a method to generate all primers for a given + module. + Examples: >>> from torchrl.collectors import SyncDataCollector >>> from torchrl.envs import TransformedEnv, InitTracker