Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jul 2, 2024
1 parent 5b92b23 commit ff0101d
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 2 deletions.
51 changes: 50 additions & 1 deletion test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
dense_stack_tds,
LazyStackedTensorDict,
TensorDict,
TensorDictBase,
)
from tensordict.nn import TensorDictModuleBase
from tensordict.utils import _unravel_key_to_tuple
Expand All @@ -68,6 +69,7 @@
from torchrl.data.tensor_specs import (
CompositeSpec,
DiscreteTensorSpec,
NonTensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.envs import (
Expand All @@ -84,7 +86,11 @@
from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv
from torchrl.envs.libs.gym import _has_gym, gym_backend, GymEnv, GymWrapper
from torchrl.envs.transforms import Compose, StepCounter, TransformedEnv
from torchrl.envs.transforms.transforms import AutoResetEnv, AutoResetTransform
from torchrl.envs.transforms.transforms import (
AutoResetEnv,
AutoResetTransform,
Transform,
)
from torchrl.envs.utils import (
_StepMDP,
_terminated_or_truncated,
Expand Down Expand Up @@ -3188,6 +3194,49 @@ def test_parallel(self, bwad, use_buffers):
r = env.rollout(N, break_when_any_done=bwad)
assert r.get("non_tensor").tolist() == [list(range(N))] * 2

class AddString(Transform):
def __init__(self):
super().__init__()
self._str = "0"

def _call(self, td):
td["string"] = str(int(self._str) + 1)
self._str = td["string"]
return td

def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
self._str = "0"
tensordict_reset["string"] = self._str
return tensordict_reset

def transform_observation_spec(self, observation_spec):
observation_spec["string"] = NonTensorSpec(())
return observation_spec

@pytest.mark.parametrize("batched", ["serial", "parallel"])
def test_partial_rest(self, batched):
env0 = lambda: CountingEnv(5).append_transform(self.AddString())
env1 = lambda: CountingEnv(6).append_transform(self.AddString())
if batched == "parallel":
env = ParallelEnv(2, [env0, env1], mp_start_method=mp_ctx)
else:
env = SerialEnv(2, [env0, env1])
s = env.reset()
i = 0
for i in range(10):
s, s_ = env.step_and_maybe_reset(
s.set("action", torch.ones(2, 1, dtype=torch.int))
)
if s.get(("next", "done")).any():
break
s = s_
assert i == 5
assert (s["next", "done"] == torch.tensor([[True], [False]])).all()
assert s_["string"] == ["0", "6"]
assert s["next", "string"] == ["6", "6"]


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
3 changes: 2 additions & 1 deletion torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,7 +988,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
tds = []
for i, _env in enumerate(self._envs):
if not needs_resetting[i]:
if not self._use_buffers and tensordict is not None:
if out_tds is not None and tensordict is not None:
out_tds[i] = tensordict[i].exclude(*self._envs[i].reset_keys)
continue
if tensordict is not None:
Expand Down Expand Up @@ -1047,6 +1047,7 @@ def select_and_clone(name, tensor):
filter_empty=True,
)
if out_tds is not None:
print("out_tds", out_tds)
out.update(
LazyStackedTensorDict(*out_tds), keys_to_update=self._non_tensor_keys
)
Expand Down

0 comments on commit ff0101d

Please sign in to comment.