Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Sep 17, 2024
1 parent 5997674 commit eeed744
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
8 changes: 4 additions & 4 deletions tensordict/_torch_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,11 +419,11 @@ def _stack(
if not len(list_of_tensordicts):
raise RuntimeError("list_of_tensordicts cannot be empty")
is_tc = any(is_tensorclass(td) for td in list_of_tensordicts)
if all(is_non_tensor(td) for td in list_of_tensordicts):
from tensordict.tensorclass import NonTensorData
if is_tc:
if all(is_non_tensor(td) for td in list_of_tensordicts):
from tensordict.tensorclass import NonTensorData

return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim)
elif is_tc:
return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim)
tc_type = type(list_of_tensordicts[0])
list_of_tensordicts = [tc._tensordict for tc in list_of_tensordicts]

Expand Down
8 changes: 4 additions & 4 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8559,13 +8559,13 @@ def _convert_to_tensor(
castable = None
if isinstance(array, (float, int, bool)):
castable = True
elif isinstance(array, np.ndarray) and array.dtype.names is not None:
return TensorDictBase.from_struct_array(array, device=self.device)
elif isinstance(array, np.ndarray):
castable = array.dtype.kind in ("c", "i", "f", "b", "u")
elif isinstance(array, np.bool_):
castable = True
array = array.item()
elif isinstance(array, (np.ndarray, np.number)):
if array.dtype.names is not None:
return TensorDictBase.from_struct_array(array, device=self.device)
castable = array.dtype.kind in ("c", "i", "f", "b", "u")
elif isinstance(array, (list, tuple)):
array = np.asarray(array)
castable = array.dtype.kind in ("c", "i", "f", "b", "u")
Expand Down

0 comments on commit eeed744

Please sign in to comment.