diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 7c3a2a950a4..fafc120fe94 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -804,12 +804,13 @@ def __getstate__(self) -> Dict[str, Any]: return state def __setstate__(self, state: Dict[str, Any]): + rngstate = None if "_rng" in state: - rng = state["_rng"] - if rng is not None: - rng = torch.Generator(device=rng.device) - rng.set_state(rng["rng_state"]) - state["_rng"] = rng + rngstate = state["_rng"] + if rngstate is not None: + rng = torch.Generator(device=rngstate.device) + rng.set_state(rngstate["rng_state"]) + if "_replay_lock_placeholder" in state: state.pop("_replay_lock_placeholder") _replay_lock = threading.RLock() @@ -819,6 +820,8 @@ def __setstate__(self, state: Dict[str, Any]): _futures_lock = threading.RLock() state["_futures_lock"] = _futures_lock self.__dict__.update(state) + if rngstate is not None: + self.set_rng(rng) @property def sampler(self): diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index c34c964c808..8e9cf2d695b 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -108,6 +108,11 @@ def loads(self, path): def __repr__(self): return f"{self.__class__.__name__}()" + def __getstate__(self): + state = copy(self.__dict__) + state["_rng"] = None + return state + class RandomSampler(Sampler): """A uniformly random sampler for composable replay buffers. @@ -395,8 +400,7 @@ def __getstate__(self): raise RuntimeError( f"Samplers of type {type(self)} cannot be shared between processes." ) - state = copy(self.__dict__) - return state + return super().__getstate__() def _init(self): if self.dtype in (torch.float, torch.FloatType, torch.float32): @@ -938,7 +942,7 @@ def __getstate__(self): f"one process will NOT erase the cache on another process's sampler, " f"which will cause synchronization issues." ) - state = copy(self.__dict__) + state = super().__getstate__() state["_cache"] = {} return state @@ -1812,6 +1816,7 @@ def __repr__(self): def __getstate__(self): state = SliceSampler.__getstate__(self) state.update(PrioritizedSampler.__getstate__(self)) + return state def mark_update( self, index: Union[int, torch.Tensor], *, storage: Storage | None = None diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 7c73ba4ed67..04cc63e231d 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -186,6 +186,11 @@ def load(self, *args, **kwargs): """Alias for :meth:`~.loads`.""" return self.loads(*args, **kwargs) + def __getstate__(self): + state = copy(self.__dict__) + state["_rng"] = None + return state + class ListStorage(Storage): """A storage stored in a list. @@ -300,7 +305,7 @@ def __getstate__(self): raise RuntimeError( f"Cannot share a storage of type {type(self)} between processes." ) - state = copy(self.__dict__) + state = super().__getstate__() return state def __repr__(self): @@ -525,7 +530,7 @@ def flatten(self): ) def __getstate__(self): - state = copy(self.__dict__) + state = super().__getstate__() if get_spawning_popen() is None: length = self._len del state["_len_value"] diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index e94117153b6..066658993b1 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -104,6 +104,11 @@ def _replicate_index(self, index): def __repr__(self): return f"{self.__class__.__name__}()" + def __getstate__(self): + state = copy(self.__dict__) + state["_rng"] = None + return state + class ImmutableDatasetWriter(Writer): """A blocking writer for immutable datasets.""" @@ -218,7 +223,7 @@ def _cursor(self, value): _cursor_value.value = value def __getstate__(self): - state = copy(self.__dict__) + state = super().__getstate__() if get_spawning_popen() is None: cursor = self._cursor del state["_cursor_value"] @@ -514,7 +519,7 @@ def __getstate__(self): raise RuntimeError( f"Writers of type {type(self)} cannot be shared between processes." ) - state = copy(self.__dict__) + state = super().__getstate__() return state def dumps(self, path):