diff --git a/tensordict/base.py b/tensordict/base.py index f4f390487..f1493e1af 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -8562,17 +8562,17 @@ def _convert_to_tensor( elif isinstance(array, np.ndarray) and array.dtype.names is not None: return TensorDictBase.from_struct_array(array, device=self.device) elif isinstance(array, np.ndarray): - castable = array.dtype.kind in ("i", "f") + castable = array.dtype.kind in ("c", "i", "f", "b", "u") elif isinstance(array, np.bool_): castable = True array = array.item() elif isinstance(array, (list, tuple)): array = np.asarray(array) - castable = array.dtype.kind in ("i", "f") + castable = array.dtype.kind in ("c", "i", "f", "b", "u") elif hasattr(array, "numpy"): # tf.Tensor with no shape can't be converted otherwise array = array.numpy() - castable = array.dtype.kind in ("i", "f") + castable = array.dtype.kind in ("c", "i", "f", "b", "u") if castable: return torch.as_tensor(array, device=self.device) else: diff --git a/test/test_compile.py b/test/test_compile.py index 83edb746f..44d92822a 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -611,6 +611,7 @@ def remove_hidden(td): assert module_compile(td) is not td def test_dispatch_nontensor(self, mode): + torch._dynamo.reset_code_caches() # Non tensor x = torch.randn(3) @@ -624,6 +625,8 @@ def test_dispatch_nontensor(self, mode): torch.testing.assert_close(mod(x=x, y=y), mod_compile(x=x, y=y)) def test_dispatch_tensor(self, mode): + torch._dynamo.reset_code_caches() + x = torch.randn(3) y = torch.randn(3) mod = Seq(