From 271557fa6b5943e83eba2d5ea2a9e8bf8d62ed07 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 7 Aug 2024 11:21:31 -0400 Subject: [PATCH 01/12] Update [ghstack-poisoned] --- test/test_env.py | 88 +++++++++++++++ torchrl/envs/batched_envs.py | 208 +++++++++++++++++++++++++---------- 2 files changed, 235 insertions(+), 61 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index b945498573d..fc95131975e 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -3340,6 +3340,94 @@ def test_pendulum_env(self): assert r.shape == torch.Size((5, 10)) +@pytest.mark.parametrize("device", [None, *get_default_devices()]) +@pytest.mark.parametrize("env_device", [None, *get_default_devices()]) +class TestPartialSteps: + @pytest.mark.parametrize("use_buffers", [False, True]) + def test_parallel_partial_steps(self, use_buffers, device, env_device): + penv = ParallelEnv( + 4, + lambda: CountingEnv(max_steps=10, start_val=2, device=env_device), + mp_start_method=mp_ctx, + use_buffers=use_buffers, + device=device, + ) + td = penv.reset() + psteps = torch.zeros(4, dtype=torch.bool) + psteps[[1, 3]] = True + td.set("_partial_steps", psteps) + + td.set("action", penv.action_spec.one()) + td = penv.step(td) + assert (td[0].get("next") == 0).all() + assert (td[1].get("next") != 0).any() + assert (td[2].get("next") == 0).all() + assert (td[3].get("next") != 0).any() + + @pytest.mark.parametrize("use_buffers", [False, True]) + def test_parallel_partial_step_and_maybe_reset( + self, use_buffers, device, env_device + ): + penv = ParallelEnv( + 4, + lambda: CountingEnv(max_steps=10, start_val=2, device=env_device), + mp_start_method=mp_ctx, + use_buffers=use_buffers, + device=device, + ) + td = penv.reset() + psteps = torch.zeros(4, dtype=torch.bool) + psteps[[1, 3]] = True + td.set("_partial_steps", psteps) + + td.set("action", penv.action_spec.one()) + td, tdreset = penv.step_and_maybe_reset(td) + assert (td[0].get("next") == 0).all() + assert (td[1].get("next") != 0).any() + assert (td[2].get("next") == 0).all() + assert (td[3].get("next") != 0).any() + + @pytest.mark.parametrize("use_buffers", [False, True]) + def test_serial_partial_steps(self, use_buffers, device, env_device): + penv = SerialEnv( + 4, + lambda: CountingEnv(max_steps=10, start_val=2, device=env_device), + use_buffers=use_buffers, + device=device, + ) + td = penv.reset() + psteps = torch.zeros(4, dtype=torch.bool) + psteps[[1, 3]] = True + td.set("_partial_steps", psteps) + + td.set("action", penv.action_spec.one()) + td = penv.step(td) + assert (td[0].get("next") == 0).all() + assert (td[1].get("next") != 0).any() + assert (td[2].get("next") == 0).all() + assert (td[3].get("next") != 0).any() + + @pytest.mark.parametrize("use_buffers", [False, True]) + def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_device): + penv = SerialEnv( + 4, + lambda: CountingEnv(max_steps=10, start_val=2, device=env_device), + use_buffers=use_buffers, + device=device, + ) + td = penv.reset() + psteps = torch.zeros(4, dtype=torch.bool) + psteps[[1, 3]] = True + td.set("_partial_steps", psteps) + + td.set("action", penv.action_spec.one()) + td = penv.step(td) + assert (td[0].get("next") == 0).all() + assert (td[1].get("next") != 0).any() + assert (td[2].get("next") == 0).all() + assert (td[3].get("next") != 0).any() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index f915af52bcc..c886fe6ef6b 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1066,18 +1066,31 @@ def _step( self, tensordict: TensorDict, ) -> TensorDict: - tensordict_in = tensordict.clone(False) + partial_steps = tensordict.get("_partial_steps", None) + tensordict_save = tensordict + if partial_steps is not None: + tensordict = tensordict[partial_steps] + workers_range = partial_steps.nonzero().squeeze().tolist() + tensordict_in = tensordict + if self._use_buffers: + shared_tensordict_parent = ( + self.shared_tensordict_parent._get_sub_tensordict(partial_steps) + ) + else: + workers_range = range(self.num_workers) + tensordict_in = tensordict.clone(False) + if self._use_buffers: + shared_tensordict_parent = self.shared_tensordict_parent + data_in = [] - for i in range(self.num_workers): + for i, td_ in zip(workers_range, tensordict_in): # shared_tensordicts are locked, and we need to select the keys since we update in-place. # There may be unexpected keys, such as "_reset", that we should comfortably ignore here. env_device = self._envs[i].device if env_device != self.device and env_device is not None: - data_in.append( - tensordict_in[i].to(env_device, non_blocking=self.non_blocking) - ) + data_in.append(td_.to(env_device, non_blocking=self.non_blocking)) else: - data_in.append(tensordict_in[i]) + data_in.append(td_) self._sync_m2w() out_tds = None @@ -1085,42 +1098,51 @@ def _step( out_tds = [] if self._use_buffers: - next_td = self.shared_tensordict_parent.get("next") - for i, _data_in in enumerate(data_in): + next_td = shared_tensordict_parent.get("next") + for i, _next_td, _data_in in zip(workers_range, next_td, data_in): out_td = self._envs[i]._step(_data_in) - next_td[i].update_( + _next_td.update_( out_td, keys_to_update=list(self._env_output_keys), non_blocking=self.non_blocking, ) if out_tds is not None: out_tds.append(out_td) - else: - for i, _data_in in enumerate(data_in): - out_td = self._envs[i]._step(_data_in) - out_tds.append(out_td) - return LazyStackedTensorDict.maybe_dense_stack(out_tds) - # We must pass a clone of the tensordict, as the values of this tensordict - # will be modified in-place at further steps - device = self.device + # We must pass a clone of the tensordict, as the values of this tensordict + # will be modified in-place at further steps + device = self.device - def select_and_clone(name, tensor): - if name in self._selected_step_keys: - return tensor.clone() + def select_and_clone(name, tensor): + if name in self._selected_step_keys: + return tensor.clone() - out = next_td.named_apply(select_and_clone, nested_keys=True, filter_empty=True) - if out_tds is not None: - out.update( - LazyStackedTensorDict(*out_tds), keys_to_update=self._non_tensor_keys + out = next_td.named_apply( + select_and_clone, nested_keys=True, filter_empty=True ) + if out_tds is not None: + out.update( + LazyStackedTensorDict(*out_tds), + keys_to_update=self._non_tensor_keys, + ) + + if out.device != device: + if device is None: + out = out.clear_device_() + elif out.device != device: + out = out.to(device, non_blocking=self.non_blocking) + self._sync_w2m() + else: + for i, _data_in in zip(workers_range, data_in): + out_td = self._envs[i]._step(_data_in) + out_tds.append(out_td) + out = LazyStackedTensorDict.maybe_dense_stack(out_tds) + + if partial_steps is not None: + result = out.new_zeros(tensordict_save.shape) + result[partial_steps] = out + return result - if out.device != device: - if device is None: - out = out.clear_device_() - elif out.device != device: - out = out.to(device, non_blocking=self.non_blocking) - self._sync_w2m() return out def __getattr__(self, attr: str) -> Any: @@ -1435,20 +1457,27 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: def _step_and_maybe_reset_no_buffers( self, tensordict: TensorDictBase ) -> Tuple[TensorDictBase, TensorDictBase]: + partial_steps = tensordict.get("_partial_steps", None) + tensordict_save = tensordict + if partial_steps is not None: + tensordict = tensordict[partial_steps] + workers_range = partial_steps.nonzero().squeeze().tolist() + else: + workers_range = range(self.num_workers) td = tensordict.consolidate(share_memory=True, inplace=True, num_threads=1) - for i in range(td.shape[0]): + for i in workers_range: # We send the same td multiple times as it is in shared mem and we just need to index it # in each process. # If we don't do this, we need to unbind it but then the custom pickler will require # some extra metadata to be collected. self.parent_channels[i].send(("step_and_maybe_reset", (td, i))) - results = [None] * self.num_workers + results = [None] * len(workers_range) consumed_indices = [] - events = set(range(self.num_workers)) - while len(consumed_indices) < self.num_workers: + events = set(workers_range) + while len(consumed_indices) < len(workers_range): for i in list(events): if self._events[i].is_set(): results[i] = self.parent_channels[i].recv() @@ -1457,9 +1486,14 @@ def _step_and_maybe_reset_no_buffers( events.discard(i) out_next, out_root = zip(*(future for future in results)) - return TensorDict.maybe_dense_stack(out_next), TensorDict.maybe_dense_stack( + out = TensorDict.maybe_dense_stack(out_next), TensorDict.maybe_dense_stack( out_root ) + if partial_steps is not None: + result = out.new_zeros(tensordict_save.shape) + result[partial_steps] = out + return result + return out @torch.no_grad() @_check_start @@ -1471,6 +1505,26 @@ def step_and_maybe_reset( # return self._step_and_maybe_reset_no_buffers(tensordict) return super().step_and_maybe_reset(tensordict) + partial_steps = tensordict.get("_partial_steps", None) + tensordict_save = tensordict + if partial_steps is not None: + shared_tensordict_parent = ( + self.shared_tensordict_parent._get_sub_tensordict(partial_steps) + ) + next_td = self._shared_tensordict_parent_next._get_sub_tensordict( + partial_steps + ) + tensordict_ = self._shared_tensordict_parent_root._get_sub_tensordict( + partial_steps + ) + tensordict = tensordict[partial_steps] + workers_range = partial_steps.nonzero().squeeze().tolist() + else: + workers_range = range(self.num_workers) + shared_tensordict_parent = self.shared_tensordict_parent + next_td = self._shared_tensordict_parent_next + tensordict_ = self._shared_tensordict_parent_root + # We must use the in_keys and nothing else for the following reasons: # - efficiency: copying all the keys will in practice mean doing a lot # of writing operations since the input tensordict may (and often will) @@ -1479,7 +1533,7 @@ def step_and_maybe_reset( # and this transform overrides an observation key (eg, CatFrames) # the shape, dtype or device may not necessarily match and writing # the value in-place will fail. - self.shared_tensordict_parent.update_( + shared_tensordict_parent.update_( tensordict, keys_to_update=self._env_input_keys, non_blocking=self.non_blocking, @@ -1489,46 +1543,41 @@ def step_and_maybe_reset( # if we have input "next" data (eg, RNNs which pass the next state) # the sub-envs will need to process them through step_and_maybe_reset. # We keep track of which keys are present to let the worker know what - # should be passd to the env (we don't want to pass done states for instance) + # should be passed to the env (we don't want to pass done states for instance) next_td_keys = list(next_td_passthrough.keys(True, True)) - data = [ - {"next_td_passthrough_keys": next_td_keys} - for _ in range(self.num_workers) - ] - self.shared_tensordict_parent.get("next").update_( + data = [{"next_td_passthrough_keys": next_td_keys} for _ in workers_range] + shared_tensordict_parent.get("next").update_( next_td_passthrough, non_blocking=self.non_blocking ) else: # next_td_keys = None - data = [{} for _ in range(self.num_workers)] + data = [{} for _ in workers_range] if self._non_tensor_keys: - for i in range(self.num_workers): + for i in workers_range: data[i]["non_tensor_data"] = tensordict[i].select( *self._non_tensor_keys, strict=False ) self._sync_m2w() - for i in range(self.num_workers): - self.parent_channels[i].send(("step_and_maybe_reset", data[i])) + for i, _data in zip(workers_range, data): + self.parent_channels[i].send(("step_and_maybe_reset", _data)) - for i in range(self.num_workers): + for i in workers_range: event = self._events[i] event.wait(self._timeout) event.clear() if self._non_tensor_keys: non_tensor_tds = [] - for i in range(self.num_workers): + for i in workers_range: msg, non_tensor_td = self.parent_channels[i].recv() non_tensor_tds.append(non_tensor_td) # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps - next_td = self._shared_tensordict_parent_next - tensordict_ = self._shared_tensordict_parent_root device = self.device - if self.shared_tensordict_parent.device == device: + if shared_tensordict_parent.device == device: next_td = next_td.clone() tensordict_ = tensordict_.clone() elif device is not None: @@ -1558,22 +1607,44 @@ def step_and_maybe_reset( keys_to_update=[("next", key) for key in self._non_tensor_keys], ) tensordict_.update(non_tensor_tds, keys_to_update=self._non_tensor_keys) + + if partial_steps is not None: + result = tensordict.new_zeros(tensordict_save.shape) + result_ = tensordict_.new_zeros(tensordict_save.shape) + result[partial_steps] = tensordict + result_[partial_steps] = tensordict_ + return result, result_ + return tensordict, tensordict_ def _step_no_buffers( self, tensordict: TensorDictBase ) -> Tuple[TensorDictBase, TensorDictBase]: + partial_steps = tensordict.get("_partial_steps", None) + tensordict_save = tensordict + if partial_steps is not None: + tensordict = tensordict[partial_steps] + workers_range = partial_steps.nonzero().squeeze().tolist() + else: + workers_range = range(self.num_workers) + data = tensordict.consolidate(share_memory=True, inplace=True, num_threads=1) - for i, local_data in enumerate(data.unbind(0)): + for i, local_data in zip(workers_range, data.unbind(0)): self.parent_channels[i].send(("step", local_data)) # for i in range(data.shape[0]): # self.parent_channels[i].send(("step", (data, i))) out_tds = [] - for i, channel in enumerate(self.parent_channels): + for i in workers_range: + channel = self.parent_channels[i] self._events[i].wait() td = channel.recv() out_tds.append(td) - return LazyStackedTensorDict.maybe_dense_stack(out_tds) + out = LazyStackedTensorDict.maybe_dense_stack(out_tds) + if partial_steps is not None: + result = out.new_zeros(tensordict_save.shape) + result[partial_steps] = out + return result + return out @torch.no_grad() @_check_start @@ -1588,8 +1659,19 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # and this transform overrides an observation key (eg, CatFrames) # the shape, dtype or device may not necessarily match and writing # the value in-place will fail. + partial_steps = tensordict.get("_partial_steps", None) + tensordict_save = tensordict + if partial_steps is not None: + shared_tensordict_parent = ( + self.shared_tensordict_parent._get_sub_tensordict(partial_steps) + ) + tensordict = tensordict[partial_steps] + workers_range = partial_steps.nonzero().squeeze().tolist() + else: + workers_range = range(self.num_workers) + shared_tensordict_parent = self.shared_tensordict_parent - self.shared_tensordict_parent.update_( + shared_tensordict_parent.update_( tensordict, keys_to_update=list(self._env_input_keys), non_blocking=self.non_blocking, @@ -1605,14 +1687,14 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: {"next_td_passthrough_keys": next_td_keys} for _ in range(self.num_workers) ] - self.shared_tensordict_parent.get("next").update_( + shared_tensordict_parent.get("next").update_( next_td_passthrough, non_blocking=self.non_blocking ) else: data = [{} for _ in range(self.num_workers)] if self._non_tensor_keys: - for i in range(self.num_workers): + for i in workers_range: data[i]["non_tensor_data"] = tensordict[i].select( *self._non_tensor_keys, strict=False ) @@ -1622,23 +1704,23 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: if self.event is not None: self.event.record() self.event.synchronize() - for i in range(self.num_workers): + for i in workers_range: self.parent_channels[i].send(("step", data[i])) - for i in range(self.num_workers): + for i in workers_range: event = self._events[i] event.wait(self._timeout) event.clear() if self._non_tensor_keys: non_tensor_tds = [] - for i in range(self.num_workers): + for i in workers_range: msg, non_tensor_td = self.parent_channels[i].recv() non_tensor_tds.append(non_tensor_td) # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps - next_td = self.shared_tensordict_parent.get("next") + next_td = shared_tensordict_parent.get("next") device = self.device if next_td.device != device and device is not None: @@ -1665,6 +1747,10 @@ def select_and_clone(name, tensor): keys_to_update=self._non_tensor_keys, ) self._sync_w2m() + if partial_steps is not None: + result = out.new_zeros(tensordict_save.shape) + result[partial_steps] = out + return result return out def _reset_no_buffers( From 61e2eb6b2162997b2d51d071f9d0bdd632c4ee01 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 7 Aug 2024 20:53:33 -0400 Subject: [PATCH 02/12] Update [ghstack-poisoned] --- torchrl/envs/batched_envs.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index c886fe6ef6b..66f72245289 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1068,6 +1068,8 @@ def _step( ) -> TensorDict: partial_steps = tensordict.get("_partial_steps", None) tensordict_save = tensordict + if partial_steps is not None and partial_steps.all(): + partial_steps = None if partial_steps is not None: tensordict = tensordict[partial_steps] workers_range = partial_steps.nonzero().squeeze().tolist() @@ -1459,6 +1461,8 @@ def _step_and_maybe_reset_no_buffers( ) -> Tuple[TensorDictBase, TensorDictBase]: partial_steps = tensordict.get("_partial_steps", None) tensordict_save = tensordict + if partial_steps is not None and partial_steps.all(): + partial_steps = None if partial_steps is not None: tensordict = tensordict[partial_steps] workers_range = partial_steps.nonzero().squeeze().tolist() @@ -1507,6 +1511,8 @@ def step_and_maybe_reset( partial_steps = tensordict.get("_partial_steps", None) tensordict_save = tensordict + if partial_steps is not None and partial_steps.all(): + partial_steps = None if partial_steps is not None: shared_tensordict_parent = ( self.shared_tensordict_parent._get_sub_tensordict(partial_steps) @@ -1622,6 +1628,8 @@ def _step_no_buffers( ) -> Tuple[TensorDictBase, TensorDictBase]: partial_steps = tensordict.get("_partial_steps", None) tensordict_save = tensordict + if partial_steps is not None and partial_steps.all(): + partial_steps = None if partial_steps is not None: tensordict = tensordict[partial_steps] workers_range = partial_steps.nonzero().squeeze().tolist() @@ -1661,6 +1669,8 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # the value in-place will fail. partial_steps = tensordict.get("_partial_steps", None) tensordict_save = tensordict + if partial_steps is not None and partial_steps.all(): + partial_steps = None if partial_steps is not None: shared_tensordict_parent = ( self.shared_tensordict_parent._get_sub_tensordict(partial_steps) From e783d975bc67901600665b7930f7386d2def20a5 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 9 Aug 2024 10:40:13 -0400 Subject: [PATCH 03/12] Update [ghstack-poisoned] --- test/test_env.py | 17 +++++++++-------- torchrl/envs/batched_envs.py | 10 +++++----- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index 78a9a46ea20..bbec29a0d78 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import argparse +import contextlib import functools import gc import os.path @@ -3347,7 +3348,7 @@ class TestPartialSteps: def test_parallel_partial_steps( self, use_buffers, device, env_device, maybe_fork_ParallelEnv ): - with torch.device(device): + with torch.device(device) if device is not None else contextlib.nullcontext(): penv = maybe_fork_ParallelEnv( 4, lambda: CountingEnv(max_steps=10, start_val=2, device=env_device), @@ -3357,7 +3358,7 @@ def test_parallel_partial_steps( td = penv.reset() psteps = torch.zeros(4, dtype=torch.bool) psteps[[1, 3]] = True - td.set("_partial_steps", psteps) + td.set("_step", psteps) td.set("action", penv.action_spec.one()) td = penv.step(td) @@ -3370,7 +3371,7 @@ def test_parallel_partial_steps( def test_parallel_partial_step_and_maybe_reset( self, use_buffers, device, env_device, maybe_fork_ParallelEnv ): - with torch.device(device): + with torch.device(device) if device is not None else contextlib.nullcontext(): penv = maybe_fork_ParallelEnv( 4, lambda: CountingEnv(max_steps=10, start_val=2, device=env_device), @@ -3380,7 +3381,7 @@ def test_parallel_partial_step_and_maybe_reset( td = penv.reset() psteps = torch.zeros(4, dtype=torch.bool) psteps[[1, 3]] = True - td.set("_partial_steps", psteps) + td.set("_step", psteps) td.set("action", penv.action_spec.one()) td, tdreset = penv.step_and_maybe_reset(td) @@ -3391,7 +3392,7 @@ def test_parallel_partial_step_and_maybe_reset( @pytest.mark.parametrize("use_buffers", [False, True]) def test_serial_partial_steps(self, use_buffers, device, env_device): - with torch.device(device): + with torch.device(device) if device is not None else contextlib.nullcontext(): penv = SerialEnv( 4, lambda: CountingEnv(max_steps=10, start_val=2, device=env_device), @@ -3401,7 +3402,7 @@ def test_serial_partial_steps(self, use_buffers, device, env_device): td = penv.reset() psteps = torch.zeros(4, dtype=torch.bool) psteps[[1, 3]] = True - td.set("_partial_steps", psteps) + td.set("_step", psteps) td.set("action", penv.action_spec.one()) td = penv.step(td) @@ -3412,7 +3413,7 @@ def test_serial_partial_steps(self, use_buffers, device, env_device): @pytest.mark.parametrize("use_buffers", [False, True]) def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_device): - with torch.device(device): + with torch.device(device) if device is not None else contextlib.nullcontext(): penv = SerialEnv( 4, lambda: CountingEnv(max_steps=10, start_val=2, device=env_device), @@ -3422,7 +3423,7 @@ def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_devi td = penv.reset() psteps = torch.zeros(4, dtype=torch.bool) psteps[[1, 3]] = True - td.set("_partial_steps", psteps) + td.set("_step", psteps) td.set("action", penv.action_spec.one()) td = penv.step(td) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index b1f3c7dad76..65b19dc6645 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1066,7 +1066,7 @@ def _step( self, tensordict: TensorDict, ) -> TensorDict: - partial_steps = tensordict.get("_partial_steps", None) + partial_steps = tensordict.get("_step", None) tensordict_save = tensordict if partial_steps is not None and partial_steps.all(): partial_steps = None @@ -1461,7 +1461,7 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: def _step_and_maybe_reset_no_buffers( self, tensordict: TensorDictBase ) -> Tuple[TensorDictBase, TensorDictBase]: - partial_steps = tensordict.get("_partial_steps", None) + partial_steps = tensordict.get("_step", None) tensordict_save = tensordict if partial_steps is not None and partial_steps.all(): partial_steps = None @@ -1511,7 +1511,7 @@ def step_and_maybe_reset( # return self._step_and_maybe_reset_no_buffers(tensordict) return super().step_and_maybe_reset(tensordict) - partial_steps = tensordict.get("_partial_steps", None) + partial_steps = tensordict.get("_step", None) tensordict_save = tensordict if partial_steps is not None and partial_steps.all(): partial_steps = None @@ -1628,7 +1628,7 @@ def step_and_maybe_reset( def _step_no_buffers( self, tensordict: TensorDictBase ) -> Tuple[TensorDictBase, TensorDictBase]: - partial_steps = tensordict.get("_partial_steps", None) + partial_steps = tensordict.get("_step", None) tensordict_save = tensordict if partial_steps is not None and partial_steps.all(): partial_steps = None @@ -1669,7 +1669,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # and this transform overrides an observation key (eg, CatFrames) # the shape, dtype or device may not necessarily match and writing # the value in-place will fail. - partial_steps = tensordict.get("_partial_steps", None) + partial_steps = tensordict.get("_step", None) tensordict_save = tensordict if partial_steps is not None and partial_steps.all(): partial_steps = None From be1c691a5d54d18e9e636616a69a7ff28a2fa73d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 11 Aug 2024 09:31:53 -0400 Subject: [PATCH 04/12] Update [ghstack-poisoned] --- torchrl/envs/batched_envs.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 65b19dc6645..378b85ffba9 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1031,12 +1031,18 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: if out_tds is not None: out_tds[i] = _td + device = self.device if not self._use_buffers: result = LazyStackedTensorDict.maybe_dense_stack(out_tds) + if result.device != device: + if device is None: + result = result.clear_device_() + else: + result = result.to(device, non_blocking=self.non_blocking) + self._sync_w2m() return result selected_output_keys = self._selected_reset_keys_filt - device = self.device # select + clone creates 2 tds, but we can create one only def select_and_clone(name, tensor): @@ -1650,6 +1656,8 @@ def _step_no_buffers( td = channel.recv() out_tds.append(td) out = LazyStackedTensorDict.maybe_dense_stack(out_tds) + if self.device is not None and out.device != self.device: + out = out.to(self.device, non_blocking=self.non_blocking) if partial_steps is not None: result = out.new_zeros(tensordict_save.shape) result[partial_steps] = out @@ -1796,7 +1804,11 @@ def _reset_no_buffers( self._events[i].wait() td = channel.recv() out_tds[i] = td - return LazyStackedTensorDict.maybe_dense_stack(out_tds) + result = LazyStackedTensorDict.maybe_dense_stack(out_tds) + device = self.device + if device is not None and result.device != device: + return result.to(self.device, non_blocking=self.non_blocking) + return result @torch.no_grad() @_check_start From 79d0bb12d48838881142eaa8c6f703a342a9e9ad Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 11 Aug 2024 09:36:21 -0400 Subject: [PATCH 05/12] Update [ghstack-poisoned] --- test/mocking_classes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 795fda399de..af7e3354d89 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1038,7 +1038,7 @@ def _step( tensordict: TensorDictBase, ) -> TensorDictBase: action = tensordict.get(self.action_key) - self.count += action.to(dtype=torch.int, device=self.device) + self.count += action.to(dtype=torch.int, device=self.full_observation_spec["observation"].device) tensordict = TensorDict( source={ "observation": self.count.clone(), From 7e73fd036912bb0ce2697a4aca5b4aa2bbd1298c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 11 Aug 2024 09:51:28 -0400 Subject: [PATCH 06/12] Update [ghstack-poisoned] --- torchrl/envs/batched_envs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 378b85ffba9..b1874a56122 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1685,7 +1685,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: shared_tensordict_parent = ( self.shared_tensordict_parent._get_sub_tensordict(partial_steps) ) - tensordict = tensordict[partial_steps] + tensordict = tensordict._fast_apply(lambda x, y: x[partial_steps].to(y.device), shared_tensordict_parent) workers_range = partial_steps.nonzero().squeeze().tolist() else: workers_range = range(self.num_workers) From 6c01e9c9759e997815b0989dd764a9163c234e10 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 11 Aug 2024 09:55:17 -0400 Subject: [PATCH 07/12] Update [ghstack-poisoned] --- torchrl/envs/batched_envs.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index b1874a56122..3f69a4a0446 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1685,7 +1685,14 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: shared_tensordict_parent = ( self.shared_tensordict_parent._get_sub_tensordict(partial_steps) ) - tensordict = tensordict._fast_apply(lambda x, y: x[partial_steps].to(y.device), shared_tensordict_parent) + tensordict = tensordict._fast_apply( + lambda x, y: x[partial_steps].to(y.device) + if y is not None + else x[partial_steps], + shared_tensordict_parent, + default=None, + device=shared_tensordict_parent.device, + ) workers_range = partial_steps.nonzero().squeeze().tolist() else: workers_range = range(self.num_workers) From 8a0308bbfaa38bb682cc535ab875637badcb43bd Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 11 Aug 2024 11:34:10 -0400 Subject: [PATCH 08/12] Update [ghstack-poisoned] --- test/mocking_classes.py | 4 ++- torchrl/envs/batched_envs.py | 61 ++++++++++++++++++++++-------------- 2 files changed, 41 insertions(+), 24 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index af7e3354d89..a90b1fa28ec 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1038,7 +1038,9 @@ def _step( tensordict: TensorDictBase, ) -> TensorDictBase: action = tensordict.get(self.action_key) - self.count += action.to(dtype=torch.int, device=self.full_observation_spec["observation"].device) + self.count += action.to( + dtype=torch.int, device=self.full_observation_spec["observation"].device + ) tensordict = TensorDict( source={ "observation": self.count.clone(), diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 3f69a4a0446..560e652e782 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1080,10 +1080,6 @@ def _step( tensordict = tensordict[partial_steps] workers_range = partial_steps.nonzero().squeeze().tolist() tensordict_in = tensordict - # if self._use_buffers: - # shared_tensordict_parent = ( - # self.shared_tensordict_parent._get_sub_tensordict(partial_steps) - # ) else: workers_range = range(self.num_workers) tensordict_in = tensordict.clone(False) @@ -1522,17 +1518,30 @@ def step_and_maybe_reset( if partial_steps is not None and partial_steps.all(): partial_steps = None if partial_steps is not None: - shared_tensordict_parent = ( - self.shared_tensordict_parent._get_sub_tensordict(partial_steps) + workers_range = partial_steps.nonzero().squeeze().tolist() + shared_tensordict_parent = TensorDict.lazy_stack( + [self._shared_tensordict[i] for i in workers_range] ) - next_td = self._shared_tensordict_parent_next._get_sub_tensordict( - partial_steps + next_td = TensorDict.lazy_stack( + [self._shared_tensordict_parent_next[i] for i in workers_range] ) - tensordict_ = self._shared_tensordict_parent_root._get_sub_tensordict( - partial_steps + tensordict_ = TensorDict.lazy_stack( + [self._shared_tensordict_parent_root[i] for i in workers_range] ) - tensordict = tensordict[partial_steps] - workers_range = partial_steps.nonzero().squeeze().tolist() + if self.shared_tensordict_parent.device is None: + tensordict = tensordict._fast_apply( + lambda x, y: x[partial_steps].to(y.device) + if y is not None + else x[partial_steps], + self.shared_tensordict_parent, + default=None, + device=None, + batch_size=shared_tensordict_parent.shape, + ) + else: + tensordict = tensordict[partial_steps].to( + self.shared_tensordict_parent.device + ) else: workers_range = range(self.num_workers) shared_tensordict_parent = self.shared_tensordict_parent @@ -1682,18 +1691,24 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: if partial_steps is not None and partial_steps.all(): partial_steps = None if partial_steps is not None: - shared_tensordict_parent = ( - self.shared_tensordict_parent._get_sub_tensordict(partial_steps) - ) - tensordict = tensordict._fast_apply( - lambda x, y: x[partial_steps].to(y.device) - if y is not None - else x[partial_steps], - shared_tensordict_parent, - default=None, - device=shared_tensordict_parent.device, - ) workers_range = partial_steps.nonzero().squeeze().tolist() + shared_tensordict_parent = TensorDict.lazy_stack( + [self.shared_tensordicts[i] for i in workers_range] + ) + if self.shared_tensordict_parent.device is None: + tensordict = tensordict._fast_apply( + lambda x, y: x[partial_steps].to(y.device) + if y is not None + else x[partial_steps], + self.shared_tensordict_parent, + default=None, + device=None, + batch_size=shared_tensordict_parent.shape, + ) + else: + tensordict = tensordict[partial_steps].to( + self.shared_tensordict_parent.device + ) else: workers_range = range(self.num_workers) shared_tensordict_parent = self.shared_tensordict_parent From 5409a2a2f1550ef8fb148d21b83c360680f5a4c4 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 11 Aug 2024 11:39:17 -0400 Subject: [PATCH 09/12] Update [ghstack-poisoned] --- torchrl/envs/batched_envs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 560e652e782..0e9fd121286 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1520,7 +1520,7 @@ def step_and_maybe_reset( if partial_steps is not None: workers_range = partial_steps.nonzero().squeeze().tolist() shared_tensordict_parent = TensorDict.lazy_stack( - [self._shared_tensordict[i] for i in workers_range] + [self.shared_tensordict_parent[i] for i in workers_range] ) next_td = TensorDict.lazy_stack( [self._shared_tensordict_parent_next[i] for i in workers_range] From 3b7c5a7bf987ff883e0e5b36bac57601bde46bc1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 11 Aug 2024 11:46:16 -0400 Subject: [PATCH 10/12] Update [ghstack-poisoned] --- torchrl/envs/batched_envs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 0e9fd121286..6a1d5375f85 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1122,7 +1122,7 @@ def select_and_clone(name, tensor): return tensor.clone() if partial_steps is not None: - next_td = next_td._get_sub_tensordict(partial_steps) + next_td = LazyStackedTensorDict([next_td[i] for i in workers_range]) out = next_td.named_apply( select_and_clone, nested_keys=True, filter_empty=True ) From 65ca9b82467121e0796a187827efdf063a4f2904 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 11 Aug 2024 11:48:09 -0400 Subject: [PATCH 11/12] Update [ghstack-poisoned] --- torchrl/envs/batched_envs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 6a1d5375f85..73ecdba64a9 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1122,7 +1122,7 @@ def select_and_clone(name, tensor): return tensor.clone() if partial_steps is not None: - next_td = LazyStackedTensorDict([next_td[i] for i in workers_range]) + next_td = TensorDict.lazy_stack([next_td[i] for i in workers_range]) out = next_td.named_apply( select_and_clone, nested_keys=True, filter_empty=True ) From f56b8e0871a85349be547186b3c47e8a325821ff Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 11 Aug 2024 12:34:08 -0400 Subject: [PATCH 12/12] Update [ghstack-poisoned] --- test/mocking_classes.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index a90b1fa28ec..4d86d8ec0ac 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1039,7 +1039,8 @@ def _step( ) -> TensorDictBase: action = tensordict.get(self.action_key) self.count += action.to( - dtype=torch.int, device=self.full_observation_spec["observation"].device + dtype=torch.int, + device=self.action_spec.device if self.device is None else self.device, ) tensordict = TensorDict( source={