Skip to content

Commit

Permalink
[Feature] torch.where (#507)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Aug 8, 2023
1 parent 37e66d1 commit fa19024
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 5 deletions.
6 changes: 4 additions & 2 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,7 @@ def _carry_over(func):
@wraps(func)
def new_func(self, *args, **kwargs):
out = getattr(self._param_td, name)(*args, **kwargs)
print("out is", out)
out = TensorDictParams(out, no_convert=True)
print("out is (2)", out)
out.no_convert = self.no_convert
return out

Expand Down Expand Up @@ -667,6 +665,10 @@ def exclude(self, *keys: str, inplace: bool = False) -> TensorDictBase:
def transpose(self, dim0, dim1):
...

@_carry_over
def where(self, condition, other, *, out=None):
...

@_carry_over
def permute(
self,
Expand Down
3 changes: 3 additions & 0 deletions tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,9 @@ def is_contiguous(self):
def masked_fill(self, mask, value):
return self.to_tensordict().masked_fill(mask, value)

def where(self, condition, other, *, out=None):
return self.to_tensordict().where(condition=condition, other=other, out=out)

def masked_fill_(self, mask, value):
for key in self.keys(include_nested=True, leaves_only=True):
array = self._get_array(key)
Expand Down
100 changes: 97 additions & 3 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2593,6 +2593,19 @@ def masked_fill(self, mask: Tensor, value: float | bool) -> TensorDictBase:
"""
raise NotImplementedError

def where(self, condition, other, *, out=None):
"""Return a ``TensorDict`` of elements selected from either self or other, depending on condition.
Args:
condition (BoolTensor): When ``True`` (nonzero), yields ``self``,
otherwise yields ``other``.
other (TensorDictBase or Scalar): value (if ``other`` is a scalar)
or values selected at indices where condition is ``False``.
out (Tensor, optional): the output ``TensorDictBase`` instance.
"""
raise NotImplementedError

def masked_select(self, mask: Tensor) -> TensorDictBase:
"""Masks all tensors of the TensorDict and return a new TensorDict instance with similar keys pointing to masked values.
Expand Down Expand Up @@ -4431,6 +4444,42 @@ def to(tensor):
f"instance, {dest} not allowed"
)

def where(self, condition, other, *, out=None):
if out is None:
if _is_tensor_collection(other.__class__):

def func(tensor, _other):
return torch.where(
expand_as_right(condition, tensor), tensor, _other
)

return self.apply(func, other)
else:

def func(tensor):
return torch.where(
expand_as_right(condition, tensor), tensor, other
)

return self.apply(func)
else:
if _is_tensor_collection(other.__class__):

def func(tensor, _other, _out):
return torch.where(
expand_as_right(condition, tensor), tensor, _other, out=_out
)

return self.apply(func, other, out)
else:

def func(tensor, _out):
return torch.where(
expand_as_right(condition, tensor), tensor, other, out=_out
)

return self.apply(func, out)

def masked_fill_(self, mask: Tensor, value: float | int | bool) -> TensorDictBase:
for item in self.values():
mask_expand = expand_as_right(mask, item)
Expand Down Expand Up @@ -4807,9 +4856,7 @@ def _zeros_like(td: TensorDictBase, **kwargs: Any) -> TensorDictBase:

@implements_for_td(torch.ones_like)
def _ones_like(td: TensorDictBase, **kwargs: Any) -> TensorDictBase:
td_clone = td.clone()
for key in td_clone.keys():
td_clone.fill_(key, 1.0)
td_clone = td.apply(lambda x: torch.ones_like(x))
if "device" in kwargs:
td_clone = td_clone.to(kwargs.pop("device"))
if len(kwargs):
Expand Down Expand Up @@ -5798,6 +5845,9 @@ def pin_memory(self) -> TensorDictBase:
def detach_(self) -> TensorDictBase:
raise RuntimeError("Detaching a sub-tensordict in-place cannot be done.")

def where(self, condition, other, *, out=None):
return self.to_tensordict().where(condition=condition, other=other, out=out)

def masked_fill_(self, mask: Tensor, value: float | bool) -> TensorDictBase:
for key, item in self.items():
self.set_(key, torch.full_like(item, value))
Expand Down Expand Up @@ -7536,6 +7586,27 @@ def sort_keys(element):

rename_key = _renamed_inplace_method(rename_key_)

def where(self, condition, other, *, out=None):
condition = condition.unbind(self.stack_dim)
if _is_tensor_collection(other.__class__) or (
isinstance(other, Tensor)
and other.shape[: self.stack_dim] == self.shape[: self.stack_dim]
):
other = other.unbind(self.stack_dim)
return torch.stack(
[
td.where(cond, _other)
for td, cond, _other in zip(self.tensordicts, condition, other)
],
self.stack_dim,
out=out,
)
return torch.stack(
[td.where(cond, other) for td, cond in zip(self.tensordicts, condition)],
self.stack_dim,
out=out,
)

def masked_fill_(self, mask: Tensor, value: float | bool) -> TensorDictBase:
mask_unbind = mask.unbind(dim=self.stack_dim)
for _mask, td in zip(mask_unbind, self.tensordicts):
Expand Down Expand Up @@ -8027,6 +8098,9 @@ def detach_(self) -> _CustomOpTensorDict:
self._source.detach_()
return self

def where(self, condition, other, *, out=None):
return self.to_tensordict().where(condition=condition, other=other, out=out)

def masked_fill_(self, mask: Tensor, value: float | bool) -> _CustomOpTensorDict:
for key, item in self.items():
val = self._source.get(key)
Expand Down Expand Up @@ -8892,3 +8966,23 @@ def _dispatch(remaining_index, stack_index, i=None):
remaining_index = _dispatch(remaining_index, stack_index.tolist())
out["remaining_index"] = _reduce_index(remaining_index)
return out


@implements_for_td(torch.where)
def where(condition, input, other, *, out=None):
"""Return a ``TensorDict`` of elements selected from either input or other, depending on condition.
Args:
condition (BoolTensor): When ``True`` (nonzero), yield ``input``, otherwise yield ``other``.
input (TensorDictBase or Scalar): value (if ``input`` is a scalar) or values selected at indices where condition is ``True``.
other (TensorDictBase or Scalar): value (if ``other`` is a scalar) or values selected at indices where condition is ``False``.
out (Tensor, optional): the output ``TensorDictBase`` instance.
"""
from tensordict.persistent import PersistentTensorDict

if isinstance(out, PersistentTensorDict):
raise RuntimeError(
"Cannot use a persistent tensordict as output of torch.where."
)
return input.where(condition, other, out=out)
31 changes: 31 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,37 @@ def test_gather(self, td_name, device, dim):
td_gather2 = torch.gather(td, dim=dim, index=index, out=out)
assert (td_gather2 != 0).any()

def test_where(self, td_name, device):
torch.manual_seed(1)
td = getattr(self, td_name)(device)
mask = torch.zeros(td.shape, dtype=torch.bool, device=device).bernoulli_()
td_where = torch.where(mask, td, 0)
for k in td.keys(True, True):
assert (td_where.get(k)[~mask] == 0).all()
td_where = torch.where(mask, td, torch.ones_like(td))
for k in td.keys(True, True):
assert (td_where.get(k)[~mask] == 1).all()
td_where = td.clone()
# torch.where(mask, td, torch.zeros((), device=device), out=td_where)
# for k in td.keys(True, True):
# assert (td_where.get(k)[~mask] == 0).all()
if td_name == "td_params":
with pytest.raises(
RuntimeError, match="don't support automatic differentiation"
):
torch.where(mask, td, torch.ones_like(td), out=td_where)
return
if td_name == "td_h5":
with pytest.raises(
RuntimeError,
match="Cannot use a persistent tensordict as output of torch.where",
):
torch.where(mask, td, torch.ones_like(td), out=td_where)
return
torch.where(mask, td, torch.ones_like(td), out=td_where)
for k in td.keys(True, True):
assert (td_where.get(k)[~mask] == 1).all()

def test_masking_set(self, td_name, device):
torch.manual_seed(1)
td = getattr(self, td_name)(device)
Expand Down

1 comment on commit fa19024

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: fa19024 Previous: 37e66d1 Ratio
benchmarks/common/common_ops_test.py::test_set_shared 3062.702857166452 iter/sec (stddev: 0.00015942139860248967) 6496.093604059103 iter/sec (stddev: 0.00007794659396342823) 2.12
benchmarks/common/memmap_benchmarks_test.py::test_add_one[memmap_tensor0] 15349.889247256275 iter/sec (stddev: 0.00004804067948678243) 32504.810915040307 iter/sec (stddev: 0.000020580883111082216) 2.12
benchmarks/common/memmap_benchmarks_test.py::test_memmaptd_index_astensor 402.4138834674256 iter/sec (stddev: 0.000516733072034344) 811.4651896320653 iter/sec (stddev: 0.000015199764819465709) 2.02
benchmarks/common/memmap_benchmarks_test.py::test_memmaptd_index_op 187.5052763699166 iter/sec (stddev: 0.0008793127779968145) 420.32222778212656 iter/sec (stddev: 0.000034711878889733113) 2.24
benchmarks/nn/functional_benchmarks_test.py::test_exec_td 2996.31221453342 iter/sec (stddev: 0.00012387417140327504) 5994.428079468301 iter/sec (stddev: 0.0000036718719775930427) 2.00

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.