diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 31733887f..c9a1dcae8 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -50,7 +50,6 @@ _get_shape_from_args, _getitem_batch_size, _is_number, - _parse_to, _renamed_inplace_method, _shape, _td_fields, @@ -1324,33 +1323,41 @@ def pin_memory(self) -> T: td.pin_memory() return self - def to(self, *args, **kwargs) -> T: - non_blocking = kwargs.pop("non_blocking", None) - device, dtype, _, convert_to_format, batch_size = _parse_to(*args, **kwargs) + def _to( + self, *, device, dtype, convert_to_format, batch_size, non_blocking + ) -> Tuple[T, bool]: if batch_size is not None: raise TypeError("Cannot pass batch-size to a LazyStackedTensorDict.") result = self if device is not None and dtype is None and device == self.device: - return result + return result, False if non_blocking in (None, True): - kwargs["non_blocking"] = True + sub_non_blocking = True else: - kwargs["non_blocking"] = False - non_blocking = bool(non_blocking) + sub_non_blocking = False + tds, must_sync = zip( + *( + td._to( + device=device, + dtype=dtype, + convert_to_format=convert_to_format, + batch_size=batch_size, + non_blocking=sub_non_blocking, + ) + for td in self.tensordicts + ) + ) + must_sync = any(must_sync) result = type(self)( - *[td.to(*args, **kwargs) for td in self.tensordicts], + *tds, stack_dim=self.stack_dim, hook_out=self.hook_out, hook_in=self.hook_in, stack_dim_name=self._td_dim_name, ) - if device is not None and not non_blocking: - self._sync_all() - if self.is_locked: - result.lock_() - return result + return result, must_sync def _check_new_batch_size(self, new_size: torch.Size) -> None: if len(new_size) <= self.stack_dim: @@ -2992,20 +2999,26 @@ def del_(self, key: NestedKey) -> _CustomOpTensorDict: self._source = self._source.del_(key) return self - def to(self, *args, **kwargs) -> T: - non_blocking = kwargs.pop("non_blocking", None) - device, dtype, _, convert_to_format, batch_size = _parse_to(*args, **kwargs) + def _to( + self, *, device, dtype, convert_to_format, batch_size, non_blocking + ) -> Tuple[T, bool]: if batch_size is not None: raise TypeError(f"Cannot pass batch-size to a {type(self)}.") result = self if device is not None and dtype is None and device == self.device: - return result + return result, False - td = self._source.to(*args, non_blocking=non_blocking, **kwargs) + td, must_sync = self._source._to( + device=device, + dtype=dtype, + convert_to_format=convert_to_format, + batch_size=batch_size, + non_blocking=non_blocking, + ) self_copy = copy(self) self_copy._source = td - return self_copy + return self_copy, must_sync def pin_memory(self) -> _CustomOpTensorDict: self._source.pin_memory() diff --git a/tensordict/_td.py b/tensordict/_td.py index 72f52c437..6fa08c981 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -61,7 +61,6 @@ _LOCK_ERROR, _NON_STR_KEY_ERR, _NON_STR_KEY_TUPLE_ERR, - _parse_to, _prune_selected_keys, _set_item, _set_max_batch_size, @@ -2507,36 +2506,41 @@ def make_memmap_from_tensor( return memmap_tensor - def to(self, *args, **kwargs: Any) -> T: - non_blocking = kwargs.pop("non_blocking", None) - device, dtype, _, convert_to_format, batch_size = _parse_to(*args, **kwargs) + def _to( + self, *, device, dtype, convert_to_format, batch_size, non_blocking + ) -> Tuple[T, bool]: result = self if device is not None and dtype is None and device == self.device: - return result + return result, False - if non_blocking is None: - sub_non_blocking = True - non_blocking = False - else: - sub_non_blocking = non_blocking + must_sync = False + # this may seem awkward but non_blocking=None acts as True and this makes only one check + # (checking against (True, None) would take longer) + sub_non_blocking = False if non_blocking is False else True if convert_to_format is not None: def to(tensor): - return tensor.to( + nonlocal must_sync + result = tensor.to( device, dtype, non_blocking=sub_non_blocking, convert_to_format=convert_to_format, ) + must_sync |= sub_non_blocking and result is not tensor + return result else: def to(tensor): - return tensor.to( + nonlocal must_sync + result = tensor.to( device=device, dtype=dtype, non_blocking=sub_non_blocking ) + must_sync |= sub_non_blocking and result is not tensor + return result apply_kwargs = {} if device is not None or dtype is not None: @@ -2545,9 +2549,10 @@ def to(tensor): result = result._fast_apply(to, propagate_lock=True, **apply_kwargs) elif batch_size is not None: result.batch_size = batch_size - if device is not None and sub_non_blocking and not non_blocking: - self._sync_all() - return result + if must_sync and device is not None and device.type == "cuda": + # cuda stream is clever enough to do the sync + must_sync = False + return result, must_sync def where(self, condition, other, *, out=None, pad=None): if _is_tensor_collection(other.__class__): @@ -3177,15 +3182,8 @@ def _stack_onto_(self, list_item: list[CompatibleType], dim: int) -> _SubTensorD self._source._stack_onto_at_(list_item, dim=dim, idx=self.idx) return self - def to(self, *args, **kwargs: Any) -> T: - device, dtype, non_blocking, convert_to_format, batch_size = _parse_to( - *args, **kwargs - ) - result = self - - if device is not None and dtype is None and device == self.device: - return result - return self.to_tensordict().to(*args, **kwargs) + def _to(self, *args, **kwargs: Any) -> Tuple[T, bool]: + return self.to_tensordict()._to(*args, **kwargs) def _change_batch_size(self, new_size: torch.Size) -> None: if not hasattr(self, "_orig_batch_size"): diff --git a/tensordict/base.py b/tensordict/base.py index 5494c9037..9d3b10741 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -48,6 +48,7 @@ _is_non_tensor, _is_tensorclass, _KEY_ERROR, + _parse_to, _proc_init, _prune_selected_keys, _set_max_batch_size, @@ -7448,7 +7449,6 @@ def to(self: T, *, other: T, non_blocking: bool = ...) -> T: def to(self: T, *, batch_size: torch.Size) -> T: ... - @abc.abstractmethod def to(self, *args, **kwargs) -> T: """Maps a TensorDictBase subclass either on another device, dtype or to another TensorDictBase subclass (if permitted). @@ -7489,6 +7489,23 @@ def to(self, *args, **kwargs) -> T: >>> data_cuda = data.to(torch.randn(3, device="cuda:0")) # using an example tensor >>> data_cuda = data.to(other=TensorDict({}, [], device="cuda:0")) # using a tensordict example """ + non_blocking = kwargs.pop("non_blocking", None) + device, dtype, _, convert_to_format, batch_size = _parse_to(*args, **kwargs) + result, must_sync = self._to( + device=device, + dtype=dtype, + convert_to_format=convert_to_format, + batch_size=batch_size, + non_blocking=non_blocking, + ) + if must_sync and non_blocking is None: + self._sync_all() + return result + + @abc.abstractmethod + def _to( + self, *, device, dtype, convert_to_format, batch_size, non_blocking + ) -> Tuple[T, bool]: ... def _sync_all(self): diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 6495d049d..9490bce53 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -11,7 +11,7 @@ import weakref from copy import copy from functools import wraps -from typing import Any, Callable, Iterator, OrderedDict, Sequence, Type +from typing import Any, Callable, Iterator, OrderedDict, Sequence, Tuple, Type import torch @@ -526,11 +526,11 @@ def __getitem__(self, index: IndexType) -> TensorDictBase: __getitems__ = __getitem__ - def to(self, *args, **kwargs) -> TensorDictBase: - params = self._param_td.to(*args, **kwargs) + def _to(self, *args, **kwargs) -> Tuple[TensorDictBase, bool]: + params, must_sync = self._param_td._to(*args, **kwargs) if params is self._param_td: - return self - return TensorDictParams(params) + return self, must_sync + return TensorDictParams(params), must_sync def cpu(self): params = self._param_td.cpu() diff --git a/tensordict/persistent.py b/tensordict/persistent.py index 89aaec338..070332254 100644 --- a/tensordict/persistent.py +++ b/tensordict/persistent.py @@ -39,7 +39,6 @@ _CloudpickleWrapper, _KEY_ERROR, _LOCK_ERROR, - _parse_to, _proc_init, _split_tensordict, cache, @@ -966,24 +965,35 @@ def share_memory_(self): "Create a regular tensordict first using the `to_tensordict` method." ) - def to(self, *args, **kwargs: Any) -> PersistentTensorDict: - device, dtype, non_blocking, convert_to_format, batch_size = _parse_to( - *args, **kwargs - ) + def _to( + self, *, device, dtype, convert_to_format, batch_size, non_blocking + ) -> Tuple[T, bool]: result = self if device is not None and dtype is None and device == self.device: - return result + return result, False if dtype is not None: - return self.to_tensordict().to(*args, **kwargs) + return self.to_tensordict()._to( + device=device, + dtype=dtype, + convert_to_format=convert_to_format, + batch_size=batch_size, + non_blocking=non_blocking, + ) result = self if device is not None: result = result.clone(False) result._device = device for key, nested in list(result._nested_tensordicts.items()): - result._nested_tensordicts[key] = nested.to(device) + result._nested_tensordicts[key], _ = nested._to( + device=device, + dtype=dtype, + convert_to_format=convert_to_format, + batch_size=batch_size, + non_blocking=non_blocking, + ) if batch_size is not None: result.batch_size = batch_size - return result + return result, False def _to_numpy(self, value): if hasattr(value, "requires_grad") and value.requires_grad: diff --git a/tensordict/utils.py b/tensordict/utils.py index 67cb80474..074454309 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -2274,3 +2274,18 @@ def __missing__(self, key): value = self.fun(key) self[key] = value return value + + +def _prefix_last_key(key, prefix): + if isinstance(key, str): + return prefix + key + if len(key) == 1: + return (_prefix_last_key(key[0], prefix),) + return key[:-1] + (_prefix_last_key(key[-1], prefix),) + + +NESTED_TENSOR_ERR = ( + "The PyTorch version isn't compatible with " + "nested tensors. Please upgrade to a more recent " + "version." +)