Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Sep 26, 2024
1 parent eb49a70 commit b628d70
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions tensordict/nn/cudagraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import contextlib
import functools
import os
import warnings
Expand Down Expand Up @@ -168,6 +169,12 @@ def __init__(
if not isinstance(warmup, int) or warmup < 1:
raise ValueError("warmup must be an integer greater than 0.")
self._warmup = warmup
if torch.cuda.is_available():
self._warmup_stream = torch.cuda.Stream()
self._warmup_stream_cm = torch.cuda.Stream(self._warmup_stream)
else:
self._warmup_stream = None
self._warmup_stream_cm = contextlib.nullcontext()

if hasattr(module, "in_keys"):
self.in_keys = module.in_keys
Expand Down Expand Up @@ -202,12 +209,17 @@ def _call(
**kwargs: Any,
) -> Any:
if self.counter < self._warmup:
if tensordict_out is not None:
kwargs["tensordict_out"] = tensordict_out
out = self.module(tensordict, *args, **kwargs)
if self._out_matches_in is None:
self._out_matches_in = out is tensordict
if self._warmup_stream is not None:
self._warmup_stream.wait_stream(torch.cuda.current_stream())
with self._warmup_stream_cm:
if tensordict_out is not None:
kwargs["tensordict_out"] = tensordict_out
out = self.module(tensordict, *args, **kwargs)
if self._out_matches_in is None:
self._out_matches_in = out is tensordict
self.counter += self._has_cuda
if self._warmup_stream is not None:
torch.cuda.current_stream().wait_stream(self._warmup_stream)
return out
elif self.counter == self._warmup:
if tensordict.device is None:
Expand Down Expand Up @@ -273,7 +285,12 @@ def check_tensor_id(name, t0, t1):

def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
if self.counter < self._warmup:
out = self.module(*args, **kwargs)
if self._warmup_stream is not None:
self._warmup_stream.wait_stream(torch.cuda.current_stream())
with self._warmup_stream_cm:
out = self.module(*args, **kwargs)
if self._warmup_stream is not None:
torch.cuda.current_stream().wait_stream(self._warmup_stream)
self.counter += self._has_cuda
return out
elif self.counter == self._warmup:
Expand Down

0 comments on commit b628d70

Please sign in to comment.