Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Fewer syncs during calls to to #819

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 33 additions & 20 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
_get_shape_from_args,
_getitem_batch_size,
_is_number,
_parse_to,
_renamed_inplace_method,
_shape,
_td_fields,
Expand Down Expand Up @@ -1324,33 +1323,41 @@ def pin_memory(self) -> T:
td.pin_memory()
return self

def to(self, *args, **kwargs) -> T:
non_blocking = kwargs.pop("non_blocking", None)
device, dtype, _, convert_to_format, batch_size = _parse_to(*args, **kwargs)
def _to(
self, *, device, dtype, convert_to_format, batch_size, non_blocking
) -> Tuple[T, bool]:
if batch_size is not None:
raise TypeError("Cannot pass batch-size to a LazyStackedTensorDict.")
result = self

if device is not None and dtype is None and device == self.device:
return result
return result, False

if non_blocking in (None, True):
kwargs["non_blocking"] = True
sub_non_blocking = True
else:
kwargs["non_blocking"] = False
non_blocking = bool(non_blocking)
sub_non_blocking = False
tds, must_sync = zip(
*(
td._to(
device=device,
dtype=dtype,
convert_to_format=convert_to_format,
batch_size=batch_size,
non_blocking=sub_non_blocking,
)
for td in self.tensordicts
)
)
must_sync = any(must_sync)
result = type(self)(
*[td.to(*args, **kwargs) for td in self.tensordicts],
*tds,
stack_dim=self.stack_dim,
hook_out=self.hook_out,
hook_in=self.hook_in,
stack_dim_name=self._td_dim_name,
)
if device is not None and not non_blocking:
self._sync_all()
if self.is_locked:
result.lock_()
return result
return result, must_sync

def _check_new_batch_size(self, new_size: torch.Size) -> None:
if len(new_size) <= self.stack_dim:
Expand Down Expand Up @@ -2992,20 +2999,26 @@ def del_(self, key: NestedKey) -> _CustomOpTensorDict:
self._source = self._source.del_(key)
return self

def to(self, *args, **kwargs) -> T:
non_blocking = kwargs.pop("non_blocking", None)
device, dtype, _, convert_to_format, batch_size = _parse_to(*args, **kwargs)
def _to(
self, *, device, dtype, convert_to_format, batch_size, non_blocking
) -> Tuple[T, bool]:
if batch_size is not None:
raise TypeError(f"Cannot pass batch-size to a {type(self)}.")
result = self

if device is not None and dtype is None and device == self.device:
return result
return result, False

td = self._source.to(*args, non_blocking=non_blocking, **kwargs)
td, must_sync = self._source._to(
device=device,
dtype=dtype,
convert_to_format=convert_to_format,
batch_size=batch_size,
non_blocking=non_blocking,
)
self_copy = copy(self)
self_copy._source = td
return self_copy
return self_copy, must_sync

def pin_memory(self) -> _CustomOpTensorDict:
self._source.pin_memory()
Expand Down
46 changes: 22 additions & 24 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
_LOCK_ERROR,
_NON_STR_KEY_ERR,
_NON_STR_KEY_TUPLE_ERR,
_parse_to,
_prune_selected_keys,
_set_item,
_set_max_batch_size,
Expand Down Expand Up @@ -2507,36 +2506,41 @@ def make_memmap_from_tensor(

return memmap_tensor

def to(self, *args, **kwargs: Any) -> T:
non_blocking = kwargs.pop("non_blocking", None)
device, dtype, _, convert_to_format, batch_size = _parse_to(*args, **kwargs)
def _to(
self, *, device, dtype, convert_to_format, batch_size, non_blocking
) -> Tuple[T, bool]:
result = self

if device is not None and dtype is None and device == self.device:
return result
return result, False

if non_blocking is None:
sub_non_blocking = True
non_blocking = False
else:
sub_non_blocking = non_blocking
must_sync = False
# this may seem awkward but non_blocking=None acts as True and this makes only one check
# (checking against (True, None) would take longer)
sub_non_blocking = False if non_blocking is False else True

if convert_to_format is not None:

def to(tensor):
return tensor.to(
nonlocal must_sync
result = tensor.to(
device,
dtype,
non_blocking=sub_non_blocking,
convert_to_format=convert_to_format,
)
must_sync |= sub_non_blocking and result is not tensor
return result

else:

def to(tensor):
return tensor.to(
nonlocal must_sync
result = tensor.to(
device=device, dtype=dtype, non_blocking=sub_non_blocking
)
must_sync |= sub_non_blocking and result is not tensor
return result

apply_kwargs = {}
if device is not None or dtype is not None:
Expand All @@ -2545,9 +2549,10 @@ def to(tensor):
result = result._fast_apply(to, propagate_lock=True, **apply_kwargs)
elif batch_size is not None:
result.batch_size = batch_size
if device is not None and sub_non_blocking and not non_blocking:
self._sync_all()
return result
if must_sync and device is not None and device.type == "cuda":
# cuda stream is clever enough to do the sync
must_sync = False
return result, must_sync

def where(self, condition, other, *, out=None, pad=None):
if _is_tensor_collection(other.__class__):
Expand Down Expand Up @@ -3177,15 +3182,8 @@ def _stack_onto_(self, list_item: list[CompatibleType], dim: int) -> _SubTensorD
self._source._stack_onto_at_(list_item, dim=dim, idx=self.idx)
return self

def to(self, *args, **kwargs: Any) -> T:
device, dtype, non_blocking, convert_to_format, batch_size = _parse_to(
*args, **kwargs
)
result = self

if device is not None and dtype is None and device == self.device:
return result
return self.to_tensordict().to(*args, **kwargs)
def _to(self, *args, **kwargs: Any) -> Tuple[T, bool]:
return self.to_tensordict()._to(*args, **kwargs)

def _change_batch_size(self, new_size: torch.Size) -> None:
if not hasattr(self, "_orig_batch_size"):
Expand Down
19 changes: 18 additions & 1 deletion tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
_is_non_tensor,
_is_tensorclass,
_KEY_ERROR,
_parse_to,
_proc_init,
_prune_selected_keys,
_set_max_batch_size,
Expand Down Expand Up @@ -7448,7 +7449,6 @@ def to(self: T, *, other: T, non_blocking: bool = ...) -> T:
def to(self: T, *, batch_size: torch.Size) -> T:
...

@abc.abstractmethod
def to(self, *args, **kwargs) -> T:
"""Maps a TensorDictBase subclass either on another device, dtype or to another TensorDictBase subclass (if permitted).

Expand Down Expand Up @@ -7489,6 +7489,23 @@ def to(self, *args, **kwargs) -> T:
>>> data_cuda = data.to(torch.randn(3, device="cuda:0")) # using an example tensor
>>> data_cuda = data.to(other=TensorDict({}, [], device="cuda:0")) # using a tensordict example
"""
non_blocking = kwargs.pop("non_blocking", None)
device, dtype, _, convert_to_format, batch_size = _parse_to(*args, **kwargs)
result, must_sync = self._to(
device=device,
dtype=dtype,
convert_to_format=convert_to_format,
batch_size=batch_size,
non_blocking=non_blocking,
)
if must_sync and non_blocking is None:
self._sync_all()
return result

@abc.abstractmethod
def _to(
self, *, device, dtype, convert_to_format, batch_size, non_blocking
) -> Tuple[T, bool]:
...

def _sync_all(self):
Expand Down
10 changes: 5 additions & 5 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import weakref
from copy import copy
from functools import wraps
from typing import Any, Callable, Iterator, OrderedDict, Sequence, Type
from typing import Any, Callable, Iterator, OrderedDict, Sequence, Tuple, Type

import torch

Expand Down Expand Up @@ -526,11 +526,11 @@ def __getitem__(self, index: IndexType) -> TensorDictBase:

__getitems__ = __getitem__

def to(self, *args, **kwargs) -> TensorDictBase:
params = self._param_td.to(*args, **kwargs)
def _to(self, *args, **kwargs) -> Tuple[TensorDictBase, bool]:
params, must_sync = self._param_td._to(*args, **kwargs)
if params is self._param_td:
return self
return TensorDictParams(params)
return self, must_sync
return TensorDictParams(params), must_sync

def cpu(self):
params = self._param_td.cpu()
Expand Down
28 changes: 19 additions & 9 deletions tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
_CloudpickleWrapper,
_KEY_ERROR,
_LOCK_ERROR,
_parse_to,
_proc_init,
_split_tensordict,
cache,
Expand Down Expand Up @@ -966,24 +965,35 @@ def share_memory_(self):
"Create a regular tensordict first using the `to_tensordict` method."
)

def to(self, *args, **kwargs: Any) -> PersistentTensorDict:
device, dtype, non_blocking, convert_to_format, batch_size = _parse_to(
*args, **kwargs
)
def _to(
self, *, device, dtype, convert_to_format, batch_size, non_blocking
) -> Tuple[T, bool]:
result = self
if device is not None and dtype is None and device == self.device:
return result
return result, False
if dtype is not None:
return self.to_tensordict().to(*args, **kwargs)
return self.to_tensordict()._to(
device=device,
dtype=dtype,
convert_to_format=convert_to_format,
batch_size=batch_size,
non_blocking=non_blocking,
)
result = self
if device is not None:
result = result.clone(False)
result._device = device
for key, nested in list(result._nested_tensordicts.items()):
result._nested_tensordicts[key] = nested.to(device)
result._nested_tensordicts[key], _ = nested._to(
device=device,
dtype=dtype,
convert_to_format=convert_to_format,
batch_size=batch_size,
non_blocking=non_blocking,
)
if batch_size is not None:
result.batch_size = batch_size
return result
return result, False

def _to_numpy(self, value):
if hasattr(value, "requires_grad") and value.requires_grad:
Expand Down
15 changes: 15 additions & 0 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2274,3 +2274,18 @@ def __missing__(self, key):
value = self.fun(key)
self[key] = value
return value


def _prefix_last_key(key, prefix):
if isinstance(key, str):
return prefix + key
if len(key) == 1:
return (_prefix_last_key(key[0], prefix),)
return key[:-1] + (_prefix_last_key(key[-1], prefix),)


NESTED_TENSOR_ERR = (
"The PyTorch version isn't compatible with "
"nested tensors. Please upgrade to a more recent "
"version."
)
Loading