diff --git a/test/test_actors.py b/test/test_actors.py index 388120a4ba7..2d160e31bba 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -130,8 +130,8 @@ def test_probabilistic_actor_nested_normal(log_prob_key, nested_dim=5, n_actions out_keys=[("data", "action")], distribution_class=TanhNormal, distribution_kwargs={ - "min": action_spec.space.low, - "max": action_spec.space.high, + "low": action_spec.space.low, + "high": action_spec.space.high, }, log_prob_key=log_prob_key, return_log_prob=True, @@ -153,8 +153,8 @@ def test_probabilistic_actor_nested_normal(log_prob_key, nested_dim=5, n_actions out_keys=[("data", "action")], distribution_class=TanhNormal, distribution_kwargs={ - "min": action_spec.space.low, - "max": action_spec.space.high, + "low": action_spec.space.low, + "high": action_spec.space.high, }, log_prob_key=log_prob_key, return_log_prob=True, diff --git a/test/test_cost.py b/test/test_cost.py index 25a30d973f9..76fc4e651f4 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -13519,17 +13519,36 @@ def __init__(self): def test_loss_exploration(): class DummyLoss(LossModule): - def forward(self, td): - assert exploration_type() == InteractionType.MODE + def forward(self, td, mode): + if mode is None: + mode = self.deterministic_sampling_mode + assert exploration_type() == mode with set_exploration_type(ExplorationType.RANDOM): assert exploration_type() == ExplorationType.RANDOM - assert exploration_type() == ExplorationType.MODE + assert exploration_type() == mode return td loss_fn = DummyLoss() with set_exploration_type(ExplorationType.RANDOM): assert exploration_type() == ExplorationType.RANDOM - loss_fn(None) + loss_fn(None, None) + assert exploration_type() == ExplorationType.RANDOM + + with set_exploration_type(ExplorationType.RANDOM): + assert exploration_type() == ExplorationType.RANDOM + loss_fn(None, ExplorationType.DETERMINISTIC) + assert exploration_type() == ExplorationType.RANDOM + + loss_fn.deterministic_sampling_mode = ExplorationType.MODE + with set_exploration_type(ExplorationType.RANDOM): + assert exploration_type() == ExplorationType.RANDOM + loss_fn(None, ExplorationType.MODE) + assert exploration_type() == ExplorationType.RANDOM + + loss_fn.deterministic_sampling_mode = ExplorationType.MEAN + with set_exploration_type(ExplorationType.RANDOM): + assert exploration_type() == ExplorationType.RANDOM + loss_fn(None, ExplorationType.MEAN) assert exploration_type() == ExplorationType.RANDOM diff --git a/test/test_distributions.py b/test/test_distributions.py index a3308a57d3a..53bfda343a2 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -85,10 +85,10 @@ def _map_all(*tensors_or_other, device): class TestTanhNormal: @pytest.mark.parametrize( - "min", [-torch.ones(3), -1, 3 * torch.tensor([-1.0, -2.0, -0.5]), -0.1] + "low", [-torch.ones(3), -1, 3 * torch.tensor([-1.0, -2.0, -0.5]), -0.1] ) @pytest.mark.parametrize( - "max", [torch.ones(3), 1, 3 * torch.tensor([1.0, 2.0, 0.5]), 0.1] + "high", [torch.ones(3), 1, 3 * torch.tensor([1.0, 2.0, 0.5]), 0.1] ) @pytest.mark.parametrize( "vecs", @@ -102,25 +102,64 @@ class TestTanhNormal: ) @pytest.mark.parametrize("shape", [torch.Size([]), torch.Size([3, 4])]) @pytest.mark.parametrize("device", get_default_devices()) - def test_tanhnormal(self, min, max, vecs, upscale, shape, device): - min, max, vecs, upscale, shape = _map_all( - min, max, vecs, upscale, shape, device=device + def test_tanhnormal(self, low, high, vecs, upscale, shape, device): + torch.manual_seed(0) + low, high, vecs, upscale, shape = _map_all( + low, high, vecs, upscale, shape, device=device ) torch.manual_seed(0) d = TanhNormal( *vecs, upscale=upscale, - min=min, - max=max, + low=low, + high=high, ) for _ in range(100): a = d.rsample(shape) assert a.shape[: len(shape)] == shape - assert (a >= d.min).all() - assert (a <= d.max).all() + assert (a >= d.low).all() + assert (a <= d.high).all() lp = d.log_prob(a) assert torch.isfinite(lp).all() + def test_tanhnormal_mode(self): + # Checks that the std of the mode computed by tanh normal is within a certain range + # when starting from close points + + torch.manual_seed(0) + # 10 start points with 1000 jitters around that + # std of the loc is about 1e-4 + loc = torch.randn(10) + torch.randn(1000, 10) / 10000 + + t = TanhNormal(loc=loc, scale=0.5, low=-1, high=1, event_dims=0) + + mode = t.get_mode() + assert mode.shape == loc.shape + empirical_mode, empirical_mode_lp = torch.zeros_like(loc), -float("inf") + for v in torch.arange(-1, 1, step=0.01): + lp = t.log_prob(v.expand_as(t.loc)) + empirical_mode = torch.where(lp > empirical_mode_lp, v, empirical_mode) + empirical_mode_lp = torch.where( + lp > empirical_mode_lp, lp, empirical_mode_lp + ) + assert abs(empirical_mode - mode).max() < 0.1, abs(empirical_mode - mode).max() + assert mode.shape == loc.shape + assert (mode.std(0).max() < 0.1).all(), mode.std(0) + + @pytest.mark.parametrize("event_dims", [0, 1, 2]) + def test_tanhnormal_event_dims(self, event_dims): + scale = 1 + loc = torch.randn(1, 2, 3, 4) + t = TanhNormal(loc=loc, scale=scale, event_dims=event_dims) + sample = t.sample() + assert sample.shape == loc.shape + exp_shape = loc.shape[:-event_dims] if event_dims > 0 else loc.shape + assert t.log_prob(sample).shape == exp_shape, ( + t.log_prob(sample).shape, + event_dims, + exp_shape, + ) + class TestTruncatedNormal: @pytest.mark.parametrize( @@ -159,13 +198,13 @@ def test_truncnormal(self, min, max, vecs, upscale, shape, device): a = d.rsample(shape) assert a.device == device assert a.shape[: len(shape)] == shape - assert (a >= d.min).all() - assert (a <= d.max).all() + assert (a >= d.low).all() + assert (a <= d.high).all() lp = d.log_prob(a) assert torch.isfinite(lp).all() - oob_min = d.min.expand((*d.batch_shape, *d.event_shape)) - 1e-2 + oob_min = d.low.expand((*d.batch_shape, *d.event_shape)) - 1e-2 assert not torch.isfinite(d.log_prob(oob_min)).any() - oob_max = d.max.expand((*d.batch_shape, *d.event_shape)) + 1e-2 + oob_max = d.high.expand((*d.batch_shape, *d.event_shape)) + 1e-2 assert not torch.isfinite(d.log_prob(oob_max)).any() @pytest.mark.skipif(not _has_scipy, reason="scipy not installed") diff --git a/test/test_exploration.py b/test/test_exploration.py index 89de8005555..f65ea655de2 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -585,7 +585,7 @@ def test_gsde( wrapper = NormalParamWrapper(model) module = SafeModule(wrapper, in_keys=in_keys, out_keys=["loc", "scale"]) distribution_class = TanhNormal - distribution_kwargs = {"min": -bound, "max": bound} + distribution_kwargs = {"low": -bound, "high": bound} spec = BoundedTensorSpec( -torch.ones(action_dim) * bound, torch.ones(action_dim) * bound, (action_dim,) ).to(device) diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index 59510dae0d9..6f81a9748bc 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -1416,8 +1416,8 @@ def test_dt_inference_wrapper(self, online): ) dist_class = TanhDelta dist_kwargs = { - "min": -1.0, - "max": 1.0, + "low": -1.0, + "high": 1.0, } actor = ProbabilisticActor( in_keys=in_keys, diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 11befbf0ee3..65206525bd2 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -335,9 +335,9 @@ class SyncDataCollector(DataCollectorBase): information. Defaults to ``False``. exploration_type (ExplorationType, optional): interaction mode to be used when - collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``, - ``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``. - Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``. + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, + ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` + or ``torchrl.envs.utils.ExplorationType.MEAN``. return_same_td (bool, optional): if ``True``, the same TensorDict will be returned at each iteration, with its values updated. This feature should be used cautiously: if the same @@ -1336,9 +1336,9 @@ class _MultiDataCollector(DataCollectorBase): information. Defaults to ``False``. exploration_type (ExplorationType, optional): interaction mode to be used when - collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``, - ``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``. - Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``. + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, + ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` + or ``torchrl.envs.utils.ExplorationType.MEAN``. reset_when_done (bool, optional): if ``True`` (default), an environment that return a ``True`` value in its ``"done"`` or ``"truncated"`` entry will be reset at the corresponding indices. @@ -2635,9 +2635,9 @@ class aSyncDataCollector(MultiaSyncDataCollector): information. Defaults to ``False``. exploration_type (ExplorationType, optional): interaction mode to be used when - collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``, - ``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``. - Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``. + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, + ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` + or ``torchrl.envs.utils.ExplorationType.MEAN``. reset_when_done (bool, optional): if ``True`` (default), an environment that return a ``True`` value in its ``"done"`` or ``"truncated"`` entry will be reset at the corresponding indices. diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index e69032d01c1..596c1f5d191 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -346,9 +346,9 @@ class DistributedDataCollector(DataCollectorBase): information. Defaults to ``False``. exploration_type (ExplorationType, optional): interaction mode to be used when - collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``, - ``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``. - Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``. + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, + ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` + or ``torchrl.envs.utils.ExplorationType.MEAN``. collector_class (type or str, optional): a collector class for the remote node. Can be :class:`~torchrl.collectors.SyncDataCollector`, :class:`~torchrl.collectors.MultiSyncDataCollector`, diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index faf4d4a6cce..79b3ee9063c 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -211,9 +211,9 @@ class RayCollector(DataCollectorBase): information. Defaults to ``False``. exploration_type (ExplorationType, optional): interaction mode to be used when - collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``, - ``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``. - Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``. + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, + ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` + or ``torchrl.envs.utils.ExplorationType.MEAN``. collector_class (Python class): a collector class to be remotely instantiated. Can be :class:`~torchrl.collectors.SyncDataCollector`, :class:`~torchrl.collectors.MultiSyncDataCollector`, diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index c32dbc8fea9..b6c324bb7b5 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -187,8 +187,9 @@ class RPCDataCollector(DataCollectorBase): information. Defaults to ``False``. exploration_type (ExplorationType, optional): interaction mode to be used when - collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``, - ``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``. + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, + ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` + or ``torchrl.envs.utils.ExplorationType.MEAN``. Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``. collector_class (type or str, optional): a collector class for the remote node. Can be :class:`~torchrl.collectors.SyncDataCollector`, diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index 3cd0728dd49..6f959086c83 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -226,9 +226,9 @@ class DistributedSyncDataCollector(DataCollectorBase): information. Defaults to ``False``. exploration_type (ExplorationType, optional): interaction mode to be used when - collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.RANDOM``, - ``torchrl.envs.utils.ExplorationType.MODE`` or ``torchrl.envs.utils.ExplorationType.MEAN``. - Defaults to ``torchrl.envs.utils.ExplorationType.RANDOM``. + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, + ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` + or ``torchrl.envs.utils.ExplorationType.MEAN``. collector_class (type or str, optional): a collector class for the remote node. Can be :class:`~torchrl.collectors.SyncDataCollector`, :class:`~torchrl.collectors.MultiSyncDataCollector`, diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index 3010539f3e6..9c5fa50ea8d 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -2,7 +2,9 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations +import warnings from numbers import Number from typing import Dict, Optional, Sequence, Tuple, Union @@ -90,6 +92,10 @@ def update(self, loc, scale): def mode(self): return self.base_dist.mean + @property + def deterministic_sample(self): + return self.mean + class SafeTanhTransform(D.TanhTransform): """TanhTransform subclass that ensured that the transformation is numerically invertible.""" @@ -201,46 +207,71 @@ class TruncatedNormal(D.Independent): "scale": constraints.greater_than(1e-6), } + def _warn_minmax(self): + warnings.warn( + f"the min / high keyword arguments are deprecated in favor of low / high in {type(self).__name__} " + f"and will be removed entirely in v0.6. ", + DeprecationWarning, + ) + def __init__( self, loc: torch.Tensor, scale: torch.Tensor, upscale: Union[torch.Tensor, float] = 5.0, - min: Union[torch.Tensor, float] = -1.0, - max: Union[torch.Tensor, float] = 1.0, + low: Union[torch.Tensor, float] = -1.0, + high: Union[torch.Tensor, float] = 1.0, tanh_loc: bool = False, + **kwargs, ): - err_msg = "TanhNormal max values must be strictly greater than min values" - if isinstance(max, torch.Tensor) or isinstance(min, torch.Tensor): - if not (max > min).all(): + if "max" in kwargs: + self._warn_minmax() + high = kwargs.pop("max") + if "min" in kwargs: + self._warn_minmax() + low = kwargs.pop("min") + + err_msg = "TanhNormal high values must be strictly greater than low values" + if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor): + if not (high > low).all(): raise RuntimeError(err_msg) - elif isinstance(max, Number) and isinstance(min, Number): - if not max > min: + elif isinstance(high, Number) and isinstance(low, Number): + if not high > low: raise RuntimeError(err_msg) else: - if not all(max > min): + if not all(high > low): raise RuntimeError(err_msg) - if isinstance(max, torch.Tensor): - self.non_trivial_max = (max != 1.0).any() + if isinstance(high, torch.Tensor): + self.non_trivial_max = (high != 1.0).any() else: - self.non_trivial_max = max != 1.0 + self.non_trivial_max = high != 1.0 - if isinstance(min, torch.Tensor): - self.non_trivial_min = (min != -1.0).any() + if isinstance(low, torch.Tensor): + self.non_trivial_min = (low != -1.0).any() else: - self.non_trivial_min = min != -1.0 + self.non_trivial_min = low != -1.0 self.tanh_loc = tanh_loc self.device = loc.device self.upscale = torch.as_tensor(upscale, device=self.device) - max = torch.as_tensor(max, device=self.device) - min = torch.as_tensor(min, device=self.device) - self.min = min - self.max = max + high = torch.as_tensor(high, device=self.device) + low = torch.as_tensor(low, device=self.device) + self.low = low + self.high = high self.update(loc, scale) + @property + def min(self): + self._warn_minmax() + return self.low + + @property + def max(self): + self._warn_minmax() + return self.high + def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None: if self.tanh_loc: loc = (loc / self.upscale).tanh() * self.upscale @@ -250,8 +281,8 @@ def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None: base_dist = _TruncatedNormal( loc, scale, - self.min.expand_as(loc), - self.max.expand_as(scale), + a=self.low.expand_as(loc), + b=self.high.expand_as(scale), device=self.device, ) super().__init__(base_dist, 1, validate_args=False) @@ -264,8 +295,12 @@ def mode(self): m = torch.min(torch.stack([m, b], -1), dim=-1)[0] return torch.max(torch.stack([m, a], -1), dim=-1)[0] + @property + def deterministic_sample(self): + return self.mean + def log_prob(self, value, **kwargs): - above_or_below = (self.min > value) | (self.max < value) + above_or_below = (self.low > value) | (self.high < value) a = self.base_dist._non_std_a + self.base_dist._dtype_min_gt_0 a = a.expand_as(value) b = self.base_dist._non_std_b - self.base_dist._dtype_min_gt_0 @@ -310,7 +345,8 @@ class TanhNormal(FasterTransformedDistribution): min (torch.Tensor or number, optional): minimum value of the distribution. Default is -1.0; max (torch.Tensor or number, optional): maximum value of the distribution. Default is 1.0; event_dims (int, optional): number of dimensions describing the action. - Default is 1; + Default is 1. Setting ``event_dims`` to ``0`` will result in a log-probability that has the same shape + as the input, ``1`` will reduce (sum over) the last dimension, ``2`` the last two etc. tanh_loc (bool, optional): if ``True``, the above formula is used for the location scaling, otherwise the raw value is kept. Default is ``False``; """ @@ -322,36 +358,54 @@ class TanhNormal(FasterTransformedDistribution): num_params = 2 + def _warn_minmax(self): + warnings.warn( + f"the min / high keyword arguments are deprecated in favor of low / high in {type(self).__name__} " + f"and will be removed entirely in v0.6. ", + DeprecationWarning, + ) + def __init__( self, loc: torch.Tensor, scale: torch.Tensor, upscale: Union[torch.Tensor, Number] = 5.0, - min: Union[torch.Tensor, Number] = -1.0, - max: Union[torch.Tensor, Number] = 1.0, - event_dims: int = 1, + low: Union[torch.Tensor, Number] = -1.0, + high: Union[torch.Tensor, Number] = 1.0, + event_dims: int | None = None, tanh_loc: bool = False, + **kwargs, ): - err_msg = "TanhNormal max values must be strictly greater than min values" - if isinstance(max, torch.Tensor) or isinstance(min, torch.Tensor): - if not (max > min).all(): + if "max" in kwargs: + self._warn_minmax() + high = kwargs.pop("max") + if "min" in kwargs: + self._warn_minmax() + low = kwargs.pop("min") + + if not isinstance(loc, torch.Tensor): + loc = torch.as_tensor(loc, dtype=torch.get_default_dtype()) + if not isinstance(scale, torch.Tensor): + scale = torch.as_tensor(scale, dtype=torch.get_default_dtype()) + if event_dims is None: + event_dims = min(1, loc.ndim) + + err_msg = "TanhNormal high values must be strictly greater than low values" + if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor): + if not (high > low).all(): raise RuntimeError(err_msg) - elif isinstance(max, Number) and isinstance(min, Number): - if not max > min: + elif isinstance(high, Number) and isinstance(low, Number): + if not high > low: raise RuntimeError(err_msg) else: - if not all(max > min): + if not all(high > low): raise RuntimeError(err_msg) - if isinstance(max, torch.Tensor): - self.non_trivial_max = (max != 1.0).any() - else: - self.non_trivial_max = max != 1.0 + high = torch.as_tensor(high, device=loc.device) + low = torch.as_tensor(low, device=loc.device) + self.non_trivial_max = (high != 1.0).any() - if isinstance(min, torch.Tensor): - self.non_trivial_min = (min != -1.0).any() - else: - self.non_trivial_min = min != -1.0 + self.non_trivial_min = (low != -1.0).any() self.tanh_loc = tanh_loc self._event_dims = event_dims @@ -363,51 +417,124 @@ def __init__( else upscale.to(self.device) ) - if isinstance(max, torch.Tensor): - max = max.to(loc.device) - if isinstance(min, torch.Tensor): - min = min.to(loc.device) - self.min = min - self.max = max + if isinstance(high, torch.Tensor): + high = high.to(loc.device) + if isinstance(low, torch.Tensor): + low = low.to(loc.device) + self.low = low + self.high = high t = SafeTanhTransform() + # t = D.TanhTransform() if self.non_trivial_max or self.non_trivial_min: t = D.ComposeTransform( [ t, - D.AffineTransform(loc=(max + min) / 2, scale=(max - min) / 2), + D.AffineTransform(loc=(high + low) / 2, scale=(high - low) / 2), ] ) self._t = t self.update(loc, scale) + @property + def min(self): + self._warn_minmax() + return self.low + + @property + def max(self): + self._warn_minmax() + return self.high + def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None: if self.tanh_loc: loc = (loc / self.upscale).tanh() * self.upscale - if self.non_trivial_max or self.non_trivial_min: - loc = loc + (self.max - self.min) / 2 + self.min + # loc must be rescaled if tanh_loc + if self.non_trivial_max or self.non_trivial_min: + loc = loc + (self.high - self.low) / 2 + self.low self.loc = loc self.scale = scale if ( hasattr(self, "base_dist") - and (self.base_dist.base_dist.loc.shape == self.loc.shape) - and (self.base_dist.base_dist.scale.shape == self.scale.shape) + and (self.root_dist.loc.shape == self.loc.shape) + and (self.root_dist.scale.shape == self.scale.shape) ): - self.base_dist.base_dist.loc = self.loc - self.base_dist.base_dist.scale = self.scale + self.root_dist.loc = self.loc + self.root_dist.scale = self.scale else: - base = D.Independent(D.Normal(self.loc, self.scale), self._event_dims) - super().__init__(base, self._t) + if self._event_dims > 0: + base = D.Independent(D.Normal(self.loc, self.scale), self._event_dims) + super().__init__(base, self._t) + else: + base = D.Normal(self.loc, self.scale) + super().__init__(base, self._t) + + @property + def root_dist(self): + bd = self + while hasattr(bd, "base_dist"): + bd = bd.base_dist + return bd @property def mode(self): - m = self.base_dist.base_dist.mean + warnings.warn( + "This computation of the mode is based on the first-order Taylor expansion " + "of the transform around the normal mean value, which can be inaccurate. " + "To use a more stable implementation of the mode, use dist.get_mode() method instead. " + "This implementation will be removed in v0.6.", + category=DeprecationWarning, + ) + return self.deterministic_sample + + @property + def deterministic_sample(self): + m = self.root_dist.mean for t in self.transforms: m = t(m) return m + @torch.enable_grad() + def get_mode(self): + """Computes an estimation of the mode using the Adam optimizer.""" + # Get starting point + m = self.sample((1000,)).mean(0) + m = torch.nn.Parameter(m.clamp(self.low, self.high).detach()) + optim = torch.optim.Adam((m,), lr=1e-2) + self_copy = type(self)( + loc=self.loc.detach(), + scale=self.scale.detach(), + low=self.low.detach(), + high=self.high.detach(), + event_dims=self._event_dims, + upscale=self.upscale, + tanh_loc=False, + ) + for _ in range(200): + lp = -self_copy.log_prob(m) + lp.mean().backward() + mc = m.clone().detach() + m.grad.clamp_max_(1) + optim.step() + optim.zero_grad() + m.data.clamp_(self_copy.low, self_copy.high) + nans = m.isnan() + if nans.any(): + m.data = torch.where(nans, mc, m.data) + if (m - mc).norm() < 1e-3: + break + return m.detach() + + @property + def mean(self): + raise NotImplementedError( + f"{type(self).__name__} does not have a closed form formula for the average. " + "Am estimate of this value can be computed using dist.sample((N,)).mean(dim=0), " + "where N is a large number of samples." + ) + def uniform_sample_tanhnormal(dist: TanhNormal, size=None) -> torch.Tensor: """Defines what uniform sampling looks like for a TanhNormal distribution. @@ -493,6 +620,10 @@ def rsample(self, size=None) -> torch.Tensor: def mode(self) -> torch.Tensor: return self.param + @property + def deterministic_sample(self): + return self.mean + @property def mean(self) -> torch.Tensor: return self.param @@ -503,8 +634,8 @@ class TanhDelta(FasterTransformedDistribution): Args: param (torch.Tensor): parameter of the delta distribution; - min (torch.Tensor or number, optional): minimum value of the distribution. Default is -1.0; - max (torch.Tensor or number, optional): maximum value of the distribution. Default is 1.0; + low (torch.Tensor or number, optional): minimum value of the distribution. Default is -1.0; + high (torch.Tensor or number, optional): maximum value of the distribution. Default is 1.0; event_dims (int, optional): number of dimensions describing the action. Default is 1; atol (number, optional): absolute tolerance to consider that a tensor matches the distribution parameter; @@ -520,38 +651,52 @@ class TanhDelta(FasterTransformedDistribution): "loc": constraints.real, } + def _warn_minmax(self): + warnings.warn( + f"the min / high keyword arguments are deprecated in favor of low / high in {type(self).__name__} " + f"and will be removed entirely in v0.6. ", + DeprecationWarning, + ) + def __init__( self, param: torch.Tensor, - min: Union[torch.Tensor, float] = -1.0, - max: Union[torch.Tensor, float] = 1.0, + low: Union[torch.Tensor, float] = -1.0, + high: Union[torch.Tensor, float] = 1.0, event_dims: int = 1, atol: float = 1e-6, rtol: float = 1e-6, **kwargs, ): - minmax_msg = "max value has been found to be equal or less than min value" - if isinstance(max, torch.Tensor) or isinstance(min, torch.Tensor): - if not (max > min).all(): + if "max" in kwargs: + self._warn_minmax() + high = kwargs.pop("max") + if "min" in kwargs: + self._warn_minmax() + low = kwargs.pop("min") + + minmax_msg = "high value has been found to be equal or less than low value" + if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor): + if not (high > low).all(): raise ValueError(minmax_msg) - elif isinstance(max, Number) and isinstance(min, Number): - if max <= min: + elif isinstance(high, Number) and isinstance(low, Number): + if high <= low: raise ValueError(minmax_msg) else: - if not all(max > min): + if not all(high > low): raise ValueError(minmax_msg) t = SafeTanhTransform() - non_trivial_min = (isinstance(min, torch.Tensor) and (min != -1.0).any()) or ( - not isinstance(min, torch.Tensor) and min != -1.0 + non_trivial_min = (isinstance(low, torch.Tensor) and (low != -1.0).any()) or ( + not isinstance(low, torch.Tensor) and low != -1.0 ) - non_trivial_max = (isinstance(max, torch.Tensor) and (max != 1.0).any()) or ( - not isinstance(max, torch.Tensor) and max != 1.0 + non_trivial_max = (isinstance(high, torch.Tensor) and (high != 1.0).any()) or ( + not isinstance(high, torch.Tensor) and high != 1.0 ) self.non_trivial = non_trivial_min or non_trivial_max - self.min = _cast_device(min, param.device) - self.max = _cast_device(max, param.device) + self.low = _cast_device(low, param.device) + self.high = _cast_device(high, param.device) loc = self.update(param) if self.non_trivial: @@ -559,7 +704,7 @@ def __init__( [ t, D.AffineTransform( - loc=(self.max + self.min) / 2, scale=(self.max - self.min) / 2 + loc=(self.high + self.low) / 2, scale=(self.high - self.low) / 2 ), ] ) @@ -576,12 +721,22 @@ def __init__( super().__init__(base, t) + @property + def min(self): + self._warn_minmax() + return self.low + + @property + def max(self): + self._warn_minmax() + return self.high + def update(self, net_output: torch.Tensor) -> Optional[torch.Tensor]: loc = net_output if self.non_trivial: device = loc.device - shift = _cast_device(self.max - self.min, device) - loc = loc + shift / 2 + _cast_device(self.min, device) + shift = _cast_device(self.high - self.low, device) + loc = loc + shift / 2 + _cast_device(self.low, device) if hasattr(self, "base_dist"): self.base_dist.update(loc) else: @@ -594,6 +749,10 @@ def mode(self) -> torch.Tensor: mode = t(mode) return mode + @property + def deterministic_sample(self): + return self.mode + @property def mean(self) -> torch.Tensor: raise AttributeError("TanhDelta mean has not analytical form.") diff --git a/torchrl/modules/distributions/discrete.py b/torchrl/modules/distributions/discrete.py index d73457b2261..c48d8168887 100644 --- a/torchrl/modules/distributions/discrete.py +++ b/torchrl/modules/distributions/discrete.py @@ -106,6 +106,10 @@ def mode(self) -> torch.Tensor: else: return (self.probs == self.probs.max(-1, True)[0]).to(torch.long) + @property + def deterministic_sample(self): + return self.mode + @_one_hot_wrapper(D.Categorical) def sample( self, sample_shape: Optional[Union[torch.Size, Sequence]] = None diff --git a/torchrl/modules/distributions/truncated_normal.py b/torchrl/modules/distributions/truncated_normal.py index ea8ee0b9fe2..1350aeb2bc3 100644 --- a/torchrl/modules/distributions/truncated_normal.py +++ b/torchrl/modules/distributions/truncated_normal.py @@ -85,6 +85,10 @@ def support(self): def mean(self): return self._mean + @property + def deterministic_sample(self): + return self.mean + @property def variance(self): return self._variance diff --git a/torchrl/modules/models/exploration.py b/torchrl/modules/models/exploration.py index 59819d940d0..2ec51b46559 100644 --- a/torchrl/modules/models/exploration.py +++ b/torchrl/modules/models/exploration.py @@ -370,7 +370,7 @@ def forward(self, mu, state, _eps_gSDE): _err_msg = f"noise and state are expected to have matching batch size, got shapes {_eps_gSDE.shape} and {state.shape}" raise RuntimeError(_err_msg) - if _eps_gSDE is None and exploration_type() == ExplorationType.MODE: + if _eps_gSDE is None and exploration_type() != ExplorationType.RANDOM: # noise is irrelevant in with no exploration _eps_gSDE = torch.zeros( *state.shape[:-1], *sigma.shape, device=sigma.device, dtype=sigma.dtype @@ -391,7 +391,11 @@ def forward(self, mu, state, _eps_gSDE): if exploration_type() in (ExplorationType.RANDOM,): action = mu + eps - elif exploration_type() in (ExplorationType.MODE,): + elif exploration_type() in ( + ExplorationType.MODE, + ExplorationType.MEAN, + ExplorationType.DETERMINISTIC, + ): action = mu else: raise RuntimeError(_err_explo) diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index d2ea7af0af8..b6a91db7bb9 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -156,10 +156,12 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential): method. Default is ``False``. default_interaction_type (str, optional): keyword-only argument. Default method to be used to retrieve - the output value. Should be one of: 'InteractionType.MODE', + the output value. Should be one of: 'InteractionType.MODE', 'InteractionType.DETERMINISTIC', 'InteractionType.MEDIAN', 'InteractionType.MEAN' or 'InteractionType.RANDOM' (in which case the value is sampled - randomly from the distribution). Defaults to is 'InteractionType.MODE'. + randomly from the distribution). + TorchRL's ``ExplorationType`` class is a proxy to ``InteractionType``. + Defaults to is 'InteractionType.DETERMINISTIC'. .. note:: When a sample is drawn, the :class:`ProbabilisticActor` instance will first look for the interaction mode dictated by the diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py index f7af8a477cb..725323e1a28 100644 --- a/torchrl/modules/tensordict_module/probabilistic.py +++ b/torchrl/modules/tensordict_module/probabilistic.py @@ -105,7 +105,7 @@ def __init__( spec: Optional[TensorSpec] = None, safe: bool = False, default_interaction_mode: str = None, - default_interaction_type: str = InteractionType.MODE, + default_interaction_type: str = InteractionType.DETERMINISTIC, distribution_class: Type = Delta, distribution_kwargs: Optional[dict] = None, return_log_prob: bool = False, diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 4f41c884244..cbfc218327d 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -6,6 +6,7 @@ from __future__ import annotations import abc +import functools import warnings from dataclasses import dataclass from typing import Iterator, List, Optional, Tuple @@ -31,10 +32,19 @@ def _updater_check_forward_prehook(module, *args, **kwargs): ) +def _forward_wrapper(func): + @functools.wraps(func) + def new_forward(self, *args, **kwargs): + with set_exploration_type(self.deterministic_sampling_mode): + return func(self, *args, **kwargs) + + return new_forward + + class _LossMeta(abc.ABCMeta): def __init__(cls, name, bases, attr_dict): super().__init__(name, bases, attr_dict) - cls.forward = set_exploration_type(ExplorationType.MODE)(cls.forward) + cls.forward = _forward_wrapper(cls.forward) class LossModule(TensorDictModuleBase, metaclass=_LossMeta): @@ -55,7 +65,7 @@ class LossModule(TensorDictModuleBase, metaclass=_LossMeta): The value estimator can be changed using the :meth:`~.make_value_estimator` method. By default, the forward method is always decorated with a - gh :class:`torchrl.envs.ExplorationType.MODE` + gh :class:`torchrl.envs.ExplorationType.MEAN` To utilize the ability configuring the tensordict keys via :meth:`~.set_keys()` a subclass must define an _AcceptedKeys dataclass. @@ -75,6 +85,15 @@ class LossModule(TensorDictModuleBase, metaclass=_LossMeta): >>> >>> loss = MyLoss() >>> loss.set_keys(action="action2") + + .. note:: When a policy that is wrapped or augmented with an exploration module is passed + to the loss, we want to deactivate the exploration through ``set_exploration_mode()`` where + ```` is either ``ExplorationType.MEAN``, ``ExplorationType.MODE`` or + ``ExplorationType.DETERMINISTIC``. The default value is ``DETERMINISTIC`` and it is set + through the ``deterministic_sampling_mode`` loss attribute. If another + exploration mode is required (or if ``DETERMINISTIC`` is not available), one can + change the value of this attribute which will change the mode. + """ @dataclass @@ -89,6 +108,9 @@ class _AcceptedKeys: _vmap_randomness = None default_value_estimator: ValueEstimators = None + + deterministic_sampling_mode: ExplorationType = ExplorationType.DETERMINISTIC + SEP = "." TARGET_NET_WARNING = ( "No target network updater has been associated " diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 1b35c660f2f..9594895394a 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -647,7 +647,7 @@ def _qvalue_v1_loss( self, tensordict: TensorDictBase ) -> Tuple[Tensor, Dict[str, Tensor]]: target_params = self._cached_target_params_actor_value - with set_exploration_type(ExplorationType.MODE): + with set_exploration_type(self.deterministic_sampling_mode): target_value = self.value_estimator.value_estimate( tensordict, target_params=target_params ).squeeze(-1) diff --git a/torchrl/trainers/helpers/models.py b/torchrl/trainers/helpers/models.py index 0a3cea40b36..0c9ec92cff4 100644 --- a/torchrl/trainers/helpers/models.py +++ b/torchrl/trainers/helpers/models.py @@ -450,7 +450,7 @@ def _dreamer_make_actor_real( SafeProbabilisticModule( in_keys=["loc", "scale"], out_keys=[action_key], - default_interaction_type=InteractionType.MODE, + default_interaction_type=InteractionType.DETERMINISTIC, distribution_class=TanhNormal, distribution_kwargs={"tanh_loc": True}, spec=CompositeSpec( diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index ccd9bb23bb3..247d039eb1e 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -1146,12 +1146,12 @@ class Recorder(TrainerHookBase): Given that this instance is supposed to both explore and render the performance of the policy, it should be possible to turn off the explorative behaviour by calling the - `set_exploration_type(ExplorationType.MODE)` context manager. + `set_exploration_type(ExplorationType.DETERMINISTIC)` context manager. environment (EnvBase): An environment instance to be used for testing. exploration_type (ExplorationType, optional): exploration mode to use for the policy. By default, no exploration is used and the value used is - ExplorationType.MODE. Set to ExplorationType.RANDOM to enable exploration + ``ExplorationType.DETERMINISTIC``. Set to ``ExplorationType.RANDOM`` to enable exploration log_keys (sequence of str or tuples or str, optional): keys to read in the tensordict for logging. Defaults to ``[("next", "reward")]``. out_keys (Dict[str, str], optional): a dictionary mapping the ``log_keys``