Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Aug 8, 2024
1 parent aa06172 commit 1fe487e
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 12 deletions.
13 changes: 8 additions & 5 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,12 +804,13 @@ def __getstate__(self) -> Dict[str, Any]:
return state

def __setstate__(self, state: Dict[str, Any]):
rngstate = None
if "_rng" in state:
rng = state["_rng"]
if rng is not None:
rng = torch.Generator(device=rng.device)
rng.set_state(rng["rng_state"])
state["_rng"] = rng
rngstate = state["_rng"]
if rngstate is not None:
rng = torch.Generator(device=rngstate.device)
rng.set_state(rngstate["rng_state"])

if "_replay_lock_placeholder" in state:
state.pop("_replay_lock_placeholder")
_replay_lock = threading.RLock()
Expand All @@ -819,6 +820,8 @@ def __setstate__(self, state: Dict[str, Any]):
_futures_lock = threading.RLock()
state["_futures_lock"] = _futures_lock
self.__dict__.update(state)
if rngstate is not None:
self.set_rng(rng)

@property
def sampler(self):
Expand Down
11 changes: 8 additions & 3 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ def loads(self, path):
def __repr__(self):
return f"{self.__class__.__name__}()"

def __getstate__(self):
state = copy(self.__dict__)
state["_rng"] = None
return state


class RandomSampler(Sampler):
"""A uniformly random sampler for composable replay buffers.
Expand Down Expand Up @@ -395,8 +400,7 @@ def __getstate__(self):
raise RuntimeError(
f"Samplers of type {type(self)} cannot be shared between processes."
)
state = copy(self.__dict__)
return state
return super().__getstate__()

def _init(self):
if self.dtype in (torch.float, torch.FloatType, torch.float32):
Expand Down Expand Up @@ -938,7 +942,7 @@ def __getstate__(self):
f"one process will NOT erase the cache on another process's sampler, "
f"which will cause synchronization issues."
)
state = copy(self.__dict__)
state = super().__getstate__()
state["_cache"] = {}
return state

Expand Down Expand Up @@ -1812,6 +1816,7 @@ def __repr__(self):
def __getstate__(self):
state = SliceSampler.__getstate__(self)
state.update(PrioritizedSampler.__getstate__(self))
return state

def mark_update(
self, index: Union[int, torch.Tensor], *, storage: Storage | None = None
Expand Down
9 changes: 7 additions & 2 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,11 @@ def load(self, *args, **kwargs):
"""Alias for :meth:`~.loads`."""
return self.loads(*args, **kwargs)

def __getstate__(self):
state = copy(self.__dict__)
state["_rng"] = None
return state


class ListStorage(Storage):
"""A storage stored in a list.
Expand Down Expand Up @@ -300,7 +305,7 @@ def __getstate__(self):
raise RuntimeError(
f"Cannot share a storage of type {type(self)} between processes."
)
state = copy(self.__dict__)
state = super().__getstate__()
return state

def __repr__(self):
Expand Down Expand Up @@ -525,7 +530,7 @@ def flatten(self):
)

def __getstate__(self):
state = copy(self.__dict__)
state = super().__getstate__()
if get_spawning_popen() is None:
length = self._len
del state["_len_value"]
Expand Down
9 changes: 7 additions & 2 deletions torchrl/data/replay_buffers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ def _replicate_index(self, index):
def __repr__(self):
return f"{self.__class__.__name__}()"

def __getstate__(self):
state = copy(self.__dict__)
state["_rng"] = None
return state


class ImmutableDatasetWriter(Writer):
"""A blocking writer for immutable datasets."""
Expand Down Expand Up @@ -218,7 +223,7 @@ def _cursor(self, value):
_cursor_value.value = value

def __getstate__(self):
state = copy(self.__dict__)
state = super().__getstate__()
if get_spawning_popen() is None:
cursor = self._cursor
del state["_cursor_value"]
Expand Down Expand Up @@ -514,7 +519,7 @@ def __getstate__(self):
raise RuntimeError(
f"Writers of type {type(self)} cannot be shared between processes."
)
state = copy(self.__dict__)
state = super().__getstate__()
return state

def dumps(self, path):
Expand Down

0 comments on commit 1fe487e

Please sign in to comment.