diff --git a/test/test_rb.py b/test/test_rb.py index 03787d9f1f3..4243917c627 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -1232,6 +1232,35 @@ def test_slice_rng(self): c = rb.sample(32) assert (a != c).any() + def test_rng_state_dict(self): + state = torch.random.get_rng_state() + rb = ReplayBufferRNG(sampler=RandomSampler(), storage=LazyTensorStorage(100)) + rb.extend(torch.arange(100)) + rb._rng.set_state(state) + sd = rb.state_dict() + assert sd.get("_rng") is not None + a = rb.sample(32) + + rb.load_state_dict(sd) + b = rb.sample(32) + assert (a == b).all() + c = rb.sample(32) + assert (a != c).any() + + def test_rng_dumps(self, tmpdir): + state = torch.random.get_rng_state() + rb = ReplayBufferRNG(sampler=RandomSampler(), storage=LazyTensorStorage(100)) + rb.extend(torch.arange(100)) + rb._rng.set_state(state) + rb.dumps(tmpdir) + a = rb.sample(32) + + rb.loads(tmpdir) + b = rb.sample(32) + assert (a == b).all() + c = rb.sample(32) + assert (a != c).any() + @pytest.mark.parametrize( "rbtype,storage", diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index b8387f97ddd..7c3a2a950a4 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -23,6 +23,7 @@ is_tensorclass, LazyStackedTensorDict, NestedKey, + TensorDict, TensorDictBase, unravel_key, ) @@ -269,7 +270,9 @@ def __init__( raise ValueError("dim_extend must be a positive value.") self.dim_extend = dim_extend self._storage.checkpointer = checkpointer + self.set_rng(generator=generator) + def set_rng(self, generator): self._rng = generator self._storage._rng = generator self._sampler._rng = generator @@ -426,6 +429,9 @@ def state_dict(self) -> Dict[str, Any]: "_writer": self._writer.state_dict(), "_transforms": self._transform.state_dict(), "_batch_size": self._batch_size, + "_rng": (self._rng.get_state().clone(), str(self._rng.device)) + if self._rng is not None + else None, } def load_state_dict(self, state_dict: Dict[str, Any]) -> None: @@ -434,6 +440,12 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self._writer.load_state_dict(state_dict["_writer"]) self._transform.load_state_dict(state_dict["_transforms"]) self._batch_size = state_dict["_batch_size"] + rng = state_dict.get("_rng") + if rng is not None: + state, device = rng + rng = torch.Generator(device=device) + rng.set_state(state) + self.set_rng(generator=rng) def dumps(self, path): """Saves the replay buffer on disk at the specified path. @@ -477,6 +489,13 @@ def dumps(self, path): self._storage.dumps(path / "storage") self._sampler.dumps(path / "sampler") self._writer.dumps(path / "writer") + if self._rng is not None: + rng_state = TensorDict( + rng_state=self._rng.get_state().clone(), + device=self._rng.device, + ) + rng_state.memmap(path / "rng_state") + # fall back on state_dict for transforms transform_sd = self._transform.state_dict() if transform_sd: @@ -499,6 +518,11 @@ def loads(self, path): self._storage.loads(path / "storage") self._sampler.loads(path / "sampler") self._writer.loads(path / "writer") + if (path / "rng_state").exists(): + rng_state = TensorDict.load_memmap(path / "rng_state") + rng = torch.Generator(device=rng_state.device) + rng.set_state(rng_state["rng_state"]) + self.set_rng(rng) # fall back on state_dict for transforms if (path / "transform.t").exists(): self._transform.load_state_dict(torch.load(path / "transform.t")) @@ -765,6 +789,12 @@ def __iter__(self): def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() + if self._rng is not None: + rng_state = TensorDict( + rng_state=self._rng.get_state().clone(), + device=self._rng.device, + ) + state["_rng"] = rng_state _replay_lock = state.pop("_replay_lock", None) _futures_lock = state.pop("_futures_lock", None) if _replay_lock is not None: @@ -774,6 +804,12 @@ def __getstate__(self) -> Dict[str, Any]: return state def __setstate__(self, state: Dict[str, Any]): + if "_rng" in state: + rng = state["_rng"] + if rng is not None: + rng = torch.Generator(device=rng.device) + rng.set_state(rng["rng_state"]) + state["_rng"] = rng if "_replay_lock_placeholder" in state: state.pop("_replay_lock_placeholder") _replay_lock = threading.RLock() @@ -1482,12 +1518,18 @@ def update_tensordict_priority(self, data: TensorDictBase) -> None: class InPlaceSampler: """A sampler to write tennsordicts in-place. + .. warning:: This class is deprecated and will be removed in v0.7. + To be used cautiously as this may lead to unexpected behaviour (i.e. tensordicts overwritten during execution). """ def __init__(self, device: DEVICE_TYPING | None = None): + warnings.warn( + "InPlaceSampler has been deprecated and will be removed in v0.7.", + category=DeprecationWarning, + ) self.out = None if device is None: device = "cpu"