diff --git a/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh b/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh index 11921b44821..9622984a421 100755 --- a/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh +++ b/.github/unittest/linux_libs/scripts_gym/batch_scripts.sh @@ -6,6 +6,7 @@ DIR="$(cd "$(dirname "$0")" && pwd)" set -e +set -v eval "$(./conda/bin/conda shell.bash hook)" conda activate ./env diff --git a/.github/unittest/linux_libs/scripts_gym/install.sh b/.github/unittest/linux_libs/scripts_gym/install.sh index d3eac779861..a66fe5fddd1 100755 --- a/.github/unittest/linux_libs/scripts_gym/install.sh +++ b/.github/unittest/linux_libs/scripts_gym/install.sh @@ -7,6 +7,7 @@ unset PYTORCH_VERSION apt-get update && apt-get install -y git wget gcc g++ set -e +set -v eval "$(./conda/bin/conda shell.bash hook)" conda activate ./env @@ -39,7 +40,7 @@ printf "Installing PyTorch with %s\n" "${CU_VERSION}" if [ "${CU_VERSION:-}" == cpu ] ; then conda install pytorch==2.0 torchvision==0.15 cpuonly -c pytorch -y else - conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia -y + conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 numpy==1.26 numpy-base==1.26 -c pytorch -c nvidia -y fi # Solving circular import: https://stackoverflow.com/questions/75501048/how-to-fix-attributeerror-partially-initialized-module-charset-normalizer-has diff --git a/.github/unittest/linux_libs/scripts_habitat/setup_env.sh b/.github/unittest/linux_libs/scripts_habitat/setup_env.sh index fc182a669ea..6ad970c3f47 100755 --- a/.github/unittest/linux_libs/scripts_habitat/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_habitat/setup_env.sh @@ -67,8 +67,9 @@ pip install pip --upgrade conda env update --file "${this_dir}/environment.yml" --prune -#conda install habitat-sim withbullet headless -c conda-forge -c aihabitat -y conda install habitat-sim withbullet headless -c conda-forge -c aihabitat -y -conda run python -m pip install git+https://github.com/facebookresearch/habitat-lab.git@stable#subdirectory=habitat-lab -#conda run python -m pip install git+https://github.com/facebookresearch/habitat-lab.git#subdirectory=habitat-baselines +git clone https://github.com/facebookresearch/habitat-lab.git +cd habitat-lab +pip3 install -e habitat-lab +pip3 install -e habitat-baselines # install habitat_baselines conda run python -m pip install "gym[atari,accept-rom-license]" pygame diff --git a/.github/unittest/linux_libs/scripts_minari/environment.yml b/.github/unittest/linux_libs/scripts_minari/environment.yml index 27963a42a24..23aedb4cc23 100644 --- a/.github/unittest/linux_libs/scripts_minari/environment.yml +++ b/.github/unittest/linux_libs/scripts_minari/environment.yml @@ -17,4 +17,4 @@ dependencies: - pyyaml - scipy - hydra-core - - minari + - minari[gcs] diff --git a/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh b/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh index 58a33cd43f4..7b7c857c37a 100755 --- a/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh +++ b/.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh @@ -39,7 +39,7 @@ printf "Installing PyTorch with %s\n" "${CU_VERSION}" if [ "${CU_VERSION:-}" == cpu ] ; then conda install pytorch==2.0 torchvision==0.15 cpuonly -c pytorch -y else - conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia -y + conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 numpy==1.26 numpy-base==1.26 -c pytorch -c nvidia -y fi # Solving circular import: https://stackoverflow.com/questions/75501048/how-to-fix-attributeerror-partially-initialized-module-charset-normalizer-has diff --git a/.github/workflows/build-wheels-aarch64-linux.yml b/.github/workflows/build-wheels-aarch64-linux.yml new file mode 100644 index 00000000000..63818f07365 --- /dev/null +++ b/.github/workflows/build-wheels-aarch64-linux.yml @@ -0,0 +1,51 @@ +name: Build Aarch64 Linux Wheels + +on: + pull_request: + push: + branches: + - nightly + - main + - release/* + tags: + # NOTE: Binary build pipelines should only get triggered on release candidate builds + # Release candidate tags look like: v1.11.0-rc1 + - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ + workflow_dispatch: + +permissions: + id-token: write + contents: read + +jobs: + generate-matrix: + uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main + with: + package-type: wheel + os: linux-aarch64 + test-infra-repository: pytorch/test-infra + test-infra-ref: main + with-cuda: disable + build: + needs: generate-matrix + strategy: + fail-fast: false + matrix: + include: + - repository: pytorch/rl + smoke-test-script: test/smoke_test.py + package-name: torchrl + name: pytorch/rl + uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main + with: + repository: ${{ matrix.repository }} + ref: "" + test-infra-repository: pytorch/test-infra + test-infra-ref: main + build-matrix: ${{ needs.generate-matrix.outputs.matrix }} + package-name: ${{ matrix.package-name }} + smoke-test-script: ${{ matrix.smoke-test-script }} + trigger-event: ${{ github.event_name }} + env-var-script: .github/scripts/td_script.sh + architecture: aarch64 + setup-miniconda: false diff --git a/.github/workflows/build-wheels-windows.yml b/.github/workflows/build-wheels-windows.yml index 9f2666ccdbf..1d47449568c 100644 --- a/.github/workflows/build-wheels-windows.yml +++ b/.github/workflows/build-wheels-windows.yml @@ -32,6 +32,8 @@ jobs: matrix: include: - repository: pytorch/rl + pre-script: .github/scripts/td_script.sh + env-script: .github/scripts/version_script.bat post-script: "python packaging/wheel/relocate.py" smoke-test-script: test/smoke_test.py package-name: torchrl @@ -43,8 +45,9 @@ jobs: test-infra-repository: pytorch/test-infra test-infra-ref: main build-matrix: ${{ needs.generate-matrix.outputs.matrix }} + pre-script: ${{ matrix.pre-script }} + env-script: ${{ matrix.env-script }} + post-script: ${{ matrix.post-script }} package-name: ${{ matrix.package-name }} smoke-test-script: ${{ matrix.smoke-test-script }} trigger-event: ${{ github.event_name }} - pre-script: .github/scripts/td_script.sh - env-script: .github/scripts/version_script.bat diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index c0edc906e78..f5fa29ab7ca 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -50,7 +50,7 @@ jobs: conda activate "${env_dir}" # 2. upgrade pip, ninja and packaging - apt-get install python3.9 python3-pip -y + # apt-get install python3.9 python3-pip -y python3 -m pip install --upgrade pip python3 -m pip install setuptools ninja packaging -U diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index 62cf1dedf35..2d6a6344970 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -57,16 +57,21 @@ projected (in a L1-manner) into the desired domain. SafeSequential TanhModule -Exploration wrappers -~~~~~~~~~~~~~~~~~~~~ +Exploration wrappers and modules +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -To efficiently explore the environment, TorchRL proposes a series of wrappers +To efficiently explore the environment, TorchRL proposes a series of modules that will override the action sampled by the policy by a noisier version. Their behavior is controlled by :func:`~torchrl.envs.utils.exploration_mode`: if the exploration is set to ``"random"``, the exploration is active. In all other cases, the action written in the tensordict is simply the network output. -.. currentmodule:: torchrl.modules.tensordict_module +.. note:: Unlike other exploration modules, :class:`~torchrl.modules.ConsistentDropoutModule` + uses the ``train``/``eval`` mode to comply with the regular `Dropout` API in PyTorch. + The :func:`~torchrl.envs.utils.set_exploration_mode` context manager will have no effect on + this module. + +.. currentmodule:: torchrl.modules .. autosummary:: :toctree: generated/ @@ -74,6 +79,7 @@ other cases, the action written in the tensordict is simply the network output. AdditiveGaussianModule AdditiveGaussianWrapper + ConsistentDropoutModule EGreedyModule EGreedyWrapper OrnsteinUhlenbeckProcessModule @@ -438,12 +444,13 @@ Regular modules :toctree: generated/ :template: rl_template_noinherit.rst - MLP - ConvNet + BatchRenorm1d + ConsistentDropout Conv3dNet - SqueezeLayer + ConvNet + MLP Squeeze2dLayer - BatchRenorm1d + SqueezeLayer Algorithm-specific modules ~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/test/test_cost.py b/test/test_cost.py index 2af5a88f9fa..ab95c55ef83 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -13,8 +13,8 @@ from packaging import version as pack_version from tensordict._C import unravel_keys - from tensordict.nn import ( + CompositeDistribution, InteractionType, ProbabilisticTensorDictModule, ProbabilisticTensorDictModule as ProbMod, @@ -25,7 +25,6 @@ TensorDictSequential as Seq, ) from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type - from torchrl.modules.models import QMixer _has_functorch = True @@ -7544,21 +7543,45 @@ def _create_mock_actor( obs_dim=3, action_dim=4, device="cpu", + action_key="action", observation_key="observation", sample_log_prob_key="sample_log_prob", + composite_action_dist=False, ): # Actor action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) + if composite_action_dist: + action_spec = Composite({action_key: {"action1": action_spec}}) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) + if composite_action_dist: + distribution_class = functools.partial( + CompositeDistribution, + distribution_map={ + "action1": TanhNormal, + }, + name_map={ + "action1": (action_key, "action1"), + }, + log_prob_key=sample_log_prob_key, + ) + module_out_keys = [ + ("params", "action1", "loc"), + ("params", "action1", "scale"), + ] + actor_in_keys = ["params"] + else: + distribution_class = TanhNormal + module_out_keys = actor_in_keys = ["loc", "scale"] module = TensorDictModule( - net, in_keys=[observation_key], out_keys=["loc", "scale"] + net, in_keys=[observation_key], out_keys=module_out_keys ) actor = ProbabilisticActor( module=module, - distribution_class=TanhNormal, - in_keys=["loc", "scale"], + distribution_class=distribution_class, + in_keys=actor_in_keys, + out_keys=[action_key], spec=action_spec, return_log_prob=True, log_prob_key=sample_log_prob_key, @@ -7582,22 +7605,51 @@ def _create_mock_value( ) return value.to(device) - def _create_mock_actor_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): + def _create_mock_actor_value( + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + composite_action_dist=False, + sample_log_prob_key="sample_log_prob", + ): # Actor action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) + if composite_action_dist: + action_spec = Composite({"action": {"action1": action_spec}}) base_layer = nn.Linear(obs_dim, 5) net = nn.Sequential( base_layer, nn.Linear(5, 2 * action_dim), NormalParamExtractor() ) + if composite_action_dist: + distribution_class = functools.partial( + CompositeDistribution, + distribution_map={ + "action1": TanhNormal, + }, + name_map={ + "action1": ("action", "action1"), + }, + log_prob_key=sample_log_prob_key, + ) + module_out_keys = [ + ("params", "action1", "loc"), + ("params", "action1", "scale"), + ] + actor_in_keys = ["params"] + else: + distribution_class = TanhNormal + module_out_keys = actor_in_keys = ["loc", "scale"] module = TensorDictModule( - net, in_keys=["observation"], out_keys=["loc", "scale"] + net, in_keys=["observation"], out_keys=module_out_keys ) actor = ProbabilisticActor( module=module, - distribution_class=TanhNormal, - in_keys=["loc", "scale"], + distribution_class=distribution_class, + in_keys=actor_in_keys, spec=action_spec, return_log_prob=True, ) @@ -7609,22 +7661,49 @@ def _create_mock_actor_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu return actor.to(device), value.to(device) def _create_mock_actor_value_shared( - self, batch=2, obs_dim=3, action_dim=4, device="cpu" + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + composite_action_dist=False, + sample_log_prob_key="sample_log_prob", ): # Actor action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) + if composite_action_dist: + action_spec = Composite({"action": {"action1": action_spec}}) base_layer = nn.Linear(obs_dim, 5) common = TensorDictModule( base_layer, in_keys=["observation"], out_keys=["hidden"] ) net = nn.Sequential(nn.Linear(5, 2 * action_dim), NormalParamExtractor()) - module = TensorDictModule(net, in_keys=["hidden"], out_keys=["loc", "scale"]) + if composite_action_dist: + distribution_class = functools.partial( + CompositeDistribution, + distribution_map={ + "action1": TanhNormal, + }, + name_map={ + "action1": ("action", "action1"), + }, + log_prob_key=sample_log_prob_key, + ) + module_out_keys = [ + ("params", "action1", "loc"), + ("params", "action1", "scale"), + ] + actor_in_keys = ["params"] + else: + distribution_class = TanhNormal + module_out_keys = actor_in_keys = ["loc", "scale"] + module = TensorDictModule(net, in_keys=["hidden"], out_keys=module_out_keys) actor_head = ProbabilisticActor( module=module, - distribution_class=TanhNormal, - in_keys=["loc", "scale"], + distribution_class=distribution_class, + in_keys=actor_in_keys, spec=action_spec, return_log_prob=True, ) @@ -7654,6 +7733,7 @@ def _create_mock_data_ppo( done_key="done", terminated_key="terminated", sample_log_prob_key="sample_log_prob", + composite_action_dist=False, ): # create a tensordict obs = torch.randn(batch, obs_dim, device=device) @@ -7679,13 +7759,17 @@ def _create_mock_data_ppo( terminated_key: terminated, reward_key: reward, }, - action_key: action, + action_key: {"action1": action} if composite_action_dist else action, sample_log_prob_key: torch.randn_like(action[..., 1]) / 10, - loc_key: loc, - scale_key: scale, }, device=device, ) + if composite_action_dist: + td[("params", "action1", loc_key)] = loc + td[("params", "action1", scale_key)] = scale + else: + td[loc_key] = loc + td[scale_key] = scale return td def _create_seq_mock_data_ppo( @@ -7698,6 +7782,7 @@ def _create_seq_mock_data_ppo( device="cpu", sample_log_prob_key="sample_log_prob", action_key="action", + composite_action_dist=False, ): # create a tensordict total_obs = torch.randn(batch, T + 1, obs_dim, device=device) @@ -7713,8 +7798,11 @@ def _create_seq_mock_data_ppo( done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = torch.ones(batch, T, dtype=torch.bool, device=device) + action = action.masked_fill_(~mask.unsqueeze(-1), 0.0) params_mean = torch.randn_like(action) / 10 params_scale = torch.rand_like(action) / 10 + loc = params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0) + scale = params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0) td = TensorDict( batch_size=(batch, T), source={ @@ -7726,16 +7814,21 @@ def _create_seq_mock_data_ppo( "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, - action_key: action.masked_fill_(~mask.unsqueeze(-1), 0.0), + action_key: {"action1": action} if composite_action_dist else action, sample_log_prob_key: ( torch.randn_like(action[..., 1]) / 10 ).masked_fill_(~mask, 0.0), - "loc": params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0), - "scale": params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0), }, device=device, names=[None, "time"], ) + if composite_action_dist: + td[("params", "action1", "loc")] = loc + td[("params", "action1", "scale")] = scale + else: + td["loc"] = loc + td["scale"] = scale + return td @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @@ -7744,6 +7837,7 @@ def _create_seq_mock_data_ppo( @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) @pytest.mark.parametrize("functional", [True, False]) + @pytest.mark.parametrize("composite_action_dist", [True, False]) def test_ppo( self, loss_class, @@ -7752,11 +7846,16 @@ def test_ppo( advantage, td_est, functional, + composite_action_dist, ): torch.manual_seed(self.seed) - td = self._create_seq_mock_data_ppo(device=device) + td = self._create_seq_mock_data_ppo( + device=device, composite_action_dist=composite_action_dist + ) - actor = self._create_mock_actor(device=device) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(device=device) if advantage == "gae": advantage = GAE( @@ -7796,7 +7895,10 @@ def test_ppo( loss = loss_fn(td) if isinstance(loss_fn, KLPENPPOLoss): - kl = loss.pop("kl") + if composite_action_dist: + kl = loss.pop("kl_approx") + else: + kl = loss.pop("kl") assert (kl != 0).any() loss_critic = loss["loss_critic"] @@ -7833,10 +7935,15 @@ def test_ppo( @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("gradient_mode", (True,)) @pytest.mark.parametrize("device", get_default_devices()) - def test_ppo_state_dict(self, loss_class, device, gradient_mode): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_ppo_state_dict( + self, loss_class, device, gradient_mode, composite_action_dist + ): torch.manual_seed(self.seed) - actor = self._create_mock_actor(device=device) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(device=device) loss_fn = loss_class(actor, value, loss_critic_type="l2") sd = loss_fn.state_dict() @@ -7846,11 +7953,16 @@ def test_ppo_state_dict(self, loss_class, device, gradient_mode): @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) - def test_ppo_shared(self, loss_class, device, advantage): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_ppo_shared(self, loss_class, device, advantage, composite_action_dist): torch.manual_seed(self.seed) - td = self._create_seq_mock_data_ppo(device=device) + td = self._create_seq_mock_data_ppo( + device=device, composite_action_dist=composite_action_dist + ) - actor, value = self._create_mock_actor_value(device=device) + actor, value = self._create_mock_actor_value( + device=device, composite_action_dist=composite_action_dist + ) if advantage == "gae": advantage = GAE( gamma=0.9, @@ -7932,18 +8044,24 @@ def test_ppo_shared(self, loss_class, device, advantage): ) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("separate_losses", [True, False]) + @pytest.mark.parametrize("composite_action_dist", [True, False]) def test_ppo_shared_seq( self, loss_class, device, advantage, separate_losses, + composite_action_dist, ): """Tests PPO with shared module with and without passing twice across the common module.""" torch.manual_seed(self.seed) - td = self._create_seq_mock_data_ppo(device=device) + td = self._create_seq_mock_data_ppo( + device=device, composite_action_dist=composite_action_dist + ) - model, actor, value = self._create_mock_actor_value_shared(device=device) + model, actor, value = self._create_mock_actor_value_shared( + device=device, composite_action_dist=composite_action_dist + ) value2 = value[-1] # prune the common module if advantage == "gae": advantage = GAE( @@ -8001,8 +8119,20 @@ def test_ppo_shared_seq( grad2 = TensorDict(dict(model.named_parameters()), []).apply( lambda x: x.grad.clone() ) - assert_allclose_td(loss, loss2) - assert_allclose_td(grad, grad2) + if composite_action_dist and loss_class is KLPENPPOLoss: + # KL computation for composite dist is based on randomly + # sampled data, thus will not be the same. + # Similarly, objective loss depends on the KL, so ir will + # not be the same either. + # Finally, gradients will be different too. + loss.pop("kl", None) + loss2.pop("kl", None) + loss.pop("loss_objective", None) + loss2.pop("loss_objective", None) + assert_allclose_td(loss, loss2) + else: + assert_allclose_td(loss, loss2) + assert_allclose_td(grad, grad2) model.zero_grad() @pytest.mark.skipif( @@ -8012,11 +8142,18 @@ def test_ppo_shared_seq( @pytest.mark.parametrize("gradient_mode", (True, False)) @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) - def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_ppo_diff( + self, loss_class, device, gradient_mode, advantage, composite_action_dist + ): torch.manual_seed(self.seed) - td = self._create_seq_mock_data_ppo(device=device) + td = self._create_seq_mock_data_ppo( + device=device, composite_action_dist=composite_action_dist + ) - actor = self._create_mock_actor(device=device) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(device=device) if advantage == "gae": advantage = GAE( @@ -8105,8 +8242,9 @@ def zero_param(p): ValueEstimators.TDLambda, ], ) - def test_ppo_tensordict_keys(self, loss_class, td_est): - actor = self._create_mock_actor() + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_ppo_tensordict_keys(self, loss_class, td_est, composite_action_dist): + actor = self._create_mock_actor(composite_action_dist=composite_action_dist) value = self._create_mock_value() loss_fn = loss_class(actor, value, loss_critic_type="l2") @@ -8145,7 +8283,10 @@ def test_ppo_tensordict_keys(self, loss_class, td_est): @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) - def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_ppo_tensordict_keys_run( + self, loss_class, advantage, td_est, composite_action_dist + ): """Test PPO loss module with non-default tensordict keys.""" torch.manual_seed(self.seed) gradient_mode = True @@ -8160,9 +8301,12 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): td = self._create_seq_mock_data_ppo( sample_log_prob_key=tensor_keys["sample_log_prob"], action_key=tensor_keys["action"], + composite_action_dist=composite_action_dist, ) actor = self._create_mock_actor( - sample_log_prob_key=tensor_keys["sample_log_prob"] + sample_log_prob_key=tensor_keys["sample_log_prob"], + composite_action_dist=composite_action_dist, + action_key=tensor_keys["action"], ) value = self._create_mock_value(out_keys=[tensor_keys["value"]]) @@ -8253,6 +8397,12 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + @pytest.mark.parametrize( + "composite_action_dist", + [ + False, + ], + ) def test_ppo_notensordict( self, loss_class, @@ -8262,6 +8412,7 @@ def test_ppo_notensordict( reward_key, done_key, terminated_key, + composite_action_dist, ): torch.manual_seed(self.seed) td = self._create_mock_data_ppo( @@ -8271,10 +8422,14 @@ def test_ppo_notensordict( reward_key=reward_key, done_key=done_key, terminated_key=terminated_key, + composite_action_dist=composite_action_dist, ) actor = self._create_mock_actor( - observation_key=observation_key, sample_log_prob_key=sample_log_prob_key + observation_key=observation_key, + sample_log_prob_key=sample_log_prob_key, + composite_action_dist=composite_action_dist, + action_key=action_key, ) value = self._create_mock_value(observation_key=observation_key) @@ -8297,7 +8452,9 @@ def test_ppo_notensordict( f"next_{observation_key}": td.get(("next", observation_key)), } if loss_class is KLPENPPOLoss: - kwargs.update({"loc": td.get("loc"), "scale": td.get("scale")}) + loc_key = "params" if composite_action_dist else "loc" + scale_key = "params" if composite_action_dist else "scale" + kwargs.update({loc_key: td.get(loc_key), scale_key: td.get(scale_key)}) td = TensorDict(kwargs, td.batch_size, names=["time"]).unflatten_keys("_") @@ -8310,6 +8467,7 @@ def test_ppo_notensordict( loss_val = loss(**kwargs) torch.manual_seed(self.seed) if beta is not None: + loss.beta = beta.clone() loss_val_td = loss(td) @@ -8337,15 +8495,20 @@ def test_ppo_notensordict( @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) - def test_ppo_reduction(self, reduction, loss_class): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_ppo_reduction(self, reduction, loss_class, composite_action_dist): torch.manual_seed(self.seed) device = ( torch.device("cpu") if torch.cuda.device_count() == 0 else torch.device("cuda") ) - td = self._create_seq_mock_data_ppo(device=device) - actor = self._create_mock_actor(device=device) + td = self._create_seq_mock_data_ppo( + device=device, composite_action_dist=composite_action_dist + ) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(device=device) advantage = GAE( gamma=0.9, @@ -8373,10 +8536,17 @@ def test_ppo_reduction(self, reduction, loss_class): @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("clip_value", [True, False, None, 0.5, torch.tensor(0.5)]) - def test_ppo_value_clipping(self, clip_value, loss_class, device): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_ppo_value_clipping( + self, clip_value, loss_class, device, composite_action_dist + ): torch.manual_seed(self.seed) - td = self._create_seq_mock_data_ppo(device=device) - actor = self._create_mock_actor(device=device) + td = self._create_seq_mock_data_ppo( + device=device, composite_action_dist=composite_action_dist + ) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(device=device) advantage = GAE( gamma=0.9, @@ -8435,22 +8605,46 @@ def _create_mock_actor( obs_dim=3, action_dim=4, device="cpu", + action_key="action", observation_key="observation", sample_log_prob_key="sample_log_prob", + composite_action_dist=False, ): # Actor action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) + if composite_action_dist: + action_spec = Composite({action_key: {"action1": action_spec}}) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) + if composite_action_dist: + distribution_class = functools.partial( + CompositeDistribution, + distribution_map={ + "action1": TanhNormal, + }, + name_map={ + "action1": (action_key, "action1"), + }, + log_prob_key=sample_log_prob_key, + ) + module_out_keys = [ + ("params", "action1", "loc"), + ("params", "action1", "scale"), + ] + actor_in_keys = ["params"] + else: + distribution_class = TanhNormal + module_out_keys = actor_in_keys = ["loc", "scale"] module = TensorDictModule( - net, in_keys=[observation_key], out_keys=["loc", "scale"] + net, in_keys=[observation_key], out_keys=module_out_keys ) actor = ProbabilisticActor( module=module, - in_keys=["loc", "scale"], + in_keys=actor_in_keys, + out_keys=[action_key], spec=action_spec, - distribution_class=TanhNormal, + distribution_class=distribution_class, return_log_prob=True, log_prob_key=sample_log_prob_key, ) @@ -8474,7 +8668,15 @@ def _create_mock_value( return value.to(device) def _create_mock_common_layer_setup( - self, n_obs=3, n_act=4, ncells=4, batch=2, n_hidden=2, T=10 + self, + n_obs=3, + n_act=4, + ncells=4, + batch=2, + n_hidden=2, + T=10, + composite_action_dist=False, + sample_log_prob_key="sample_log_prob", ): common_net = MLP( num_cells=ncells, @@ -8495,10 +8697,11 @@ def _create_mock_common_layer_setup( out_features=1, ) batch = [batch, T] + action = torch.randn(*batch, n_act) td = TensorDict( { "obs": torch.randn(*batch, n_obs), - "action": torch.randn(*batch, n_act), + "action": {"action1": action} if composite_action_dist else action, "sample_log_prob": torch.randn(*batch), "done": torch.zeros(*batch, 1, dtype=torch.bool), "terminated": torch.zeros(*batch, 1, dtype=torch.bool), @@ -8513,14 +8716,35 @@ def _create_mock_common_layer_setup( names=[None, "time"], ) common = Mod(common_net, in_keys=["obs"], out_keys=["hidden"]) + + if composite_action_dist: + distribution_class = functools.partial( + CompositeDistribution, + distribution_map={ + "action1": TanhNormal, + }, + name_map={ + "action1": ("action", "action1"), + }, + log_prob_key=sample_log_prob_key, + ) + module_out_keys = [ + ("params", "action1", "loc"), + ("params", "action1", "scale"), + ] + actor_in_keys = ["params"] + else: + distribution_class = TanhNormal + module_out_keys = actor_in_keys = ["loc", "scale"] + actor = ProbSeq( common, Mod(actor_net, in_keys=["hidden"], out_keys=["param"]), - Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]), + Mod(NormalParamExtractor(), in_keys=["param"], out_keys=module_out_keys), ProbMod( - in_keys=["loc", "scale"], + in_keys=actor_in_keys, out_keys=["action"], - distribution_class=TanhNormal, + distribution_class=distribution_class, ), ) critic = Seq( @@ -8544,6 +8768,7 @@ def _create_seq_mock_data_a2c( done_key="done", terminated_key="terminated", sample_log_prob_key="sample_log_prob", + composite_action_dist=False, ): # create a tensordict total_obs = torch.randn(batch, T + 1, obs_dim, device=device) @@ -8559,8 +8784,11 @@ def _create_seq_mock_data_a2c( done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) + action = action.masked_fill_(~mask.unsqueeze(-1), 0.0) params_mean = torch.randn_like(action) / 10 params_scale = torch.rand_like(action) / 10 + loc = params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0) + scale = params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0) td = TensorDict( batch_size=(batch, T), source={ @@ -8572,17 +8800,21 @@ def _create_seq_mock_data_a2c( reward_key: reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, - action_key: action.masked_fill_(~mask.unsqueeze(-1), 0.0), + action_key: {"action1": action} if composite_action_dist else action, sample_log_prob_key: torch.randn_like(action[..., 1]).masked_fill_( ~mask, 0.0 ) / 10, - "loc": params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0), - "scale": params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0), }, device=device, names=[None, "time"], ) + if composite_action_dist: + td[("params", "action1", "loc")] = loc + td[("params", "action1", "scale")] = scale + else: + td["loc"] = loc + td["scale"] = scale return td @pytest.mark.parametrize("gradient_mode", (True, False)) @@ -8590,11 +8822,24 @@ def _create_seq_mock_data_a2c( @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) @pytest.mark.parametrize("functional", (True, False)) - def test_a2c(self, device, gradient_mode, advantage, td_est, functional): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_a2c( + self, + device, + gradient_mode, + advantage, + td_est, + functional, + composite_action_dist, + ): torch.manual_seed(self.seed) - td = self._create_seq_mock_data_a2c(device=device) + td = self._create_seq_mock_data_a2c( + device=device, composite_action_dist=composite_action_dist + ) - actor = self._create_mock_actor(device=device) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(device=device) if advantage == "gae": advantage = GAE( @@ -8627,14 +8872,24 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional): functional=functional, ) + def set_requires_grad(tensor, requires_grad): + tensor.requires_grad = requires_grad + return tensor + # Check error is raised when actions require grads - td["action"].requires_grad = True + if composite_action_dist: + td["action"].apply_(lambda x: set_requires_grad(x, True)) + else: + td["action"].requires_grad = True with pytest.raises( RuntimeError, - match="tensordict stored action require grad.", + match="tensordict stored action requires grad.", ): _ = loss_fn._log_probs(td) - td["action"].requires_grad = False + if composite_action_dist: + td["action"].apply_(lambda x: set_requires_grad(x, False)) + else: + td["action"].requires_grad = False td = td.exclude(loss_fn.tensor_keys.value_target) if advantage is not None: @@ -8675,9 +8930,12 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional): @pytest.mark.parametrize("gradient_mode", (True, False)) @pytest.mark.parametrize("device", get_default_devices()) - def test_a2c_state_dict(self, device, gradient_mode): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_a2c_state_dict(self, device, gradient_mode, composite_action_dist): torch.manual_seed(self.seed) - actor = self._create_mock_actor(device=device) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(device=device) loss_fn = A2CLoss(actor, value, loss_critic_type="l2") sd = loss_fn.state_dict() @@ -8685,23 +8943,36 @@ def test_a2c_state_dict(self, device, gradient_mode): loss_fn2.load_state_dict(sd) @pytest.mark.parametrize("separate_losses", [False, True]) - def test_a2c_separate_losses(self, separate_losses): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_a2c_separate_losses(self, separate_losses, composite_action_dist): torch.manual_seed(self.seed) - actor, critic, common, td = self._create_mock_common_layer_setup() + actor, critic, common, td = self._create_mock_common_layer_setup( + composite_action_dist=composite_action_dist + ) loss_fn = A2CLoss( actor_network=actor, critic_network=critic, separate_losses=separate_losses, ) + def set_requires_grad(tensor, requires_grad): + tensor.requires_grad = requires_grad + return tensor + # Check error is raised when actions require grads - td["action"].requires_grad = True + if composite_action_dist: + td["action"].apply_(lambda x: set_requires_grad(x, True)) + else: + td["action"].requires_grad = True with pytest.raises( RuntimeError, - match="tensordict stored action require grad.", + match="tensordict stored action requires grad.", ): _ = loss_fn._log_probs(td) - td["action"].requires_grad = False + if composite_action_dist: + td["action"].apply_(lambda x: set_requires_grad(x, False)) + else: + td["action"].requires_grad = False td = td.exclude(loss_fn.tensor_keys.value_target) loss = loss_fn(td) @@ -8745,13 +9016,18 @@ def test_a2c_separate_losses(self, separate_losses): @pytest.mark.parametrize("gradient_mode", (True, False)) @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) - def test_a2c_diff(self, device, gradient_mode, advantage): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_a2c_diff(self, device, gradient_mode, advantage, composite_action_dist): if pack_version.parse(torch.__version__) > pack_version.parse("1.14"): raise pytest.skip("make_functional_with_buffers needs to be changed") torch.manual_seed(self.seed) - td = self._create_seq_mock_data_a2c(device=device) + td = self._create_seq_mock_data_a2c( + device=device, composite_action_dist=composite_action_dist + ) - actor = self._create_mock_actor(device=device) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(device=device) if advantage == "gae": advantage = GAE( @@ -8821,8 +9097,9 @@ def test_a2c_diff(self, device, gradient_mode, advantage): ValueEstimators.TDLambda, ], ) - def test_a2c_tensordict_keys(self, td_est): - actor = self._create_mock_actor() + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_a2c_tensordict_keys(self, td_est, composite_action_dist): + actor = self._create_mock_actor(composite_action_dist=composite_action_dist) value = self._create_mock_value() loss_fn = A2CLoss(actor, value, loss_critic_type="l2") @@ -8867,7 +9144,10 @@ def test_a2c_tensordict_keys(self, td_est): ) @pytest.mark.parametrize("advantage", ("gae", "vtrace", None)) @pytest.mark.parametrize("device", get_default_devices()) - def test_a2c_tensordict_keys_run(self, device, advantage, td_est): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_a2c_tensordict_keys_run( + self, device, advantage, td_est, composite_action_dist + ): """Test A2C loss module with non-default tensordict keys.""" torch.manual_seed(self.seed) gradient_mode = True @@ -8887,10 +9167,14 @@ def test_a2c_tensordict_keys_run(self, device, advantage, td_est): done_key=done_key, terminated_key=terminated_key, sample_log_prob_key=sample_log_prob_key, + composite_action_dist=composite_action_dist, ) actor = self._create_mock_actor( - device=device, sample_log_prob_key=sample_log_prob_key + device=device, + sample_log_prob_key=sample_log_prob_key, + composite_action_dist=composite_action_dist, + action_key=action_key, ) value = self._create_mock_value(device=device, out_keys=[value_key]) if advantage == "gae": @@ -8969,12 +9253,26 @@ def test_a2c_tensordict_keys_run(self, device, advantage, td_est): @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + @pytest.mark.parametrize( + "composite_action_dist", + [ + False, + ], + ) def test_a2c_notensordict( - self, action_key, observation_key, reward_key, done_key, terminated_key + self, + action_key, + observation_key, + reward_key, + done_key, + terminated_key, + composite_action_dist, ): torch.manual_seed(self.seed) - actor = self._create_mock_actor(observation_key=observation_key) + actor = self._create_mock_actor( + observation_key=observation_key, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(observation_key=observation_key) td = self._create_seq_mock_data_a2c( action_key=action_key, @@ -8982,6 +9280,7 @@ def test_a2c_notensordict( reward_key=reward_key, done_key=done_key, terminated_key=terminated_key, + composite_action_dist=composite_action_dist, ) loss = A2CLoss(actor, value) @@ -9026,15 +9325,20 @@ def test_a2c_notensordict( assert loss_critic == loss_val_td["loss_critic"] @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) - def test_a2c_reduction(self, reduction): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_a2c_reduction(self, reduction, composite_action_dist): torch.manual_seed(self.seed) device = ( torch.device("cpu") if torch.cuda.device_count() == 0 else torch.device("cuda") ) - td = self._create_seq_mock_data_a2c(device=device) - actor = self._create_mock_actor(device=device) + td = self._create_seq_mock_data_a2c( + device=device, composite_action_dist=composite_action_dist + ) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(device=device) advantage = GAE( gamma=0.9, @@ -9061,10 +9365,15 @@ def test_a2c_reduction(self, reduction): @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("clip_value", [True, None, 0.5, torch.tensor(0.5)]) - def test_a2c_value_clipping(self, clip_value, device): + @pytest.mark.parametrize("composite_action_dist", [True, False]) + def test_a2c_value_clipping(self, clip_value, device, composite_action_dist): torch.manual_seed(self.seed) - td = self._create_seq_mock_data_a2c(device=device) - actor = self._create_mock_actor(device=device) + td = self._create_seq_mock_data_a2c( + device=device, composite_action_dist=composite_action_dist + ) + actor = self._create_mock_actor( + device=device, composite_action_dist=composite_action_dist + ) value = self._create_mock_value(device=device) advantage = GAE( gamma=0.9, diff --git a/test/test_exploration.py b/test/test_exploration.py index 3bb05708d83..f6a3ab7041b 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -31,7 +31,7 @@ NormalParamExtractor, TanhNormal, ) -from torchrl.modules.models.exploration import LazygSDEModule +from torchrl.modules.models.exploration import ConsistentDropoutModule, LazygSDEModule from torchrl.modules.tensordict_module.actors import ( Actor, ProbabilisticActor, @@ -738,6 +738,156 @@ def test_gsde_init(sigma_init, state_dim, action_dim, mean, std, device, learn_s ), f"failed: mean={mean}, std={std}, sigma_init={sigma_init}, actual: {sigma.mean()}" +class TestConsistentDropout: + @pytest.mark.parametrize("dropout_p", [0.0, 0.1, 0.5]) + @pytest.mark.parametrize("parallel_spec", [False, True]) + @pytest.mark.parametrize("device", get_default_devices()) + def test_consistent_dropout(self, dropout_p, parallel_spec, device): + """ + + This preliminary test seeks to ensure two things for both + ConsistentDropout and ConsistentDropoutModule: + 1. Rollout transitions generate a dropout mask as desired. + - We can easily verify the existence of a mask + 2. The dropout mask is correctly applied. + - We will check with stochastic policies whether or not + the loc and scale are the same. + """ + torch.manual_seed(0) + + # NOTE: Please only put a module with one dropout layer. + # That's how this test is constructed anyways. + @torch.no_grad + def inner_verify_routine(module, env): + # Perform transitions. + collector = SyncDataCollector( + create_env_fn=env, + policy=module, + frames_per_batch=1, + total_frames=10, + device=device, + ) + for frames in collector: + masks = [ + (key, value) + for key, value in frames.items() + if key.startswith("mask_") + ] + # Assert rollouts do indeed correctly generate the masks. + assert len(masks) == 1, ( + "Expected exactly ONE mask since we only put " + f"one dropout module, got {len(masks)}." + ) + + # Verify that the result for this batch is the same. + # Kind of Monte Carlo, to be honest. + sentinel_mask = masks[0][1].clone() + sentinel_outputs = frames.select("loc", "scale").clone() + + desired_dropout_mask = torch.full_like( + sentinel_mask, 1 / (1 - dropout_p) + ) + desired_dropout_mask[sentinel_mask == 0.0] = 0.0 + # As of 15/08/24, :meth:`~torch.nn.functional.dropout` + # is being used. Never hurts to be safe. + assert torch.allclose( + sentinel_mask, desired_dropout_mask + ), "Dropout was not scaled properly." + + new_frames = module(frames.clone()) + infer_mask = new_frames[masks[0][0]] + infer_outputs = new_frames.select("loc", "scale") + assert (infer_mask == sentinel_mask).all(), "Mask does not match" + + assert all( + [ + torch.allclose(infer_outputs[key], sentinel_outputs[key]) + for key in ("loc", "scale") + ] + ), ( + "Outputs do not match:\n " + f"{infer_outputs['loc']}\n--- vs ---\n{sentinel_outputs['loc']}" + f"{infer_outputs['scale']}\n--- vs ---\n{sentinel_outputs['scale']}" + ) + + env = SerialEnv( + 2, + ContinuousActionVecMockEnv, + ) + env = TransformedEnv(env.to(device), InitTracker()) + env = env.to(device) + # the module must work with the action spec of a single env or a serial env + if parallel_spec: + action_spec = env.action_spec + else: + action_spec = ContinuousActionVecMockEnv(device=device).action_spec + d_act = action_spec.shape[-1] + + # NOTE: Please only put a module with one dropout layer. + # That's how this test is constructed anyways. + module_td_seq = TensorDictSequential( + TensorDictModule( + nn.LazyLinear(2 * d_act), in_keys=["observation"], out_keys=["out"] + ), + ConsistentDropoutModule(p=dropout_p, in_keys="out"), + TensorDictModule( + NormalParamExtractor(), in_keys=["out"], out_keys=["loc", "scale"] + ), + ) + + policy_td_seq = ProbabilisticActor( + module=module_td_seq, + in_keys=["loc", "scale"], + distribution_class=TanhNormal, + default_interaction_type=InteractionType.RANDOM, + spec=action_spec, + ).to(device) + + # Wake up the policies + policy_td_seq(env.reset()) + + # Test. + inner_verify_routine(policy_td_seq, env) + + def test_consistent_dropout_primer(self): + import torch + + from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq + from torchrl.envs import SerialEnv, StepCounter + from torchrl.modules import ConsistentDropoutModule, get_primers_from_module + + torch.manual_seed(0) + + m = Seq( + Mod( + torch.nn.Linear(7, 4), + in_keys=["observation"], + out_keys=["intermediate"], + ), + ConsistentDropoutModule( + p=0.5, + input_shape=( + 2, + 4, + ), + in_keys="intermediate", + ), + Mod(torch.nn.Linear(4, 7), in_keys=["intermediate"], out_keys=["action"]), + ) + primer = get_primers_from_module(m) + env0 = ContinuousActionVecMockEnv().append_transform(StepCounter(5)) + env1 = ContinuousActionVecMockEnv().append_transform(StepCounter(6)) + env = SerialEnv(2, [lambda env=env0: env, lambda env=env1: env]) + env = env.append_transform(primer) + r = env.rollout(10, m, break_when_any_done=False) + mask = [k for k in r.keys() if k.startswith("mask")][0] + assert (r[mask][0, :5] != r[mask][0, 5:6]).any() + assert (r[mask][0, :4] == r[mask][0, 4:5]).all() + + assert (r[mask][1, :6] != r[mask][1, 6:7]).any() + assert (r[mask][1, :5] == r[mask][1, 5:6]).all() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_libs.py b/test/test_libs.py index cb551473690..6f5cc1bebeb 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -153,6 +153,16 @@ _has_meltingpot = importlib.util.find_spec("meltingpot") is not None +_has_minigrid = importlib.util.find_spec("minigrid") is not None + + +@pytest.fixture(scope="session", autouse=True) +def maybe_init_minigrid(): + if _has_minigrid and _has_gymnasium: + import minigrid + + minigrid.register_minigrid_envs() + def get_gym_pixel_wrapper(): try: @@ -1279,6 +1289,24 @@ def test_resetting_strategies(self, heterogeneous): gc.collect() +@pytest.mark.skipif( + not _has_minigrid or not _has_gymnasium, reason="MiniGrid not found" +) +class TestMiniGrid: + @pytest.mark.parametrize( + "id", + [ + "BabyAI-KeyCorridorS6R3-v0", + "MiniGrid-Empty-16x16-v0", + "MiniGrid-BlockedUnlockPickup-v0", + ], + ) + def test_minigrid(self, id): + env_base = gymnasium.make(id) + env = GymWrapper(env_base) + check_env_specs(env) + + @implement_for("gym", None, "0.26") def _make_gym_environment(env_name): # noqa: F811 gym = gym_backend() diff --git a/test/test_loggers.py b/test/test_loggers.py index 735911bd95c..eb40ca1fdb8 100644 --- a/test/test_loggers.py +++ b/test/test_loggers.py @@ -281,25 +281,27 @@ def test_log_video(self, wandb_logger): # C - number of image channels (e.g. 3 for RGB), H, W - image dimensions. # the first 64 frames are black and the next 64 are white video = torch.cat( - (torch.zeros(64, 1, 32, 32), torch.full((64, 1, 32, 32), 255)) + (torch.zeros(128, 1, 32, 32), torch.full((128, 1, 32, 32), 255)) ) video = video[None, :] wandb_logger.log_video( name="foo", video=video, - fps=6, + fps=4, + format="mp4", ) wandb_logger.log_video( - name="foo_12fps", + name="foo_16fps", video=video, - fps=24, + fps=16, + format="mp4", ) sleep(0.01) # wait until events are registered # check that fps can be passed and that it has impact on the length of the video - video_6fps_size = wandb_logger.experiment.summary["foo"]["size"] - video_24fps_size = wandb_logger.experiment.summary["foo_12fps"]["size"] - assert video_6fps_size > video_24fps_size, video_6fps_size + video_4fps_size = wandb_logger.experiment.summary["foo"]["size"] + video_16fps_size = wandb_logger.experiment.summary["foo_16fps"]["size"] + assert video_4fps_size > video_16fps_size, (video_4fps_size, video_16fps_size) # check that we catch the error in case the format of the tensor is wrong video_wrong_format = torch.zeros(64, 2, 32, 32) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 4a10f4304f2..80fb1c1f768 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -507,7 +507,8 @@ def __init__( # Cuda handles sync if torch.cuda.is_available(): self._sync_storage = torch.cuda.synchronize - elif torch.backends.mps.is_available(): + elif torch.backends.mps.is_available() and hasattr(torch, "mps"): + # Will break for older PT versions which don't have torch.mps self._sync_storage = torch.mps.synchronize elif self.storing_device.type == "cpu": self._sync_storage = _do_nothing @@ -521,7 +522,7 @@ def __init__( # Cuda handles sync if torch.cuda.is_available(): self._sync_env = torch.cuda.synchronize - elif torch.backends.mps.is_available(): + elif torch.backends.mps.is_available() and hasattr(torch, "mps"): self._sync_env = torch.mps.synchronize elif self.env_device.type == "cpu": self._sync_env = _do_nothing @@ -534,7 +535,7 @@ def __init__( # Cuda handles sync if torch.cuda.is_available(): self._sync_policy = torch.cuda.synchronize - elif torch.backends.mps.is_available(): + elif torch.backends.mps.is_available() and hasattr(torch, "mps"): self._sync_policy = torch.mps.synchronize elif self.policy_device.type == "cpu": self._sync_policy = _do_nothing diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 8338fdff74b..869ea5cdae3 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -22,7 +22,7 @@ from torchrl._extension import EXTENSION_WARNING -from torchrl._utils import _replace_last, implement_for, logger +from torchrl._utils import _replace_last, logger from torchrl.data.replay_buffers.storages import Storage, StorageEnsemble, TensorStorage from torchrl.data.replay_buffers.utils import _is_int, unravel_index @@ -1076,9 +1076,28 @@ def _tensor_slices_from_startend(self, seq_length, start, storage_length): # seq_length is a 1d tensor indicating the desired length of each sequence if isinstance(seq_length, int): - result = torch.cat( - [self._start_to_end(_start, length=seq_length) for _start in start] + arange = torch.arange(seq_length, device=start.device, dtype=start.dtype) + ndims = start.shape[-1] - 1 if (start.ndim - 1) else 0 + if ndims: + arange_reshaped = torch.empty( + arange.shape + torch.Size([ndims + 1]), + device=start.device, + dtype=start.dtype, + ) + arange_reshaped[..., 0] = arange + arange_reshaped[..., 1:] = 0 + else: + arange_reshaped = arange.unsqueeze(-1) + arange_expanded = arange_reshaped.expand( + torch.Size([start.shape[0]]) + arange_reshaped.shape ) + if start.shape != arange_expanded.shape: + n_missing_dims = arange_expanded.dim() - start.dim() + start_expanded = start[ + (slice(None),) + (None,) * n_missing_dims + ].expand_as(arange_expanded) + result = (start_expanded + arange_expanded).flatten(0, 1) + else: # when padding is needed result = torch.cat( @@ -1823,28 +1842,31 @@ def mark_update( ) -> None: return PrioritizedSampler.mark_update(self, index, storage=storage) - @implement_for("torch", "2.4") def _padded_indices(self, shapes, arange) -> torch.Tensor: # this complex mumbo jumbo creates a left padded tensor with valid indices on the right, e.g. # tensor([[ 0, 1, 2, 3, 4], # [-1, -1, 5, 6, 7], # [-1, 8, 9, 10, 11]]) # where the -1 items on the left are padded values - st, off = torch._nested_compute_contiguous_strides_offsets(shapes.flip(0)) - nt = torch._nested_view_from_buffer( - arange.flip(0).contiguous(), shapes.flip(0), st, off + num_groups = shapes.shape[0] + max_group_len = shapes.max() + pad_lengths = max_group_len - shapes + + # Get all the start and end indices within arange for each group + group_ends = shapes.cumsum(0) + group_starts = torch.empty_like(group_ends) + group_starts[0] = 0 + group_starts[1:] = group_ends[:-1] + pad = torch.empty( + (num_groups, max_group_len), dtype=arange.dtype, device=arange.device ) - pad = nt.to_padded_tensor(-1).flip(-1).flip(0) - return pad + for pad_row, group_start, group_end, pad_len in zip( + pad, group_starts, group_ends, pad_lengths + ): + pad_row[:pad_len] = -1 + pad_row[pad_len:] = arange[group_start:group_end] - @implement_for("torch", None, "2.4") - def _padded_indices(self, shapes, arange) -> torch.Tensor: # noqa: F811 - arange = arange.flip(0).split(shapes.flip(0).squeeze().unbind()) - return ( - torch.nn.utils.rnn.pad_sequence(arange, batch_first=True, padding_value=-1) - .flip(-1) - .flip(0) - ) + return pad def _preceding_stop_idx(self, storage, lengths, seq_length, start_idx): preceding_stop_idx = self._cache.get("preceding_stop_idx") diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 9bbd068b434..60c1009990e 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -397,16 +397,6 @@ def high(self, value): self.device = value.device self._high = value.cpu() - @low.setter - def low(self, value): - self.device = value.device - self._low = value.cpu() - - @high.setter - def high(self, value): - self.device = value.device - self._high = value.cpu() - def __post_init__(self): self.low = self.low.clone() self.high = self.high.clone() @@ -2269,9 +2259,10 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Bounded: dest_device = torch.device(dest) if dest_device == self.device and dest_dtype == self.dtype: return self + self.space.device = dest_device return Bounded( - low=self.space.low.to(dest), - high=self.space.high.to(dest), + low=self.space.low, + high=self.space.high, shape=self.shape, device=dest_device, dtype=dest_dtype, diff --git a/torchrl/envs/custom/tictactoeenv.py b/torchrl/envs/custom/tictactoeenv.py index 6e5dee781e8..2c93a5748ef 100644 --- a/torchrl/envs/custom/tictactoeenv.py +++ b/torchrl/envs/custom/tictactoeenv.py @@ -218,7 +218,7 @@ def _step(self, state: TensorDict) -> TensorDict: turn = state["turn"].clone() action = state["action"] board.flatten(-2, -1).scatter_(index=action.unsqueeze(-1), dim=-1, value=1) - wins = self.win(state["board"], action) + wins = self.win(board, action) mask = board.flatten(-2, -1) == -1 done = wins | ~mask.any(-1, keepdim=True) @@ -234,7 +234,7 @@ def _step(self, state: TensorDict) -> TensorDict: ("player0", "reward"): reward_0.float(), ("player1", "reward"): reward_1.float(), "board": torch.where(board == -1, board, 1 - board), - "turn": 1 - state["turn"], + "turn": 1 - turn, "mask": mask, }, batch_size=state.batch_size, @@ -260,13 +260,15 @@ def _set_seed(self, seed: int | None): def win(board: torch.Tensor, action: torch.Tensor): row = action // 3 # type: ignore col = action % 3 # type: ignore - return ( - board[..., row, :].sum() - == 3 | board[..., col].sum() - == 3 | board.diagonal(0, -2, -1).sum() - == 3 | board.flip(-1).diagonal(0, -2, -1).sum() - == 3 - ) + if board[..., row, :].sum() == 3: + return True + if board[..., col].sum() == 3: + return True + if board.diagonal(0, -2, -1).sum() == 3: + return True + if board.flip(-1).diagonal(0, -2, -1).sum() == 3: + return True + return False @staticmethod def full(board: torch.Tensor) -> bool: diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 9092d419075..995f245a8ac 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -12,10 +12,10 @@ import numpy as np import torch -from tensordict import TensorDict, TensorDictBase +from tensordict import NonTensorData, TensorDict, TensorDictBase from torchrl._utils import logger as torchrl_logger -from torchrl.data.tensor_specs import Composite, TensorSpec, Unbounded +from torchrl.data.tensor_specs import Composite, NonTensor, TensorSpec, Unbounded from torchrl.envs.common import _EnvWrapper, EnvBase @@ -283,9 +283,12 @@ def read_obs( observations = observations_dict else: for key, val in observations.items(): - observations[key] = self.observation_spec[key].encode( - val, ignore_device=True - ) + if isinstance(self.observation_spec[key], NonTensor): + observations[key] = NonTensorData(val) + else: + observations[key] = self.observation_spec[key].encode( + val, ignore_device=True + ) return observations def _step(self, tensordict: TensorDictBase) -> TensorDictBase: diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 34af87b75f9..a82286659cb 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -29,6 +29,7 @@ Composite, MultiCategorical, MultiOneHot, + NonTensor, OneHot, TensorSpec, Unbounded, @@ -55,6 +56,14 @@ _has_mo = importlib.util.find_spec("mo_gymnasium") is not None _has_sb3 = importlib.util.find_spec("stable_baselines3") is not None +_has_minigrid = importlib.util.find_spec("minigrid") is not None + + +def _minigrid_lib(): + assert _has_minigrid, "minigrid not found" + import minigrid + + return minigrid class set_gym_backend(_DecoratorContextManager): @@ -369,6 +378,8 @@ def _gym_to_torchrl_spec_transform( categorical_action_encoding=categorical_action_encoding, remap_state_to_observation=remap_state_to_observation, ) + elif _has_minigrid and isinstance(spec, _minigrid_lib().core.mission.MissionSpace): + return NonTensor((), device=device) else: raise NotImplementedError( f"spec of type {type(spec).__name__} is currently unaccounted for" @@ -766,14 +777,20 @@ def __init__(self, env=None, categorical_action_encoding=False, **kwargs): self._seed_calls_reset = None self._categorical_action_encoding = categorical_action_encoding if env is not None: - if "EnvCompatibility" in str( - env - ): # a hacky way of knowing if EnvCompatibility is part of the wrappers of env - raise ValueError( - "GymWrapper does not support the gym.wrapper.compatibility.EnvCompatibility wrapper. " - "If this feature is needed, detail your use case in an issue of " - "https://github.com/pytorch/rl/issues." - ) + try: + env_str = str(env) + except TypeError: + # MiniGrid has a bug where the __str__ method fails + pass + else: + if ( + "EnvCompatibility" in env_str + ): # a hacky way of knowing if EnvCompatibility is part of the wrappers of env + raise ValueError( + "GymWrapper does not support the gym.wrapper.compatibility.EnvCompatibility wrapper. " + "If this feature is needed, detail your use case in an issue of " + "https://github.com/pytorch/rl/issues." + ) libname = self.get_library_name(env) with set_gym_backend(libname): kwargs["env"] = env diff --git a/torchrl/envs/libs/vmas.py b/torchrl/envs/libs/vmas.py index 5811580826d..22f9835303b 100644 --- a/torchrl/envs/libs/vmas.py +++ b/torchrl/envs/libs/vmas.py @@ -94,6 +94,16 @@ def _vmas_to_torchrl_spec_transform( device=device, ) ) + elif isinstance(spec, gym_spaces.Dict): + spec_out = {} + for key in spec.keys(): + spec_out[key] = _vmas_to_torchrl_spec_transform( + spec[key], + device=device, + categorical_action_encoding=categorical_action_encoding, + ) + # the batch-size must be set later + return Composite(spec_out, device=device) else: raise NotImplementedError( f"spec of type {type(spec).__name__} is currently unaccounted for vmas" diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 50c1d2c3f88..efa2fcfb270 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4601,7 +4601,7 @@ class TensorDictPrimer(Transform): .. note:: Some TorchRL modules rely on specific keys being present in the environment TensorDicts, like :class:`~torchrl.modules.models.LSTM` or :class:`~torchrl.modules.models.GRU`. - To facilitate this process, the method :func:`~torchrl.models.utils.get_primers_from_module` + To facilitate this process, the method :func:`~torchrl.modules.utils.get_primers_from_module` automatically checks for required primer transforms in a module and its submodules and generates them. """ @@ -4668,10 +4668,15 @@ def __init__( def reset_key(self): reset_key = self.__dict__.get("_reset_key", None) if reset_key is None: + if self.parent is None: + raise RuntimeError( + "Missing parent, cannot infer reset_key automatically." + ) reset_keys = self.parent.reset_keys if len(reset_keys) > 1: raise RuntimeError( - f"Got more than one reset key in env {self.container}, cannot infer which one to use. Consider providing the reset key in the {type(self)} constructor." + f"Got more than one reset key in env {self.container}, cannot infer which one to use. " + f"Consider providing the reset key in the {type(self)} constructor." ) reset_key = self._reset_key = reset_keys[0] return reset_key diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index c246b553e95..f65461842bb 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -21,6 +21,7 @@ ) from .models import ( BatchRenorm1d, + ConsistentDropoutModule, Conv3dNet, ConvNet, DdpgCnnActor, @@ -85,4 +86,5 @@ VmapModule, WorldModelWrapper, ) +from .utils import get_primers_from_module from .planners import CEMPlanner, MPCPlannerBase, MPPIPlanner # usort:skip diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index 9a814e35477..90b9fadd747 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -9,7 +9,12 @@ from .batchrenorm import BatchRenorm1d from .decision_transformer import DecisionTransformer -from .exploration import NoisyLazyLinear, NoisyLinear, reset_noise +from .exploration import ( + ConsistentDropoutModule, + NoisyLazyLinear, + NoisyLinear, + reset_noise, +) from .model_based import ( DreamerActor, ObsDecoder, diff --git a/torchrl/modules/models/exploration.py b/torchrl/modules/models/exploration.py index 16c6ac5ff30..720934a6809 100644 --- a/torchrl/modules/models/exploration.py +++ b/torchrl/modules/models/exploration.py @@ -2,16 +2,24 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import functools import math import warnings -from typing import Optional, Sequence, Union +from typing import List, Optional, Sequence, Union import torch + +from tensordict.nn import TensorDictModuleBase +from tensordict.utils import NestedKey from torch import distributions as d, nn +from torch.nn import functional as F +from torch.nn.modules.dropout import _DropoutNd from torch.nn.modules.lazy import LazyModuleMixin from torch.nn.parameter import UninitializedBuffer, UninitializedParameter - from torchrl._utils import prod +from torchrl.data.tensor_specs import Unbounded from torchrl.data.utils import DEVICE_TYPING, DEVICE_TYPING_ARGS from torchrl.envs.utils import exploration_type, ExplorationType from torchrl.modules.distributions.utils import _cast_transform_device @@ -520,3 +528,203 @@ def initialize_parameters( ) self._sigma.materialize((action_dim, state_dim)) self._sigma.data.copy_(self.sigma_init.expand_as(self._sigma)) + + +class ConsistentDropout(_DropoutNd): + """Implements a :class:`~torch.nn.Dropout` variant with consistent dropout. + + This method is proposed in `"Consistent Dropout for Policy Gradient Reinforcement Learning" (Hausknecht & Wagener, 2022) + `_. + + This :class:`~torch.nn.Dropout` variant attempts to increase training stability and + reduce update variance by caching the dropout masks used during rollout + and reusing them during the update phase. + + The class you are looking at is independent of the rest of TorchRL's API and does not require tensordict to be run. + :class:`~torchrl.modules.ConsistentDropoutModule` is a wrapper around ``ConsistentDropout`` that capitalizes on the extensibility + of ``TensorDict``s by storing generated dropout masks in the transition ``TensorDict`` themselves. + See this class for a detailed explanation as well as usage examples. + + There is otherwise little conceptual deviance from the PyTorch + :class:`~torch.nn.Dropout` implementation. + + ..note:: TorchRL's data collectors perform rollouts in :meth:`~torch.no_grad` mode but not in `eval` mode, + so the dropout masks will be applied unless the policy passed to the collector is in eval mode. + + .. note:: Unlike other exploration modules, :class:`~torchrl.modules.ConsistentDropoutModule` + uses the ``train``/``eval`` mode to comply with the regular `Dropout` API in PyTorch. + The :func:`~torchrl.envs.utils.set_exploration_mode` context manager will have no effect on + this module. + + Args: + p (float, optional): Dropout probability. Defaults to ``0.5``. + + .. seealso:: + + - :class:`~torchrl.collectors.SyncDataCollector`: + :meth:`~torchrl.collectors.SyncDataCollector.rollout()` and :meth:`~torchrl.collectors.SyncDataCollector.iterator()` + - :class:`~torchrl.collectors.MultiSyncDataCollector`: + Uses :meth:`~torchrl.collectors.collectors._main_async_collector` (:class:`~torchrl.collectors.SyncDataCollector`) + under the hood + - :class:`~torchrl.collectors.MultiaSyncDataCollector`, :class:`~torchrl.collectors.aSyncDataCollector`: Ditto. + + """ + + def __init__(self, p: float = 0.5): + super().__init__() + self.p = p + + def forward( + self, x: torch.Tensor, mask: torch.Tensor | None = None + ) -> torch.Tensor: + """During training (rollouts & updates), this call masks a tensor full of ones before multiplying with the input tensor. + + During evaluation, this call results in a no-op and only the input is returned. + + Args: + x (torch.Tensor): the input tensor. + mask (torch.Tensor, optional): the optional mask for the dropout. + + Returns: a tensor and a corresponding mask in train mode, and only a tensor in eval mode. + """ + if self.training: + if mask is None: + mask = self.make_mask(input=x) + return x * mask, mask + + return x + + def make_mask(self, *, input=None, shape=None): + if input is not None: + return F.dropout( + torch.ones_like(input), self.p, self.training, inplace=False + ) + elif shape is not None: + return F.dropout(torch.ones(shape), self.p, self.training, inplace=False) + else: + raise RuntimeError("input or shape must be passed to make_mask.") + + +class ConsistentDropoutModule(TensorDictModuleBase): + """A TensorDictModule wrapper for :class:`~ConsistentDropout`. + + Args: + p (float, optional): Dropout probability. Default: ``0.5``. + in_keys (NestedKey or list of NestedKeys): keys to be read + from input tensordict and passed to this module. + out_keys (NestedKey or iterable of NestedKeys): keys to be written to the input tensordict. + Defaults to ``in_keys`` values. + + Keyword Args: + input_shape (tuple, optional): the shape of the input (non-batchted), used to generate the + tensordict primers with :meth:`~.make_tensordict_primer`. + input_dtype (torch.dtype, optional): the dtype of the input for the primer. If none is pased, + ``torch.get_default_dtype`` is assumed. + + .. note:: To use this class within a policy, one needs the mask to be reset at reset time. + This can be achieved through a :class:`~torchrl.envs.TensorDictPrimer` transform that can be obtained + with :meth:`~.make_tensordict_primer`. See this method for more information. + + Examples: + >>> from tensordict import TensorDict + >>> module = ConsistentDropoutModule(p = 0.1) + >>> td = TensorDict({"x": torch.randn(3, 4)}, [3]) + >>> module(td) + TensorDict( + fields={ + mask_6127171760: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.bool, is_shared=False), + x: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([3]), + device=None, + is_shared=False) + """ + + def __init__( + self, + p: float, + in_keys: NestedKey | List[NestedKey], + out_keys: NestedKey | List[NestedKey] | None = None, + input_shape: torch.Size = None, + input_dtype: torch.dtype | None = None, + ): + if isinstance(in_keys, NestedKey): + in_keys = [in_keys, f"mask_{id(self)}"] + if out_keys is None: + out_keys = list(in_keys) + if isinstance(out_keys, NestedKey): + out_keys = [out_keys, f"mask_{id(self)}"] + if len(in_keys) != 2 or len(out_keys) != 2: + raise ValueError( + "in_keys and out_keys length must be 2 for consistent dropout." + ) + self.in_keys = in_keys + self.out_keys = out_keys + self.input_shape = input_shape + self.input_dtype = input_dtype + super().__init__() + + if not 0 <= p < 1: + raise ValueError(f"p must be in [0,1), got p={p: 4.4f}.") + + self.consistent_dropout = ConsistentDropout(p) + + def forward(self, tensordict): + x = tensordict.get(self.in_keys[0]) + mask = tensordict.get(self.in_keys[1], default=None) + if self.consistent_dropout.training: + x, mask = self.consistent_dropout(x, mask=mask) + tensordict.set(self.out_keys[0], x) + tensordict.set(self.out_keys[1], mask) + else: + x = self.consistent_dropout(x, mask=mask) + tensordict.set(self.out_keys[0], x) + + return tensordict + + def make_tensordict_primer(self): + """Makes a tensordict primer for the environment to generate random masks during reset calls. + + .. seealso:: :func:`torchrl.modules.utils.get_primers_from_module` for a method to generate all primers for a given + module. + + Examples: + >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod + >>> from torchrl.envs import GymEnv, StepCounter, SerialEnv + >>> m = Seq( + ... Mod(torch.nn.Linear(7, 4), in_keys=["observation"], out_keys=["intermediate"]), + ... ConsistentDropoutModule( + ... p=0.5, + ... input_shape=(2, 4), + ... in_keys="intermediate", + ... ), + ... Mod(torch.nn.Linear(4, 7), in_keys=["intermediate"], out_keys=["action"]), + ... ) + >>> primer = get_primers_from_module(m) + >>> env0 = GymEnv("Pendulum-v1").append_transform(StepCounter(5)) + >>> env1 = GymEnv("Pendulum-v1").append_transform(StepCounter(6)) + >>> env = SerialEnv(2, [lambda env=env0: env, lambda env=env1: env]) + >>> env = env.append_transform(primer) + >>> r = env.rollout(10, m, break_when_any_done=False) + >>> mask = [k for k in r.keys() if k.startswith("mask")][0] + >>> assert (r[mask][0, :5] != r[mask][0, 5:6]).any() + >>> assert (r[mask][0, :4] == r[mask][0, 4:5]).all() + + """ + from torchrl.envs.transforms.transforms import TensorDictPrimer + + shape = self.input_shape + dtype = self.input_dtype + if dtype is None: + dtype = torch.get_default_dtype() + if shape is None: + raise RuntimeError( + "Cannot infer the shape of the input automatically. " + "Please pass the shape of the tensor to `ConstistentDropoutModule` during construction " + "with the `input_shape` kwarg." + ) + return TensorDictPrimer( + primers={self.in_keys[1]: Unbounded(dtype=dtype, shape=shape)}, + default_value=functools.partial( + self.consistent_dropout.make_mask, shape=shape + ), + ) diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index 48756683c11..f538f8e95c5 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + from typing import Optional, Tuple import torch @@ -387,7 +389,7 @@ class LSTMModule(ModuleBase): .. note:: This module relies on specific ``recurrent_state`` keys being present in the input TensorDicts. To generate a :class:`~torchrl.envs.transforms.TensorDictPrimer` transform that will automatically add hidden states to the environment TensorDicts, use the method :func:`~torchrl.modules.rnn.LSTMModule.make_tensordict_primer`. - If this class is a submodule in a larger module, the method :func:`~torchrl.models.utils.get_primers_from_module` can be called + If this class is a submodule in a larger module, the method :func:`~torchrl.modules.utils.get_primers_from_module` can be called on the parent module to automatically generate the primer transforms required for all submodules, including this one. @@ -534,6 +536,9 @@ def make_tensordict_primer(self): tensordict, which the meth:`~torchrl.EnvBase.step_mdp` method will not be able to do as the recurrent states are not registered within the environment specs. + See :func:`torchrl.modules.utils.get_primers_from_module` for a method to generate all primers for a given + module. + Examples: >>> from torchrl.collectors import SyncDataCollector >>> from torchrl.envs import TransformedEnv, InitTracker @@ -1108,7 +1113,7 @@ class GRUModule(ModuleBase): .. note:: This module relies on specific ``recurrent_state`` keys being present in the input TensorDicts. To generate a :class:`~torchrl.envs.transforms.TensorDictPrimer` transform that will automatically add hidden states to the environment TensorDicts, use the method :func:`~torchrl.modules.rnn.GRUModule.make_tensordict_primer`. - If this class is a submodule in a larger module, the method :func:`~torchrl.models.utils.get_primers_from_module` can be called + If this class is a submodule in a larger module, the method :func:`~torchrl.modules.utils.get_primers_from_module` can be called on the parent module to automatically generate the primer transforms required for all submodules, including this one. Examples: @@ -1280,6 +1285,9 @@ def make_tensordict_primer(self): tensordict, which the meth:`~torchrl.EnvBase.step_mdp` method will not be able to do as the recurrent states are not registered within the environment specs. + See :func:`torchrl.modules.utils.get_primers_from_module` for a method to generate all primers for a given + module. + Examples: >>> from torchrl.collectors import SyncDataCollector >>> from torchrl.envs import TransformedEnv, InitTracker diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index cac81ec253e..56f6fe90a0b 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -10,7 +10,12 @@ from typing import Tuple import torch -from tensordict import TensorDict, TensorDictBase, TensorDictParams +from tensordict import ( + is_tensor_collection, + TensorDict, + TensorDictBase, + TensorDictParams, +) from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule from tensordict.utils import NestedKey from torch import distributions as d @@ -191,6 +196,13 @@ class A2CLoss(LossModule): ... next_reward = torch.randn(*batch, 1), ... next_observation = torch.randn(*batch, n_obs)) >>> loss_obj.backward() + + .. note:: + There is an exception regarding compatibility with non-tensordict-based modules. + If the actor network is probabilistic and uses a :class:`~tensordict.nn.distributions.CompositeDistribution`, + this class must be used with tensordicts and cannot function as a tensordict-independent module. + This is because composite action spaces inherently rely on the structured representation of data provided by + tensordicts to handle their actions. """ @dataclass @@ -390,7 +402,10 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: entropy = dist.entropy() except NotImplementedError: x = dist.rsample((self.samples_mc_entropy,)) - entropy = -dist.log_prob(x).mean(0) + log_prob = dist.log_prob(x) + if is_tensor_collection(log_prob): + log_prob = log_prob.get(self.tensor_keys.sample_log_prob) + entropy = -log_prob.mean(0) return entropy.unsqueeze(-1) def _log_probs( @@ -398,10 +413,6 @@ def _log_probs( ) -> Tuple[torch.Tensor, d.Distribution]: # current log_prob of actions action = tensordict.get(self.tensor_keys.action) - if action.requires_grad: - raise RuntimeError( - f"tensordict stored {self.tensor_keys.action} require grad." - ) tensordict_clone = tensordict.select( *self.actor_network.in_keys, strict=False ).clone() @@ -409,7 +420,15 @@ def _log_probs( self.actor_network ) if self.functional else contextlib.nullcontext(): dist = self.actor_network.get_dist(tensordict_clone) - log_prob = dist.log_prob(action) + if action.requires_grad: + raise RuntimeError( + f"tensordict stored {self.tensor_keys.action} requires grad." + ) + if isinstance(action, torch.Tensor): + log_prob = dist.log_prob(action) + else: + tensordict = dist.log_prob(tensordict) + log_prob = tensordict.get(self.tensor_keys.sample_log_prob) log_prob = log_prob.unsqueeze(-1) return log_prob, dist diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index a9d50cadd50..6cbb8b02426 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -47,7 +47,7 @@ class DQNLoss(LossModule): Defaults to "l2". delay_value (bool, optional): whether to duplicate the value network into a new target value network to - create a DQN with a target network. Default is ``False``. + create a DQN with a target network. Default is ``True``. double_dqn (bool, optional): whether to use Double DQN, as described in https://arxiv.org/abs/1509.06461. Defaults to ``False``. action_space (str or TensorSpec, optional): Action space. Must be one of diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 8f32d04b94b..5758b1ed7d8 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -11,7 +11,12 @@ from typing import Tuple import torch -from tensordict import TensorDict, TensorDictBase, TensorDictParams +from tensordict import ( + is_tensor_collection, + TensorDict, + TensorDictBase, + TensorDictParams, +) from tensordict.nn import ( dispatch, ProbabilisticTensorDictModule, @@ -238,6 +243,12 @@ class PPOLoss(LossModule): ... next_observation=torch.randn(*batch, n_obs)) >>> loss_objective.backward() + .. note:: + There is an exception regarding compatibility with non-tensordict-based modules. + If the actor network is probabilistic and uses a :class:`~tensordict.nn.distributions.CompositeDistribution`, + this class must be used with tensordicts and cannot function as a tensordict-independent module. + This is because composite action spaces inherently rely on the structured representation of data provided by + tensordicts to handle their actions. """ @dataclass @@ -454,7 +465,10 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: entropy = dist.entropy() except NotImplementedError: x = dist.rsample((self.samples_mc_entropy,)) - entropy = -dist.log_prob(x).mean(0) + log_prob = dist.log_prob(x) + if is_tensor_collection(log_prob): + log_prob = log_prob.get(self.tensor_keys.sample_log_prob) + entropy = -log_prob.mean(0) return entropy.unsqueeze(-1) def _log_weight( @@ -462,20 +476,27 @@ def _log_weight( ) -> Tuple[torch.Tensor, d.Distribution]: # current log_prob of actions action = tensordict.get(self.tensor_keys.action) - if action.requires_grad: - raise RuntimeError( - f"tensordict stored {self.tensor_keys.action} requires grad." - ) with self.actor_network_params.to_module( self.actor_network ) if self.functional else contextlib.nullcontext(): dist = self.actor_network.get_dist(tensordict) - log_prob = dist.log_prob(action) prev_log_prob = tensordict.get(self.tensor_keys.sample_log_prob) if prev_log_prob.requires_grad: - raise RuntimeError("tensordict prev_log_prob requires grad.") + raise RuntimeError( + f"tensordict stored {self.tensor_keys.sample_log_prob} requires grad." + ) + + if action.requires_grad: + raise RuntimeError( + f"tensordict stored {self.tensor_keys.action} requires grad." + ) + if isinstance(action, torch.Tensor): + log_prob = dist.log_prob(action) + else: + tensordict = dist.log_prob(tensordict) + log_prob = tensordict.get(self.tensor_keys.sample_log_prob) log_weight = (log_prob - prev_log_prob).unsqueeze(-1) kl_approx = (prev_log_prob - log_prob).unsqueeze(-1) @@ -1116,7 +1137,17 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: kl = torch.distributions.kl.kl_divergence(previous_dist, current_dist) except NotImplementedError: x = previous_dist.sample((self.samples_mc_kl,)) - kl = (previous_dist.log_prob(x) - current_dist.log_prob(x)).mean(0) + previous_log_prob = previous_dist.log_prob(x) + current_log_prob = current_dist.log_prob(x) + if is_tensor_collection(x): + previous_log_prob = previous_log_prob.get( + self.tensor_keys.sample_log_prob + ) + current_log_prob = current_log_prob.get( + self.tensor_keys.sample_log_prob + ) + + kl = (previous_log_prob - current_log_prob).mean(0) kl = kl.unsqueeze(-1) neg_loss = neg_loss - self.beta * kl if kl.mean() > self.dtarg * 1.5: diff --git a/tutorials/sphinx-tutorials/getting-started-5.py b/tutorials/sphinx-tutorials/getting-started-5.py index 5f95fe1e534..d355d1888c5 100644 --- a/tutorials/sphinx-tutorials/getting-started-5.py +++ b/tutorials/sphinx-tutorials/getting-started-5.py @@ -89,7 +89,7 @@ optim_steps = 10 collector = SyncDataCollector( env, - policy, + policy_explore, frames_per_batch=frames_per_batch, total_frames=-1, init_random_frames=init_rand_steps,