Skip to content

Commit

Permalink
[Feature,Doc] get_stateful_net and document MARL initialization
Browse files Browse the repository at this point in the history
ghstack-source-id: 96f3be31ee968ab4431d0717f6168564bb598391
Pull Request resolved: #2309
  • Loading branch information
vmoens committed Jul 23, 2024
1 parent 59c3374 commit b5ddef3
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 7 deletions.
9 changes: 7 additions & 2 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -355,13 +355,18 @@ algorithms, such as DQN, DDPG or Dreamer.
Multi-agent-specific modules
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

These networks implement models that can be used in
multi-agent contexts.
These networks implement models that can be used in multi-agent contexts.
They use :func:`~torch.vmap` to execute multiple networks all at once on the
network inputs. Because the parameters are batched, initialization may differ
from what is usually done with other PyTorch modules, see
:meth:`~torchrl.modules.MultiAgentNetBase.get_stateful_net`
for more information.

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

MultiAgentNetBase
MultiAgentMLP
MultiAgentConvNet
QMixer
Expand Down
52 changes: 52 additions & 0 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,58 @@ def test_multiagent_mlp(
agent_dim={-2}\)"""
assert re.match(pattern, str(mlp), re.DOTALL)

@retry(AssertionError, 5)
@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize("share_params", [True, False])
@pytest.mark.parametrize("centralized", [True, False])
@pytest.mark.parametrize("n_agent_inputs", [6, None])
@pytest.mark.parametrize("batch", [(4,), (4, 3), ()])
def test_multiagent_mlp_init(
self,
n_agents,
centralized,
share_params,
batch,
n_agent_inputs,
n_agent_outputs=2,
):
torch.manual_seed(1)
mlp = MultiAgentMLP(
n_agent_inputs=n_agent_inputs,
n_agent_outputs=n_agent_outputs,
n_agents=n_agents,
centralized=centralized,
share_params=share_params,
depth=2,
)
if n_agent_inputs is None:
n_agent_inputs = 6
td = self._get_mock_input_td(n_agents, n_agent_inputs, batch=batch)
obs = td.get(("agents", "observation"))
mlp(obs)
snet = mlp.get_stateful_net()
assert snet is not mlp._empty_net

def zero_inplace(mod):
if hasattr(mod, "weight"):
mod.weight.data *= 0
if hasattr(mod, "bias"):
mod.bias.data *= 0

snet.apply(zero_inplace)
assert (mlp.params == 0).all()

def one_outofplace(mod):
if hasattr(mod, "weight"):
mod.weight = nn.Parameter(torch.ones_like(mod.weight.data))
if hasattr(mod, "bias"):
mod.bias = nn.Parameter(torch.ones_like(mod.bias.data))

snet.apply(one_outofplace)
assert (mlp.params == 0).all()
mlp.from_stateful_net(snet)
assert (mlp.params == 1).all()

def test_multiagent_mlp_lazy(self):
mlp = MultiAgentMLP(
n_agent_inputs=None,
Expand Down
1 change: 1 addition & 0 deletions torchrl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
MLP,
MultiAgentConvNet,
MultiAgentMLP,
MultiAgentNetBase,
NoisyLazyLinear,
NoisyLinear,
ObsDecoder,
Expand Down
8 changes: 7 additions & 1 deletion torchrl/modules/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,11 @@
MLP,
OnlineDTActor,
)
from .multiagent import MultiAgentConvNet, MultiAgentMLP, QMixer, VDNMixer
from .multiagent import (
MultiAgentConvNet,
MultiAgentMLP,
MultiAgentNetBase,
QMixer,
VDNMixer,
)
from .utils import Squeeze2dLayer, SqueezeLayer
99 changes: 95 additions & 4 deletions torchrl/modules/models/multiagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import abc
from copy import deepcopy
from textwrap import indent
from typing import Optional, Sequence, Tuple, Type, Union

Expand All @@ -21,7 +22,13 @@


class MultiAgentNetBase(nn.Module):
"""A base class for multi-agent networks."""
"""A base class for multi-agent networks.
.. note:: to initialize the MARL module parameters with the `torch.nn.init`
module, please refer to :meth:`~.get_stateful_net` and :meth:`~.from_stateful_net`
methods.
"""

_empty_net: nn.Module

Expand Down Expand Up @@ -142,6 +149,82 @@ def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor:

return output

def get_stateful_net(self, copy: bool = True):
"""Returns a stateful version of the network.
This can be used to initialize parameters.
Such networks will generally not be callable out-of-the-box and will require some `vmap`
execution. to work
Args:
copy (bool, optional): if ``True``, a deepcopy of the network is made.
Defaults to ``True``.
If the parameters are modified in-place (recommended) there is no need to copy the
parameters back into the MARL module.
See :meth:`~.from_stateful_net` for details on how to re-populate the MARL model with
parameters that have been re-initialized out-of-place.
Examples:
>>> from torchrl.modules import MultiAgentMLP
>>> import torch
>>> n_agents = 6
>>> n_agent_inputs=3
>>> n_agent_outputs=2
>>> batch = 64
>>> obs = torch.zeros(batch, n_agents, n_agent_inputs)
>>> mlp = MultiAgentMLP(
... n_agent_inputs=n_agent_inputs,
... n_agent_outputs=n_agent_outputs,
... n_agents=n_agents,
... centralized=False,
... share_params=False,
... depth=2,
... )
>>> snet = mlp.get_stateful_net()
>>> def init(module):
... if hasattr(module, "weight"):
... torch.nn.init.kaiming_normal_(module.weight)
>>> snet.apply(init)
>>> # If the module has been updated out-of-place (not the case here) we can reset the params
>>> mlp.from_stateful_net(snet)
"""
if copy:
try:
net = deepcopy(self._empty_net)
except RuntimeError as err:
raise RuntimeError(
"Failed to deepcopy the module, consider using copy=False."
) from err
else:
net = self._empty_net
self.params.to_module(net)
return net

@abc.abstractmethod
def from_stateful_net(self, stateful_net: nn.Module):
"""Populates the parameters given a stateful version of the network.
See :meth:`~.get_stateful_net` for details on how to gather a stateful version of the network.
Args:
stateful_net (nn.Module): the stateful network from which the params should be
gathered.
"""
params = TensorDict.from_module(stateful_net, as_module=True)
keyset0 = set(params.keys(True, True))
keyset1 = set(self.params.keys(True, True))
if keyset0 != keyset1:
raise RuntimeError(
f"The keys of params and provided module differ: "
f"{keyset1-keyset0} are in self.params and not in the module, "
f"{keyset0-keyset1} are in the module but not in self.params."
)
self.params.data.update_(params.data)

def __repr__(self):
empty_net = self.__dict__["_empty_net"]
with self.params.to_module(empty_net):
Expand Down Expand Up @@ -212,15 +295,19 @@ class MultiAgentMLP(MultiAgentNetBase):
default: nn.Tanh.
**kwargs: for :class:`torchrl.modules.models.MLP` can be passed to customize the MLPs.
.. note:: to initialize the MARL module parameters with the `torch.nn.init`
module, please refer to :meth:`~.get_stateful_net` and :meth:`~.from_stateful_net`
methods.
Examples:
>>> from torchrl.modules import MultiAgentMLP
>>> import torch
>>> n_agents = 6
>>> n_agent_inputs=3
>>> n_agent_outputs=2
>>> batch = 64
>>> obs = torch.zeros(batch, n_agents, n_agent_inputs
First let's instantiate a local network shared by all agents (e.g. a parameter-shared policy)
>>> obs = torch.zeros(batch, n_agents, n_agent_inputs)
>>> # instantiate a local network shared by all agents (e.g. a parameter-shared policy)
>>> mlp = MultiAgentMLP(
... n_agent_inputs=n_agent_inputs,
... n_agent_outputs=n_agent_outputs,
Expand Down Expand Up @@ -357,6 +444,10 @@ class MultiAgentConvNet(MultiAgentNetBase):
It expects inputs with shape ``(*B, n_agents, channels, x, y)``.
.. note:: to initialize the MARL module parameters with the `torch.nn.init`
module, please refer to :meth:`~.get_stateful_net` and :meth:`~.from_stateful_net`
methods.
Args:
n_agents (int): number of agents.
centralized (bool): If ``True``, each agent will use the inputs of all agents to compute its output, resulting in input of shape ``(*B, n_agents * channels, x, y)``. Otherwise, each agent will only use its data as input.
Expand Down Expand Up @@ -388,7 +479,7 @@ class MultiAgentConvNet(MultiAgentNetBase):
>>> n_agents = 7
>>> channels, x, y = 3, 100, 100
>>> obs = torch.randn(*batch, n_agents, channels, x, y)
>>> # First lets consider a centralized network with shared parameters.
>>> # Let's consider a centralized network with shared parameters.
>>> cnn = MultiAgentConvNet(
... n_agents,
... centralized = True,
Expand Down

0 comments on commit b5ddef3

Please sign in to comment.