Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Sep 17, 2024
2 parents 367d817 + ea42301 commit 22f1b52
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8562,17 +8562,17 @@ def _convert_to_tensor(
elif isinstance(array, np.ndarray) and array.dtype.names is not None:
return TensorDictBase.from_struct_array(array, device=self.device)
elif isinstance(array, np.ndarray):
castable = array.dtype.kind in ("i", "f")
castable = array.dtype.kind in ("c", "i", "f", "b", "u")
elif isinstance(array, np.bool_):
castable = True
array = array.item()
elif isinstance(array, (list, tuple)):
array = np.asarray(array)
castable = array.dtype.kind in ("i", "f")
castable = array.dtype.kind in ("c", "i", "f", "b", "u")
elif hasattr(array, "numpy"):
# tf.Tensor with no shape can't be converted otherwise
array = array.numpy()
castable = array.dtype.kind in ("i", "f")
castable = array.dtype.kind in ("c", "i", "f", "b", "u")
if castable:
return torch.as_tensor(array, device=self.device)
else:
Expand Down
3 changes: 3 additions & 0 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,7 @@ def remove_hidden(td):
assert module_compile(td) is not td

def test_dispatch_nontensor(self, mode):
torch._dynamo.reset_code_caches()

# Non tensor
x = torch.randn(3)
Expand All @@ -624,6 +625,8 @@ def test_dispatch_nontensor(self, mode):
torch.testing.assert_close(mod(x=x, y=y), mod_compile(x=x, y=y))

def test_dispatch_tensor(self, mode):
torch._dynamo.reset_code_caches()

x = torch.randn(3)
y = torch.randn(3)
mod = Seq(
Expand Down

0 comments on commit 22f1b52

Please sign in to comment.