Skip to content

Commit

Permalink
[Feature] selected_out_keys arg in TDS constructor
Browse files Browse the repository at this point in the history
ghstack-source-id: 73667464576900c92b50e89e9f6a431da88a956f
Pull Request resolved: #993
  • Loading branch information
vmoens committed Sep 17, 2024
1 parent f696e64 commit 78b8905
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
12 changes: 11 additions & 1 deletion tensordict/nn/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import logging
from copy import deepcopy
from typing import Any, Iterable
from typing import Any, Iterable, List

from tensordict._nestedkey import NestedKey

Expand Down Expand Up @@ -49,12 +49,19 @@ class TensorDictSequential(TensorDictModule):
Args:
modules (iterable of TensorDictModules): ordered sequence of TensorDictModule instances to be run sequentially.
Keyword Args:
partial_tolerant (bool, optional): if True, the input tensordict can miss some of the input keys.
If so, the only module that will be executed are those who can be executed given the keys that
are present.
Also, if the input tensordict is a lazy stack of tensordicts AND if partial_tolerant is :obj:`True` AND if the
stack does not have the required keys, then TensorDictSequential will scan through the sub-tensordicts
looking for those that have the required keys, if any.
selected_out_keys (iterable of NestedKeys, optional): the list of out-keys to select. If not provided, all
``out_keys`` will be written.
.. note:: A :class:`TensorDictSequential` instance may have a long list of output keys, and one may wish to remove
some of them after execution for clarity or memory purposes. If this is the case, the method :meth:`~.select_out_keys`
can be used after instantiation, or `selected_out_keys` may be passed to the constructor.
Examples:
>>> import torch
Expand Down Expand Up @@ -161,6 +168,7 @@ def __init__(
self,
*modules: TensorDictModuleBase,
partial_tolerant: bool = False,
selected_out_keys: List[NestedKey] | None = None,
) -> None:
modules = self._convert_modules(modules)
in_keys, out_keys = self._compute_in_and_out_keys(modules)
Expand All @@ -170,6 +178,8 @@ def __init__(
)

self.partial_tolerant = partial_tolerant
if selected_out_keys:
self.select_out_keys(*selected_out_keys)

@staticmethod
def _convert_modules(modules):
Expand Down
16 changes: 16 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,6 +1083,22 @@ def test_key_exclusion(self):
assert set(seq.in_keys) == set(unravel_key_list(("key1", "key2", "key3")))
assert set(seq.out_keys) == set(unravel_key_list(("foo1", "key1", "key2")))

def test_key_exclusion(self):
module1 = TensorDictModule(
nn.Linear(3, 4), in_keys=["key1", "key2"], out_keys=["foo1"]
)
module2 = TensorDictModule(
nn.Linear(3, 4), in_keys=["key1", "key3"], out_keys=["key1"]
)
module3 = TensorDictModule(
nn.Linear(3, 4), in_keys=["foo1", "key3"], out_keys=["key2"]
)
seq = TensorDictSequential(
module1, module2, module3, selected_out_keys=["key2"]
)
assert set(seq.in_keys) == set(unravel_key_list(("key1", "key2", "key3")))
assert seq.out_keys == ["key2"]

@pytest.mark.parametrize("lazy", [True, False])
def test_stateful(self, lazy):
torch.manual_seed(0)
Expand Down

0 comments on commit 78b8905

Please sign in to comment.