diff --git a/tensordict/memmap.py b/tensordict/memmap.py index 046024657..3af163c22 100644 --- a/tensordict/memmap.py +++ b/tensordict/memmap.py @@ -842,6 +842,13 @@ def _cat( implements_for_memmap(torch.cat)(_cat) +def _where(condition, input, other): + return torch.where(condition=condition, input=input.as_tensor(), other=other) + + +implements_for_memmap(torch.where)(_where) + + def set_transfer_ownership(memmap: MemmapTensor, value: bool = True) -> None: """Changes the transfer_ownership attribute of a MemmapTensor.""" if isinstance(memmap, MemmapTensor): diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 62917f5a8..d43703e5a 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -714,7 +714,7 @@ def transpose(self, dim0, dim1): ... @_carry_over - def where(self, condition, other, *, out=None): + def where(self, condition, other, *, out=None, pad=None): ... @_carry_over diff --git a/tensordict/persistent.py b/tensordict/persistent.py index d0b8bd92d..cf75d3ecf 100644 --- a/tensordict/persistent.py +++ b/tensordict/persistent.py @@ -508,8 +508,10 @@ def is_contiguous(self): def masked_fill(self, mask, value): return self.to_tensordict().masked_fill(mask, value) - def where(self, condition, other, *, out=None): - return self.to_tensordict().where(condition=condition, other=other, out=out) + def where(self, condition, other, *, out=None, pad=None): + return self.to_tensordict().where( + condition=condition, other=other, out=out, pad=pad + ) def masked_fill_(self, mask, value): for key in self.keys(include_nested=True, leaves_only=True): diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index de444bf16..0cde92342 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -2932,7 +2932,7 @@ def masked_fill(self, mask: Tensor, value: float | bool) -> T: """ raise NotImplementedError - def where(self, condition, other, *, out=None): + def where(self, condition, other, *, out=None, pad=None): # noqa: D417 """Return a ``TensorDict`` of elements selected from either self or other, depending on condition. Args: @@ -2940,7 +2940,13 @@ def where(self, condition, other, *, out=None): otherwise yields ``other``. other (TensorDictBase or Scalar): value (if ``other`` is a scalar) or values selected at indices where condition is ``False``. - out (Tensor, optional): the output ``TensorDictBase`` instance. + + Keyword Args: + out (TensorDictBase, optional): the output ``TensorDictBase`` instance. + pad_value (scalar, optional): if provided, missing keys from the source + or destination tensordict will be written as `torch.where(mask, self, pad)` + or `torch.where(mask, pad, other)`. Defaults to ``None``, ie + missing keys are not tolerated. """ raise NotImplementedError @@ -4802,38 +4808,87 @@ def to(tensor): result.batch_size = batch_size return result - def where(self, condition, other, *, out=None): - if out is None: - if _is_tensor_collection(other.__class__): - - def func(tensor, _other): - return torch.where( - expand_as_right(condition, tensor), tensor, _other - ) + def where(self, condition, other, *, out=None, pad=None): + if _is_tensor_collection(other.__class__): - return self._fast_apply(func, other) - else: + def func(tensor, _other, key): + if tensor is None: + if pad is not None: + tensor = _other + _other = pad + else: + raise KeyError( + f"Key {key} not found and no pad value provided." + ) + cond = expand_as_right(~condition, tensor) + elif _other is None: + if pad is not None: + _other = pad + else: + raise KeyError( + f"Key {key} not found and no pad value provided." + ) + cond = expand_as_right(condition, tensor) + else: + cond = expand_as_right(condition, tensor) + return torch.where( + condition=cond, + input=tensor, + other=_other, + ) - def func(tensor): - return torch.where( - expand_as_right(condition, tensor), tensor, other + result = self.empty() if out is None else out + other_keys = set(other.keys()) + # we turn into a list because out could be = to self! + for key in list(self.keys()): + tensor = self._get_str(key, default=NO_DEFAULT) + _other = other._get_str(key, default=None) + if _is_tensor_collection(type(tensor)): + _out = None if out is None else out._get_str(key, None) + if _other is None: + _other = tensor.empty() + val = tensor.where( + condition=condition, other=_other, out=_out, pad=pad ) - - return self._fast_apply(func) + else: + val = func(tensor, _other, key) + result._set_str(key, val, inplace=False, validated=True) + other_keys.discard(key) + for key in other_keys: + tensor = None + _other = other._get_str(key, default=NO_DEFAULT) + if _is_tensor_collection(type(_other)): + try: + tensor = _other.empty() + except NotImplementedError: + # H5 tensordicts do not support select() + tensor = _other.to_tensordict().empty() + val = _other.where( + condition=~condition, other=tensor, out=None, pad=pad + ) + else: + val = func(tensor, _other, key) + result._set_str(key, val, inplace=False, validated=True) + return result else: - if _is_tensor_collection(other.__class__): + if out is None: - def func(tensor, _other, _out): + def func(tensor): return torch.where( - expand_as_right(condition, tensor), tensor, _other, out=_out + condition=expand_as_right(condition, tensor), + input=tensor, + other=other, ) - return self._fast_apply(func, other, out) + return self._fast_apply(func) else: def func(tensor, _out): return torch.where( - expand_as_right(condition, tensor), tensor, other, out=_out + condition=expand_as_right(condition, tensor), + input=tensor, + other=other, + out=_out, ) return self._fast_apply(func, out) @@ -6179,8 +6234,10 @@ def pin_memory(self) -> T: def detach_(self) -> T: raise RuntimeError("Detaching a sub-tensordict in-place cannot be done.") - def where(self, condition, other, *, out=None): - return self.to_tensordict().where(condition=condition, other=other, out=out) + def where(self, condition, other, *, out=None, pad=None): + return self.to_tensordict().where( + condition=condition, other=other, out=out, pad=pad + ) def masked_fill_(self, mask: Tensor, value: float | bool) -> T: for key, item in self.items(): @@ -7866,26 +7923,34 @@ def sort_keys(element): rename_key = _renamed_inplace_method(rename_key_) - def where(self, condition, other, *, out=None): + def where(self, condition, other, *, out=None, pad=None): condition = condition.unbind(self.stack_dim) if _is_tensor_collection(other.__class__) or ( isinstance(other, Tensor) and other.shape[: self.stack_dim] == self.shape[: self.stack_dim] ): other = other.unbind(self.stack_dim) - return torch.stack( + result = torch.stack( [ - td.where(cond, _other) + td.where(cond, _other, pad=pad) for td, cond, _other in zip(self.tensordicts, condition, other) ], self.stack_dim, - out=out, ) - return torch.stack( - [td.where(cond, other) for td, cond in zip(self.tensordicts, condition)], - self.stack_dim, - out=out, - ) + else: + result = torch.stack( + [ + td.where(cond, other, pad=pad) + for td, cond in zip(self.tensordicts, condition) + ], + self.stack_dim, + ) + # We should not pass out to stack because this will overwrite the tensors in-place, but + # we don't want that + if out is not None: + out.update(result) + return out + return result def masked_fill_(self, mask: Tensor, value: float | bool) -> T: mask_unbind = mask.unbind(dim=self.stack_dim) @@ -8371,8 +8436,10 @@ def detach_(self) -> _CustomOpTensorDict: self._source.detach_() return self - def where(self, condition, other, *, out=None): - return self.to_tensordict().where(condition=condition, other=other, out=out) + def where(self, condition, other, *, out=None, pad=None): + return self.to_tensordict().where( + condition=condition, other=other, out=out, pad=pad + ) def masked_fill_(self, mask: Tensor, value: float | bool) -> _CustomOpTensorDict: for key, item in self.items(): diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 8ec9edd9f..b0a7e28c0 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -1241,15 +1241,7 @@ def test_where(self, td_name, device): for k in td.keys(True, True): assert (td_where.get(k)[~mask] == 1).all() td_where = td.clone() - # torch.where(mask, td, torch.zeros((), device=device), out=td_where) - # for k in td.keys(True, True): - # assert (td_where.get(k)[~mask] == 0).all() - if td_name == "td_params": - with pytest.raises( - RuntimeError, match="don't support automatic differentiation" - ): - torch.where(mask, td, torch.ones_like(td), out=td_where) - return + if td_name == "td_h5": with pytest.raises( RuntimeError, @@ -1261,6 +1253,44 @@ def test_where(self, td_name, device): for k in td.keys(True, True): assert (td_where.get(k)[~mask] == 1).all() + def test_where_pad(self, td_name, device): + torch.manual_seed(1) + td = getattr(self, td_name)(device) + # test with other empty td + mask = torch.zeros(td.shape, dtype=torch.bool, device=td.device).bernoulli_() + if td_name in ("td_h5",): + td_full = td.to_tensordict() + else: + td_full = td + td_empty = td_full.empty() + result = td.where(mask, td_empty, pad=1) + for v in result.values(True, True): + assert (v[~mask] == 1).all() + td_empty = td_full.empty() + result = td_empty.where(~mask, td, pad=1) + for v in result.values(True, True): + assert (v[~mask] == 1).all() + # with output + td_out = td_full.empty() + result = td.where(mask, td_empty, pad=1, out=td_out) + for v in result.values(True, True): + assert (v[~mask] == 1).all() + if td_name not in ("td_params",): + assert result is td_out + else: + assert isinstance(result, TensorDictParams) + td_out = td_full.empty() + td_empty = td_full.empty() + result = td_empty.where(~mask, td, pad=1, out=td_out) + for v in result.values(True, True): + assert (v[~mask] == 1).all() + assert result is td_out + + with pytest.raises(KeyError, match="not found and no pad value provided"): + td.where(mask, td_full.empty()) + with pytest.raises(KeyError, match="not found and no pad value provided"): + td_full.empty().where(mask, td) + def test_masking_set(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device)