Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Sep 25, 2024
1 parent ae8dd4f commit 7549295
Showing 1 changed file with 1 addition and 10 deletions.
11 changes: 1 addition & 10 deletions tensordict/nn/functional_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,15 +270,6 @@ def _unwrap_batched(
with _exclude_td_from_pytree():
flat_batched_outputs, output_spec = tree_flatten(batched_outputs)

for out in flat_batched_outputs:
# Change here:
if isinstance(out, torch.Tensor) or is_tensor_collection(out):
continue
raise ValueError(
f"vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return "
f"Tensors, got type {type(out)} as a return."
)

def incompatible_error():
raise ValueError(
f"vmap({_get_name(func)}, ..., out_dims={out_dims})(<inputs>): "
Expand Down Expand Up @@ -320,7 +311,7 @@ def incompatible_error():
flat_outputs.append(out)
return tree_unflatten(flat_outputs, output_spec)

vmap_src._unwrap_batched = _unwrap_batched
vmap_src._unwrap_batched = _unwrap_batched


# Tensordict-compatible Functional modules
Expand Down

0 comments on commit 7549295

Please sign in to comment.