Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Consistent Dropout #2399

Merged
merged 12 commits into from
Sep 10, 2024
23 changes: 15 additions & 8 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,23 +57,29 @@ projected (in a L1-manner) into the desired domain.
SafeSequential
TanhModule

Exploration wrappers
~~~~~~~~~~~~~~~~~~~~
Exploration wrappers and modules
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

To efficiently explore the environment, TorchRL proposes a series of wrappers
To efficiently explore the environment, TorchRL proposes a series of modules
that will override the action sampled by the policy by a noisier version.
Their behavior is controlled by :func:`~torchrl.envs.utils.exploration_mode`:
if the exploration is set to ``"random"``, the exploration is active. In all
other cases, the action written in the tensordict is simply the network output.

.. currentmodule:: torchrl.modules.tensordict_module
.. note:: Unlike other exploration modules, :class:`~torchrl.modules.ConsistentDropoutModule`
uses the ``train``/``eval`` mode to comply with the regular `Dropout` API in PyTorch.
The :func:`~torchrl.envs.utils.set_exploration_mode` context manager will have no effect on
this module.

.. currentmodule:: torchrl.modules

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

AdditiveGaussianModule
AdditiveGaussianWrapper
ConsistentDropoutModule
EGreedyModule
EGreedyWrapper
OrnsteinUhlenbeckProcessModule
Expand Down Expand Up @@ -438,12 +444,13 @@ Regular modules
:toctree: generated/
:template: rl_template_noinherit.rst

MLP
ConvNet
BatchRenorm1d
ConsistentDropout
Conv3dNet
SqueezeLayer
ConvNet
MLP
Squeeze2dLayer
BatchRenorm1d
SqueezeLayer

