diff --git a/test/test_modules.py b/test/test_modules.py index feff5ea6819..11cf11f41e6 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -863,6 +863,13 @@ def test_multiagent_mlp_init( share_params=share_params, depth=2, ) + for m in mlp.modules(): + if isinstance(m, nn.Linear): + assert not isinstance(m.weight, nn.Parameter) + assert m.weight.device == torch.device("meta") + break + else: + raise RuntimeError("could not find a Linear module") if n_agent_inputs is None: n_agent_inputs = 6 td = self._get_mock_input_td(n_agents, n_agent_inputs, batch=batch)