Skip to content

Commit

Permalink
[BugFix] Fix indexing of lazy stacks (#488)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jul 20, 2023
1 parent 9fa6e4f commit 38f89ed
Show file tree
Hide file tree
Showing 6 changed files with 519 additions and 388 deletions.
3 changes: 2 additions & 1 deletion tensordict/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from tensordict.utils import (
_getitem_batch_size,
convert_ellipsis_to_idx,
DeviceType,
IndexType,
NUMPY_TO_TORCH_DTYPE_DICT,
Expand Down Expand Up @@ -283,7 +284,7 @@ def _create_memmap_with_index(memmap_tensor, index):
else:
# avoid extending someone else's index
memmap_copy._index = deepcopy(memmap_copy._index)
memmap_copy._index.append(index)
memmap_copy._index.append(convert_ellipsis_to_idx(index, memmap_tensor.shape))
memmap_copy._shape_indexed = None
memmap_copy.file = memmap_tensor.file
memmap_copy._memmap_array = memmap_tensor._memmap_array
Expand Down
24 changes: 21 additions & 3 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def __torch_function__(
cls.__ne__ = __ne__
cls.set = _set
cls.set_at_ = _set_at_
cls.del_ = _del_
cls.get = _get
cls.get_at = _get_at
cls.unbind = _unbind
Expand Down Expand Up @@ -551,7 +552,7 @@ def _setitem(self, item: NestedKey, value: Any) -> None: # noqa: D417
def _repr(self) -> str:
"""Return a string representation of Tensor class object."""
fields = _all_td_fields_as_str(self._tensordict)
field_str = fields
field_str = [fields] if fields else []
non_tensor_fields = _all_non_td_fields_as_str(self._non_tensordict)
batch_size_str = indent(f"batch_size={self.batch_size}", 4 * " ")
device_str = indent(f"device={self.device}", 4 * " ")
Expand All @@ -562,10 +563,11 @@ def _repr(self) -> str:
4 * " ",
)
string = ",\n".join(
[field_str, non_tensor_field_str, batch_size_str, device_str, is_shared_str]
field_str
+ [non_tensor_field_str, batch_size_str, device_str, is_shared_str]
)
else:
string = ",\n".join([field_str, batch_size_str, device_str, is_shared_str])
string = ",\n".join(field_str + [batch_size_str, device_str, is_shared_str])
return f"{self.__class__.__name__}(\n{string})"


Expand Down Expand Up @@ -660,6 +662,22 @@ def _set(self, key: NestedKey, value: Any, inplace: bool = False):
)


def _del_(self, key):
key = _unravel_key_to_tuple(key)
if len(key) > 1:
td = self.get(key[0])
td.del_(key[1:])
return
if key[0] in self._tensordict.keys():
self._tensordict.del_(key[0])
# self.set(key[0], None)
elif key[0] in self._non_tensordict.keys():
self._non_tensordict[key[0]] = None
else:
raise KeyError(f"Key {key} could not be found in tensorclass {self}.")
return


def _set_at_(self, key: NestedKey, value: Any, idx: IndexType):
if key in self._non_tensordict:
del self._non_tensordict[key]
Expand Down
Loading

0 comments on commit 38f89ed

Please sign in to comment.