From 0e08c1240028505cf56e42581b9080a712e6b14c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 16 Sep 2024 17:32:38 -0700 Subject: [PATCH] [Doc] export tutorial, TDM tuto refactoring ghstack-source-id: 190c37737e970a4bedd7e3bcdb31a6dafef1fdb7 Pull Request resolved: https://github.com/pytorch/tensordict/pull/994 --- docs/source/index.rst | 2 +- tutorials/sphinx_tuto/export.py | 14 + tutorials/sphinx_tuto/tensordict_module.py | 827 +++++------------- .../tensordict_module_functional.py | 78 -- 4 files changed, 221 insertions(+), 700 deletions(-) create mode 100644 tutorials/sphinx_tuto/export.py delete mode 100644 tutorials/sphinx_tuto/tensordict_module_functional.py diff --git a/docs/source/index.rst b/docs/source/index.rst index f76312dd2..4dedb7751 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -80,7 +80,7 @@ tensordict.nn :maxdepth: 1 tutorials/tensordict_module - tutorials/tensordict_module_functional + tutorials/export Dataloading ----------- diff --git a/tutorials/sphinx_tuto/export.py b/tutorials/sphinx_tuto/export.py new file mode 100644 index 000000000..ea0b8b409 --- /dev/null +++ b/tutorials/sphinx_tuto/export.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- + +""" +Exporting tensordict modules +============================ + +**Author**: `Vincent Moens `_ + +Prerequisites +~~~~~~~~~~~~~ + + + +""" diff --git a/tutorials/sphinx_tuto/tensordict_module.py b/tutorials/sphinx_tuto/tensordict_module.py index 96274ea0a..031977976 100644 --- a/tutorials/sphinx_tuto/tensordict_module.py +++ b/tutorials/sphinx_tuto/tensordict_module.py @@ -1,20 +1,26 @@ """ TensorDictModule ================ + +**Author**: `Nicolas Dufour `_, `Vincent Moens `_ + In this tutorial you will learn how to use :class:`~.TensorDictModule` and :class:`~.TensorDictSequential` to create generic and reusable modules that can accept :class:`~.TensorDict` as input. + """ ############################################################################## -# For a convenient usage of the :class:`~.TensorDict` class with ``nn.Module``, -# :mod:`tensordict` provides an interface between the two named ``TensorDictModule``. -# The ``TensorDictModule`` class is an ``nn.Module`` that takes a -# :class:`~.TensorDict` as input when called. +# +# For a convenient usage of the :class:`~.TensorDict` class with :class:`~torch.nn.Module`, +# :mod:`tensordict` provides an interface between the two named :class:`~tensordict.nn.TensorDictModule`. +# +# The :class:`~tensordict.nn.TensorDictModule` class is an :class:`~torch.nn.Module` that takes a +# :class:`~tensordict.TensorDict` as input when called. It will read a sequence of input keys, pass them to the wrapped +# module or function as input, and write the outputs in the same tensordict after completing the execution. +# # It is up to the user to define the keys to be read as input and output. # -# TensorDictModule by examples -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # sphinx_gallery_start_ignore import warnings @@ -28,159 +34,237 @@ from tensordict.nn import TensorDictModule, TensorDictSequential ############################################################################### -# Example 1: Simple usage -# -------------------------------------- -# We have a :class:`~.TensorDict` with 2 entries ``"a"`` and ``"b"`` but only the -# value associated with ``"a"`` has to be read by the network. +# +# Simple example: coding a recurrent layer +# ---------------------------------------- +# +# The simplest usage of :class:`~tensordict.nn.TensorDictModule` is exemplified below. +# If at first it may look like using this class introduces an unwated level of complexity, we will see +# later on that this API enables users to programatically concatenate modules together, cache values +# in between modules or programmatically build one. +# One of the simplest examples of this is a recurrent module in an architecture like ResNet, where the input of the +# module is cached and added to the output of a tiny multi-layered perceptron (MLP). +# +# To start, let's first consider we you would chunk an MLP, and code it using :mod:`tensordict.nn`. +# The first layer of the stack would presumably be a :class:`~torch.nn.Linear` layer, taking an entry as input +# (let us name it `x`) and outputting another entry (which we will name `y`). +# +# To feed to our module, we have a :class:`~tensordict.TensorDict` instance with a single entry, +# ``"x"``: tensordict = TensorDict( - {"a": torch.randn(5, 3), "b": torch.zeros(5, 4, 3)}, + x=torch.randn(5, 3), batch_size=[5], ) -linear = TensorDictModule(nn.Linear(3, 10), in_keys=["a"], out_keys=["a_out"]) -linear(tensordict) -assert (tensordict.get("b") == 0).all() -print(tensordict) ############################################################################### -# Example 2: Multiple inputs -# -------------------------------------- -# Suppose we have a slightly more complex network that takes 2 entries and -# averages them into a single output tensor. To make a ``TensorDictModule`` -# instance read multiple input values, one must register them in the -# ``in_keys`` keyword argument of the constructor. +# Now, we build our simple module using :class:`tensordict.nn.TensorDictModule`. By default, this class writes in the +# input tensordict in-place (meaning that entries are written in the same tensordict as the input, not that entries +# are overwritten in-place!), such that we don't need to explicitly indicate what the output is: +# +linear0 = TensorDictModule(nn.Linear(3, 128), in_keys=["x"], out_keys=["linear0"]) +linear0(tensordict) +assert "linear0" in tensordict -class MergeLinear(nn.Module): - def __init__(self, in_1, in_2, out): - super().__init__() - self.linear_1 = nn.Linear(in_1, out) - self.linear_2 = nn.Linear(in_2, out) +############################################################################### +# +# If the module outputs multiple tensors (or tensordicts!) their entries must be passed to +# :class:`~tensordict.nn.TensorDictModule` in the right order. +# +# Stacking modules +# ~~~~~~~~~~~~~~~~ +# +# Our MLP isn't made of a single layer, so we now need to add another layer to it. +# This layer will be an activation function, for instance :class:`~torch.nn.ReLU`. +# We can stack this module and the previous one using :class:`~tensordict.nn.TensorDictSequential`. +# +# .. note:: Here comes the true power of ``tensordict.nn``: unlike :class:`~torch.nn.Sequential`, +# :class:`~tensordict.nn.TensorDictSequential` will keep in memory all the previous inputs and outputs +# (with the possibility to filter them out afterwards), making it easy to have complex network structures +# built on-the-fly and programmatically. +# - def forward(self, x_1, x_2): - return (self.linear_1(x_1) + self.linear_2(x_2)) / 2 +relu0 = TensorDictModule(nn.ReLU(), in_keys=["linear0"], out_keys=["relu0"]) +block0 = TensorDictSequential(linear0, relu0) +block0(tensordict) +assert "linear0" in tensordict +assert "relu0" in tensordict ############################################################################### +# We can repeat this logic to get a full MLP: +# -tensordict = TensorDict( - { - "a": torch.randn(5, 3), - "b": torch.randn(5, 4), - }, - batch_size=[5], -) +linear1 = TensorDictModule(nn.Linear(128, 128), in_keys=["relu0"], out_keys=["linear1"]) +relu1 = TensorDictModule(nn.ReLU(), in_keys=["linear1"], out_keys=["relu1"]) +linear2 = TensorDictModule(nn.Linear(128, 3), in_keys=["relu1"], out_keys=["linear2"]) +block1 = TensorDictSequential(linear1, relu1, linear2) -mergelinear = TensorDictModule( - MergeLinear(3, 4, 10), in_keys=["a", "b"], out_keys=["output"] -) +############################################################################### +# Multiple input keys +# ~~~~~~~~~~~~~~~~~~~ +# +# The last step of the residual network is to add the input to the output of the last linear layer. +# No need to write a special :class:`~torch.nn.Module` subclass for this! :class:`~tensordict.nn.TensorDictModule` +# can be used to wrap simple functions too: -mergelinear(tensordict) +residual = TensorDictModule( + lambda x, y: x + y, in_keys=["x", "linear2"], out_keys=["y"] +) ############################################################################### -# Example 3: Multiple outputs -# -------------------------------------- -# Similarly, ``TensorDictModule`` not only supports multiple inputs but also -# multiple outputs. To make a ``TensorDictModule`` instance write to multiple -# output values, one must register them in the ``out_keys`` keyword argument -# of the constructor. - +# And we can now put together ``block0``, ``block1`` and ``residual`` for a fully fleshed residual block: -class MultiHeadLinear(nn.Module): - def __init__(self, in_1, out_1, out_2): - super().__init__() - self.linear_1 = nn.Linear(in_1, out_1) - self.linear_2 = nn.Linear(in_1, out_2) +block = TensorDictSequential(block0, block1, residual) +block(tensordict) +assert "y" in tensordict - def forward(self, x): - return self.linear_1(x), self.linear_2(x) +############################################################################### +# A genuine concern may be the accumulation of entries in the tensordict used as input: in some cases (e.g., when +# gradients are required) intermediate values may be cached anyway, but this isn't always the case and it can be useful +# to let the garbage collector know that some entries can be discarded. :class:`tensordict.nn.TensorDictModuleBase` and +# its subclasses (including :class:`tensordict.nn.TensorDictModule` and :class:`tensordict.nn.TensorDictSequential`) +# have the option of seeing their output keys filtered after execution. To do this, just call the +# :class:`tensordict.nn.TensorDictModuleBase.select_out_keys` method. This will update the module in-place and all the +# unwanted entries will be discarded: +block.select_out_keys("y") -############################################################################### +tensordict = TensorDict(x=torch.randn(1, 3), batch_size=[1]) +block(tensordict) +assert "y" in tensordict -tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5]) +assert "linear1" not in tensordict -splitlinear = TensorDictModule( - MultiHeadLinear(3, 4, 10), - in_keys=["a"], - out_keys=["output_1", "output_2"], -) -splitlinear(tensordict) +############################################################################### +# However, the input keys are preserved: +assert "x" in tensordict ############################################################################### -# When having multiple input keys and output keys, make sure they match the -# order in the module. -# -# ``TensorDictModule`` can work with :class:`~.TensorDict` instances that contain -# more tensors than what the ``in_keys`` attribute indicates. +# As a side note, ``selected_out_keys`` may also be passed to :class:`tensordict.nn.TensorDictSequential` to avoid +# calling this method separately. # -# Unless a ``vmap`` operator is used, the :class:`~.TensorDict` is modified in-place. +# Using `TensorDictModule` without tensordict +# ------------------------------------------- # -# **Ignoring some outputs** -# -# Note that it is possible to avoid writing some of the tensors to the -# :class:`~.TensorDict` output, using ``"_"`` in ``out_keys``. +# The opportunity offered by :class:`tensordict.nn.TensorDictSequential` to build complex architectures on-the-go +# does not mean that one necessarily has to switch to tensordict to represent the data. Thanks to +# :class:`~tensordict.nn.dispatch`, modules from `tensordict.nn` support arguments and keyword arguments that match the +# entry names too: + +x = torch.randn(1, 3) +y = block(x=x) +assert isinstance(y, torch.Tensor) + +############################################################################### +# Under the hood, :class:`~tensordict.nn.dispatch` rebuilds a tensordict, runs the module and then deconstructs it. +# This may cause some overhead but, as we will see just after, there is a solution to get rid of this. # -# Example 4: Combining multiple ``TensorDictModule`` with ``TensorDictSequential`` -# ---------------------------------------------------------------------------------- -# To combine multiple ``TensorDictModule`` instances, we can use -# ``TensorDictSequential``. We create a list where each ``TensorDictModule`` must -# be executed sequentially. ``TensorDictSequential`` will read and write keys to the -# tensordict following the sequence of modules provided. +# Runtime +# ------- # -# We can also gather the inputs needed by ``TensorDictSequential`` with the -# ``in_keys`` property, and the outputs keys are found at the ``out_keys`` attribute. +# :class:`tensordict.nn.TensorDictModule` and :class:`tensordict.nn.TensorDictSequential` do incur some overhead when +# executed, as they need to read and write from a tensordict. However, we can greatly reduce this overhead by using +# :func:`~torch.compile`. For this, let us compare the three versions of this code with and without compile: + + +class ResidualBlock(nn.Module): + def __init__(self): + super().__init__() + self.linear0 = nn.Linear(3, 128) + self.relu0 = nn.ReLU() + self.linear1 = nn.Linear(128, 128) + self.relu1 = nn.ReLU() + self.linear2 = nn.Linear(128, 3) + + def forward(self, x): + y = self.linear0(x) + y = self.relu0(y) + y = self.linear1(y) + y = self.relu1(y) + return self.linear2(y) + x + + +print("Without compile") +x = torch.randn(256, 3) +block_notd = ResidualBlock() +block_tdm = TensorDictModule(block_notd, in_keys=["x"], out_keys=["y"]) +block_tds = block -tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5]) +from torch.utils.benchmark import Timer -splitlinear = TensorDictModule( - MultiHeadLinear(3, 4, 10), - in_keys=["a"], - out_keys=["output_1", "output_2"], +print( + f"Regular: {Timer('block_notd(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us" ) -mergelinear = TensorDictModule( - MergeLinear(4, 10, 13), - in_keys=["output_1", "output_2"], - out_keys=["output"], +print( + f"TDM: {Timer('block_tdm(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us" +) +print( + f"Sequential: {Timer('block_tds(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us" ) -split_and_merge_linear = TensorDictSequential(splitlinear, mergelinear) - -assert split_and_merge_linear(tensordict)["output"].shape == torch.Size([5, 13]) +print("Compiled versions") +block_notd_c = torch.compile(block_notd, mode="reduce-overhead") +for _ in range(5): # warmup + block_notd_c(x) +print( + f"Compiled regular: {Timer('block_notd_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us" +) +block_tdm_c = torch.compile(block_tdm, mode="reduce-overhead") +for _ in range(5): # warmup + block_tdm_c(x=x) +print( + f"Compiled TDM: {Timer('block_tdm_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us" +) +block_tds_c = torch.compile(block_tds, mode="reduce-overhead") +for _ in range(5): # warmup + block_tds_c(x=x) +print( + f"Compiled sequential: {Timer('block_tds_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us" +) ############################################################################### +# As one can see, the onverhead introduced by :class:`~tensordict.nn.TensorDictSequential` has been completely resolved. +# # Do's and don't with TensorDictModule -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# ------------------------------------ +# +# - Don't use :class:`~torch.nn.Sequence` around modules from :mod:`tensordict.nn`. It would break the input/output +# key structure. +# Always try to rely on :class:`~tensordict.nn:TensorDictSequential` instead. # -# Don't use ``nn.Sequence``, similar to ``nn.Module``, it would break features -# such as ``functorch`` compatibility. Do use ``TensorDictSequential`` instead. +# - Don't assign the output tensordict to a new variable, as the output tensordict is just the input modified in-place. +# Assigning a new variable name isn't strictly prohibited, but it means that you may wish for both of them to disappear +# when one is deleted, when in fact the garbage collector will still see the tensors in the workspace and the no memory +# will be freed: # -# Don't assign the output tensordict to a new variable, as the output -# tensordict is just the input modified in-place: +# .. code-block:: # -# tensordict = module(tensordict) # ok! +# >>> tensordict = module(tensordict) # ok! +# >>> tensordict_out = module(tensordict) # don't! # -# tensordict_out = module(tensordict) # don't! +# Working with distributions: :class:`~tensordict.nn.ProbabilisticTensorDictModule` +# --------------------------------------------------------------------------------- # -# ``ProbabilisticTensorDictModule`` -# ---------------------------------- -# ``ProbabilisticTensorDictModule`` is a non-parametric module representing a +# :class:`~tensordict.nn.ProbabilisticTensorDictModule` is a non-parametric module representing a # probability distribution. Distribution parameters are read from tensordict # input, and the output is written to an output tensordict. The output is # sampled given some rule, specified by the input ``default_interaction_type`` -# argument and the ``exploration_mode()`` global function. If they conflict, +# argument and the :func:`~tensordict.nn.interaction_type` global function. If they conflict, # the context manager precedes. # -# It can be wired together with a ``TensorDictModule`` that returns +# It can be wired together with a :class:`~tensordict.nn.TensorDictModule` that returns # a tensordict updated with the distribution parameters using -# ``ProbabilisticTensorDictSequential``. This is a special case of -# ``TensorDictSequential`` that terminates in a -# ``ProbabilisticTensorDictModule``. +# :class:`~tensordict.nn.ProbabilisticTensorDictSequential`. This is a special case of +# :class:`~tensordict.nn.TensorDictSequential` whose last layer is a +# :class:`~tensordict.nn.ProbabilisticTensorDictModule` instance. # -# ``ProbabilisticTensorDictModule`` is responsible for constructing the -# distribution (through the ``get_dist()`` method) and/or sampling from this -# distribution (through a regular ``__call__()`` to the module). The same -# ``get_dist()`` method is exposed on ``ProbabilisticTensorDictSequential. +# :class:`~tensordict.nn.ProbabilisticTensorDictModule` is responsible for constructing the +# distribution (through the :meth:`~tensordict.nn.ProbabilisticTensorDictModule.get_dist` method) and/or +# sampling from this distribution (through a regular `forward` call to the module). The same +# :meth:`~tensordict.nn.ProbabilisticTensorDictModule.get_dist` method is exposed within +# :class:`~tensordict.nn.ProbabilisticTensorDictSequential`. # # One can find the parameters in the output tensordict as well as the log # probability if needed. @@ -211,523 +295,24 @@ def forward(self, x): td_module(td) print(f"TensorDict after going through module now as keys action, loc and scale: {td}") -################################################################################# -# Showcase: Implementing a transformer using TensorDictModule -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -# To demonstrate the flexibility of ``TensorDictModule``, we are going to -# create a transformer that reads :class:`~.TensorDict` objects using ``TensorDictModule``. -# -# The following figure shows the classical transformer architecture -# (Vaswani et al, 2017). -# -# .. image:: /reference/generated/tutorials/media/transformer.png -# :alt: The transformer png +############################################################################### +# Conclusion +# ---------- # -# We have let the positional encoders aside for simplicity. +# We have seen how `tensordict.nn` can be used to dynamically build complex neural architectures on-the-fly. +# This opens the possibility of building pipelines that are oblivious to the model signature, i.e., write generic codes +# that use networks with an arbitrary number of inputs or outputs in a flexible manner. # -# Let's re-write the classical transformers blocks: - - -class TokensToQKV(nn.Module): - def __init__(self, to_dim, from_dim, latent_dim): - super().__init__() - self.q = nn.Linear(to_dim, latent_dim) - self.k = nn.Linear(from_dim, latent_dim) - self.v = nn.Linear(from_dim, latent_dim) - - def forward(self, X_to, X_from): - Q = self.q(X_to) - K = self.k(X_from) - V = self.v(X_from) - return Q, K, V - - -class SplitHeads(nn.Module): - def __init__(self, num_heads): - super().__init__() - self.num_heads = num_heads - - def forward(self, Q, K, V): - batch_size, to_num, latent_dim = Q.shape - _, from_num, _ = K.shape - d_tensor = latent_dim // self.num_heads - Q = Q.reshape(batch_size, to_num, self.num_heads, d_tensor).transpose(1, 2) - K = K.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2) - V = V.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2) - return Q, K, V - - -class Attention(nn.Module): - def __init__(self, latent_dim, to_dim): - super().__init__() - self.softmax = nn.Softmax(dim=-1) - self.out = nn.Linear(latent_dim, to_dim) - - def forward(self, Q, K, V): - batch_size, n_heads, to_num, d_in = Q.shape - attn = self.softmax(Q @ K.transpose(2, 3) / d_in) - out = attn @ V - out = self.out(out.transpose(1, 2).reshape(batch_size, to_num, n_heads * d_in)) - return out, attn - - -class SkipLayerNorm(nn.Module): - def __init__(self, to_len, to_dim): - super().__init__() - self.layer_norm = nn.LayerNorm((to_len, to_dim)) - - def forward(self, x_0, x_1): - return self.layer_norm(x_0 + x_1) - - -class FFN(nn.Module): - def __init__(self, to_dim, hidden_dim, dropout_rate=0.2): - super().__init__() - self.FFN = nn.Sequential( - nn.Linear(to_dim, hidden_dim), - nn.ReLU(), - nn.Linear(hidden_dim, to_dim), - nn.Dropout(dropout_rate), - ) - - def forward(self, X): - return self.FFN(X) - - -class AttentionBlock(nn.Module): - def __init__(self, to_dim, to_len, from_dim, latent_dim, num_heads): - super().__init__() - self.tokens_to_qkv = TokensToQKV(to_dim, from_dim, latent_dim) - self.split_heads = SplitHeads(num_heads) - self.attention = Attention(latent_dim, to_dim) - self.skip = SkipLayerNorm(to_len, to_dim) - - def forward(self, X_to, X_from): - Q, K, V = self.tokens_to_qkv(X_to, X_from) - Q, K, V = self.split_heads(Q, K, V) - out, attention = self.attention(Q, K, V) - out = self.skip(X_to, out) - return out - - -class EncoderTransformerBlock(nn.Module): - def __init__(self, to_dim, to_len, latent_dim, num_heads): - super().__init__() - self.attention_block = AttentionBlock( - to_dim, to_len, to_dim, latent_dim, num_heads - ) - self.FFN = FFN(to_dim, 4 * to_dim) - self.skip = SkipLayerNorm(to_len, to_dim) - - def forward(self, X_to): - X_to = self.attention_block(X_to, X_to) - X_out = self.FFN(X_to) - return self.skip(X_out, X_to) - - -class DecoderTransformerBlock(nn.Module): - def __init__(self, to_dim, to_len, from_dim, latent_dim, num_heads): - super().__init__() - self.attention_block = AttentionBlock( - to_dim, to_len, from_dim, latent_dim, num_heads - ) - self.encoder_block = EncoderTransformerBlock( - to_dim, to_len, latent_dim, num_heads - ) - - def forward(self, X_to, X_from): - X_to = self.attention_block(X_to, X_from) - X_to = self.encoder_block(X_to) - return X_to - - -class TransformerEncoder(nn.Module): - def __init__(self, num_blocks, to_dim, to_len, latent_dim, num_heads): - super().__init__() - self.encoder = nn.ModuleList( - [ - EncoderTransformerBlock(to_dim, to_len, latent_dim, num_heads) - for i in range(num_blocks) - ] - ) - - def forward(self, X_to): - for i in range(len(self.encoder)): - X_to = self.encoder[i](X_to) - return X_to - - -class TransformerDecoder(nn.Module): - def __init__(self, num_blocks, to_dim, to_len, from_dim, latent_dim, num_heads): - super().__init__() - self.decoder = nn.ModuleList( - [ - DecoderTransformerBlock(to_dim, to_len, from_dim, latent_dim, num_heads) - for i in range(num_blocks) - ] - ) - - def forward(self, X_to, X_from): - for i in range(len(self.decoder)): - X_to = self.decoder[i](X_to, X_from) - return X_to - - -class Transformer(nn.Module): - def __init__( - self, num_blocks, to_dim, to_len, from_dim, from_len, latent_dim, num_heads - ): - super().__init__() - self.encoder = TransformerEncoder( - num_blocks, to_dim, to_len, latent_dim, num_heads - ) - self.decoder = TransformerDecoder( - num_blocks, from_dim, from_len, to_dim, latent_dim, num_heads - ) - - def forward(self, X_to, X_from): - X_to = self.encoder(X_to) - X_out = self.decoder(X_from, X_to) - return X_out - - -############################################################################### -# We first create the ``AttentionBlockTensorDict``, the attention block using -# ``TensorDictModule`` and ``TensorDictSequential``. -# -# The wiring operation that connects the modules to each other requires us -# to indicate which key each of them must read and write. Unlike -# ``nn.Sequence``, a ``TensorDictSequential`` can read/write more than one -# input/output. Moreover, its components inputs need not be identical to the -# previous layers outputs, allowing us to code complicated neural architecture. - - -class AttentionBlockTensorDict(TensorDictSequential): - def __init__( - self, - to_name, - from_name, - to_dim, - to_len, - from_dim, - latent_dim, - num_heads, - ): - super().__init__( - TensorDictModule( - TokensToQKV(to_dim, from_dim, latent_dim), - in_keys=[to_name, from_name], - out_keys=["Q", "K", "V"], - ), - TensorDictModule( - SplitHeads(num_heads), - in_keys=["Q", "K", "V"], - out_keys=["Q", "K", "V"], - ), - TensorDictModule( - Attention(latent_dim, to_dim), - in_keys=["Q", "K", "V"], - out_keys=["X_out", "Attn"], - ), - TensorDictModule( - SkipLayerNorm(to_len, to_dim), - in_keys=[to_name, "X_out"], - out_keys=[to_name], - ), - ) - - -############################################################################### -# We build the encoder and decoder blocks that will be part of the transformer -# thanks to ``TensorDictModule``. - - -class TransformerBlockEncoderTensorDict(TensorDictSequential): - def __init__( - self, - to_name, - from_name, - to_dim, - to_len, - from_dim, - latent_dim, - num_heads, - ): - super().__init__( - AttentionBlockTensorDict( - to_name, - from_name, - to_dim, - to_len, - from_dim, - latent_dim, - num_heads, - ), - TensorDictModule( - FFN(to_dim, 4 * to_dim), - in_keys=[to_name], - out_keys=["X_out"], - ), - TensorDictModule( - SkipLayerNorm(to_len, to_dim), - in_keys=[to_name, "X_out"], - out_keys=[to_name], - ), - ) - - -class TransformerBlockDecoderTensorDict(TensorDictSequential): - def __init__( - self, - to_name, - from_name, - to_dim, - to_len, - from_dim, - latent_dim, - num_heads, - ): - super().__init__( - AttentionBlockTensorDict( - to_name, - to_name, - to_dim, - to_len, - to_dim, - latent_dim, - num_heads, - ), - TransformerBlockEncoderTensorDict( - to_name, - from_name, - to_dim, - to_len, - from_dim, - latent_dim, - num_heads, - ), - ) - - -############################################################################### -# We create the transformer encoder and decoder. -# -# For an encoder, we just need to take the same tokens for both queries, -# keys and values. -# -# For a decoder, we now can extract info from ``X_from`` into ``X_to``. -# ``X_from`` will map to queries whereas ``X_from`` will map to keys and values. - - -class TransformerEncoderTensorDict(TensorDictSequential): - def __init__( - self, - num_blocks, - to_name, - from_name, - to_dim, - to_len, - from_dim, - latent_dim, - num_heads, - ): - super().__init__( - *[ - TransformerBlockEncoderTensorDict( - to_name, - from_name, - to_dim, - to_len, - from_dim, - latent_dim, - num_heads, - ) - for _ in range(num_blocks) - ] - ) - - -class TransformerDecoderTensorDict(TensorDictSequential): - def __init__( - self, - num_blocks, - to_name, - from_name, - to_dim, - to_len, - from_dim, - latent_dim, - num_heads, - ): - super().__init__( - *[ - TransformerBlockDecoderTensorDict( - to_name, - from_name, - to_dim, - to_len, - from_dim, - latent_dim, - num_heads, - ) - for _ in range(num_blocks) - ] - ) - - -class TransformerTensorDict(TensorDictSequential): - def __init__( - self, - num_blocks, - to_name, - from_name, - to_dim, - to_len, - from_dim, - from_len, - latent_dim, - num_heads, - ): - super().__init__( - TransformerEncoderTensorDict( - num_blocks, - to_name, - to_name, - to_dim, - to_len, - to_dim, - latent_dim, - num_heads, - ), - TransformerDecoderTensorDict( - num_blocks, - from_name, - to_name, - from_dim, - from_len, - to_dim, - latent_dim, - num_heads, - ), - ) - - -############################################################################### -# We now test our new ``TransformerTensorDict``. - -to_dim = 5 -from_dim = 6 -latent_dim = 10 -to_len = 3 -from_len = 10 -batch_size = 8 -num_heads = 2 -num_blocks = 6 - -tokens = TensorDict( - { - "X_encode": torch.randn(batch_size, to_len, to_dim), - "X_decode": torch.randn(batch_size, from_len, from_dim), - }, - batch_size=[batch_size], -) - -transformer = TransformerTensorDict( - num_blocks, - "X_encode", - "X_decode", - to_dim, - to_len, - from_dim, - from_len, - latent_dim, - num_heads, -) - -transformer(tokens) -tokens - -############################################################################### -# We've achieved to create a transformer with ``TensorDictModule``. This -# shows that ``TensorDictModule`` is a flexible module that can implement -# complex operarations. +# We have also seen how :class:`~tensordict.nn.dispatch` enables to use `tensordict.nn` to build such networks and use +# them without recurring to :class:`~tensordict.TensorDict` directly. Thanks to :func:`~torch.compile`, the overhead +# introduced by :class:`tensordict.nn.TensorDictSequential` can be completely removed, leaving users with a neat, +# tensordict-free version of their module. # -# Benchmarking -# ------------------------------ - -############################################################################### - -to_dim = 5 -from_dim = 6 -latent_dim = 10 -to_len = 3 -from_len = 10 -batch_size = 8 -num_heads = 2 -num_blocks = 6 - -############################################################################### - -td_tokens = TensorDict( - { - "X_encode": torch.randn(batch_size, to_len, to_dim), - "X_decode": torch.randn(batch_size, from_len, from_dim), - }, - batch_size=[batch_size], -) - -############################################################################### - -X_encode = torch.randn(batch_size, to_len, to_dim) -X_decode = torch.randn(batch_size, from_len, from_dim) - -############################################################################### - -tdtransformer = TransformerTensorDict( - num_blocks, - "X_encode", - "X_decode", - to_dim, - to_len, - from_dim, - from_len, - latent_dim, - num_heads, -) - -############################################################################### - -transformer = Transformer( - num_blocks, to_dim, to_len, from_dim, from_len, latent_dim, num_heads -) - -############################################################################### -# **Inference Time** - -import time - -############################################################################### - -t1 = time.time() -tokens = tdtransformer(td_tokens) -t2 = time.time() -print("Execution time:", t2 - t1, "seconds") - -############################################################################### - -t3 = time.time() -X_out = transformer(X_encode, X_decode) -t4 = time.time() -print("Execution time:", t4 - t3, "seconds") - -############################################################################### -# We can see on this minimal example that the overhead introduced by -# ``TensorDictModule`` is marginal. +# In the next tutorial, we will be seeing how ``torch.export`` can be used to isolate a module and export it. # -# Have fun with TensorDictModule! # sphinx_gallery_start_ignore import time -time.sleep(10) +time.sleep(3) # sphinx_gallery_end_ignore diff --git a/tutorials/sphinx_tuto/tensordict_module_functional.py b/tutorials/sphinx_tuto/tensordict_module_functional.py deleted file mode 100644 index fcb09894d..000000000 --- a/tutorials/sphinx_tuto/tensordict_module_functional.py +++ /dev/null @@ -1,78 +0,0 @@ -""" -Functionalizing TensorDictModule -================================ -In this tutorial you will learn how to use :class:`~.TensorDictModule` in conjunction -with functorch to create functionlized modules. -""" - -############################################################################## -# Before we take a look at the functional utilities in :mod:`tensordict.nn`, let us -# reintroduce one of the example modules from the :class:`~.TensorDictModule` tutorial. -# -# We'll create a simple module that has two linear layers, which share the input and -# return separate outputs. - -import functorch -import torch -import torch.nn as nn -from tensordict import TensorDict -from tensordict.nn import TensorDictModule - - -class MultiHeadLinear(nn.Module): - def __init__(self, in_1, out_1, out_2): - super().__init__() - self.linear_1 = nn.Linear(in_1, out_1) - self.linear_2 = nn.Linear(in_1, out_2) - - def forward(self, x): - return self.linear_1(x), self.linear_2(x) - - -############################################################################## -# We can now create a :class:`~.TensorDictModule` that will read the input from a key -# ``"a"``, and write to the keys ``"output_1"`` and ``"output_2"``. -splitlinear = TensorDictModule( - MultiHeadLinear(3, 4, 10), in_keys=["a"], out_keys=["output_1", "output_2"] -) - -############################################################################## -# Ordinarily we would use this module by simply calling it on a :class:`~.TensorDict` -# with the required input keys. - -tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5]) -splitlinear(tensordict) -print(tensordict) - - -############################################################################## -# However, we can also use :func:`functorch.make_functional_with_buffers` in order to -# functionalise the module. -func, params, buffers = functorch.make_functional_with_buffers(splitlinear) -print(func(params, buffers, tensordict)) - -############################################################################### -# This can be used with the vmap operator. For example, we use 3 replicas of the -# params and buffers and execute a vectorized map over these for a single batch -# of data: - -params_expand = [p.expand(3, *p.shape) for p in params] -buffers_expand = [p.expand(3, *p.shape) for p in buffers] -print(torch.vmap(func, (0, 0, None))(params_expand, buffers_expand, tensordict)) - -############################################################################### -# We can also use the native :func:`make_functional ` -# function from :mod:`tensordict.nn``, which modifies the module to make it accept the -# parameters as regular inputs: - -from tensordict.nn import make_functional - -tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5]) - -num_models = 10 -model = TensorDictModule(nn.Linear(3, 4), in_keys=["a"], out_keys=["output"]) -params = make_functional(model) -# we stack two groups of parameters to show the vmap usage: -params = torch.stack([params, params.apply(lambda x: torch.zeros_like(x))], 0) -result_td = torch.vmap(model, (None, 0))(tensordict, params) -print("the output tensordict shape is: ", result_td.shape)