From ff0101d04ec6b1a722fe4ce0ea01b87c667dd1c1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 2 Jul 2024 09:11:19 +0100 Subject: [PATCH] amend --- test/test_env.py | 51 +++++++++++++++++++++++++++++++++++- torchrl/envs/batched_envs.py | 3 ++- 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index bfda10f0e93..0bbef301369 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -59,6 +59,7 @@ dense_stack_tds, LazyStackedTensorDict, TensorDict, + TensorDictBase, ) from tensordict.nn import TensorDictModuleBase from tensordict.utils import _unravel_key_to_tuple @@ -68,6 +69,7 @@ from torchrl.data.tensor_specs import ( CompositeSpec, DiscreteTensorSpec, + NonTensorSpec, UnboundedContinuousTensorSpec, ) from torchrl.envs import ( @@ -84,7 +86,11 @@ from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv from torchrl.envs.libs.gym import _has_gym, gym_backend, GymEnv, GymWrapper from torchrl.envs.transforms import Compose, StepCounter, TransformedEnv -from torchrl.envs.transforms.transforms import AutoResetEnv, AutoResetTransform +from torchrl.envs.transforms.transforms import ( + AutoResetEnv, + AutoResetTransform, + Transform, +) from torchrl.envs.utils import ( _StepMDP, _terminated_or_truncated, @@ -3188,6 +3194,49 @@ def test_parallel(self, bwad, use_buffers): r = env.rollout(N, break_when_any_done=bwad) assert r.get("non_tensor").tolist() == [list(range(N))] * 2 + class AddString(Transform): + def __init__(self): + super().__init__() + self._str = "0" + + def _call(self, td): + td["string"] = str(int(self._str) + 1) + self._str = td["string"] + return td + + def _reset( + self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase + ) -> TensorDictBase: + self._str = "0" + tensordict_reset["string"] = self._str + return tensordict_reset + + def transform_observation_spec(self, observation_spec): + observation_spec["string"] = NonTensorSpec(()) + return observation_spec + + @pytest.mark.parametrize("batched", ["serial", "parallel"]) + def test_partial_rest(self, batched): + env0 = lambda: CountingEnv(5).append_transform(self.AddString()) + env1 = lambda: CountingEnv(6).append_transform(self.AddString()) + if batched == "parallel": + env = ParallelEnv(2, [env0, env1], mp_start_method=mp_ctx) + else: + env = SerialEnv(2, [env0, env1]) + s = env.reset() + i = 0 + for i in range(10): + s, s_ = env.step_and_maybe_reset( + s.set("action", torch.ones(2, 1, dtype=torch.int)) + ) + if s.get(("next", "done")).any(): + break + s = s_ + assert i == 5 + assert (s["next", "done"] == torch.tensor([[True], [False]])).all() + assert s_["string"] == ["0", "6"] + assert s["next", "string"] == ["6", "6"] + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 500f457ad20..7fb180ac121 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -988,7 +988,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: tds = [] for i, _env in enumerate(self._envs): if not needs_resetting[i]: - if not self._use_buffers and tensordict is not None: + if out_tds is not None and tensordict is not None: out_tds[i] = tensordict[i].exclude(*self._envs[i].reset_keys) continue if tensordict is not None: @@ -1047,6 +1047,7 @@ def select_and_clone(name, tensor): filter_empty=True, ) if out_tds is not None: + print("out_tds", out_tds) out.update( LazyStackedTensorDict(*out_tds), keys_to_update=self._non_tensor_keys )