Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jul 25, 2024
1 parent c6ef080 commit 6268443
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion torchrl/modules/models/multiagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,23 @@ def __init__(
break
self.initialized = initialized
self._make_params(agent_networks)
# We make sure all params and buffers are on 'meta' device
# To do this, we set the device keyword arg to 'meta', we also temporarily change
# the default device. Finally, we convert all params to 'meta' tensors that are not params.
kwargs["device"] = "meta"
self.__dict__["_empty_net"] = self._build_single_net(**kwargs)
with torch.device("meta"):
try:
self._empty_net = self._build_single_net(**kwargs)
except NotImplementedError as err:
if "Cannot copy out of meta tensor" in str(err):
raise RuntimeError(
"The network was built using `factory().to(device), build the network directly "
"on device using `factory(device=device)` instead."
)
# Remove all parameters
TensorDict.from_module(self._empty_net).data.to("meta").to_module(
self._empty_net
)

@property
def vmap_randomness(self):
Expand Down

0 comments on commit 6268443

Please sign in to comment.