diff --git a/tensordict/nn/functional_modules.py b/tensordict/nn/functional_modules.py index e9d570b1a..bed1ba00f 100644 --- a/tensordict/nn/functional_modules.py +++ b/tensordict/nn/functional_modules.py @@ -311,7 +311,7 @@ def incompatible_error(): flat_outputs.append(out) return tree_unflatten(flat_outputs, output_spec) - vmap_src._unwrap_batched = _unwrap_batched + vmap_src._unwrap_batched = _unwrap_batched # Tensordict-compatible Functional modules