Skip to content

Commit

Permalink
[BugFix] Fix vmap monkey patching
Browse files Browse the repository at this point in the history
ghstack-source-id: 69f4795d5cb81db7b79d9c98626414c4cc5ce886
Pull Request resolved: #1009
  • Loading branch information
vmoens committed Sep 25, 2024
1 parent d9fece7 commit eba0769
Showing 1 changed file with 0 additions and 9 deletions.
9 changes: 0 additions & 9 deletions tensordict/nn/functional_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,15 +270,6 @@ def _unwrap_batched(
with _exclude_td_from_pytree():
flat_batched_outputs, output_spec = tree_flatten(batched_outputs)

for out in flat_batched_outputs:
# Change here:
if isinstance(out, torch.Tensor) or is_tensor_collection(out):
continue
raise ValueError(
f"vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return "
f"Tensors, got type {type(out)} as a return."
)

def incompatible_error():
raise ValueError(
f"vmap({_get_name(func)}, ..., out_dims={out_dims})(<inputs>): "
Expand Down

0 comments on commit eba0769

Please sign in to comment.