diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index d04256008..ad46217bb 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -1915,7 +1915,10 @@ def update_( # ) if clone: value = value.clone() - self.set_(key, value) + if is_tensor_collection(value) or isinstance(value, dict): + self._get_str(key, default=NO_DEFAULT).update_(value) + else: + self.set_(key, value) return self def update_at_( @@ -7885,7 +7888,10 @@ def update_( ) if clone: value = value.clone() - self.set_(key, value, **kwargs) + if is_tensor_collection(value) or isinstance(value, dict): + self._get_str(key, default=NO_DEFAULT).update_(value) + else: + self.set_(key, value, **kwargs) return self def update_at_(