diff --git a/tensordict/nn/cudagraphs.py b/tensordict/nn/cudagraphs.py index 5250bbd4a..fecb4e456 100644 --- a/tensordict/nn/cudagraphs.py +++ b/tensordict/nn/cudagraphs.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +import contextlib import functools import os import warnings @@ -168,6 +169,12 @@ def __init__( if not isinstance(warmup, int) or warmup < 1: raise ValueError("warmup must be an integer greater than 0.") self._warmup = warmup + if torch.cuda.is_available(): + self._warmup_stream = torch.cuda.Stream() + self._warmup_stream_cm = torch.cuda.Stream(self._warmup_stream) + else: + self._warmup_stream = None + self._warmup_stream_cm = contextlib.nullcontext() if hasattr(module, "in_keys"): self.in_keys = module.in_keys @@ -202,12 +209,17 @@ def _call( **kwargs: Any, ) -> Any: if self.counter < self._warmup: - if tensordict_out is not None: - kwargs["tensordict_out"] = tensordict_out - out = self.module(tensordict, *args, **kwargs) - if self._out_matches_in is None: - self._out_matches_in = out is tensordict + if self._warmup_stream is not None: + self._warmup_stream.wait_stream(torch.cuda.current_stream()) + with self._warmup_stream_cm: + if tensordict_out is not None: + kwargs["tensordict_out"] = tensordict_out + out = self.module(tensordict, *args, **kwargs) + if self._out_matches_in is None: + self._out_matches_in = out is tensordict self.counter += self._has_cuda + if self._warmup_stream is not None: + torch.cuda.current_stream().wait_stream(self._warmup_stream) return out elif self.counter == self._warmup: if tensordict.device is None: @@ -273,7 +285,12 @@ def check_tensor_id(name, t0, t1): def _call(*args: torch.Tensor, **kwargs: torch.Tensor): if self.counter < self._warmup: - out = self.module(*args, **kwargs) + if self._warmup_stream is not None: + self._warmup_stream.wait_stream(torch.cuda.current_stream()) + with self._warmup_stream_cm: + out = self.module(*args, **kwargs) + if self._warmup_stream is not None: + torch.cuda.current_stream().wait_stream(self._warmup_stream) self.counter += self._has_cuda return out elif self.counter == self._warmup: