Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Sep 16, 2024
2 parents d9fda4a + 72afb16 commit 367d817
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 3 deletions.
2 changes: 1 addition & 1 deletion tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8566,7 +8566,7 @@ def _convert_to_tensor(
elif isinstance(array, np.bool_):
castable = True
array = array.item()
elif isinstance(array, list):
elif isinstance(array, (list, tuple)):
array = np.asarray(array)
castable = array.dtype.kind in ("i", "f")
elif hasattr(array, "numpy"):
Expand Down
11 changes: 10 additions & 1 deletion tensordict/nn/cudagraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,16 @@
from tensordict.utils import strtobool
from torch import Tensor

from torch.utils._pytree import SUPPORTED_NODES, tree_leaves, tree_map
from torch.utils._pytree import SUPPORTED_NODES, tree_map

try:
from torch.utils._pytree import tree_leaves
except ImportError:
from torch.utils._pytree import tree_flatten

def tree_leaves(pytree):
"""Torch 2.0 compatible version of tree_leaves."""
return tree_flatten(pytree)[0]


class CudaGraphModule:
Expand Down
1 change: 1 addition & 0 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def __subclasscheck__(self, subclass):
"_get_names_idx", # no wrap output
"_get_str",
"_get_tuple",
"_get_tuple_maybe_non_tensor",
"_has_names",
"_items_list",
"_maybe_names",
Expand Down
7 changes: 6 additions & 1 deletion test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
import argparse
import contextlib
import importlib.util
import os
from pathlib import Path
from typing import Any
Expand All @@ -19,6 +20,8 @@

TORCH_VERSION = version.parse(torch.__version__).base_version

_has_onnx = importlib.util.find_spec("onnxruntime", None) is not None


def test_vmap_compile():
# Since we monkey patch vmap we need to make sure compile is happy with it
Expand Down Expand Up @@ -765,13 +768,15 @@ def call(x, td):

class TestExport:
def test_export_module(self):
torch._dynamo.reset_code_caches()
tdm = Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"])
x = torch.randn(3)
y = torch.randn(3)
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y})
assert (out.module()(x=x, y=y) == tdm(x=x, y=y)).all()

def test_export_seq(self):
torch._dynamo.reset_code_caches()
tdm = Seq(
Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"]),
Mod(lambda z, x: z + x, in_keys=["z", "x"], out_keys=["out"]),
Expand All @@ -782,7 +787,7 @@ def test_export_seq(self):
torch.testing.assert_close(out.module()(x=x, y=y), tdm(x=x, y=y))


@pytest.mark.slow
@pytest.mark.skipif(not _has_onnx, reason="ONNX is not available")
class TestONNXExport:
def test_onnx_export_module(self, tmpdir):
tdm = Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"])
Expand Down

0 comments on commit 367d817

Please sign in to comment.