Skip to content

Commit

Permalink
[Performance] Faster to (#524)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Oct 9, 2023
1 parent a61f7a0 commit d445682
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/benchmarks_pr.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Continuous Benchmark
name: Continuous Benchmark (PR)

on:
pull_request:
Expand Down
4 changes: 4 additions & 0 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,10 @@ def apply(
) -> TensorDictBase:
...

@_unlock_and_set(inplace=True)
def _apply_nest(*args, **kwargs):
...

@_get_post_hook
@_fallback
def get(
Expand Down
131 changes: 99 additions & 32 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ def state_dict(
"""
out = collections.OrderedDict()
source = self.apply(memmap_tensor_as_tensor)
source = self._fast_apply(memmap_tensor_as_tensor)
if flatten:
source = source.flatten_keys(".")
for key, item in source.items():
Expand Down Expand Up @@ -1570,6 +1570,28 @@ def apply(
>>> assert (td_2["a"] == -2).all()
>>> assert (td_2["b", "c"] == 2).all()
"""
return self._apply_nest(
fn,
*others,
batch_size=batch_size,
device=device,
names=names,
inplace=inplace,
checked=False,
**constructor_kwargs,
)

def _apply_nest(
self,
fn: Callable,
*others: T,
batch_size: Sequence[int] | None = None,
device: torch.device | None = None,
names: Sequence[str] | None = None,
inplace: bool = False,
checked: bool = False,
**constructor_kwargs,
) -> T:
if inplace:
out = self
elif batch_size is not None:
Expand All @@ -1596,14 +1618,15 @@ def apply(
out.unlock_()

for key, item in self.items():
_others = [_other.get(key) for _other in others]
_others = [_other._get_str(key, default=NO_DEFAULT) for _other in others]
if _is_tensor_collection(item.__class__):
item_trsf = item.apply(
item_trsf = item._apply_nest(
fn,
*_others,
inplace=inplace,
batch_size=batch_size,
device=device,
checked=checked,
**constructor_kwargs,
)
else:
Expand All @@ -1618,13 +1641,41 @@ def apply(
key,
item_trsf,
inplace=BEST_ATTEMPT_INPLACE if inplace else False,
validated=False,
validated=checked,
)

if not inplace and is_locked:
out.lock_()
return out

def _fast_apply(
self,
fn: Callable,
*others: T,
batch_size: Sequence[int] | None = None,
device: torch.device | None = None,
names: Sequence[str] | None = None,
inplace: bool = False,
**constructor_kwargs,
) -> T:
"""A faster apply method.
This method does not run any check after performing the func. This
means that one to make sure that the metadata of the resulting tensors
(device, shape etc.) match the :meth:`~.apply` ones.
"""
return self._apply_nest(
fn,
*others,
batch_size=batch_size,
device=device,
names=names,
inplace=inplace,
checked=True,
**constructor_kwargs,
)

def map(
self,
fn: Callable,
Expand Down Expand Up @@ -1756,7 +1807,7 @@ def as_tensor(self):
"""
try:
return self.apply(lambda x: x.as_tensor())
return self._fast_apply(lambda x: x.as_tensor())
except AttributeError as err:
raise AttributeError(
f"{self.__class__.__name__} does not have an 'as_tensor' method "
Expand Down Expand Up @@ -2187,7 +2238,7 @@ def flatten(tensor):
)
else:
batch_size = [nelt] + list(self.batch_size[end_dim + 1 :])
out = self.apply(flatten, batch_size=batch_size)
out = self._fast_apply(flatten, batch_size=batch_size)
if self._has_names():
names = [
name
Expand Down Expand Up @@ -2233,7 +2284,7 @@ def unflatten(tensor):
)
else:
batch_size = list(unflattened_size) + list(self.batch_size[1:])
out = self.apply(unflatten, batch_size=batch_size)
out = self._fast_apply(unflatten, batch_size=batch_size)
if self._has_names():
names = copy(self.names)
for _ in range(len(unflattened_size) - 1):
Expand Down Expand Up @@ -2537,7 +2588,7 @@ def detach(self) -> T:
a new tensordict with no tensor requiring gradient.
"""
return self.apply(lambda x: x.detach())
return self._fast_apply(lambda x: x.detach())

def to_h5(
self,
Expand Down Expand Up @@ -3732,7 +3783,7 @@ def fill_(self, key: NestedKey, value: float | bool) -> T:
key = _unravel_key_to_tuple(key)
data = self._get_tuple(key, NO_DEFAULT)
if _is_tensor_collection(data.__class__):
data.apply_(lambda x: x.fill_(value))
data._fast_apply(lambda x: x.fill_(value), inplace=True)
# self._set(key, tensordict, inplace=True)
else:
data = data.fill_(value)
Expand Down Expand Up @@ -3872,27 +3923,27 @@ def is_floating_point(self):

def double(self):
r"""Casts all tensors to ``torch.bool``."""
return self.apply(lambda x: x.double())
return self._fast_apply(lambda x: x.double())

def float(self):
r"""Casts all tensors to ``torch.float``."""
return self.apply(lambda x: x.float())
return self._fast_apply(lambda x: x.float())

def int(self):
r"""Casts all tensors to ``torch.int``."""
return self.apply(lambda x: x.int())
return self._fast_apply(lambda x: x.int())

def bool(self):
r"""Casts all tensors to ``torch.bool``."""
return self.apply(lambda x: x.bool())
return self._fast_apply(lambda x: x.bool())

def half(self):
r"""Casts all tensors to ``torch.half``."""
return self.apply(lambda x: x.half())
return self._fast_apply(lambda x: x.half())

def bfloat16(self):
r"""Casts all tensors to ``torch.bfloat16``."""
return self.apply(lambda x: x.bfloat16())
return self._fast_apply(lambda x: x.bfloat16())

def type(self, dst_type):
r"""Casts all tensors to :attr:`dst_type`.
Expand All @@ -3901,7 +3952,7 @@ def type(self, dst_type):
dst_type (type or string): the desired type
"""
return self.apply(lambda x: x.type(dst_type))
return self._fast_apply(lambda x: x.type(dst_type))


_ACCEPTED_CLASSES = [
Expand Down Expand Up @@ -4316,7 +4367,7 @@ def pin_memory(self) -> T:
def pin_mem(tensor):
return tensor.pin_memory()

return self.apply(pin_mem)
return self._fast_apply(pin_mem)

@overload
def expand(self, *shape: int) -> T:
Expand Down Expand Up @@ -4719,6 +4770,18 @@ def to(self, *args, **kwargs: Any) -> T:

if device is not None and dtype is None and device == self.device:
return result
# if device is not None and dtype is None:
# if device == self.device:
# return result
# elif non_blocking:
# return TensorDict(
# self._tensordict,
# device=device,
# names=self.names,
# batch_size=batch_size
# if batch_size is not None
# else self.batch_size,
# )

if convert_to_format is not None:

Expand All @@ -4734,7 +4797,7 @@ def to(tensor):
if device is not None or dtype is not None:
apply_kwargs["device"] = device
apply_kwargs["batch_size"] = batch_size
result = result.apply(to, **apply_kwargs)
result = result._fast_apply(to, **apply_kwargs)
elif batch_size is not None:
result.batch_size = batch_size
return result
Expand All @@ -4748,15 +4811,15 @@ def func(tensor, _other):
expand_as_right(condition, tensor), tensor, _other
)

return self.apply(func, other)
return self._fast_apply(func, other)
else:

def func(tensor):
return torch.where(
expand_as_right(condition, tensor), tensor, other
)

return self.apply(func)
return self._fast_apply(func)
else:
if _is_tensor_collection(other.__class__):

Expand All @@ -4765,15 +4828,15 @@ def func(tensor, _other, _out):
expand_as_right(condition, tensor), tensor, _other, out=_out
)

return self.apply(func, other, out)
return self._fast_apply(func, other, out)
else:

def func(tensor, _out):
return torch.where(
expand_as_right(condition, tensor), tensor, other, out=_out
)

return self.apply(func, out)
return self._fast_apply(func, out)

def masked_fill_(self, mask: Tensor, value: float | int | bool) -> T:
for item in self.values():
Expand Down Expand Up @@ -5132,7 +5195,7 @@ def _full_like(td: T, fill_value: float, **kwargs: Any) -> T:

@implements_for_td(torch.zeros_like)
def _zeros_like(td: T, **kwargs: Any) -> T:
td_clone = td.apply(torch.zeros_like)
td_clone = td._fast_apply(torch.zeros_like)
if "dtype" in kwargs:
raise ValueError("Cannot pass dtype to full_like with TensorDict")
if "device" in kwargs:
Expand All @@ -5147,7 +5210,7 @@ def _zeros_like(td: T, **kwargs: Any) -> T:

@implements_for_td(torch.ones_like)
def _ones_like(td: T, **kwargs: Any) -> T:
td_clone = td.apply(lambda x: torch.ones_like(x))
td_clone = td._fast_apply(lambda x: torch.ones_like(x))
if "device" in kwargs:
td_clone = td_clone.to(kwargs.pop("device"))
if len(kwargs):
Expand All @@ -5168,7 +5231,9 @@ def _empty_like(td: T, *args, **kwargs) -> T:
"cloned, preventing empty_like to be called. "
"Consider calling tensordict.to_tensordict() first."
) from err
return tdclone.apply_(lambda x: torch.empty_like(x, *args, **kwargs))
return tdclone._fast_apply(
lambda x: torch.empty_like(x, *args, **kwargs), inplace=True
)


@implements_for_td(torch.clone)
Expand Down Expand Up @@ -5733,7 +5798,7 @@ def names(self):
@names.setter
def names(self, value):
raise RuntimeError(
"Names of a subtensordict cannot be modified. Instantiate the tensordict first."
"Names of a subtensordict cannot be modified. Instantiate it as a TensorDict first."
)

def _has_names(self):
Expand Down Expand Up @@ -6089,7 +6154,7 @@ def expand(self, *args: int, inplace: bool = False) -> T:
shape = tuple(args[0])
else:
shape = args
return self.apply(
return self._fast_apply(
lambda x: x.expand((*shape, *x.shape[self.ndim :])), batch_size=shape
)

Expand Down Expand Up @@ -6954,7 +7019,7 @@ def _add_batch_dim(self, *, in_dim, vmap_level):
in_dim = in_dim - 1
stack_dim = td.stack_dim
tds = [
td.apply(
td._fast_apply(
lambda _arg: _add_batch_dim(_arg, in_dim, vmap_level),
batch_size=[b for i, b in enumerate(td.batch_size) if i != in_dim],
names=[name for i, name in enumerate(td.names) if i != in_dim],
Expand Down Expand Up @@ -7187,17 +7252,18 @@ def entry_class(self, key: NestedKey) -> type:
def apply_(self, fn: Callable, *others):
for i, td in enumerate(self.tensordicts):
idx = (slice(None),) * self.stack_dim + (i,)
td.apply_(fn, *[other[idx] for other in others])
td._fast_apply(fn, *[other[idx] for other in others], inplace=True)
return self

def apply(
def _apply_nest(
self,
fn: Callable,
*others: T,
batch_size: Sequence[int] | None = None,
device: torch.device | None = None,
names: Sequence[str] | None = None,
inplace: bool = False,
checked: bool = False,
**constructor_kwargs,
) -> T:
if inplace:
Expand All @@ -7208,18 +7274,19 @@ def apply(
return self.apply_(fn, *others)
else:
if batch_size is not None:
return super().apply(
return super()._apply_nest(
fn,
*others,
batch_size=batch_size,
device=device,
names=names,
checked=checked,
**constructor_kwargs,
)
others = (other.unbind(self.stack_dim) for other in others)
out = LazyStackedTensorDict(
*(
td.apply(fn, *oth, device=device)
td._apply_nest(fn, *oth, checked=checked, device=device)
for td, *oth in zip(self.tensordicts, *others)
),
stack_dim=self.stack_dim,
Expand Down

0 comments on commit d445682

Please sign in to comment.