Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Pass replay buffers to MultiaSyncDataCollector #2387

Merged
merged 13 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2830,6 +2830,31 @@ def test_collector_rb_multisync(self):
collector.shutdown()
assert len(rb) == 256

@pytest.mark.skipif(not _has_gym, reason="requires gym.")
def test_collector_rb_multiasync(self):
env = GymEnv(CARTPOLE_VERSIONED())
env.set_seed(0)

rb = ReplayBuffer(storage=LazyTensorStorage(256), batch_size=5)
rb.add(env.rand_step(env.reset()))
rb.empty()

collector = MultiaSyncDataCollector(
[lambda: env, lambda: env],
RandomPolicy(env.action_spec),
replay_buffer=rb,
total_frames=256,
frames_per_batch=16,
)
torch.manual_seed(0)
pred_len = 0
for c in collector:
pred_len += 16
assert c is None
assert len(rb) >= pred_len
collector.shutdown()
assert len(rb) == 256


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
21 changes: 13 additions & 8 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import sys
import time
import warnings
from collections import OrderedDict
from collections import defaultdict, OrderedDict
from copy import deepcopy
from multiprocessing import connection, queues
from multiprocessing.managers import SyncManager
Expand Down Expand Up @@ -2433,7 +2433,7 @@ class MultiaSyncDataCollector(_MultiDataCollector):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.out_tensordicts = {}
self.out_tensordicts = defaultdict(lambda: None)
self.running = False

if self.postprocs is not None:
Expand Down Expand Up @@ -2478,7 +2478,9 @@ def frames_per_batch_worker(self):
def _get_from_queue(self, timeout=None) -> Tuple[int, int, TensorDictBase]:
new_data, j = self.queue_out.get(timeout=timeout)
use_buffers = self._use_buffers
if j == 0 or not use_buffers:
if self.replay_buffer is not None:
idx = new_data
elif j == 0 or not use_buffers:
try:
data, idx = new_data
self.out_tensordicts[idx] = data
Expand All @@ -2493,7 +2495,7 @@ def _get_from_queue(self, timeout=None) -> Tuple[int, int, TensorDictBase]:
else:
idx = new_data
out = self.out_tensordicts[idx]
if j == 0 or use_buffers:
if not self.replay_buffer and (j == 0 or use_buffers):
# we clone the data to make sure that we'll be working with a fixed copy
out = out.clone()
return idx, j, out
Expand All @@ -2518,9 +2520,12 @@ def iterator(self) -> Iterator[TensorDictBase]:
_check_for_faulty_process(self.procs)
self._iter += 1
idx, j, out = self._get_from_queue()
worker_frames = out.numel()
if self.split_trajs:
out = split_trajectories(out, prefix="collector")
if self.replay_buffer is None:
worker_frames = out.numel()
if self.split_trajs:
out = split_trajectories(out, prefix="collector")
else:
worker_frames = self.frames_per_batch_worker
self._frames += worker_frames
workers_frames[idx] = workers_frames[idx] + worker_frames
if self.postprocs:
Expand All @@ -2536,7 +2541,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
else:
msg = "continue"
self.pipes[idx].send((idx, msg))
if self._exclude_private_keys:
if out is not None and self._exclude_private_keys:
excluded_keys = [key for key in out.keys() if key.startswith("_")]
out = out.exclude(*excluded_keys)
yield out
Expand Down
Loading