diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 9b8bb838a..2ab80511e 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -10,8 +10,6 @@ on: - v[0-9]+.[0-9]+.[0-9] - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ pull_request: - branches: - - "*" workflow_dispatch: concurrency: @@ -90,13 +88,13 @@ jobs: # 10. Build doc cd ./docs - make docs + sphinx-build ./source _local_build cd .. - cp -r docs/build/html/* "${RUNNER_ARTIFACT_DIR}" + cp -r docs/_local_build/* "${RUNNER_ARTIFACT_DIR}" echo $(ls "${RUNNER_ARTIFACT_DIR}") if [[ ${{ github.event_name == 'pull_request' }} ]]; then - cp -r docs/build/html/* "${RUNNER_DOCS_DIR}" + cp -r docs/_local_build/* "${RUNNER_DOCS_DIR}" fi upload: @@ -118,7 +116,12 @@ jobs: REF_NAME=${{ github.ref_name }} if [[ "${REF_TYPE}" == branch ]]; then - TARGET_FOLDER="${REF_NAME}" + if [[ "${REF_NAME}" == main ]]; then + TARGET_FOLDER="${REF_NAME}" + # Bebug: + # else + # TARGET_FOLDER="release-doc" + fi elif [[ "${REF_TYPE}" == tag ]]; then case "${REF_NAME}" in *-rc*) diff --git a/docs/source/fx.rst b/docs/source/fx.rst index 248695dc9..327c873f7 100644 --- a/docs/source/fx.rst +++ b/docs/source/fx.rst @@ -51,8 +51,8 @@ We'll illustrate with an example from the overview. We create a :obj:`TensorDict We can check that a forward pass with each module results in the same outputs. >>> tensordict = TensorDict({"input": torch.randn(32, 100)}, [32]) - >>> module_out = module(tensordict, tensordict_out=TensorDict({}, [])) - >>> graph_module_out = graph_module(tensordict, tensordict_out=TensorDict({}, [])) + >>> module_out = module(tensordict, tensordict_out=TensorDict()) + >>> graph_module_out = graph_module(tensordict, tensordict_out=TensorDict()) >>> assert ( ... module_out["outputs", "logits"] == graph_module_out["outputs", "logits"] ... ).all() diff --git a/docs/source/index.rst b/docs/source/index.rst index f76312dd2..4dedb7751 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -80,7 +80,7 @@ tensordict.nn :maxdepth: 1 tutorials/tensordict_module - tutorials/tensordict_module_functional + tutorials/export Dataloading ----------- diff --git a/docs/source/saving.rst b/docs/source/saving.rst index 6081b6fa3..408edeca2 100644 --- a/docs/source/saving.rst +++ b/docs/source/saving.rst @@ -233,7 +233,7 @@ Here is a full example: >>> snapshot = torchsnapshot.Snapshot.take(app_state=state, path=path) >>> # later >>> snapshot = torchsnapshot.Snapshot(path=path) - >>> tensordict2 = TensorDict({}, []) + >>> tensordict2 = TensorDict() >>> target_state = { >>> "state": tensordict2 >>> } diff --git a/tensordict/base.py b/tensordict/base.py index 79d5ab4ed..1431ed511 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -5409,6 +5409,10 @@ def keys( ): """Returns a generator of tensordict keys. + .. warning:: TensorDict ``keys()`` method returns a lazy view of the keys. If the ``keys`` + are queried but not iterated over and then the tensordict is modified, iterating over + the keys later will return the new configuration of the keys. + Args: include_nested (bool, optional): if ``True``, nested values will be returned. Defaults to ``False``. @@ -10276,7 +10280,7 @@ def is_tensor_collection(datatype: type | Any) -> bool: Examples: >>> is_tensor_collection(TensorDictBase) # True - >>> is_tensor_collection(TensorDict({}, [])) # True + >>> is_tensor_collection(TensorDict()) # True >>> @tensorclass ... class MyClass: ... pass diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 0736524e4..6a4b69e7e 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -32,7 +32,6 @@ from tensordict.utils import ( _unravel_key_to_tuple, _zip_strict, - implement_for, NestedKey, unravel_key_list, ) @@ -319,8 +318,10 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: else: out = func(tensordict, *args, **kwargs) + # This makes dispatch responsible of handling partial outputs (such as selected through select_out_keys) out = tuple(out[key] for key in dest) return out[0] if len(out) == 1 else out + if _self is not None: return func(_self, tensordict, *args, **kwargs) return func(tensordict, *args, **kwargs) @@ -374,98 +375,46 @@ def get_source(self, func, self_func): class _OutKeysSelect: + module: nn.Module | None = None + def __init__(self, out_keys): - self.out_keys = out_keys - self._initialized = False + self.out_keys = list(out_keys) + self._initialized = None + self._is_dispatched = None def _init(self, module): if self._initialized: return self._initialized = True self.module = module - module.out_keys = list(self.out_keys) + if not all(key in module.out_keys for key in self.out_keys): + raise RuntimeError("Some keys are not part of the module out_keys.") + module.out_keys = self.out_keys - @implement_for("torch", None, "2.0") def __call__( # noqa: F811 self, module: TensorDictModuleBase, tensordict_in: TensorDictBase, + kwargs: Dict, tensordict_out: TensorDictBase, ): - if not isinstance(tensordict_out, TensorDictBase): + if not self._initialized: raise RuntimeError( - "You are likely using tensordict.nn.dispatch with keyword arguments with an older (< 2.0) version of pytorch. " - "This is currently not supported. Please use unnamed arguments or upgrade pytorch." + "_OutKeysSelect must be initialized before being called." ) # detect dispatch calls in_keys = module.in_keys - is_dispatched = self._detect_dispatch(tensordict_in, in_keys) - out_keys = self.out_keys - # if dispatch filtered the out keys as they should we're happy - if is_dispatched: - if (not isinstance(tensordict_out, tuple) and len(out_keys) == 1) or ( - len(out_keys) == len(tensordict_out) - ): - return tensordict_out - self._init(module) - if is_dispatched: - # it might be the case that dispatch was not aware of what the out-keys were. - if isinstance(tensordict_out, tuple): - out = tuple( - item - for i, item in enumerate(tensordict_out) - if module._out_keys[i] in module.out_keys - ) - if len(out) == 1: - return out[0] - return out - elif module._out_keys[0] in module.out_keys and len(module._out_keys) == 1: - return tensordict_out - elif ( - module._out_keys[0] not in module.out_keys - and len(module._out_keys) == 1 - ): - return () - else: - raise RuntimeError( - f"Selecting out-keys failed. Original out_keys: {module._out_keys}, selected: {module.out_keys}." - ) - if tensordict_out is tensordict_in: - return tensordict_out.select( - *in_keys, - *out_keys, - inplace=True, - ) - else: - return tensordict_out.select( - *in_keys, - *out_keys, - inplace=True, - strict=False, - ) - - @implement_for("torch", "2.0", None) - def __call__( # noqa: F811 - self, - module: TensorDictModuleBase, - tensordict_in: TensorDictBase, - kwargs: Dict, - tensordict_out: TensorDictBase, - ): - # detect dispatch calls - in_keys = module.in_keys - # TODO: v0.7: remove the None - if not tensordict_in and kwargs.get("tensordict", None) is not None: + if not tensordict_in and kwargs.get("tensordict") is not None: tensordict_in = kwargs.pop("tensordict") is_dispatched = self._detect_dispatch(tensordict_in, kwargs, in_keys) out_keys = self.out_keys # if dispatch filtered the out keys as they should we're happy if is_dispatched: if (not isinstance(tensordict_out, tuple) and len(out_keys) == 1) or ( - len(out_keys) == len(tensordict_out) + isinstance(tensordict_out, tuple) + and len(out_keys) == len(tensordict_out) ): return tensordict_out - self._init(module) if is_dispatched: # it might be the case that dispatch was not aware of what the out-keys were. if isinstance(tensordict_out, tuple): @@ -488,36 +437,10 @@ def __call__( # noqa: F811 raise RuntimeError( f"Selecting out-keys failed. Original out_keys: {module._out_keys}, selected: {module.out_keys}." ) - if tensordict_out is tensordict_in: - return tensordict_out.select( - *in_keys, - *out_keys, - inplace=True, - ) - else: - return tensordict_out.select( - *in_keys, - *out_keys, - inplace=True, - strict=False, - ) - - @implement_for("torch", None, "2.0") - def _detect_dispatch(self, tensordict_in, in_keys): # noqa: F811 - if isinstance(tensordict_in, TensorDictBase) and all( - key in tensordict_in.keys() for key in in_keys - ): - return False - elif isinstance(tensordict_in, tuple): - if len(tensordict_in): - if isinstance(tensordict_in[0], TensorDictBase): - return self._detect_dispatch(tensordict_in[0], in_keys) - return True - return not len(in_keys) - # not a TDBase: must be True - return True + return tensordict_out.select( + *in_keys, *out_keys, inplace=True, strict=tensordict_out is tensordict_in + ) - @implement_for("torch", "2.0", None) def _detect_dispatch(self, tensordict_in, kwargs, in_keys): # noqa: F811 if isinstance(tensordict_in, TensorDictBase) and all( key in tensordict_in.keys(include_nested=True) for key in in_keys @@ -530,8 +453,7 @@ def _detect_dispatch(self, tensordict_in, kwargs, in_keys): # noqa: F811 elif ( not len(tensordict_in) and len(kwargs) - # TODO: v0.7: remove the None - and isinstance(kwargs.get("tensordict", None), TensorDictBase) + and isinstance(kwargs.get("tensordict"), TensorDictBase) ): return self._detect_dispatch(kwargs["tensordict"], in_keys) return True @@ -541,6 +463,8 @@ def _detect_dispatch(self, tensordict_in, kwargs, in_keys): # noqa: F811 def remove(self): # reset ground truth + if self.module is None: + return if self.module._out_keys is not None: self.module.out_keys = self.module._out_keys @@ -614,122 +538,7 @@ def out_keys(self, value: List[Union[str, Tuple[str]]]): self._out_keys = value self._out_keys_apparent = value - @implement_for("torch", None, "2.0") - def select_out_keys(self, *out_keys): # noqa: F811 - """Selects the keys that will be found in the output tensordict. - - This is useful whenever one wants to get rid of intermediate keys in a - complicated graph, or when the presence of these keys may trigger unexpected - behaviours. - - The original ``out_keys`` can still be accessed via ``module.out_keys_source``. - - Args: - *out_keys (a sequence of strings or tuples of strings): the - out_keys that should be found in the output tensordict. - - Returns: the same module, modified in-place with updated ``out_keys``. - - The simplest usage is with :class:`~.TensorDictModule`: - - Examples: - >>> from tensordict import TensorDict - >>> from tensordict.nn import TensorDictModule, TensorDictSequential - >>> import torch - >>> mod = TensorDictModule(lambda x, y: (x+2, y+2), in_keys=["a", "b"], out_keys=["c", "d"]) - >>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, []) - >>> mod(td) - TensorDict( - fields={ - a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([]), - device=None, - is_shared=False) - >>> mod.select_out_keys("d") - >>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, []) - >>> mod(td) - TensorDict( - fields={ - a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([]), - device=None, - is_shared=False) - - This feature will also work with dispatched arguments: - Examples: - >>> mod(torch.zeros(()), torch.ones(())) - tensor(2.) - - This change will occur in-place (ie the same module will be returned - with an updated list of out_keys). It can be reverted using the - :meth:`TensorDictModuleBase.reset_out_keys` method. - - Examples: - >>> mod.reset_out_keys() - >>> mod(TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, [])) - TensorDict( - fields={ - a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([]), - device=None, - is_shared=False) - - This will work with other classes too, such as Sequential: - Examples: - >>> from tensordict.nn import TensorDictSequential - >>> seq = TensorDictSequential( - ... TensorDictModule(lambda x: x+1, in_keys=["x"], out_keys=["y"]), - ... TensorDictModule(lambda x: x+1, in_keys=["y"], out_keys=["z"]), - ... ) - >>> td = TensorDict({"x": torch.zeros(())}, []) - >>> seq(td) - TensorDict( - fields={ - x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - y: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - z: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([]), - device=None, - is_shared=False) - >>> seq.select_out_keys("z") - >>> td = TensorDict({"x": torch.zeros(())}, []) - >>> seq(td) - TensorDict( - fields={ - x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - z: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([]), - device=None, - is_shared=False) - - """ - out_keys = unravel_key_list(list(out_keys)) - if len(out_keys) == 1: - if out_keys[0] not in self.out_keys: - err_msg = f"Can't select non existent key: {out_keys[0]}. " - if ( - out_keys[0] - and isinstance(out_keys[0], (tuple, list)) - and out_keys[0][0] in self.out_keys - ): - err_msg += f"Are you passing the keys in a list? Try unpacking as: `{', '.join(out_keys[0])}`" - raise ValueError(err_msg) - self.register_forward_hook(_OutKeysSelect(out_keys)) - for hook in self._forward_hooks.values(): - if isinstance(hook, _OutKeysSelect): - hook._init(self) - return self - - @implement_for("torch", "2.0", None) - def select_out_keys(self, *out_keys): # noqa: F811 + def select_out_keys(self, *out_keys) -> TensorDictModuleBase: # noqa: F811 """Selects the keys that will be found in the output tensordict. This is useful whenever one wants to get rid of intermediate keys in a @@ -877,6 +686,7 @@ def reset_out_keys(self): """ for i, hook in list(self._forward_hooks.items()): if isinstance(hook, _OutKeysSelect): + hook.remove() del self._forward_hooks[i] return self diff --git a/tensordict/nn/sequence.py b/tensordict/nn/sequence.py index 8cefc6755..4e3ac7b93 100644 --- a/tensordict/nn/sequence.py +++ b/tensordict/nn/sequence.py @@ -163,6 +163,7 @@ class TensorDictSequential(TensorDictModule): """ module: nn.ModuleList + _select_before_return = False def __init__( self, @@ -172,6 +173,7 @@ def __init__( ) -> None: modules = self._convert_modules(modules) in_keys, out_keys = self._compute_in_and_out_keys(modules) + self._complete_out_keys = list(out_keys) super().__init__( module=nn.ModuleList(list(modules)), in_keys=in_keys, out_keys=out_keys @@ -179,7 +181,17 @@ def __init__( self.partial_tolerant = partial_tolerant if selected_out_keys: - self.select_out_keys(*selected_out_keys) + self._select_before_return = True + selected_out_keys = unravel_key_list(selected_out_keys) + if not all(key in self.out_keys for key in selected_out_keys): + raise ValueError("All keys in selected_out_keys must be in out_keys.") + self.out_keys = selected_out_keys + else: + self._select_before_return = False + + def reset_out_keys(self): + self.out_keys = list(self._complete_out_keys) + return self @staticmethod def _convert_modules(modules): @@ -230,6 +242,14 @@ def _find_functional_module(module: TensorDictModuleBase) -> nn.Module: ) return fmodule + def select_out_keys(self, *selected_out_keys) -> TensorDictSequential: + self._select_before_return = True + selected_out_keys = unravel_key_list(selected_out_keys) + if not all(key in self.out_keys for key in selected_out_keys): + raise ValueError("All keys in selected_out_keys must be in out_keys.") + self.out_keys = selected_out_keys + return self + def select_subsequence( self, in_keys: Iterable[NestedKey] | None = None, @@ -449,17 +469,27 @@ def forward( tensordict_out: TensorDictBase | None = None, **kwargs: Any, ) -> TensorDictBase: + if tensordict_out is None and self._select_before_return: + tensordict_exec = tensordict.copy() + else: + tensordict_exec = tensordict if not len(kwargs): + if tensordict_out is not None: + tensordict_exec = tensordict_exec.copy() for module in self.module: - tensordict = self._run_module(module, tensordict, **kwargs) + tensordict_exec = self._run_module(module, tensordict_exec, **kwargs) else: raise RuntimeError( f"TensorDictSequential does not support keyword arguments other than 'tensordict_out' or in_keys: {self.in_keys}. Got {kwargs.keys()} instead." ) if tensordict_out is not None: - tensordict_out.update(tensordict, inplace=True) - return tensordict_out - return tensordict + result = tensordict_out + result.update(tensordict_exec, keys_to_update=self.out_keys) + else: + result = tensordict_exec + if self._select_before_return: + return tensordict.update(result.select(*self.out_keys)) + return result def __len__(self) -> int: return len(self.module) diff --git a/tensordict/nn/utils.py b/tensordict/nn/utils.py index 152223a4e..dee383cd4 100644 --- a/tensordict/nn/utils.py +++ b/tensordict/nn/utils.py @@ -195,7 +195,7 @@ class set_skip_existing(_DecoratorContextManager): batch_size=torch.Size([]), device=None, is_shared=False) - >>> module(TensorDict({}, [])) # prints hello + >>> module(TensorDict()) # prints hello hello TensorDict( fields={ diff --git a/tensordict/utils.py b/tensordict/utils.py index ecf584d02..772214458 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1204,7 +1204,7 @@ def _as_context_manager(attr=None): Examples: >>> from tensordict import TensorDict - >>> data = TensorDict({}, []) + >>> data = TensorDict() >>> with data.lock_(): # lock_ is decorated ... assert data.is_locked >>> assert not data.is_locked diff --git a/test/test_fx.py b/test/test_fx.py index 186d63113..50cd8e762 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -33,8 +33,8 @@ def forward(self, x): tensordict = TensorDict({"input": torch.randn(32, 100)}, [32]) - module_out = TensorDict({}, []) - graph_module_out = TensorDict({}, []) + module_out = TensorDict() + graph_module_out = TensorDict() module(tensordict, tensordict_out=module_out) graph_module(tensordict, tensordict_out=graph_module_out) @@ -84,8 +84,8 @@ def forward(self, x, mask): batch_size=[32], ) - module_out = TensorDict({}, []) - graph_module_out = TensorDict({}, []) + module_out = TensorDict() + graph_module_out = TensorDict() module(tensordict, tensordict_out=module_out) graph_module(tensordict, tensordict_out=graph_module_out) @@ -131,8 +131,8 @@ def forward(self, x): tensordict = TensorDict({"input": torch.rand(32, 100)}, [32]) - module_out = TensorDict({}, []) - graph_module_out = TensorDict({}, []) + module_out = TensorDict() + graph_module_out = TensorDict() tdmodule(tensordict, tensordict_out=module_out) graph_module(tensordict, tensordict_out=graph_module_out) diff --git a/test/test_nn.py b/test/test_nn.py index 1216aa6ea..76ac44ea1 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -1069,6 +1069,25 @@ def fn(a, b=None, *, c=None): module = TensorDictSequential(module0, module1) assert (module(td)["a"] == 2).all() + def test_tdseq_tdoutput(self): + mod = TensorDictSequential( + TensorDictModule(lambda x: x + 2, in_keys=["a"], out_keys=["c"]), + TensorDictModule(lambda x: (x + 2, x), in_keys=["b"], out_keys=["d", "e"]), + ) + inp = TensorDict({"a": 0, "b": 1}) + inp_clone = inp.clone() + out = TensorDict() + out2 = mod(inp, tensordict_out=out) + assert out is out2 + assert set(out.keys()) == set(mod.out_keys) + assert set(inp.keys()) == set(inp_clone.keys()) + mod.select_out_keys("d") + out = TensorDict() + out2 = mod(inp, tensordict_out=out) + assert out is out2 + assert set(out.keys()) == set(mod.out_keys) == {"d"} + assert set(inp.keys()) == set(inp_clone.keys()) + def test_key_exclusion(self): module1 = TensorDictModule( nn.Linear(3, 4), in_keys=["key1", "key2"], out_keys=["foo1"] @@ -1099,6 +1118,30 @@ def test_key_exclusion_constructor(self): assert set(seq.in_keys) == set(unravel_key_list(("key1", "key2", "key3"))) assert seq.out_keys == ["key2"] + def test_key_exclusion_constructor_exec(self): + module1 = TensorDictModule( + lambda x, y: x + y, in_keys=["key1", "key2"], out_keys=["foo1"] + ) + module2 = TensorDictModule( + lambda x, y: x + y, in_keys=["key1", "key3"], out_keys=["key1"] + ) + module3 = TensorDictModule( + lambda x, y: x + y, in_keys=["foo1", "key3"], out_keys=["key2"] + ) + seq = TensorDictSequential( + module1, module2, module3, selected_out_keys=["key2"] + ) + assert set(seq.in_keys) == set(unravel_key_list(("key1", "key2", "key3"))) + assert seq.out_keys == ["key2"] + td = TensorDict(key1=0, key2=0, key3=1) + out = seq(td) + assert out is td + assert "key1" in out + assert "key2" in out + assert "key3" in out + assert "foo1" not in out + assert out["key2"] == 1 + @pytest.mark.parametrize("lazy", [True, False]) def test_stateful(self, lazy): torch.manual_seed(0) @@ -2437,7 +2480,7 @@ def forward(self, tensordict): module = MyModule() td = module(TensorDict({"out": torch.zeros(())}, [])) assert (td["out"] == 0).all() - td = module(TensorDict({}, [])) # prints hello + td = module(TensorDict()) # prints hello assert (td["out"] == 1).all() def test_tdmodule(self): @@ -2536,8 +2579,11 @@ def test_tdmodule_dispatch(self, out_d_key, unpack): res = [res] for i, v in enumerate(list(out_d_key)): assert (res[i] == exp_res[v]).all() + mod2 = mod.reset_out_keys() assert mod2 is mod + assert mod.out_keys == ["c", "d", "e"] + res = mod(torch.zeros(()), torch.ones(())) assert len(res) == 3 for i, v in enumerate(["c", "d", "e"]): @@ -2637,7 +2683,7 @@ def test_tdseq(self, out_d_key, unpack): else: with pytest.raises( (RuntimeError, ValueError), - match=r"key should be a |Can't select non existent", + match=r"key should be a |Can't select non existent|All keys in selected_out_keys must be in out_keys", ): mod2 = mod.select_out_keys(out_d_key) @@ -2670,7 +2716,7 @@ def test_tdseq_dispatch(self, out_d_key, unpack): else: with pytest.raises( (RuntimeError, ValueError), - match=r"key should be a |Can't select non existent", + match=r"key should be a |Can't select non existent|All keys in selected_out_keys must be in out_keys", ): mod2 = mod.select_out_keys(out_d_key) diff --git a/test/test_torchrec.py b/test/test_torchrec.py index 54096b0a8..292406a86 100644 --- a/test/test_torchrec.py +++ b/test/test_torchrec.py @@ -51,7 +51,7 @@ def test_kjt_indexing(self, index): def test_td_build(self): jag_tensor = _get_kjt() - _ = TensorDict({}, []) + _ = TensorDict() _ = TensorDict({"b": jag_tensor}, []) _ = TensorDict({"b": jag_tensor}, [3]) diff --git a/tutorials/sphinx_tuto/export.py b/tutorials/sphinx_tuto/export.py new file mode 100644 index 000000000..1d1e5f30b --- /dev/null +++ b/tutorials/sphinx_tuto/export.py @@ -0,0 +1,228 @@ +# -*- coding: utf-8 -*- + +""" +Exporting tensordict modules +============================ + +**Author**: `Vincent Moens `_ + +Prerequisites +~~~~~~~~~~~~~ + +Reading the :ref:`TensorDictModule ` tutorial is preferable to fully benefit from this tutorial. + +Once a module has been written using ``tensordict.nn``, it is often useful to isolate the computational graph and export +that graph. The goal of this may be to execute the model on hardware (e.g., robots, drones, edge devices) or eliminate +the dependency on tensordict altogether. + +PyTorch provides multiple methods for exporting modules, including ``onnx`` and ``torch.export``, both of which are +compatible with ``tensordict``. + +In this short tutorial, we will see how one can use ``torch.export`` to isolate the computational graph of a model. +``torch.onnx`` support follows the same logic. + +Key learnings +~~~~~~~~~~~~~ + +- Executing a ``tensordict.nn`` module without :class:`~tensordict.TensorDict` inputs; +- Selecting the output(s) of a model; +- Handling stochstic models; +- Exporting such model using `torch.export`; +- Saving the model to a file; +- Isolating the pytorch model; + + +""" +import time + +import torch +from tensordict.nn import ( + InteractionType, + NormalParamExtractor, + ProbabilisticTensorDictModule as Prob, + set_interaction_type, + TensorDictModule as Mod, + TensorDictSequential as Seq, +) +from torch import distributions as dists, nn + +################################################## +# Designing the model +# ------------------- +# +# In many applications, it is useful to work with stochastic models, i.e., models that output a variable that is not +# deterministically defined but that is sampled according to a parametric distribution. For instance, generative AI +# models will often generate different outputs when the same input if provided, because they sample the output based +# on a distribution which parameters are defined by the input. +# +# The ``tensordict`` library deals with this through the :class:`~tensordict.nn.ProbabilisticTensorDictModule` class. +# This primitive is built using a distribtion class (:class:`~torch.distributions.Normal` in our case) and an indicator +# of the input keys that will be used at execution time to build that distribution. +# +# The network we are building is therefore going to be the combination of three main components: +# +# - A network mapping the input to a latent parameter; +# - A :class:`tensordict.nn.NormalParamExtractor` module splitting the input in a location `"loc"` and `"scale"` +# parameters to be passed to the ``Normal`` distrbution; +# - A distribution constructor module. +# +model = Seq( + # 1. A small network for embedding + Mod(nn.Linear(3, 4), in_keys=["x"], out_keys=["hidden"]), + Mod(nn.ReLU(), in_keys=["hidden"], out_keys=["hidden"]), + Mod(nn.Linear(4, 4), in_keys=["hidden"], out_keys=["latent"]), + # 2. Extracting params + Mod(NormalParamExtractor(), in_keys=["latent"], out_keys=["loc", "scale"]), + # 3. Probabilistic module + Prob( + in_keys=["loc", "scale"], + out_keys=["sample"], + distribution_class=dists.Normal, + ), +) + +################################################## +# Let us run this model and see what the output looks like: +# + +x = torch.randn(1, 3) +print(model(x=x)) + +################################################## +# As expected, running the model with a tensor input returns as many tensors as the module's output keys! For large +# models, this can be quite annoying and wasteful. Later, we will see how we can limit the number of outputs of the +# model to deal with this issue. +# +# Using ``torch.export`` with a ``TensorDictModule`` +# -------------------------------------------------- +# +# Now that we have successfully built our model, we would like to extract its computational graph in a single object that +# is independent of ``tensordict``. ``torch.export`` is a PyTorch module dedicated to isolate the graph of a module and +# represent it in a standardized way. Its main entry point is :func:`~torch.export.export` which returns a ``ExportedProgram`` +# object. In turn, this object has several attributes of interest that we will explore below: a ``graph_module``, +# which represents the FX graph captured by ``export``, a ``graph_signature`` with input, outputs etc of the graph, +# and finally a ``module()`` that returns a callable that can be used in-place of the original module. +# +# Although our module accepts both args and kwargs, we will focus on its usage with kwargs as this is clearer. + +from torch.export import export + +model_export = export(model, args=(), kwargs={"x": x}) + +################################################## +# Let us look at the module: +# +print("module:", model_export.module()) + +################################################## +# This module can be run exactly like our original module (with a lower overhead): +# + +t0 = time.time() +model(x=x) +print(f"Time for TDModule: {(time.time()-t0)*1e6: 4.2f} micro-seconds") +exported = model_export.module() + +# Exported version +t0 = time.time() +exported(x=x) +print(f"Time for exported module: {(time.time()-t0)*1e6: 4.2f} micro-seconds") + +################################################## +# and the FX graph: +print("fx graph:", model_export.graph_module.print_readable()) + +################################################## +# Note that the callable returned by `module()` is a pure python callable that can be in turn compiled using +# :func:`~torch.compile`. +# +# Saving the exported module +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# ``torch.export`` has its own serialization protocol, :func:`~torch.export.save` and :func:`~torch.export.load`. +# Conventionally, the `".pt2"` extension is to be used: +# +# >>> torch.export.save(model_export, "model.pt2") +# +# Selecting the outputs +# --------------------- +# +# Recall that the ``tensordict.nn`` is to keep every intermediate value in the output, unless the user specifically asks +# for only a specific value. During training, this can be very useful: one can easily log intermediate values of the +# graph, or use them for other purposes (e.g., reconstruct a distribution based on its saved parameters, rather than +# saving the :class:`~torch.distributions.Distribution` object itself). One could also argue that, during training, the +# impact on memory of registering intermediate values is negligeable since they are part of the computational graph +# used by ``torch.autograd`` to compute the parameter gradients. +# +# During inference, though, we most likely are only interested in the final sample of the model. +# Because we want to extract the model for usages that are independent of the ``tensordict`` library, it makes sense to +# isolate the only output we desire. +# To do this, we have several options: +# +# 1. Build the :meth:`~tensordict.nn.TensorDictSequential` with the ``selected_out_keys`` keyword argument, which will +# induce the selection of the desired entries during calls to the module; +# 2. Using the :meth:`~tensordict.nn.TensorDictModule.select_out_keys` method, which will modify the ``out_keys`` +# attribute in-place (this can be reverted through :meth:`~tensordict.nn.TensorDictModule.reset_out_keys`). +# 3. Wrap the existing instance in a :meth:`~tensordict.nn.TensorDictSequential` that will filter out the unwanted keys: +# +# >>> module_filtered = Seq(module, selected_out_keys=["sample"]) +# +# Let us test the model after selecting its output keys. +# When an `x` input is provided, we expect our model to output a single tensor corresponding to a sample of the +# distribution: + +model.select_out_keys("sample") +print(model(x=x)) + +################################################## +# We see that the output is now a single tensor, corresponding to the sample of the distribution. +# We can create a new exported graph from this. Its computational graph should be simplified: + +model_export = export(model, args=(), kwargs={"x": x}) +print("module:", model_export.module()) + +################################################## +# Controlling the Sampling Strategy +# --------------------------------- +# +# We have not yet discussed how the :class:`~tensordict.nn.ProbabilisticTensorDictModule` samples from the distribution. +# By sampling, we mean obtaining a value within the space defined by the distribution according to a specific strategy. +# For instance, one may desire to get stochastic samples during training but deterministic samples (e.g., the mean or +# the mode) at inference time. To address this, ``tensordict`` utilizes the :class:`~tensordict.nn.set_interaction_type` +# decorator and context manager, which accepts ``InteractionType`` Enum inputs: +# +# >>> with set_interaction_type(InteractionType.MEAN): +# ... output = module(input) # takes the input of the distribution, if ProbabilisticTensorDictModule is invoked +# +# The default ``InteractionType`` is ``InteractionType.DETERMINISTIC``, which, if not implemented directly, is either +# the mean of distributions with a real domain, or the mode of distributions with a discrete domain. This default value +# can be changed using the ``default_interaction_type`` keyword argument of ``ProbabilisticTensorDictModule``. +# +# Let us recap: to control the sampling strategy of our network, we can either define a default sampling strategy in the +# constructor, or override it at runtime through the ``set_interaction_type`` context manager. +# +# As we can see from the following example, ``torch.export`` respond correctly the usage of the decorator: if we ask for +# a random sample, the output is different than if we ask for the mean: +# + +with set_interaction_type(InteractionType.RANDOM): + model_export = export(model, args=(), kwargs={"x": x}) + print(model_export.module()) + +with set_interaction_type(InteractionType.MEAN): + model_export = export(model, args=(), kwargs={"x": x}) + print(model_export.module()) + +################################################## +# This is all you need to know to use ``torch.export``. Please refer to the +# `official documentation `_ for more info. +# +# Next steps and further reading +# ------------------------------ +# +# - Check the ``torch.export`` tutorial, available `here `__; +# - ONNX support: check the `ONNX tutorials `_ +# to learn more about this feature. Exporting to ONNX is very similar to `torch.export` explained here. +# - For deployment of PyTorch code on servers without python environment, check the +# `AOTInductor `_ documentation. +# diff --git a/tutorials/sphinx_tuto/tensordict_keys.py b/tutorials/sphinx_tuto/tensordict_keys.py index 913b0eb44..b7ae63a56 100644 --- a/tutorials/sphinx_tuto/tensordict_keys.py +++ b/tutorials/sphinx_tuto/tensordict_keys.py @@ -21,7 +21,7 @@ import torch from tensordict.tensordict import TensorDict -tensordict = TensorDict({}, []) +tensordict = TensorDict() # set a key a = torch.rand(10) @@ -39,7 +39,7 @@ # # We can also use the methods ``.get()`` and ``.set`` to accomplish the same thing. -tensordict = TensorDict({}, []) +tensordict = TensorDict() # set a key a = torch.rand(10) diff --git a/tutorials/sphinx_tuto/tensordict_module.py b/tutorials/sphinx_tuto/tensordict_module.py index 96274ea0a..b31b6c7bd 100644 --- a/tutorials/sphinx_tuto/tensordict_module.py +++ b/tutorials/sphinx_tuto/tensordict_module.py @@ -1,20 +1,28 @@ """ TensorDictModule ================ + +.. _tensordictmodule: + +**Author**: `Nicolas Dufour `_, `Vincent Moens `_ + In this tutorial you will learn how to use :class:`~.TensorDictModule` and :class:`~.TensorDictSequential` to create generic and reusable modules that can accept :class:`~.TensorDict` as input. + """ ############################################################################## -# For a convenient usage of the :class:`~.TensorDict` class with ``nn.Module``, -# :mod:`tensordict` provides an interface between the two named ``TensorDictModule``. -# The ``TensorDictModule`` class is an ``nn.Module`` that takes a -# :class:`~.TensorDict` as input when called. +# +# For a convenient usage of the :class:`~.TensorDict` class with :class:`~torch.nn.Module`, +# :mod:`tensordict` provides an interface between the two named :class:`~tensordict.nn.TensorDictModule`. +# +# The :class:`~tensordict.nn.TensorDictModule` class is an :class:`~torch.nn.Module` that takes a +# :class:`~tensordict.TensorDict` as input when called. It will read a sequence of input keys, pass them to the wrapped +# module or function as input, and write the outputs in the same tensordict after completing the execution. +# # It is up to the user to define the keys to be read as input and output. # -# TensorDictModule by examples -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # sphinx_gallery_start_ignore import warnings @@ -28,159 +36,253 @@ from tensordict.nn import TensorDictModule, TensorDictSequential ############################################################################### -# Example 1: Simple usage -# -------------------------------------- -# We have a :class:`~.TensorDict` with 2 entries ``"a"`` and ``"b"`` but only the -# value associated with ``"a"`` has to be read by the network. +# +# Simple example: coding a recurrent layer +# ---------------------------------------- +# +# The simplest usage of :class:`~tensordict.nn.TensorDictModule` is exemplified below. +# If at first it may look like using this class introduces an unwated level of complexity, we will see +# later on that this API enables users to programatically concatenate modules together, cache values +# in between modules or programmatically build one. +# One of the simplest examples of this is a recurrent module in an architecture like ResNet, where the input of the +# module is cached and added to the output of a tiny multi-layered perceptron (MLP). +# +# To start, let's first consider we you would chunk an MLP, and code it using :mod:`tensordict.nn`. +# The first layer of the stack would presumably be a :class:`~torch.nn.Linear` layer, taking an entry as input +# (let us name it `x`) and outputting another entry (which we will name `y`). +# +# To feed to our module, we have a :class:`~tensordict.TensorDict` instance with a single entry, +# ``"x"``: tensordict = TensorDict( - {"a": torch.randn(5, 3), "b": torch.zeros(5, 4, 3)}, + x=torch.randn(5, 3), batch_size=[5], ) -linear = TensorDictModule(nn.Linear(3, 10), in_keys=["a"], out_keys=["a_out"]) -linear(tensordict) -assert (tensordict.get("b") == 0).all() -print(tensordict) ############################################################################### -# Example 2: Multiple inputs -# -------------------------------------- -# Suppose we have a slightly more complex network that takes 2 entries and -# averages them into a single output tensor. To make a ``TensorDictModule`` -# instance read multiple input values, one must register them in the -# ``in_keys`` keyword argument of the constructor. +# Now, we build our simple module using :class:`tensordict.nn.TensorDictModule`. By default, this class writes in the +# input tensordict in-place (meaning that entries are written in the same tensordict as the input, not that entries +# are overwritten in-place!), such that we don't need to explicitly indicate what the output is: +# +linear0 = TensorDictModule(nn.Linear(3, 128), in_keys=["x"], out_keys=["linear0"]) +linear0(tensordict) +assert "linear0" in tensordict -class MergeLinear(nn.Module): - def __init__(self, in_1, in_2, out): - super().__init__() - self.linear_1 = nn.Linear(in_1, out) - self.linear_2 = nn.Linear(in_2, out) - - def forward(self, x_1, x_2): - return (self.linear_1(x_1) + self.linear_2(x_2)) / 2 +############################################################################### +# +# If the module outputs multiple tensors (or tensordicts!) their entries must be passed to +# :class:`~tensordict.nn.TensorDictModule` in the right order. +# +# Support for Callables +# ~~~~~~~~~~~~~~~~~~~~~ +# +# When designing a model, it often happens that you want to incorporate an arbitrary non-parametric function into +# the network. For instance, you may wish to permute the dimensions of an image when it is passed to a convolutional network +# or a vision transformer, or divide the values by 255. +# There are several ways to do this: you could use a `forward_hook`, for example, or design a new +# :class:`~torch.nn.Module` that performs this operation. +# +# :class:`~tensordict.nn.TensorDictModule` works with any callable, not just modules, which makes it easy to +# incorporate arbitrary functions into a module. For instance, let's see how we can integrate the ``relu`` activation +# function without using the :class:`~torch.nn.ReLU` module: +relu0 = TensorDictModule(torch.relu, in_keys=["linear0"], out_keys=["relu0"]) ############################################################################### +# +# Stacking modules +# ~~~~~~~~~~~~~~~~ +# +# Our MLP isn't made of a single layer, so we now need to add another layer to it. +# This layer will be an activation function, for instance :class:`~torch.nn.ReLU`. +# We can stack this module and the previous one using :class:`~tensordict.nn.TensorDictSequential`. +# +# .. note:: Here comes the true power of ``tensordict.nn``: unlike :class:`~torch.nn.Sequential`, +# :class:`~tensordict.nn.TensorDictSequential` will keep in memory all the previous inputs and outputs +# (with the possibility to filter them out afterwards), making it easy to have complex network structures +# built on-the-fly and programmatically. +# -tensordict = TensorDict( - { - "a": torch.randn(5, 3), - "b": torch.randn(5, 4), - }, - batch_size=[5], -) - -mergelinear = TensorDictModule( - MergeLinear(3, 4, 10), in_keys=["a", "b"], out_keys=["output"] -) +block0 = TensorDictSequential(linear0, relu0) -mergelinear(tensordict) +block0(tensordict) +assert "linear0" in tensordict +assert "relu0" in tensordict ############################################################################### -# Example 3: Multiple outputs -# -------------------------------------- -# Similarly, ``TensorDictModule`` not only supports multiple inputs but also -# multiple outputs. To make a ``TensorDictModule`` instance write to multiple -# output values, one must register them in the ``out_keys`` keyword argument -# of the constructor. +# We can repeat this logic to get a full MLP: +# +linear1 = TensorDictModule(nn.Linear(128, 128), in_keys=["relu0"], out_keys=["linear1"]) +relu1 = TensorDictModule(nn.ReLU(), in_keys=["linear1"], out_keys=["relu1"]) +linear2 = TensorDictModule(nn.Linear(128, 3), in_keys=["relu1"], out_keys=["linear2"]) +block1 = TensorDictSequential(linear1, relu1, linear2) -class MultiHeadLinear(nn.Module): - def __init__(self, in_1, out_1, out_2): - super().__init__() - self.linear_1 = nn.Linear(in_1, out_1) - self.linear_2 = nn.Linear(in_1, out_2) +############################################################################### +# Multiple input keys +# ~~~~~~~~~~~~~~~~~~~ +# +# The last step of the residual network is to add the input to the output of the last linear layer. +# No need to write a special :class:`~torch.nn.Module` subclass for this! :class:`~tensordict.nn.TensorDictModule` +# can be used to wrap simple functions too: - def forward(self, x): - return self.linear_1(x), self.linear_2(x) +residual = TensorDictModule( + lambda x, y: x + y, in_keys=["x", "linear2"], out_keys=["y"] +) +############################################################################### +# And we can now put together ``block0``, ``block1`` and ``residual`` for a fully fleshed residual block: + +block = TensorDictSequential(block0, block1, residual) +block(tensordict) +assert "y" in tensordict ############################################################################### +# A genuine concern may be the accumulation of entries in the tensordict used as input: in some cases (e.g., when +# gradients are required) intermediate values may be cached anyway, but this isn't always the case and it can be useful +# to let the garbage collector know that some entries can be discarded. :class:`tensordict.nn.TensorDictModuleBase` and +# its subclasses (including :class:`tensordict.nn.TensorDictModule` and :class:`tensordict.nn.TensorDictSequential`) +# have the option of seeing their output keys filtered after execution. To do this, just call the +# :class:`tensordict.nn.TensorDictModuleBase.select_out_keys` method. This will update the module in-place and all the +# unwanted entries will be discarded: -tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5]) +block.select_out_keys("y") -splitlinear = TensorDictModule( - MultiHeadLinear(3, 4, 10), - in_keys=["a"], - out_keys=["output_1", "output_2"], -) -splitlinear(tensordict) +tensordict = TensorDict(x=torch.randn(1, 3), batch_size=[1]) +block(tensordict) +assert "y" in tensordict + +assert "linear1" not in tensordict ############################################################################### -# When having multiple input keys and output keys, make sure they match the -# order in the module. -# -# ``TensorDictModule`` can work with :class:`~.TensorDict` instances that contain -# more tensors than what the ``in_keys`` attribute indicates. -# -# Unless a ``vmap`` operator is used, the :class:`~.TensorDict` is modified in-place. +# However, the input keys are preserved: +assert "x" in tensordict + +############################################################################### +# As a side note, ``selected_out_keys`` may also be passed to :class:`tensordict.nn.TensorDictSequential` to avoid +# calling this method separately. # -# **Ignoring some outputs** +# Using `TensorDictModule` without tensordict +# ------------------------------------------- # -# Note that it is possible to avoid writing some of the tensors to the -# :class:`~.TensorDict` output, using ``"_"`` in ``out_keys``. +# The opportunity offered by :class:`tensordict.nn.TensorDictSequential` to build complex architectures on-the-go +# does not mean that one necessarily has to switch to tensordict to represent the data. Thanks to +# :class:`~tensordict.nn.dispatch`, modules from `tensordict.nn` support arguments and keyword arguments that match the +# entry names too: + +x = torch.randn(1, 3) +y = block(x=x) +assert isinstance(y, torch.Tensor) + +############################################################################### +# Under the hood, :class:`~tensordict.nn.dispatch` rebuilds a tensordict, runs the module and then deconstructs it. +# This may cause some overhead but, as we will see just after, there is a solution to get rid of this. # -# Example 4: Combining multiple ``TensorDictModule`` with ``TensorDictSequential`` -# ---------------------------------------------------------------------------------- -# To combine multiple ``TensorDictModule`` instances, we can use -# ``TensorDictSequential``. We create a list where each ``TensorDictModule`` must -# be executed sequentially. ``TensorDictSequential`` will read and write keys to the -# tensordict following the sequence of modules provided. +# Runtime +# ------- # -# We can also gather the inputs needed by ``TensorDictSequential`` with the -# ``in_keys`` property, and the outputs keys are found at the ``out_keys`` attribute. +# :class:`tensordict.nn.TensorDictModule` and :class:`tensordict.nn.TensorDictSequential` do incur some overhead when +# executed, as they need to read and write from a tensordict. However, we can greatly reduce this overhead by using +# :func:`~torch.compile`. For this, let us compare the three versions of this code with and without compile: -tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5]) -splitlinear = TensorDictModule( - MultiHeadLinear(3, 4, 10), - in_keys=["a"], - out_keys=["output_1", "output_2"], +class ResidualBlock(nn.Module): + def __init__(self): + super().__init__() + self.linear0 = nn.Linear(3, 128) + self.relu0 = nn.ReLU() + self.linear1 = nn.Linear(128, 128) + self.relu1 = nn.ReLU() + self.linear2 = nn.Linear(128, 3) + + def forward(self, x): + y = self.linear0(x) + y = self.relu0(y) + y = self.linear1(y) + y = self.relu1(y) + return self.linear2(y) + x + + +print("Without compile") +x = torch.randn(256, 3) +block_notd = ResidualBlock() +block_tdm = TensorDictModule(block_notd, in_keys=["x"], out_keys=["y"]) +block_tds = block + +from torch.utils.benchmark import Timer + +print( + f"Regular: {Timer('block_notd(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us" ) -mergelinear = TensorDictModule( - MergeLinear(4, 10, 13), - in_keys=["output_1", "output_2"], - out_keys=["output"], +print( + f"TDM: {Timer('block_tdm(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us" +) +print( + f"Sequential: {Timer('block_tds(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us" ) -split_and_merge_linear = TensorDictSequential(splitlinear, mergelinear) - -assert split_and_merge_linear(tensordict)["output"].shape == torch.Size([5, 13]) +print("Compiled versions") +block_notd_c = torch.compile(block_notd, mode="reduce-overhead") +for _ in range(5): # warmup + block_notd_c(x) +print( + f"Compiled regular: {Timer('block_notd_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us" +) +block_tdm_c = torch.compile(block_tdm, mode="reduce-overhead") +for _ in range(5): # warmup + block_tdm_c(x=x) +print( + f"Compiled TDM: {Timer('block_tdm_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us" +) +block_tds_c = torch.compile(block_tds, mode="reduce-overhead") +for _ in range(5): # warmup + block_tds_c(x=x) +print( + f"Compiled sequential: {Timer('block_tds_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us" +) ############################################################################### +# As one can see, the onverhead introduced by :class:`~tensordict.nn.TensorDictSequential` has been completely resolved. +# # Do's and don't with TensorDictModule -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# ------------------------------------ +# +# - Don't use :class:`~torch.nn.Sequence` around modules from :mod:`tensordict.nn`. It would break the input/output +# key structure. +# Always try to rely on :class:`~tensordict.nn:TensorDictSequential` instead. # -# Don't use ``nn.Sequence``, similar to ``nn.Module``, it would break features -# such as ``functorch`` compatibility. Do use ``TensorDictSequential`` instead. +# - Don't assign the output tensordict to a new variable, as the output tensordict is just the input modified in-place. +# Assigning a new variable name isn't strictly prohibited, but it means that you may wish for both of them to disappear +# when one is deleted, when in fact the garbage collector will still see the tensors in the workspace and the no memory +# will be freed: # -# Don't assign the output tensordict to a new variable, as the output -# tensordict is just the input modified in-place: +# .. code-block:: # -# tensordict = module(tensordict) # ok! +# >>> tensordict = module(tensordict) # ok! +# >>> tensordict_out = module(tensordict) # don't! # -# tensordict_out = module(tensordict) # don't! +# Working with distributions: :class:`~tensordict.nn.ProbabilisticTensorDictModule` +# --------------------------------------------------------------------------------- # -# ``ProbabilisticTensorDictModule`` -# ---------------------------------- -# ``ProbabilisticTensorDictModule`` is a non-parametric module representing a +# :class:`~tensordict.nn.ProbabilisticTensorDictModule` is a non-parametric module representing a # probability distribution. Distribution parameters are read from tensordict # input, and the output is written to an output tensordict. The output is # sampled given some rule, specified by the input ``default_interaction_type`` -# argument and the ``exploration_mode()`` global function. If they conflict, +# argument and the :func:`~tensordict.nn.interaction_type` global function. If they conflict, # the context manager precedes. # -# It can be wired together with a ``TensorDictModule`` that returns +# It can be wired together with a :class:`~tensordict.nn.TensorDictModule` that returns # a tensordict updated with the distribution parameters using -# ``ProbabilisticTensorDictSequential``. This is a special case of -# ``TensorDictSequential`` that terminates in a -# ``ProbabilisticTensorDictModule``. +# :class:`~tensordict.nn.ProbabilisticTensorDictSequential`. This is a special case of +# :class:`~tensordict.nn.TensorDictSequential` whose last layer is a +# :class:`~tensordict.nn.ProbabilisticTensorDictModule` instance. # -# ``ProbabilisticTensorDictModule`` is responsible for constructing the -# distribution (through the ``get_dist()`` method) and/or sampling from this -# distribution (through a regular ``__call__()`` to the module). The same -# ``get_dist()`` method is exposed on ``ProbabilisticTensorDictSequential. +# :class:`~tensordict.nn.ProbabilisticTensorDictModule` is responsible for constructing the +# distribution (through the :meth:`~tensordict.nn.ProbabilisticTensorDictModule.get_dist` method) and/or +# sampling from this distribution (through a regular `forward` call to the module). The same +# :meth:`~tensordict.nn.ProbabilisticTensorDictModule.get_dist` method is exposed within +# :class:`~tensordict.nn.ProbabilisticTensorDictSequential`. # # One can find the parameters in the output tensordict as well as the log # probability if needed. @@ -211,523 +313,24 @@ def forward(self, x): td_module(td) print(f"TensorDict after going through module now as keys action, loc and scale: {td}") -################################################################################# -# Showcase: Implementing a transformer using TensorDictModule -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# To demonstrate the flexibility of ``TensorDictModule``, we are going to -# create a transformer that reads :class:`~.TensorDict` objects using ``TensorDictModule``. -# -# The following figure shows the classical transformer architecture -# (Vaswani et al, 2017). +############################################################################### +# Conclusion +# ---------- # -# .. image:: /reference/generated/tutorials/media/transformer.png -# :alt: The transformer png +# We have seen how `tensordict.nn` can be used to dynamically build complex neural architectures on-the-fly. +# This opens the possibility of building pipelines that are oblivious to the model signature, i.e., write generic codes +# that use networks with an arbitrary number of inputs or outputs in a flexible manner. # -# We have let the positional encoders aside for simplicity. +# We have also seen how :class:`~tensordict.nn.dispatch` enables to use `tensordict.nn` to build such networks and use +# them without recurring to :class:`~tensordict.TensorDict` directly. Thanks to :func:`~torch.compile`, the overhead +# introduced by :class:`tensordict.nn.TensorDictSequential` can be completely removed, leaving users with a neat, +# tensordict-free version of their module. # -# Let's re-write the classical transformers blocks: - - -class TokensToQKV(nn.Module): - def __init__(self, to_dim, from_dim, latent_dim): - super().__init__() - self.q = nn.Linear(to_dim, latent_dim) - self.k = nn.Linear(from_dim, latent_dim) - self.v = nn.Linear(from_dim, latent_dim) - - def forward(self, X_to, X_from): - Q = self.q(X_to) - K = self.k(X_from) - V = self.v(X_from) - return Q, K, V - - -class SplitHeads(nn.Module): - def __init__(self, num_heads): - super().__init__() - self.num_heads = num_heads - - def forward(self, Q, K, V): - batch_size, to_num, latent_dim = Q.shape - _, from_num, _ = K.shape - d_tensor = latent_dim // self.num_heads - Q = Q.reshape(batch_size, to_num, self.num_heads, d_tensor).transpose(1, 2) - K = K.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2) - V = V.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2) - return Q, K, V - - -class Attention(nn.Module): - def __init__(self, latent_dim, to_dim): - super().__init__() - self.softmax = nn.Softmax(dim=-1) - self.out = nn.Linear(latent_dim, to_dim) - - def forward(self, Q, K, V): - batch_size, n_heads, to_num, d_in = Q.shape - attn = self.softmax(Q @ K.transpose(2, 3) / d_in) - out = attn @ V - out = self.out(out.transpose(1, 2).reshape(batch_size, to_num, n_heads * d_in)) - return out, attn - - -class SkipLayerNorm(nn.Module): - def __init__(self, to_len, to_dim): - super().__init__() - self.layer_norm = nn.LayerNorm((to_len, to_dim)) - - def forward(self, x_0, x_1): - return self.layer_norm(x_0 + x_1) - - -class FFN(nn.Module): - def __init__(self, to_dim, hidden_dim, dropout_rate=0.2): - super().__init__() - self.FFN = nn.Sequential( - nn.Linear(to_dim, hidden_dim), - nn.ReLU(), - nn.Linear(hidden_dim, to_dim), - nn.Dropout(dropout_rate), - ) - - def forward(self, X): - return self.FFN(X) - - -class AttentionBlock(nn.Module): - def __init__(self, to_dim, to_len, from_dim, latent_dim, num_heads): - super().__init__() - self.tokens_to_qkv = TokensToQKV(to_dim, from_dim, latent_dim) - self.split_heads = SplitHeads(num_heads) - self.attention = Attention(latent_dim, to_dim) - self.skip = SkipLayerNorm(to_len, to_dim) - - def forward(self, X_to, X_from): - Q, K, V = self.tokens_to_qkv(X_to, X_from) - Q, K, V = self.split_heads(Q, K, V) - out, attention = self.attention(Q, K, V) - out = self.skip(X_to, out) - return out - - -class EncoderTransformerBlock(nn.Module): - def __init__(self, to_dim, to_len, latent_dim, num_heads): - super().__init__() - self.attention_block = AttentionBlock( - to_dim, to_len, to_dim, latent_dim, num_heads - ) - self.FFN = FFN(to_dim, 4 * to_dim) - self.skip = SkipLayerNorm(to_len, to_dim) - - def forward(self, X_to): - X_to = self.attention_block(X_to, X_to) - X_out = self.FFN(X_to) - return self.skip(X_out, X_to) - - -class DecoderTransformerBlock(nn.Module): - def __init__(self, to_dim, to_len, from_dim, latent_dim, num_heads): - super().__init__() - self.attention_block = AttentionBlock( - to_dim, to_len, from_dim, latent_dim, num_heads - ) - self.encoder_block = EncoderTransformerBlock( - to_dim, to_len, latent_dim, num_heads - ) - - def forward(self, X_to, X_from): - X_to = self.attention_block(X_to, X_from) - X_to = self.encoder_block(X_to) - return X_to - - -class TransformerEncoder(nn.Module): - def __init__(self, num_blocks, to_dim, to_len, latent_dim, num_heads): - super().__init__() - self.encoder = nn.ModuleList( - [ - EncoderTransformerBlock(to_dim, to_len, latent_dim, num_heads) - for i in range(num_blocks) - ] - ) - - def forward(self, X_to): - for i in range(len(self.encoder)): - X_to = self.encoder[i](X_to) - return X_to - - -class TransformerDecoder(nn.Module): - def __init__(self, num_blocks, to_dim, to_len, from_dim, latent_dim, num_heads): - super().__init__() - self.decoder = nn.ModuleList( - [ - DecoderTransformerBlock(to_dim, to_len, from_dim, latent_dim, num_heads) - for i in range(num_blocks) - ] - ) - - def forward(self, X_to, X_from): - for i in range(len(self.decoder)): - X_to = self.decoder[i](X_to, X_from) - return X_to - - -class Transformer(nn.Module): - def __init__( - self, num_blocks, to_dim, to_len, from_dim, from_len, latent_dim, num_heads - ): - super().__init__() - self.encoder = TransformerEncoder( - num_blocks, to_dim, to_len, latent_dim, num_heads - ) - self.decoder = TransformerDecoder( - num_blocks, from_dim, from_len, to_dim, latent_dim, num_heads - ) - - def forward(self, X_to, X_from): - X_to = self.encoder(X_to) - X_out = self.decoder(X_from, X_to) - return X_out - - -############################################################################### -# We first create the ``AttentionBlockTensorDict``, the attention block using -# ``TensorDictModule`` and ``TensorDictSequential``. -# -# The wiring operation that connects the modules to each other requires us -# to indicate which key each of them must read and write. Unlike -# ``nn.Sequence``, a ``TensorDictSequential`` can read/write more than one -# input/output. Moreover, its components inputs need not be identical to the -# previous layers outputs, allowing us to code complicated neural architecture. - - -class AttentionBlockTensorDict(TensorDictSequential): - def __init__( - self, - to_name, - from_name, - to_dim, - to_len, - from_dim, - latent_dim, - num_heads, - ): - super().__init__( - TensorDictModule( - TokensToQKV(to_dim, from_dim, latent_dim), - in_keys=[to_name, from_name], - out_keys=["Q", "K", "V"], - ), - TensorDictModule( - SplitHeads(num_heads), - in_keys=["Q", "K", "V"], - out_keys=["Q", "K", "V"], - ), - TensorDictModule( - Attention(latent_dim, to_dim), - in_keys=["Q", "K", "V"], - out_keys=["X_out", "Attn"], - ), - TensorDictModule( - SkipLayerNorm(to_len, to_dim), - in_keys=[to_name, "X_out"], - out_keys=[to_name], - ), - ) - - -############################################################################### -# We build the encoder and decoder blocks that will be part of the transformer -# thanks to ``TensorDictModule``. - - -class TransformerBlockEncoderTensorDict(TensorDictSequential): - def __init__( - self, - to_name, - from_name, - to_dim, - to_len, - from_dim, - latent_dim, - num_heads, - ): - super().__init__( - AttentionBlockTensorDict( - to_name, - from_name, - to_dim, - to_len, - from_dim, - latent_dim, - num_heads, - ), - TensorDictModule( - FFN(to_dim, 4 * to_dim), - in_keys=[to_name], - out_keys=["X_out"], - ), - TensorDictModule( - SkipLayerNorm(to_len, to_dim), - in_keys=[to_name, "X_out"], - out_keys=[to_name], - ), - ) - - -class TransformerBlockDecoderTensorDict(TensorDictSequential): - def __init__( - self, - to_name, - from_name, - to_dim, - to_len, - from_dim, - latent_dim, - num_heads, - ): - super().__init__( - AttentionBlockTensorDict( - to_name, - to_name, - to_dim, - to_len, - to_dim, - latent_dim, - num_heads, - ), - TransformerBlockEncoderTensorDict( - to_name, - from_name, - to_dim, - to_len, - from_dim, - latent_dim, - num_heads, - ), - ) - - -############################################################################### -# We create the transformer encoder and decoder. -# -# For an encoder, we just need to take the same tokens for both queries, -# keys and values. -# -# For a decoder, we now can extract info from ``X_from`` into ``X_to``. -# ``X_from`` will map to queries whereas ``X_from`` will map to keys and values. - - -class TransformerEncoderTensorDict(TensorDictSequential): - def __init__( - self, - num_blocks, - to_name, - from_name, - to_dim, - to_len, - from_dim, - latent_dim, - num_heads, - ): - super().__init__( - *[ - TransformerBlockEncoderTensorDict( - to_name, - from_name, - to_dim, - to_len, - from_dim, - latent_dim, - num_heads, - ) - for _ in range(num_blocks) - ] - ) - - -class TransformerDecoderTensorDict(TensorDictSequential): - def __init__( - self, - num_blocks, - to_name, - from_name, - to_dim, - to_len, - from_dim, - latent_dim, - num_heads, - ): - super().__init__( - *[ - TransformerBlockDecoderTensorDict( - to_name, - from_name, - to_dim, - to_len, - from_dim, - latent_dim, - num_heads, - ) - for _ in range(num_blocks) - ] - ) - - -class TransformerTensorDict(TensorDictSequential): - def __init__( - self, - num_blocks, - to_name, - from_name, - to_dim, - to_len, - from_dim, - from_len, - latent_dim, - num_heads, - ): - super().__init__( - TransformerEncoderTensorDict( - num_blocks, - to_name, - to_name, - to_dim, - to_len, - to_dim, - latent_dim, - num_heads, - ), - TransformerDecoderTensorDict( - num_blocks, - from_name, - to_name, - from_dim, - from_len, - to_dim, - latent_dim, - num_heads, - ), - ) - - -############################################################################### -# We now test our new ``TransformerTensorDict``. - -to_dim = 5 -from_dim = 6 -latent_dim = 10 -to_len = 3 -from_len = 10 -batch_size = 8 -num_heads = 2 -num_blocks = 6 - -tokens = TensorDict( - { - "X_encode": torch.randn(batch_size, to_len, to_dim), - "X_decode": torch.randn(batch_size, from_len, from_dim), - }, - batch_size=[batch_size], -) - -transformer = TransformerTensorDict( - num_blocks, - "X_encode", - "X_decode", - to_dim, - to_len, - from_dim, - from_len, - latent_dim, - num_heads, -) - -transformer(tokens) -tokens - -############################################################################### -# We've achieved to create a transformer with ``TensorDictModule``. This -# shows that ``TensorDictModule`` is a flexible module that can implement -# complex operarations. -# -# Benchmarking -# ------------------------------ - -############################################################################### - -to_dim = 5 -from_dim = 6 -latent_dim = 10 -to_len = 3 -from_len = 10 -batch_size = 8 -num_heads = 2 -num_blocks = 6 - -############################################################################### - -td_tokens = TensorDict( - { - "X_encode": torch.randn(batch_size, to_len, to_dim), - "X_decode": torch.randn(batch_size, from_len, from_dim), - }, - batch_size=[batch_size], -) - -############################################################################### - -X_encode = torch.randn(batch_size, to_len, to_dim) -X_decode = torch.randn(batch_size, from_len, from_dim) - -############################################################################### - -tdtransformer = TransformerTensorDict( - num_blocks, - "X_encode", - "X_decode", - to_dim, - to_len, - from_dim, - from_len, - latent_dim, - num_heads, -) - -############################################################################### - -transformer = Transformer( - num_blocks, to_dim, to_len, from_dim, from_len, latent_dim, num_heads -) - -############################################################################### -# **Inference Time** - -import time - -############################################################################### - -t1 = time.time() -tokens = tdtransformer(td_tokens) -t2 = time.time() -print("Execution time:", t2 - t1, "seconds") - -############################################################################### - -t3 = time.time() -X_out = transformer(X_encode, X_decode) -t4 = time.time() -print("Execution time:", t4 - t3, "seconds") - -############################################################################### -# We can see on this minimal example that the overhead introduced by -# ``TensorDictModule`` is marginal. +# In the next tutorial, we will be seeing how ``torch.export`` can be used to isolate a module and export it. # -# Have fun with TensorDictModule! # sphinx_gallery_start_ignore import time -time.sleep(10) +time.sleep(3) # sphinx_gallery_end_ignore diff --git a/tutorials/sphinx_tuto/tensordict_module_functional.py b/tutorials/sphinx_tuto/tensordict_module_functional.py deleted file mode 100644 index fcb09894d..000000000 --- a/tutorials/sphinx_tuto/tensordict_module_functional.py +++ /dev/null @@ -1,78 +0,0 @@ -""" -Functionalizing TensorDictModule -================================ -In this tutorial you will learn how to use :class:`~.TensorDictModule` in conjunction -with functorch to create functionlized modules. -""" - -############################################################################## -# Before we take a look at the functional utilities in :mod:`tensordict.nn`, let us -# reintroduce one of the example modules from the :class:`~.TensorDictModule` tutorial. -# -# We'll create a simple module that has two linear layers, which share the input and -# return separate outputs. - -import functorch -import torch -import torch.nn as nn -from tensordict import TensorDict -from tensordict.nn import TensorDictModule - - -class MultiHeadLinear(nn.Module): - def __init__(self, in_1, out_1, out_2): - super().__init__() - self.linear_1 = nn.Linear(in_1, out_1) - self.linear_2 = nn.Linear(in_1, out_2) - - def forward(self, x): - return self.linear_1(x), self.linear_2(x) - - -############################################################################## -# We can now create a :class:`~.TensorDictModule` that will read the input from a key -# ``"a"``, and write to the keys ``"output_1"`` and ``"output_2"``. -splitlinear = TensorDictModule( - MultiHeadLinear(3, 4, 10), in_keys=["a"], out_keys=["output_1", "output_2"] -) - -############################################################################## -# Ordinarily we would use this module by simply calling it on a :class:`~.TensorDict` -# with the required input keys. - -tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5]) -splitlinear(tensordict) -print(tensordict) - - -############################################################################## -# However, we can also use :func:`functorch.make_functional_with_buffers` in order to -# functionalise the module. -func, params, buffers = functorch.make_functional_with_buffers(splitlinear) -print(func(params, buffers, tensordict)) - -############################################################################### -# This can be used with the vmap operator. For example, we use 3 replicas of the -# params and buffers and execute a vectorized map over these for a single batch -# of data: - -params_expand = [p.expand(3, *p.shape) for p in params] -buffers_expand = [p.expand(3, *p.shape) for p in buffers] -print(torch.vmap(func, (0, 0, None))(params_expand, buffers_expand, tensordict)) - -############################################################################### -# We can also use the native :func:`make_functional ` -# function from :mod:`tensordict.nn``, which modifies the module to make it accept the -# parameters as regular inputs: - -from tensordict.nn import make_functional - -tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5]) - -num_models = 10 -model = TensorDictModule(nn.Linear(3, 4), in_keys=["a"], out_keys=["output"]) -params = make_functional(model) -# we stack two groups of parameters to show the vmap usage: -params = torch.stack([params, params.apply(lambda x: torch.zeros_like(x))], 0) -result_td = torch.vmap(model, (None, 0))(tensordict, params) -print("the output tensordict shape is: ", result_td.shape)