From 918bfe614b6b0312d0e4faf83d26503eca0ac622 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 8 Aug 2024 16:30:41 +0100 Subject: [PATCH] [Feature] RNG for RBs (#2379) --- test/test_rb.py | 152 +++++++++++++++--- torchrl/data/replay_buffers/replay_buffers.py | 82 ++++++++++ torchrl/data/replay_buffers/samplers.py | 46 +++++- torchrl/data/replay_buffers/storages.py | 27 +++- torchrl/data/replay_buffers/writers.py | 21 ++- 5 files changed, 295 insertions(+), 33 deletions(-) diff --git a/test/test_rb.py b/test/test_rb.py index e17cd410c49..4243917c627 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -109,6 +109,11 @@ ".".join([str(s) for s in version.parse(str(torch.__version__)).release]) ) >= version.parse("2.3.0") +ReplayBufferRNG = functools.partial(ReplayBuffer, generator=torch.Generator()) +TensorDictReplayBufferRNG = functools.partial( + TensorDictReplayBuffer, generator=torch.Generator() +) + @pytest.mark.parametrize( "sampler", @@ -125,17 +130,27 @@ "rb_type,storage,datatype", [ [ReplayBuffer, ListStorage, None], + [ReplayBufferRNG, ListStorage, None], [TensorDictReplayBuffer, ListStorage, "tensordict"], + [TensorDictReplayBufferRNG, ListStorage, "tensordict"], [RemoteTensorDictReplayBuffer, ListStorage, "tensordict"], [ReplayBuffer, LazyTensorStorage, "tensor"], [ReplayBuffer, LazyTensorStorage, "tensordict"], [ReplayBuffer, LazyTensorStorage, "pytree"], + [ReplayBufferRNG, LazyTensorStorage, "tensor"], + [ReplayBufferRNG, LazyTensorStorage, "tensordict"], + [ReplayBufferRNG, LazyTensorStorage, "pytree"], [TensorDictReplayBuffer, LazyTensorStorage, "tensordict"], + [TensorDictReplayBufferRNG, LazyTensorStorage, "tensordict"], [RemoteTensorDictReplayBuffer, LazyTensorStorage, "tensordict"], [ReplayBuffer, LazyMemmapStorage, "tensor"], [ReplayBuffer, LazyMemmapStorage, "tensordict"], [ReplayBuffer, LazyMemmapStorage, "pytree"], + [ReplayBufferRNG, LazyMemmapStorage, "tensor"], + [ReplayBufferRNG, LazyMemmapStorage, "tensordict"], + [ReplayBufferRNG, LazyMemmapStorage, "pytree"], [TensorDictReplayBuffer, LazyMemmapStorage, "tensordict"], + [TensorDictReplayBufferRNG, LazyMemmapStorage, "tensordict"], [RemoteTensorDictReplayBuffer, LazyMemmapStorage, "tensordict"], ], ) @@ -1155,17 +1170,115 @@ def test_replay_buffer_trajectories(stack, reduction, datatype): # sampled_td_filtered.batch_size = [3, 4] +class TestRNG: + def test_rb_rng(self): + state = torch.random.get_rng_state() + rb = ReplayBufferRNG(sampler=RandomSampler(), storage=LazyTensorStorage(100)) + rb.extend(torch.arange(100)) + rb._rng.set_state(state) + a = rb.sample(32) + rb._rng.set_state(state) + b = rb.sample(32) + assert (a == b).all() + c = rb.sample(32) + assert (a != c).any() + + def test_prb_rng(self): + state = torch.random.get_rng_state() + rb = ReplayBuffer( + sampler=PrioritizedSampler(100, 1.0, 1.0), + storage=LazyTensorStorage(100), + generator=torch.Generator(), + ) + rb.extend(torch.arange(100)) + rb.update_priority(index=torch.arange(100), priority=torch.arange(1, 101)) + + rb._rng.set_state(state) + a = rb.sample(32) + + rb._rng.set_state(state) + b = rb.sample(32) + assert (a == b).all() + + c = rb.sample(32) + assert (a != c).any() + + def test_slice_rng(self): + state = torch.random.get_rng_state() + rb = ReplayBuffer( + sampler=SliceSampler(num_slices=4), + storage=LazyTensorStorage(100), + generator=torch.Generator(), + ) + done = torch.zeros(100, 1, dtype=torch.bool) + done[49] = 1 + done[-1] = 1 + data = TensorDict( + { + "data": torch.arange(100), + ("next", "done"): done, + }, + batch_size=[100], + ) + rb.extend(data) + + rb._rng.set_state(state) + a = rb.sample(32) + + rb._rng.set_state(state) + b = rb.sample(32) + assert (a == b).all() + + c = rb.sample(32) + assert (a != c).any() + + def test_rng_state_dict(self): + state = torch.random.get_rng_state() + rb = ReplayBufferRNG(sampler=RandomSampler(), storage=LazyTensorStorage(100)) + rb.extend(torch.arange(100)) + rb._rng.set_state(state) + sd = rb.state_dict() + assert sd.get("_rng") is not None + a = rb.sample(32) + + rb.load_state_dict(sd) + b = rb.sample(32) + assert (a == b).all() + c = rb.sample(32) + assert (a != c).any() + + def test_rng_dumps(self, tmpdir): + state = torch.random.get_rng_state() + rb = ReplayBufferRNG(sampler=RandomSampler(), storage=LazyTensorStorage(100)) + rb.extend(torch.arange(100)) + rb._rng.set_state(state) + rb.dumps(tmpdir) + a = rb.sample(32) + + rb.loads(tmpdir) + b = rb.sample(32) + assert (a == b).all() + c = rb.sample(32) + assert (a != c).any() + + @pytest.mark.parametrize( "rbtype,storage", [ (ReplayBuffer, None), (ReplayBuffer, ListStorage), + (ReplayBufferRNG, None), + (ReplayBufferRNG, ListStorage), (PrioritizedReplayBuffer, None), (PrioritizedReplayBuffer, ListStorage), (TensorDictReplayBuffer, None), (TensorDictReplayBuffer, ListStorage), (TensorDictReplayBuffer, LazyTensorStorage), (TensorDictReplayBuffer, LazyMemmapStorage), + (TensorDictReplayBufferRNG, None), + (TensorDictReplayBufferRNG, ListStorage), + (TensorDictReplayBufferRNG, LazyTensorStorage), + (TensorDictReplayBufferRNG, LazyMemmapStorage), (TensorDictPrioritizedReplayBuffer, None), (TensorDictPrioritizedReplayBuffer, ListStorage), (TensorDictPrioritizedReplayBuffer, LazyTensorStorage), @@ -1175,33 +1288,34 @@ def test_replay_buffer_trajectories(stack, reduction, datatype): @pytest.mark.parametrize("size", [3, 5, 100]) @pytest.mark.parametrize("prefetch", [0]) class TestBuffers: - _default_params_rb = {} - _default_params_td_rb = {} - _default_params_prb = {"alpha": 0.8, "beta": 0.9} - _default_params_td_prb = {"alpha": 0.8, "beta": 0.9} + + default_constr = { + ReplayBuffer: ReplayBuffer, + PrioritizedReplayBuffer: functools.partial( + PrioritizedReplayBuffer, alpha=0.8, beta=0.9 + ), + TensorDictReplayBuffer: TensorDictReplayBuffer, + TensorDictPrioritizedReplayBuffer: functools.partial( + TensorDictPrioritizedReplayBuffer, alpha=0.8, beta=0.9 + ), + TensorDictReplayBufferRNG: TensorDictReplayBufferRNG, + ReplayBufferRNG: ReplayBufferRNG, + } def _get_rb(self, rbtype, size, storage, prefetch): if storage is not None: storage = storage(size) - if rbtype is ReplayBuffer: - params = self._default_params_rb - elif rbtype is PrioritizedReplayBuffer: - params = self._default_params_prb - elif rbtype is TensorDictReplayBuffer: - params = self._default_params_td_rb - elif rbtype is TensorDictPrioritizedReplayBuffer: - params = self._default_params_td_prb - else: - raise NotImplementedError(rbtype) - rb = rbtype(storage=storage, prefetch=prefetch, batch_size=3, **params) + rb = self.default_constr[rbtype]( + storage=storage, prefetch=prefetch, batch_size=3 + ) return rb def _get_datum(self, rbtype): - if rbtype is ReplayBuffer: + if rbtype in (ReplayBuffer, ReplayBufferRNG): data = torch.randint(100, (1,)) elif rbtype is PrioritizedReplayBuffer: data = torch.randint(100, (1,)) - elif rbtype is TensorDictReplayBuffer: + elif rbtype in (TensorDictReplayBuffer, TensorDictReplayBufferRNG): data = TensorDict({"a": torch.randint(100, (1,))}, []) elif rbtype is TensorDictPrioritizedReplayBuffer: data = TensorDict({"a": torch.randint(100, (1,))}, []) @@ -1210,11 +1324,11 @@ def _get_datum(self, rbtype): return data def _get_data(self, rbtype, size): - if rbtype is ReplayBuffer: + if rbtype in (ReplayBuffer, ReplayBufferRNG): data = [torch.randint(100, (1,)) for _ in range(size)] elif rbtype is PrioritizedReplayBuffer: data = [torch.randint(100, (1,)) for _ in range(size)] - elif rbtype is TensorDictReplayBuffer: + elif rbtype in (TensorDictReplayBuffer, TensorDictReplayBufferRNG): data = TensorDict( { "a": torch.randint(100, (size,)), diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index a688bd8585e..fafc120fe94 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -23,6 +23,7 @@ is_tensorclass, LazyStackedTensorDict, NestedKey, + TensorDict, TensorDictBase, unravel_key, ) @@ -120,7 +121,13 @@ class ReplayBuffer: >>> for d in data.unbind(1): ... rb.add(d) >>> rb.extend(data) + generator (torch.Generator, optional): a generator to use for sampling. + Using a dedicated generator for the replay buffer can allow a fine-grained control + over seeding, for instance keeping the global seed different but the RB seed identical + for distributed jobs. + Defaults to ``None`` (global default generator). + .. warning:: As of now, the generator has no effect on the transforms. Examples: >>> import torch @@ -204,6 +211,7 @@ def __init__( batch_size: int | None = None, dim_extend: int | None = None, checkpointer: "StorageCheckpointerBase" | None = None, # noqa: F821 + generator: torch.Generator | None = None, ) -> None: self._storage = storage if storage is not None else ListStorage(max_size=1_000) self._storage.attach(self) @@ -262,6 +270,13 @@ def __init__( raise ValueError("dim_extend must be a positive value.") self.dim_extend = dim_extend self._storage.checkpointer = checkpointer + self.set_rng(generator=generator) + + def set_rng(self, generator): + self._rng = generator + self._storage._rng = generator + self._sampler._rng = generator + self._writer._rng = generator @property def dim_extend(self): @@ -414,6 +429,9 @@ def state_dict(self) -> Dict[str, Any]: "_writer": self._writer.state_dict(), "_transforms": self._transform.state_dict(), "_batch_size": self._batch_size, + "_rng": (self._rng.get_state().clone(), str(self._rng.device)) + if self._rng is not None + else None, } def load_state_dict(self, state_dict: Dict[str, Any]) -> None: @@ -422,6 +440,12 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self._writer.load_state_dict(state_dict["_writer"]) self._transform.load_state_dict(state_dict["_transforms"]) self._batch_size = state_dict["_batch_size"] + rng = state_dict.get("_rng") + if rng is not None: + state, device = rng + rng = torch.Generator(device=device) + rng.set_state(state) + self.set_rng(generator=rng) def dumps(self, path): """Saves the replay buffer on disk at the specified path. @@ -465,6 +489,13 @@ def dumps(self, path): self._storage.dumps(path / "storage") self._sampler.dumps(path / "sampler") self._writer.dumps(path / "writer") + if self._rng is not None: + rng_state = TensorDict( + rng_state=self._rng.get_state().clone(), + device=self._rng.device, + ) + rng_state.memmap(path / "rng_state") + # fall back on state_dict for transforms transform_sd = self._transform.state_dict() if transform_sd: @@ -487,6 +518,11 @@ def loads(self, path): self._storage.loads(path / "storage") self._sampler.loads(path / "sampler") self._writer.loads(path / "writer") + if (path / "rng_state").exists(): + rng_state = TensorDict.load_memmap(path / "rng_state") + rng = torch.Generator(device=rng_state.device) + rng.set_state(rng_state["rng_state"]) + self.set_rng(rng) # fall back on state_dict for transforms if (path / "transform.t").exists(): self._transform.load_state_dict(torch.load(path / "transform.t")) @@ -753,6 +789,12 @@ def __iter__(self): def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() + if self._rng is not None: + rng_state = TensorDict( + rng_state=self._rng.get_state().clone(), + device=self._rng.device, + ) + state["_rng"] = rng_state _replay_lock = state.pop("_replay_lock", None) _futures_lock = state.pop("_futures_lock", None) if _replay_lock is not None: @@ -762,6 +804,13 @@ def __getstate__(self) -> Dict[str, Any]: return state def __setstate__(self, state: Dict[str, Any]): + rngstate = None + if "_rng" in state: + rngstate = state["_rng"] + if rngstate is not None: + rng = torch.Generator(device=rngstate.device) + rng.set_state(rngstate["rng_state"]) + if "_replay_lock_placeholder" in state: state.pop("_replay_lock_placeholder") _replay_lock = threading.RLock() @@ -771,6 +820,8 @@ def __setstate__(self, state: Dict[str, Any]): _futures_lock = threading.RLock() state["_futures_lock"] = _futures_lock self.__dict__.update(state) + if rngstate is not None: + self.set_rng(rng) @property def sampler(self): @@ -995,6 +1046,13 @@ class TensorDictReplayBuffer(ReplayBuffer): >>> for d in data.unbind(1): ... rb.add(d) >>> rb.extend(data) + generator (torch.Generator, optional): a generator to use for sampling. + Using a dedicated generator for the replay buffer can allow a fine-grained control + over seeding, for instance keeping the global seed different but the RB seed identical + for distributed jobs. + Defaults to ``None`` (global default generator). + + .. warning:: As of now, the generator has no effect on the transforms. Examples: >>> import torch @@ -1327,6 +1385,13 @@ class TensorDictPrioritizedReplayBuffer(TensorDictReplayBuffer): >>> for d in data.unbind(1): ... rb.add(d) >>> rb.extend(data) + generator (torch.Generator, optional): a generator to use for sampling. + Using a dedicated generator for the replay buffer can allow a fine-grained control + over seeding, for instance keeping the global seed different but the RB seed identical + for distributed jobs. + Defaults to ``None`` (global default generator). + + .. warning:: As of now, the generator has no effect on the transforms. Examples: >>> import torch @@ -1400,6 +1465,7 @@ def __init__( reduction: str = "max", batch_size: int | None = None, dim_extend: int | None = None, + generator: torch.Generator | None = None, ) -> None: if storage is None: storage = ListStorage(max_size=1_000) @@ -1416,6 +1482,7 @@ def __init__( transform=transform, batch_size=batch_size, dim_extend=dim_extend, + generator=generator, ) @@ -1454,12 +1521,18 @@ def update_tensordict_priority(self, data: TensorDictBase) -> None: class InPlaceSampler: """A sampler to write tennsordicts in-place. + .. warning:: This class is deprecated and will be removed in v0.7. + To be used cautiously as this may lead to unexpected behaviour (i.e. tensordicts overwritten during execution). """ def __init__(self, device: DEVICE_TYPING | None = None): + warnings.warn( + "InPlaceSampler has been deprecated and will be removed in v0.7.", + category=DeprecationWarning, + ) self.out = None if device is None: device = "cpu" @@ -1555,6 +1628,13 @@ class ReplayBufferEnsemble(ReplayBuffer): sampled according to the probabilities ``p``. Can also be passed to torchrl.data.replay_buffers.samplers.SamplerEnsemble` if the buffer is built explicitely. + generator (torch.Generator, optional): a generator to use for sampling. + Using a dedicated generator for the replay buffer can allow a fine-grained control + over seeding, for instance keeping the global seed different but the RB seed identical + for distributed jobs. + Defaults to ``None`` (global default generator). + + .. warning:: As of now, the generator has no effect on the transforms. Examples: >>> from torchrl.envs import Compose, ToTensorImage, Resize, RenameTransform @@ -1644,6 +1724,7 @@ def __init__( p: Tensor = None, sample_from_all: bool = False, num_buffer_sampled: int | None = None, + generator: torch.Generator | None = None, **kwargs, ): @@ -1680,6 +1761,7 @@ def __init__( transform=transform, batch_size=batch_size, collate_fn=collate_fn, + generator=generator, **kwargs, ) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 582ac88f52d..8e9cf2d695b 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -46,6 +46,9 @@ class Sampler(ABC): # need to keep track of the number of remaining batches _remaining_batches = int(torch.iinfo(torch.int64).max) + # The RNG is set by the replay buffer + _rng: torch.Generator | None = None + @abstractmethod def sample(self, storage: Storage, batch_size: int) -> Tuple[Any, dict]: ... @@ -105,6 +108,11 @@ def loads(self, path): def __repr__(self): return f"{self.__class__.__name__}()" + def __getstate__(self): + state = copy(self.__dict__) + state["_rng"] = None + return state + class RandomSampler(Sampler): """A uniformly random sampler for composable replay buffers. @@ -192,7 +200,9 @@ def _get_sample_list(self, storage: Storage, len_storage: int, batch_size: int): device = storage.device if hasattr(storage, "device") else None if self.shuffle: - _sample_list = torch.randperm(len_storage, device=device) + _sample_list = torch.randperm( + len_storage, device=device, generator=self._rng + ) else: _sample_list = torch.arange(len_storage, device=device) self._sample_list = _sample_list @@ -390,8 +400,7 @@ def __getstate__(self): raise RuntimeError( f"Samplers of type {type(self)} cannot be shared between processes." ) - state = copy(self.__dict__) - return state + return super().__getstate__() def _init(self): if self.dtype in (torch.float, torch.FloatType, torch.float32): @@ -473,7 +482,11 @@ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor: raise RuntimeError("non-positive p_min") # For some undefined reason, only np.random works here. # All PT attempts fail, even when subsequently transformed into numpy - mass = np.random.uniform(0.0, p_sum, size=batch_size) + if self._rng is None: + mass = np.random.uniform(0.0, p_sum, size=batch_size) + else: + mass = torch.rand(batch_size, generator=self._rng) * p_sum + # mass = torch.zeros(batch_size, dtype=torch.double).uniform_(0.0, p_sum) # mass = torch.rand(batch_size).mul_(p_sum) index = self._sum_tree.scan_lower_bound(mass) @@ -929,7 +942,7 @@ def __getstate__(self): f"one process will NOT erase the cache on another process's sampler, " f"which will cause synchronization issues." ) - state = copy(self.__dict__) + state = super().__getstate__() state["_cache"] = {} return state @@ -1187,7 +1200,9 @@ def _sample_slices( # start_idx and stop_idx are 2d tensors organized like a non-zero def get_traj_idx(maxval): - return torch.randint(maxval, (num_slices,), device=lengths.device) + return torch.randint( + maxval, (num_slices,), device=lengths.device, generator=self._rng + ) if (lengths < seq_length).any(): if self.strict_length: @@ -1290,7 +1305,8 @@ def _get_index( start_point = -span_right relative_starts = ( - torch.rand(num_slices, device=lengths.device) * (end_point - start_point) + torch.rand(num_slices, device=lengths.device, generator=self._rng) + * (end_point - start_point) ).floor().to(start_idx.dtype) + start_point if self.span[0]: @@ -1800,6 +1816,7 @@ def __repr__(self): def __getstate__(self): state = SliceSampler.__getstate__(self) state.update(PrioritizedSampler.__getstate__(self)) + return state def mark_update( self, index: Union[int, torch.Tensor], *, storage: Storage | None = None @@ -2033,6 +2050,7 @@ class SamplerEnsemble(Sampler): def __init__( self, *samplers, p=None, sample_from_all=False, num_buffer_sampled=None ): + self._rng_private = None self._samplers = samplers self.sample_from_all = sample_from_all if sample_from_all and p is not None: @@ -2042,6 +2060,16 @@ def __init__( self.p = p self.num_buffer_sampled = num_buffer_sampled + @property + def _rng(self): + return self._rng_private + + @_rng.setter + def _rng(self, value): + self._rng_private = value + for sampler in self._samplers: + sampler._rng = value + @property def p(self): return self._p @@ -2082,7 +2110,9 @@ def sample(self, storage, batch_size): else: if self.p is None: buffer_ids = torch.randint( - len(self._samplers), (self.num_buffer_sampled,) + len(self._samplers), + (self.num_buffer_sampled,), + generator=self._rng, ) else: buffer_ids = torch.multinomial(self.p, self.num_buffer_sampled, True) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 58b1729296d..04cc63e231d 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -54,6 +54,7 @@ class Storage: ndim = 1 max_size: int _default_checkpointer: StorageCheckpointerBase = StorageCheckpointerBase + _rng: torch.Generator | None = None def __init__( self, max_size: int, checkpointer: StorageCheckpointerBase | None = None @@ -142,7 +143,7 @@ def _empty(self): def _rand_given_ndim(self, batch_size): # a method to return random indices given the storage ndim if self.ndim == 1: - return torch.randint(0, len(self), (batch_size,)) + return torch.randint(0, len(self), (batch_size,), generator=self._rng) raise RuntimeError( f"Random number generation is not implemented for storage of type {type(self)} with ndim {self.ndim}. " f"Please report this exception as well as the use case (incl. buffer construction) on github." @@ -185,6 +186,11 @@ def load(self, *args, **kwargs): """Alias for :meth:`~.loads`.""" return self.loads(*args, **kwargs) + def __getstate__(self): + state = copy(self.__dict__) + state["_rng"] = None + return state + class ListStorage(Storage): """A storage stored in a list. @@ -299,7 +305,7 @@ def __getstate__(self): raise RuntimeError( f"Cannot share a storage of type {type(self)} between processes." ) - state = copy(self.__dict__) + state = super().__getstate__() return state def __repr__(self): @@ -497,7 +503,9 @@ def _rand_given_ndim(self, batch_size): if self.ndim == 1: return super()._rand_given_ndim(batch_size) shape = self.shape - return tuple(torch.randint(_dim, (batch_size,)) for _dim in shape) + return tuple( + torch.randint(_dim, (batch_size,), generator=self._rng) for _dim in shape + ) def flatten(self): if self.ndim == 1: @@ -522,7 +530,7 @@ def flatten(self): ) def __getstate__(self): - state = copy(self.__dict__) + state = super().__getstate__() if get_spawning_popen() is None: length = self._len del state["_len_value"] @@ -1142,6 +1150,7 @@ def __init__( *storages: Storage, transforms: List["Transform"] = None, # noqa: F821 ): + self._rng_private = None self._storages = storages self._transforms = transforms if transforms is not None and len(transforms) != len(storages): @@ -1149,6 +1158,16 @@ def __init__( "transforms must have the same length as the storages " "provided." ) + @property + def _rng(self): + return self._rng_private + + @_rng.setter + def _rng(self, value): + self._rng_private = value + for storage in self._storages: + storage._rng = value + @property def _attached_entities(self): return set() diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index ea3b2b4a047..066658993b1 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -38,6 +38,7 @@ class Writer(ABC): """A ReplayBuffer base Writer class.""" _storage: Storage + _rng: torch.Generator | None = None def __init__(self) -> None: self._storage = None @@ -103,6 +104,11 @@ def _replicate_index(self, index): def __repr__(self): return f"{self.__class__.__name__}()" + def __getstate__(self): + state = copy(self.__dict__) + state["_rng"] = None + return state + class ImmutableDatasetWriter(Writer): """A blocking writer for immutable datasets.""" @@ -217,7 +223,7 @@ def _cursor(self, value): _cursor_value.value = value def __getstate__(self): - state = copy(self.__dict__) + state = super().__getstate__() if get_spawning_popen() is None: cursor = self._cursor del state["_cursor_value"] @@ -513,7 +519,7 @@ def __getstate__(self): raise RuntimeError( f"Writers of type {type(self)} cannot be shared between processes." ) - state = copy(self.__dict__) + state = super().__getstate__() return state def dumps(self, path): @@ -582,8 +588,19 @@ class WriterEnsemble(Writer): """ def __init__(self, *writers): + self._rng_private = None self._writers = writers + @property + def _rng(self): + return self._rng_private + + @_rng.setter + def _rng(self, value): + self._rng_private = value + for writer in self._writers: + writer._rng = value + def _empty(self): raise NotImplementedError