Skip to content

Commit

Permalink
[Performance] Faster indexing (#376)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed May 7, 2023
1 parent 0889ca3 commit 961bf9d
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 13 deletions.
24 changes: 12 additions & 12 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3026,7 +3026,7 @@ def __getitem__(self, idx: IndexType) -> TensorDictBase:
return self._index_tensordict((idx,))

if isinstance(idx, list):
idx = torch.tensor(idx, device=self.device)
# idx = torch.tensor(idx, device=self.device)
return self._index_tensordict(idx)

if isinstance(idx, np.ndarray):
Expand Down Expand Up @@ -3578,17 +3578,17 @@ def _check_device(self) -> None:
"all elements must share that device."
)

def _index_tensordict(self, idx: IndexType) -> TensorDictBase:
names = self._get_names_idx(idx)
self_copy = copy(self)
# self_copy = self.clone(False)
self_copy._tensordict = {
key: _get_item(item, idx) for key, item in self.items()
}
self_copy._batch_size = _getitem_batch_size(self_copy.batch_size, idx)
self_copy._device = self.device
self_copy.names = names
return self_copy
# def _index_tensordict(self, idx: IndexType) -> TensorDictBase:
# names = self._get_names_idx(idx)
# self_copy = copy(self)
# # self_copy = self.clone(False)
# self_copy._tensordict = {
# key: _get_item(item, idx) for key, item in self.items()
# }
# self_copy._batch_size = _getitem_batch_size(self_copy.batch_size, idx)
# self_copy._device = self.device
# self_copy.names = names
# return self_copy

def pin_memory(self) -> TensorDictBase:
def pin_mem(tensor):
Expand Down
46 changes: 45 additions & 1 deletion tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import math
import time

import warnings
from functools import wraps
from numbers import Number
from typing import Any, List, Sequence, Tuple, TYPE_CHECKING, Union
Expand Down Expand Up @@ -95,6 +97,13 @@ def _getitem_batch_size(shape: torch.Size, items: IndexType) -> torch.Size:
) or isinstance(items, list):
if isinstance(items, torch.Tensor) and not items.shape:
return shape[1:]
if _is_lis_of_list_of_bools(items):
warnings.warn(
"Got a list of list of bools: this indexing behaviour will be deprecated soon.",
category=DeprecationWarning,
)
items = torch.tensor(items)
return torch.Size([items.sum(), *shape[items.ndimension() :]])
if len(items):
return torch.Size([len(items), *shape[1:]])
else:
Expand Down Expand Up @@ -692,7 +701,22 @@ def _dtype(tensor: torch.Tensor) -> torch.dtype:

def _get_item(tensor: torch.Tensor, index: IndexType) -> torch.Tensor:
if isinstance(tensor, torch.Tensor):
return tensor[index]
try:
return tensor[index]
except IndexError as err:
# try to map list index to tensor, and assess type. If bool, we
# likely have a nested list of booleans which is not supported by pytorch
if _is_lis_of_list_of_bools(index):
index = torch.tensor(index, device=tensor.device)
if index.dtype is torch.bool:
warnings.warn(
"Indexing a tensor with a nested list of boolean values is "
"going to be deprecated as this functionality is not supported "
f"by PyTorch. (follows error: {err})",
category=DeprecationWarning,
)
return tensor[index]
raise err
elif isinstance(tensor, KeyedJaggedTensor):
return index_keyedjaggedtensor(tensor, index)
else:
Expand Down Expand Up @@ -790,3 +814,23 @@ def int_generator(seed):
rng = np.random.default_rng(seed)
seed = int.from_bytes(rng.bytes(8), "big")
return seed % max_seed_val


def _is_lis_of_list_of_bools(index, first_level=True):
# determines if an index is a list of list of bools.
# this is aimed at catching a deprecation feature where list of list
# of bools are valid indices
if first_level:
if not isinstance(index, list):
return False
if not len(index):
return False
if isinstance(index[0], list):
return _is_lis_of_list_of_bools(index[0], False)
return False
# then we know it is a list of lists
if isinstance(index[0], bool):
return True
if isinstance(index[0], list):
return _is_lis_of_list_of_bools(index[0], False)
return False
2 changes: 2 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2749,6 +2749,8 @@ def test_batchsize_reset():

# test index
td[torch.tensor([1, 2])]
td[:]
td[[1, 2]]
with pytest.raises(
IndexError,
match=re.escape("too many indices for tensor of dimension 1"),
Expand Down

0 comments on commit 961bf9d

Please sign in to comment.