Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Cudagraphs #986

Merged
merged 24 commits into from
Sep 16, 2024
16 changes: 8 additions & 8 deletions benchmarks/compile/compile_td_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def get_flat_tc():


# Tests runtime of a simple arithmetic op over a highly nested tensordict
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.parametrize("mode", ["compile", "eager"])
@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"])
def test_compile_add_one_nested(mode, dict_type, benchmark):
Expand All @@ -128,7 +128,7 @@ def test_compile_add_one_nested(mode, dict_type, benchmark):


# Tests the speed of copying a nested tensordict
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.parametrize("mode", ["compile", "eager"])
@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"])
def test_compile_copy_nested(mode, dict_type, benchmark):
Expand All @@ -150,7 +150,7 @@ def test_compile_copy_nested(mode, dict_type, benchmark):


# Tests runtime of a simple arithmetic op over a flat tensordict
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.parametrize("mode", ["compile", "eager"])
@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"])
def test_compile_add_one_flat(mode, dict_type, benchmark):
Expand All @@ -177,7 +177,7 @@ def test_compile_add_one_flat(mode, dict_type, benchmark):
benchmark(func, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.parametrize("mode", ["eager", "compile"])
@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"])
def test_compile_add_self_flat(mode, dict_type, benchmark):
Expand Down Expand Up @@ -207,7 +207,7 @@ def test_compile_add_self_flat(mode, dict_type, benchmark):


# Tests the speed of copying a flat tensordict
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.parametrize("mode", ["compile", "eager"])
@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"])
def test_compile_copy_flat(mode, dict_type, benchmark):
Expand Down Expand Up @@ -235,7 +235,7 @@ def test_compile_copy_flat(mode, dict_type, benchmark):


# Tests the speed of assigning entries to an empty tensordict
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.parametrize("mode", ["compile", "eager"])
@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"])
def test_compile_assign_and_add(mode, dict_type, benchmark):
Expand Down Expand Up @@ -264,7 +264,7 @@ def test_compile_assign_and_add(mode, dict_type, benchmark):
# Tests the speed of assigning entries to a lazy stacked tensordict


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
torch.cuda.is_available(), reason="max recursion depth error with cuda"
)
Expand All @@ -285,7 +285,7 @@ def test_compile_assign_and_add_stack(mode, benchmark):


# Tests indexing speed
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.parametrize("mode", ["compile", "eager"])
@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"])
@pytest.mark.parametrize("index_type", ["tensor", "slice", "int"])
Expand Down
18 changes: 9 additions & 9 deletions benchmarks/compile/tensordict_nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def mlp(device, depth=2, num_cells=32, feature_dim=3):
)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
def test_mod_add(mode, benchmark):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -64,7 +64,7 @@ def test_mod_add(mode, benchmark):
benchmark(module, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
def test_mod_wrap(mode, benchmark):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -80,7 +80,7 @@ def test_mod_wrap(mode, benchmark):
benchmark(module, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
def test_mod_wrap_and_backward(mode, benchmark):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -104,7 +104,7 @@ def module_exec(td):
benchmark(module_exec, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
def test_seq_add(mode, benchmark):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -129,7 +129,7 @@ def delhidden(td):
benchmark(module, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
def test_seq_wrap(mode, benchmark):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -161,7 +161,7 @@ def delhidden(td):
benchmark(module, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.slow
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
def test_seq_wrap_and_backward(mode, benchmark):
Expand Down Expand Up @@ -201,7 +201,7 @@ def module_exec(td):
benchmark(module_exec, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
@pytest.mark.parametrize("functional", [False, True])
def test_func_call_runtime(mode, functional, benchmark):
Expand Down Expand Up @@ -272,7 +272,7 @@ def call(x, td):
benchmark(call, x)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.slow
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -354,7 +354,7 @@ def call(x, td):
benchmark(call_vmap, x, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.slow
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
@pytest.mark.parametrize("plain_decorator", [None, False, True])
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ to build distributions from network outputs and get summary statistics or sample
ProbabilisticTensorDictModule
TensorDictSequential
TensorDictModuleWrapper
CudaGraphModule

Functional
----------
Expand Down
2 changes: 2 additions & 0 deletions tensordict/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,5 @@
set_skip_existing,
skip_existing,
)

from .cudagraphs import CudaGraphModule
36 changes: 31 additions & 5 deletions tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,28 +240,44 @@ def __init__(
self.auto_batch_size = auto_batch_size

def __call__(self, func: Callable) -> Callable:

is_method = inspect.ismethod(func) or (
inspect.isfunction(func)
and func.__code__.co_argcount > 0
and func.__code__.co_varnames[0] == "self"
)
# sanity check
for i, key in enumerate(inspect.signature(func).parameters):
if i == 0:
if (is_method or (key == "self")) and (i == 0):
is_method = True
# skip self
continue
if key != "tensordict":
raise RuntimeError(
"the first argument of the wrapped function must be "
"named 'tensordict'."
f"named 'tensordict'. Got {key} instead."
)
break
# if the env variable was used, we can skip the wrapper altogether
if not _dispatch_td_nn_modules():
return func

@functools.wraps(func)
def wrapper(_self, *args: Any, **kwargs: Any) -> Any:
def wrapper(*args: Any, **kwargs: Any) -> Any:
if is_method:
_self = args[0]
args = args[1:]
else:
_self = None
if not _dispatch_td_nn_modules():
return func(_self, *args, **kwargs)

source = self.source
if isinstance(source, str):
if _self is None:
raise RuntimeError(
"The in keys must be passed to dispatch when func is not a method but a function."
)
source = getattr(_self, source)
tensordict = None
if len(args):
Expand All @@ -271,6 +287,10 @@ def wrapper(_self, *args: Any, **kwargs: Any) -> Any:
tensordict_values = {}
dest = self.dest
if isinstance(dest, str):
if _self is None:
raise RuntimeError(
"The in keys must be passed to dispatch when func is not a method but a function."
)
dest = getattr(_self, dest)
for key in source:
expected_key = self.separator.join(_unravel_key_to_tuple(key))
Expand All @@ -293,10 +313,16 @@ def wrapper(_self, *args: Any, **kwargs: Any) -> Any:
tensordict_values,
batch_size=torch.Size([]) if not self.auto_batch_size else None,
)
out = func(_self, tensordict, *args, **kwargs)
if _self is not None:
out = func(_self, tensordict, *args, **kwargs)
else:
out = func(tensordict, *args, **kwargs)

out = tuple(out[key] for key in dest)
return out[0] if len(out) == 1 else out
return func(_self, tensordict, *args, **kwargs)
if _self is not None:
return func(_self, tensordict, *args, **kwargs)
return func(tensordict, *args, **kwargs)

return wrapper

Expand Down
Loading
Loading