Algorithm-specific modules
~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
152 changes: 151 additions & 1 deletion test/test_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
NormalParamExtractor,
TanhNormal,
)
from torchrl.modules.models.exploration import LazygSDEModule
from torchrl.modules.models.exploration import ConsistentDropoutModule, LazygSDEModule
from torchrl.modules.tensordict_module.actors import (
Actor,
ProbabilisticActor,
Expand Down Expand Up @@ -738,6 +738,156 @@ def test_gsde_init(sigma_init, state_dim, action_dim, mean, std, device, learn_s
), f"failed: mean={mean}, std={std}, sigma_init={sigma_init}, actual: {sigma.mean()}"


class TestConsistentDropout:
@pytest.mark.parametrize("dropout_p", [0.0, 0.1, 0.5])
@pytest.mark.parametrize("parallel_spec", [False, True])
@pytest.mark.parametrize("device", get_default_devices())
def test_consistent_dropout(self, dropout_p, parallel_spec, device):
"""

This preliminary test seeks to ensure two things for both
ConsistentDropout and ConsistentDropoutModule:
1. Rollout transitions generate a dropout mask as desired.
- We can easily verify the existence of a mask
2. The dropout mask is correctly applied.
- We will check with stochastic policies whether or not
the loc and scale are the same.
"""
torch.manual_seed(0)

# NOTE: Please only put a module with one dropout layer.
# That's how this test is constructed anyways.
@torch.no_grad
def inner_verify_routine(module, env):
# Perform transitions.
collector = SyncDataCollector(
create_env_fn=env,
policy=module,
frames_per_batch=1,
total_frames=10,
device=device,
)
for frames in collector:
masks = [
(key, value)
for key, value in frames.items()
if key.startswith("mask_")
]
# Assert rollouts do indeed correctly generate the masks.
assert len(masks) == 1, (
"Expected exactly ONE mask since we only put "
f"one dropout module, got {len(masks)}."
)

# Verify that the result for this batch is the same.
# Kind of Monte Carlo, to be honest.
sentinel_mask = masks[0][1].clone()
sentinel_outputs = frames.select("loc", "scale").clone()

desired_dropout_mask = torch.full_like(
sentinel_mask, 1 / (1 - dropout_p)
)
desired_dropout_mask[sentinel_mask == 0.0] = 0.0
# As of 15/08/24, :meth:`~torch.nn.functional.dropout`
# is being used. Never hurts to be safe.
assert torch.allclose(
sentinel_mask, desired_dropout_mask
), "Dropout was not scaled properly."

new_frames = module(frames.clone())
infer_mask = new_frames[masks[0][0]]
infer_outputs = new_frames.select("loc", "scale")
assert (infer_mask == sentinel_mask).all(), "Mask does not match"

assert all(
[
torch.allclose(infer_outputs[key], sentinel_outputs[key])
for key in ("loc", "scale")
]
), (
"Outputs do not match:\n "
f"{infer_outputs['loc']}\n--- vs ---\n{sentinel_outputs['loc']}"
f"{infer_outputs['scale']}\n--- vs ---\n{sentinel_outputs['scale']}"
)

env = SerialEnv(
2,
ContinuousActionVecMockEnv,
)
env = TransformedEnv(env.to(device), InitTracker())
env = env.to(device)
# the module must work with the action spec of a single env or a serial env
if parallel_spec:
action_spec = env.action_spec
else:
action_spec = ContinuousActionVecMockEnv(device=device).action_spec
d_act = action_spec.shape[-1]

# NOTE: Please only put a module with one dropout layer.
# That's how this test is constructed anyways.
module_td_seq = TensorDictSequential(
TensorDictModule(
nn.LazyLinear(2 * d_act), in_keys=["observation"], out_keys=["out"]
),
ConsistentDropoutModule(p=dropout_p, in_keys="out"),
TensorDictModule(
NormalParamExtractor(), in_keys=["out"], out_keys=["loc", "scale"]
),
)

policy_td_seq = ProbabilisticActor(
module=module_td_seq,
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
default_interaction_type=InteractionType.RANDOM,
spec=action_spec,
).to(device)

# Wake up the policies
policy_td_seq(env.reset())

# Test.
inner_verify_routine(policy_td_seq, env)

def test_consistent_dropout_primer(self):
import torch

from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
from torchrl.envs import SerialEnv, StepCounter
from torchrl.modules import ConsistentDropoutModule, get_primers_from_module

torch.manual_seed(0)

m = Seq(
Mod(
torch.nn.Linear(7, 4),
in_keys=["observation"],
out_keys=["intermediate"],
),
ConsistentDropoutModule(
p=0.5,
input_shape=(
2,
4,
),
in_keys="intermediate",
),
Mod(torch.nn.Linear(4, 7), in_keys=["intermediate"], out_keys=["action"]),
)
primer = get_primers_from_module(m)
env0 = ContinuousActionVecMockEnv().append_transform(StepCounter(5))
env1 = ContinuousActionVecMockEnv().append_transform(StepCounter(6))
env = SerialEnv(2, [lambda env=env0: env, lambda env=env1: env])
env = env.append_transform(primer)
r = env.rollout(10, m, break_when_any_done=False)
mask = [k for k in r.keys() if k.startswith("mask")][0]
assert (r[mask][0, :5] != r[mask][0, 5:6]).any()
assert (r[mask][0, :4] == r[mask][0, 4:5]).all()

assert (r[mask][1, :6] != r[mask][1, 6:7]).any()
assert (r[mask][1, :5] == r[mask][1, 5:6]).all()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
9 changes: 7 additions & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4597,7 +4597,7 @@ class TensorDictPrimer(Transform):

.. note:: Some TorchRL modules rely on specific keys being present in the environment TensorDicts,
like :class:`~torchrl.modules.models.LSTM` or :class:`~torchrl.modules.models.GRU`.
To facilitate this process, the method :func:`~torchrl.models.utils.get_primers_from_module`
To facilitate this process, the method :func:`~torchrl.modules.utils.get_primers_from_module`
automatically checks for required primer transforms in a module and its submodules and
generates them.
"""
Expand Down Expand Up @@ -4664,10 +4664,15 @@ def __init__(
def reset_key(self):
reset_key = self.__dict__.get("_reset_key", None)
if reset_key is None:
if self.parent is None:
raise RuntimeError(
"Missing parent, cannot infer reset_key automatically."
)
reset_keys = self.parent.reset_keys
if len(reset_keys) > 1:
raise RuntimeError(
f"Got more than one reset key in env {self.container}, cannot infer which one to use. Consider providing the reset key in the {type(self)} constructor."
f"Got more than one reset key in env {self.container}, cannot infer which one to use. "
f"Consider providing the reset key in the {type(self)} constructor."
)
reset_key = self._reset_key = reset_keys[0]
return reset_key
Expand Down
2 changes: 2 additions & 0 deletions torchrl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from .models import (
BatchRenorm1d,
ConsistentDropoutModule,
Conv3dNet,
ConvNet,
DdpgCnnActor,
Expand Down Expand Up @@ -85,4 +86,5 @@
VmapModule,
WorldModelWrapper,
)
from .utils import get_primers_from_module
from .planners import CEMPlanner, MPCPlannerBase, MPPIPlanner # usort:skip
7 changes: 6 additions & 1 deletion torchrl/modules/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
from .batchrenorm import BatchRenorm1d

from .decision_transformer import DecisionTransformer
from .exploration import NoisyLazyLinear, NoisyLinear, reset_noise
from .exploration import (
ConsistentDropoutModule,
NoisyLazyLinear,
NoisyLinear,
reset_noise,
)
from .model_based import (
DreamerActor,
ObsDecoder,
Expand Down
Loading
Loading