Skip to content

Commit

Permalink
[Feature] dense_stack_tds (#506)
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Aug 3, 2023
1 parent a7be2f4 commit 37e66d1
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/tensordict.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,4 @@ Utils
merge_tensordicts
pad
pad_sequence
dense_stack_tds
2 changes: 2 additions & 0 deletions tensordict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tensordict.persistent import PersistentTensorDict
from tensordict.tensorclass import tensorclass
from tensordict.tensordict import (
dense_stack_tds,
is_batchedtensor,
is_memmap,
is_tensor_collection,
Expand Down Expand Up @@ -46,6 +47,7 @@
"pad",
"PersistentTensorDict",
"tensorclass",
"dense_stack_tds",
]

# from tensordict._pytree import *
91 changes: 91 additions & 0 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -8646,6 +8646,97 @@ def make_tensordict(
return TensorDict.from_dict(kwargs, batch_size=batch_size, device=device)


def dense_stack_tds(
td_list: Sequence[TensorDictBase] | LazyStackedTensorDict,
dim: int = None,
) -> TensorDictBase:
"""Densely stack a list of :class:`tensordict.TensorDictBase` objects (or a :class:`tensordict.LazyStackedTensorDict`) given that they have the same structure.
This function is called with a list of :class:`tensordict.TensorDictBase` (either passed directly or obtrained from
a :class:`tensordict.LazyStackedTensorDict`).
Instead of calling ``torch.stack(td_list)``, which would return a :class:`tensordict.LazyStackedTensorDict`,
this function expands the first element of the input list and stacks the input list onto that element.
This works only when all the elements of the input list have the same structure.
The :class:`tensordict.TensorDictBase` returned will have the same type of the elements of the input list.
This function is useful when some of the :class:`tensordict.TensorDictBase` objects that need to be stacked
are :class:`tensordict.LazyStackedTensorDict` or have :class:`tensordict.LazyStackedTensorDict`
among entries (or nested entries).
In those cases, calling ``torch.stack(td_list).to_tensordict()`` is infeasible.
Thus, this function provides an alternative for densely stacking the list provided.
Args:
td_list (List of TensorDictBase or LazyStackedTensorDict): the tds to stack.
dim (int, optional): the dimension to stack them.
If td_list is a LazyStackedTensorDict, it will be retrieved automatically.
Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict import dense_stack_tds
>>> from tensordict.tensordict import assert_allclose_td
>>> td0 = TensorDict({"a": torch.zeros(3)},[])
>>> td1 = TensorDict({"a": torch.zeros(4), "b": torch.zeros(2)},[])
>>> td_lazy = torch.stack([td0, td1], dim=0)
>>> td_container = TensorDict({"lazy": td_lazy}, [])
>>> td_container_clone = td_container.clone()
>>> td_stack = torch.stack([td_container, td_container_clone], dim=0)
>>> td_stack
LazyStackedTensorDict(
fields={
lazy: LazyStackedTensorDict(
fields={
a: Tensor(shape=torch.Size([2, 2, -1]), device=cpu, dtype=torch.float32, is_shared=False)},
exclusive_fields={
},
batch_size=torch.Size([2, 2]),
device=None,
is_shared=False,
stack_dim=0)},
exclusive_fields={
},
batch_size=torch.Size([2]),
device=None,
is_shared=False,
stack_dim=0)
>>> td_stack = dense_stack_tds(td_stack) # Automatically use the LazyStackedTensorDict stack_dim
TensorDict(
fields={
lazy: LazyStackedTensorDict(
fields={
a: Tensor(shape=torch.Size([2, 2, -1]), device=cpu, dtype=torch.float32, is_shared=False)},
exclusive_fields={
1 ->
b: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2, 2]),
device=None,
is_shared=False,
stack_dim=1)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
# Note that
# (1) td_stack is now a TensorDict
# (2) this has pushed the stack_dim of "lazy" (0 -> 1)
# (3) this has revealed the exclusive keys.
>>> assert_allclose_td(td_stack, dense_stack_tds([td_container, td_container_clone], dim=0))
# This shows it is the same to pass a list or a LazyStackedTensorDict
"""
if isinstance(td_list, LazyStackedTensorDict):
dim = td_list.stack_dim
td_list = td_list.tensordicts
elif dim is None:
raise ValueError(
"If a list of tensordicts is provided, stack_dim must not be None"
)
shape = list(td_list[0].shape)
shape.insert(dim, len(td_list))

out = td_list[0].unsqueeze(dim).expand(shape).clone()
return torch.stack(td_list, dim=dim, out=out)


def _set_max_batch_size(source: TensorDictBase, batch_dims=None):
"""Updates a tensordict with its maximium batch size."""
tensor_data = list(source.values())
Expand Down
52 changes: 52 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_CustomOpTensorDict,
_stack as stack_td,
assert_allclose_td,
dense_stack_tds,
is_tensor_collection,
make_tensordict,
pad,
Expand Down Expand Up @@ -5906,6 +5907,57 @@ def test_empty():
assert len(list(td_empty.get("b").keys())) == 1


@pytest.mark.parametrize(
"stack_dim",
[0, 1, 2, 3],
)
@pytest.mark.parametrize(
"nested_stack_dim",
[0, 1, 2],
)
def test_dense_stack_tds(stack_dim, nested_stack_dim):
batch_size = (5, 6)
td0 = TensorDict(
{"a": torch.zeros(*batch_size, 3)},
batch_size,
)
td1 = TensorDict(
{"a": torch.zeros(*batch_size, 4), "b": torch.zeros(*batch_size, 2)},
batch_size,
)
td_lazy = torch.stack([td0, td1], dim=nested_stack_dim)
td_container = TensorDict({"lazy": td_lazy}, td_lazy.batch_size)
td_container_clone = td_container.clone()
td_container_clone.apply_(lambda x: x + 1)

assert td_lazy.stack_dim == nested_stack_dim
td_stack = torch.stack([td_container, td_container_clone], dim=stack_dim)
assert td_stack.stack_dim == stack_dim

assert isinstance(td_stack, LazyStackedTensorDict)
dense_td_stack = dense_stack_tds(td_stack)
assert isinstance(dense_td_stack, TensorDict) # check outer layer is non-lazy
assert isinstance(
dense_td_stack["lazy"], LazyStackedTensorDict
) # while inner layer is still lazy
assert "b" not in dense_td_stack["lazy"].tensordicts[0].keys()
assert "b" in dense_td_stack["lazy"].tensordicts[1].keys()

assert assert_allclose_td(
dense_td_stack,
dense_stack_tds([td_container, td_container_clone], dim=stack_dim),
) # This shows it is the same to pass a list or a LazyStackedTensorDict

for i in range(2):
index = (slice(None),) * stack_dim + (i,)
assert (dense_td_stack[index] == i).all()

if stack_dim > nested_stack_dim:
assert dense_td_stack["lazy"].stack_dim == nested_stack_dim
else:
assert dense_td_stack["lazy"].stack_dim == nested_stack_dim + 1


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

0 comments on commit 37e66d1

Please sign in to comment.