Skip to content

Commit

Permalink
[BugFix] Fix params.clone (#509)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Aug 10, 2023
1 parent ad8609c commit 8796784
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 2 deletions.
30 changes: 29 additions & 1 deletion tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,35 @@ def cuda(self, device=None):
return TensorDictParams(params)

def clone(self, recurse: bool = True) -> TensorDictBase:
return TensorDictParams(self._param_td.clone(recurse=recurse))
"""Clones the TensorDictParams.
The effect of this call is different from a regular torch.Tensor.clone call
in that it will create a TensorDictParams instance with a new copy of the
parameters and buffers __detached__ from the current graph.
See :meth:`tensordict.TensorDictBase.clone` for more info on the clone
method.
"""
if not recurse:
return TensorDictParams(self._param_td.clone(False), no_convert=True)
out = {}
for key, val in self._param_td.items(True, True):
if isinstance(val, nn.Parameter):
out[key] = nn.Parameter(
val.data.clone(), requires_grad=val.requires_grad
)
else:
out[key] = Buffer(val.data.clone(), requires_grad=val.requires_grad)
return TensorDictParams(
TensorDict(
out,
batch_size=self._param_td.batch_size,
device=self._param_td.device,
names=self._param_td.names,
),
no_convert=True,
)

@_fallback
def chunk(self, chunks: int, dim: int = 0) -> tuple[TensorDictBase, ...]:
Expand Down
23 changes: 22 additions & 1 deletion test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from tensordict.nn.ensemble import EnsembleModule
from tensordict.nn.functional_modules import is_functional, make_functional
from tensordict.nn.probabilistic import InteractionType, set_interaction_type
from tensordict.nn.utils import set_skip_existing, skip_existing
from tensordict.nn.utils import Buffer, set_skip_existing, skip_existing
from torch import nn
from torch.distributions import Normal

Expand Down Expand Up @@ -2875,6 +2875,27 @@ def test_td_params_post_hook(self):
assert not param_td.get("e").requires_grad
assert not param_td.get(("a", "b", "c")).requires_grad

def test_tdparams_clone(self):
td = TensorDict(
{
"a": {
"b": {"c": nn.Parameter(torch.zeros((), requires_grad=True))},
"d": Buffer(torch.zeros((), requires_grad=False)),
},
"e": nn.Parameter(torch.zeros((), requires_grad=True)),
"f": Buffer(torch.zeros((), requires_grad=False)),
},
[],
)
td = TensorDictParams(td, no_convert=True)
tdclone = td.clone()
assert type(tdclone) == type(td) # noqa
for key, val in tdclone.items(True, True):
assert type(val) == type(td.get(key)) # noqa
assert val.requires_grad == td.get(key).requires_grad
assert val.data_ptr() != td.get(key).data_ptr()
assert (val == td.get(key)).all()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down

0 comments on commit 8796784

Please sign in to comment.