diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index 170251a81fa..c73ed5083fd 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -355,13 +355,18 @@ algorithms, such as DQN, DDPG or Dreamer. Multi-agent-specific modules ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -These networks implement models that can be used in -multi-agent contexts. +These networks implement models that can be used in multi-agent contexts. +They use :func:`~torch.vmap` to execute multiple networks all at once on the +network inputs. Because the parameters are batched, initialization may differ +from what is usually done with other PyTorch modules, see +:meth:`~torchrl.modules.MultiAgentNetBase.get_stateful_net` +for more information. .. autosummary:: :toctree: generated/ :template: rl_template_noinherit.rst + MultiAgentNetBase MultiAgentMLP MultiAgentConvNet QMixer diff --git a/test/test_modules.py b/test/test_modules.py index 592464f0a96..feff5ea6819 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -839,6 +839,58 @@ def test_multiagent_mlp( agent_dim={-2}\)""" assert re.match(pattern, str(mlp), re.DOTALL) + @retry(AssertionError, 5) + @pytest.mark.parametrize("n_agents", [1, 3]) + @pytest.mark.parametrize("share_params", [True, False]) + @pytest.mark.parametrize("centralized", [True, False]) + @pytest.mark.parametrize("n_agent_inputs", [6, None]) + @pytest.mark.parametrize("batch", [(4,), (4, 3), ()]) + def test_multiagent_mlp_init( + self, + n_agents, + centralized, + share_params, + batch, + n_agent_inputs, + n_agent_outputs=2, + ): + torch.manual_seed(1) + mlp = MultiAgentMLP( + n_agent_inputs=n_agent_inputs, + n_agent_outputs=n_agent_outputs, + n_agents=n_agents, + centralized=centralized, + share_params=share_params, + depth=2, + ) + if n_agent_inputs is None: + n_agent_inputs = 6 + td = self._get_mock_input_td(n_agents, n_agent_inputs, batch=batch) + obs = td.get(("agents", "observation")) + mlp(obs) + snet = mlp.get_stateful_net() + assert snet is not mlp._empty_net + + def zero_inplace(mod): + if hasattr(mod, "weight"): + mod.weight.data *= 0 + if hasattr(mod, "bias"): + mod.bias.data *= 0 + + snet.apply(zero_inplace) + assert (mlp.params == 0).all() + + def one_outofplace(mod): + if hasattr(mod, "weight"): + mod.weight = nn.Parameter(torch.ones_like(mod.weight.data)) + if hasattr(mod, "bias"): + mod.bias = nn.Parameter(torch.ones_like(mod.bias.data)) + + snet.apply(one_outofplace) + assert (mlp.params == 0).all() + mlp.from_stateful_net(snet) + assert (mlp.params == 1).all() + def test_multiagent_mlp_lazy(self): mlp = MultiAgentMLP( n_agent_inputs=None, diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 944210386e9..0c6505602f3 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -32,6 +32,7 @@ MLP, MultiAgentConvNet, MultiAgentMLP, + MultiAgentNetBase, NoisyLazyLinear, NoisyLinear, ObsDecoder, diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index 62ccf53c30a..9a814e35477 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -32,5 +32,11 @@ MLP, OnlineDTActor, ) -from .multiagent import MultiAgentConvNet, MultiAgentMLP, QMixer, VDNMixer +from .multiagent import ( + MultiAgentConvNet, + MultiAgentMLP, + MultiAgentNetBase, + QMixer, + VDNMixer, +) from .utils import Squeeze2dLayer, SqueezeLayer diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py index c44042388a5..a48ad5b634b 100644 --- a/torchrl/modules/models/multiagent.py +++ b/torchrl/modules/models/multiagent.py @@ -5,6 +5,7 @@ from __future__ import annotations import abc +from copy import deepcopy from textwrap import indent from typing import Optional, Sequence, Tuple, Type, Union @@ -21,7 +22,13 @@ class MultiAgentNetBase(nn.Module): - """A base class for multi-agent networks.""" + """A base class for multi-agent networks. + + .. note:: to initialize the MARL module parameters with the `torch.nn.init` + module, please refer to :meth:`~.get_stateful_net` and :meth:`~.from_stateful_net` + methods. + + """ _empty_net: nn.Module @@ -142,6 +149,82 @@ def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: return output + def get_stateful_net(self, copy: bool = True): + """Returns a stateful version of the network. + + This can be used to initialize parameters. + + Such networks will generally not be callable out-of-the-box and will require some `vmap` + execution. to work + + Args: + copy (bool, optional): if ``True``, a deepcopy of the network is made. + Defaults to ``True``. + + If the parameters are modified in-place (recommended) there is no need to copy the + parameters back into the MARL module. + See :meth:`~.from_stateful_net` for details on how to re-populate the MARL model with + parameters that have been re-initialized out-of-place. + + Examples: + >>> from torchrl.modules import MultiAgentMLP + >>> import torch + >>> n_agents = 6 + >>> n_agent_inputs=3 + >>> n_agent_outputs=2 + >>> batch = 64 + >>> obs = torch.zeros(batch, n_agents, n_agent_inputs) + >>> mlp = MultiAgentMLP( + ... n_agent_inputs=n_agent_inputs, + ... n_agent_outputs=n_agent_outputs, + ... n_agents=n_agents, + ... centralized=False, + ... share_params=False, + ... depth=2, + ... ) + >>> snet = mlp.get_stateful_net() + >>> def init(module): + ... if hasattr(module, "weight"): + ... torch.nn.init.kaiming_normal_(module.weight) + >>> snet.apply(init) + >>> # If the module has been updated out-of-place (not the case here) we can reset the params + >>> mlp.from_stateful_net(snet) + + """ + if copy: + try: + net = deepcopy(self._empty_net) + except RuntimeError as err: + raise RuntimeError( + "Failed to deepcopy the module, consider using copy=False." + ) from err + else: + net = self._empty_net + self.params.to_module(net) + return net + + @abc.abstractmethod + def from_stateful_net(self, stateful_net: nn.Module): + """Populates the parameters given a stateful version of the network. + + See :meth:`~.get_stateful_net` for details on how to gather a stateful version of the network. + + Args: + stateful_net (nn.Module): the stateful network from which the params should be + gathered. + + """ + params = TensorDict.from_module(stateful_net, as_module=True) + keyset0 = set(params.keys(True, True)) + keyset1 = set(self.params.keys(True, True)) + if keyset0 != keyset1: + raise RuntimeError( + f"The keys of params and provided module differ: " + f"{keyset1-keyset0} are in self.params and not in the module, " + f"{keyset0-keyset1} are in the module but not in self.params." + ) + self.params.data.update_(params.data) + def __repr__(self): empty_net = self.__dict__["_empty_net"] with self.params.to_module(empty_net): @@ -212,6 +295,10 @@ class MultiAgentMLP(MultiAgentNetBase): default: nn.Tanh. **kwargs: for :class:`torchrl.modules.models.MLP` can be passed to customize the MLPs. + .. note:: to initialize the MARL module parameters with the `torch.nn.init` + module, please refer to :meth:`~.get_stateful_net` and :meth:`~.from_stateful_net` + methods. + Examples: >>> from torchrl.modules import MultiAgentMLP >>> import torch @@ -219,8 +306,8 @@ class MultiAgentMLP(MultiAgentNetBase): >>> n_agent_inputs=3 >>> n_agent_outputs=2 >>> batch = 64 - >>> obs = torch.zeros(batch, n_agents, n_agent_inputs - First let's instantiate a local network shared by all agents (e.g. a parameter-shared policy) + >>> obs = torch.zeros(batch, n_agents, n_agent_inputs) + >>> # instantiate a local network shared by all agents (e.g. a parameter-shared policy) >>> mlp = MultiAgentMLP( ... n_agent_inputs=n_agent_inputs, ... n_agent_outputs=n_agent_outputs, @@ -357,6 +444,10 @@ class MultiAgentConvNet(MultiAgentNetBase): It expects inputs with shape ``(*B, n_agents, channels, x, y)``. + .. note:: to initialize the MARL module parameters with the `torch.nn.init` + module, please refer to :meth:`~.get_stateful_net` and :meth:`~.from_stateful_net` + methods. + Args: n_agents (int): number of agents. centralized (bool): If ``True``, each agent will use the inputs of all agents to compute its output, resulting in input of shape ``(*B, n_agents * channels, x, y)``. Otherwise, each agent will only use its data as input. @@ -388,7 +479,7 @@ class MultiAgentConvNet(MultiAgentNetBase): >>> n_agents = 7 >>> channels, x, y = 3, 100, 100 >>> obs = torch.randn(*batch, n_agents, channels, x, y) - >>> # First lets consider a centralized network with shared parameters. + >>> # Let's consider a centralized network with shared parameters. >>> cnn = MultiAgentConvNet( ... n_agents, ... centralized = True,