From 879678435e52d2610e4bb820cb4878cd01caec7e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 10 Aug 2023 08:55:34 +0200 Subject: [PATCH] [BugFix] Fix params.clone (#509) --- tensordict/nn/params.py | 30 +++++++++++++++++++++++++++++- test/test_nn.py | 23 ++++++++++++++++++++++- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index eb4e55980..d8378482f 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -440,7 +440,35 @@ def cuda(self, device=None): return TensorDictParams(params) def clone(self, recurse: bool = True) -> TensorDictBase: - return TensorDictParams(self._param_td.clone(recurse=recurse)) + """Clones the TensorDictParams. + + The effect of this call is different from a regular torch.Tensor.clone call + in that it will create a TensorDictParams instance with a new copy of the + parameters and buffers __detached__ from the current graph. + + See :meth:`tensordict.TensorDictBase.clone` for more info on the clone + method. + + """ + if not recurse: + return TensorDictParams(self._param_td.clone(False), no_convert=True) + out = {} + for key, val in self._param_td.items(True, True): + if isinstance(val, nn.Parameter): + out[key] = nn.Parameter( + val.data.clone(), requires_grad=val.requires_grad + ) + else: + out[key] = Buffer(val.data.clone(), requires_grad=val.requires_grad) + return TensorDictParams( + TensorDict( + out, + batch_size=self._param_td.batch_size, + device=self._param_td.device, + names=self._param_td.names, + ), + no_convert=True, + ) @_fallback def chunk(self, chunks: int, dim: int = 0) -> tuple[TensorDictBase, ...]: diff --git a/test/test_nn.py b/test/test_nn.py index 198a038ee..45b4a7aee 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -28,7 +28,7 @@ from tensordict.nn.ensemble import EnsembleModule from tensordict.nn.functional_modules import is_functional, make_functional from tensordict.nn.probabilistic import InteractionType, set_interaction_type -from tensordict.nn.utils import set_skip_existing, skip_existing +from tensordict.nn.utils import Buffer, set_skip_existing, skip_existing from torch import nn from torch.distributions import Normal @@ -2875,6 +2875,27 @@ def test_td_params_post_hook(self): assert not param_td.get("e").requires_grad assert not param_td.get(("a", "b", "c")).requires_grad + def test_tdparams_clone(self): + td = TensorDict( + { + "a": { + "b": {"c": nn.Parameter(torch.zeros((), requires_grad=True))}, + "d": Buffer(torch.zeros((), requires_grad=False)), + }, + "e": nn.Parameter(torch.zeros((), requires_grad=True)), + "f": Buffer(torch.zeros((), requires_grad=False)), + }, + [], + ) + td = TensorDictParams(td, no_convert=True) + tdclone = td.clone() + assert type(tdclone) == type(td) # noqa + for key, val in tdclone.items(True, True): + assert type(val) == type(td.get(key)) # noqa + assert val.requires_grad == td.get(key).requires_grad + assert val.data_ptr() != td.get(key).data_ptr() + assert (val == td.get(key)).all() + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args()