Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Sep 10, 2024
1 parent be095c4 commit 3231780
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 7 deletions.
2 changes: 1 addition & 1 deletion torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4597,7 +4597,7 @@ class TensorDictPrimer(Transform):
.. note:: Some TorchRL modules rely on specific keys being present in the environment TensorDicts,
like :class:`~torchrl.modules.models.LSTM` or :class:`~torchrl.modules.models.GRU`.
To facilitate this process, the method :func:`~torchrl.models.utils.get_primers_from_module`
To facilitate this process, the method :func:`~torchrl.modules.utils.get_primers_from_module`
automatically checks for required primer transforms in a module and its submodules and
generates them.
"""
Expand Down
41 changes: 35 additions & 6 deletions torchrl/modules/models/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import functools
import math
import warnings
Expand Down Expand Up @@ -531,16 +533,17 @@ def initialize_parameters(
class ConsistentDropout(_DropoutNd):
"""Implements a :class:`~torch.nn.Dropout` variant with consistent dropout.
This method is proposed in `"Consistent Dropout for Policy Gradient Reinforcement Learning" (Hausknecht & Wagener, 2022) <https://arxiv.org/abs/2202.11818>`_.
This method is proposed in `"Consistent Dropout for Policy Gradient Reinforcement Learning" (Hausknecht & Wagener, 2022)
<https://arxiv.org/abs/2202.11818>`_.
This :class:`~torch.nn.Dropout` variant attempts to increase training stability and
reduce update variance by caching the dropout masks used during rollout
and reusing them during the update phase.
TorchRL's implementation capitalizes on the extensibility of ``TensorDict``s by storing generated dropout masks
in the transition ``TensorDict`` themselves. This class can be used through :class:`~torchrl.modules.ConsistentDropoutModule`
within policies coded using the :class:`~tensordict.nn.TensorDictModuleBase` API. See this class for a detailed
explanation as well as usage examples.
The class you are looking at is independent of the rest of TorchRL's API and does not require tensordict to be run.
:class:`~torchrl.modules.ConsistentDropoutModule` is a wrapper around ``ConsistentDropout`` that capitalizes on the extensibility
of ``TensorDict``s by storing generated dropout masks in the transition ``TensorDict`` themselves.
See this class for a detailed explanation as well as usage examples.
There is otherwise little conceptual deviance from the PyTorch
:class:`~torch.nn.Dropout` implementation.
Expand Down Expand Up @@ -618,6 +621,10 @@ class ConsistentDropoutModule(TensorDictModuleBase):
input_dtype (torch.dtype, optional): the dtype of the input for the primer. If none is pased,
``torch.get_default_dtype`` is assumed.
.. note:: To use this class within a policy, one needs the mask to be reset at reset time.
This can be achieved through a :class:`~torchrl.envs.TensorDictPrimer` transform that can be obtained
with :meth:`~.make_tensordict_primer`. See this method for more information.
Examples:
>>> from tensordict import TensorDict
>>> module = ConsistentDropoutModule(p = 0.1)
Expand Down Expand Up @@ -680,8 +687,30 @@ def make_tensordict_primer(self):
.. seealso:: :func:`torchrl.modules.utils.get_primers_from_module` for a method to generate all primers for a given
module.
Examples:
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>> from torchrl.envs import GymEnv, StepCounter, SerialEnv
>>> m = Seq(
... Mod(torch.nn.Linear(7, 4), in_keys=["observation"], out_keys=["intermediate"]),
... ConsistentDropoutModule(
... p=0.5,
... input_shape=(2, 4),
... in_keys="intermediate",
... ),
... Mod(torch.nn.Linear(4, 7), in_keys=["intermediate"], out_keys=["action"]),
... )
>>> primer = get_primers_from_module(m)
>>> env0 = GymEnv("Pendulum-v1").append_transform(StepCounter(5))
>>> env1 = GymEnv("Pendulum-v1").append_transform(StepCounter(6))
>>> env = SerialEnv(2, [lambda env=env0: env, lambda env=env1: env])
>>> env = env.append_transform(primer)
>>> r = env.rollout(10, m, break_when_any_done=False)
>>> mask = [k for k in r.keys() if k.startswith("mask")][0]
>>> assert (r[mask][0, :5] != r[mask][0, 5:6]).any()
>>> assert (r[mask][0, :4] == r[mask][0, 4:5]).all()
"""
from torchrl.envs import TensorDictPrimer
from torchrl.envs.transforms.transforms import TensorDictPrimer

shape = self.input_shape
dtype = self.input_dtype
Expand Down
2 changes: 2 additions & 0 deletions torchrl/modules/tensordict_module/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from typing import Optional, Tuple

import torch
Expand Down

0 comments on commit 3231780

Please sign in to comment.