Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Aug 10, 2024
1 parent 4499d50 commit 81ee73b
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 62 deletions.
25 changes: 25 additions & 0 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import gc

import sys
import time

import numpy as np
import pytest
Expand Down Expand Up @@ -2797,6 +2798,30 @@ def test_collector_rb_sync(self):
del collector, env
assert assert_allclose_td(rbdata0, rbdata1)

def test_collector_rb_multisync(self):
env = GymEnv(CARTPOLE_VERSIONED())
env.set_seed(0)

rb = ReplayBuffer(storage=LazyTensorStorage(256), batch_size=5)
rb.add(env.rand_step(env.reset()))
rb.empty()

collector = MultiSyncDataCollector(
[lambda: env, lambda: env],
RandomPolicy(env.action_spec),
replay_buffer=rb,
total_frames=256,
frames_per_batch=16,
)
torch.manual_seed(0)
pred_len = 0
for c in collector:
pred_len += 16
assert c is None
assert len(rb) == pred_len
collector.shutdown()
assert len(rb) == 256


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
139 changes: 89 additions & 50 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ def __init__(
set_truncated: bool = False,
use_buffers: bool | None = None,
replay_buffer: ReplayBuffer | None = None,
**kwargs,
):
from torchrl.envs.batched_envs import BatchedEnvBase

Expand All @@ -476,6 +477,14 @@ def __init__(

policy = RandomPolicy(env.full_action_spec)

##########################
# Trajectory pool
self._traj_pool_val = kwargs.pop("traj_pool", None)
if kwargs:
raise TypeError(
f"Keys {list(kwargs.keys())} are unknown to {type(self).__name__}."
)

##########################
# Setting devices:
# The rule is the following:
Expand Down Expand Up @@ -554,6 +563,7 @@ def __init__(
self.replay_buffer = replay_buffer

self.closed = False

if not reset_when_done:
raise ValueError("reset_when_done is deprectated.")
self.reset_when_done = reset_when_done
Expand Down Expand Up @@ -667,6 +677,13 @@ def __init__(
self._frames = 0
self._iter = -1

@property
def _traj_pool(self):
pool = getattr(self, "_traj_pool_val", None)
if pool is None:
pool = self._traj_pool_val = _TrajectoryPool()
return pool

def _make_shuttle(self):
# Shuttle is a deviceless tensordict that just carried data from env to policy and policy to env
with torch.no_grad():
Expand All @@ -677,9 +694,9 @@ def _make_shuttle(self):
else:
self._shuttle_has_no_device = False

traj_ids = torch.arange(self.n_env, device=self.storing_device).view(
self.env.batch_size
)
traj_ids = self._traj_pool.get_traj_and_increment(
self.n_env, device=self.storing_device
).view(self.env.batch_size)
self._shuttle.set(
("collector", "traj_ids"),
traj_ids,
Expand Down Expand Up @@ -999,11 +1016,12 @@ def _update_traj_ids(self, env_output) -> None:
if traj_sop.any():
traj_ids = self._shuttle.get(("collector", "traj_ids"))
traj_sop = traj_sop.to(self.storing_device)
traj_ids = traj_ids.clone().to(self.storing_device)
traj_ids[traj_sop] = traj_ids.max() + torch.arange(
1,
traj_sop.sum() + 1,
device=self.storing_device,
pool = self._traj_pool
new_traj = pool.get_traj_and_increment(
traj_sop.sum(), device=self.storing_device
)
traj_ids = traj_ids.to(self.storing_device).masked_scatter(
traj_sop, new_traj
)
self._shuttle.set(("collector", "traj_ids"), traj_ids)

Expand Down Expand Up @@ -1410,6 +1428,8 @@ class _MultiDataCollector(DataCollectorBase):
use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data.
This isn't compatible with environments with dynamic specs. Defaults to ``True``
for envs without dynamic specs, ``False`` for others.
replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordict
but populate the buffer instead. Defaults to ``None``.
"""

Expand Down Expand Up @@ -1445,6 +1465,7 @@ def __init__(
cat_results: str | int | None = None,
set_truncated: bool = False,
use_buffers: bool | None = None,
replay_buffer: ReplayBuffer | None = None,
):
exploration_type = _convert_exploration_type(
exploration_mode=exploration_mode, exploration_type=exploration_type
Expand Down Expand Up @@ -1488,6 +1509,13 @@ def __init__(
del storing_device, env_device, policy_device, device

self._use_buffers = use_buffers
self.replay_buffer = replay_buffer
if (
replay_buffer is not None
and hasattr(replay_buffer, "shared")
and not replay_buffer.shared
):
replay_buffer.share()

_policy_weights_dict = {}
_get_weights_fn_dict = {}
Expand Down Expand Up @@ -1724,6 +1752,7 @@ def _run_processes(self) -> None:
queue_out = mp.Queue(self._queue_len) # sends data from proc to main
self.procs = []
self.pipes = []
traj_pool = _TrajectoryPool(lock=True)
for i, (env_fun, env_fun_kwargs) in enumerate(
zip(self.create_env_fn, self.create_env_kwargs)
):
Expand Down Expand Up @@ -1760,6 +1789,8 @@ def _run_processes(self) -> None:
"interruptor": self.interruptor,
"set_truncated": self.set_truncated,
"use_buffers": self._use_buffers,
"replay_buffer": self.replay_buffer,
"traj_pool": traj_pool,
}
proc = _ProcessNoWarn(
target=_main_async_collector,
Expand Down Expand Up @@ -2118,10 +2149,6 @@ def iterator(self) -> Iterator[TensorDictBase]:
workers_frames = [0 for _ in range(self.num_workers)]
same_device = None
self.out_buffer = None
last_traj_ids = [-10 for _ in range(self.num_workers)]
last_traj_ids_subs = [None for _ in range(self.num_workers)]
traj_max = -1
traj_ids_list = [None for _ in range(self.num_workers)]
preempt = self.interruptor is not None and self.preemptive_threshold < 1.0

while not all(dones) and self._frames < self.total_frames:
Expand Down Expand Up @@ -2155,7 +2182,13 @@ def iterator(self) -> Iterator[TensorDictBase]:
for _ in range(self.num_workers):
new_data, j = self.queue_out.get()
use_buffers = self._use_buffers
if j == 0 or not use_buffers:
if self.replay_buffer is not None:
idx = new_data
workers_frames[idx] = (
workers_frames[idx] + self.frames_per_batch_worker
)
continue
elif j == 0 or not use_buffers:
try:
data, idx = new_data
self.buffers[idx] = data
Expand Down Expand Up @@ -2197,51 +2230,25 @@ def iterator(self) -> Iterator[TensorDictBase]:
if workers_frames[idx] >= self.total_frames:
dones[idx] = True

if self.replay_buffer is not None:
yield
self._frames += self.frames_per_batch_worker * self.num_workers
continue

# we have to correct the traj_ids to make sure that they don't overlap
# We can count the number of frames collected for free in this loop
n_collected = 0
for idx in range(self.num_workers):
buffer = buffers[idx]
traj_ids = buffer.get(("collector", "traj_ids"))
is_last = traj_ids == last_traj_ids[idx]
# If we `cat` interrupted data, we have already filtered out
# non-valid steps. If we stack, we haven't.
if preempt and cat_results == "stack":
valid = buffer.get(("collector", "traj_ids")) != -1
if valid.ndim > 2:
valid = valid.flatten(0, -2)
if valid.ndim == 2:
valid = valid.any(0)
last_traj_ids[idx] = traj_ids[..., valid][..., -1:].clone()
else:
last_traj_ids[idx] = traj_ids[..., -1:].clone()
if not is_last.all():
traj_to_correct = traj_ids[~is_last]
traj_to_correct = (
traj_to_correct + (traj_max + 1) - traj_to_correct.min()
)
traj_ids = traj_ids.masked_scatter(~is_last, traj_to_correct)
# is_last can only be true if we're after the first iteration
if is_last.any():
traj_ids = torch.where(
is_last, last_traj_ids_subs[idx].expand_as(traj_ids), traj_ids
)

if preempt:
if cat_results == "stack":
mask_frames = buffer.get(("collector", "traj_ids")) != -1
traj_ids = torch.where(mask_frames, traj_ids, -1)
n_collected += mask_frames.sum().cpu()
last_traj_ids_subs[idx] = traj_ids[..., valid][..., -1:].clone()
else:
last_traj_ids_subs[idx] = traj_ids[..., -1:].clone()
n_collected += traj_ids.numel()
else:
last_traj_ids_subs[idx] = traj_ids[..., -1:].clone()
n_collected += traj_ids.numel()
traj_ids_list[idx] = traj_ids

traj_max = max(traj_max, traj_ids.max())

if same_device is None:
prev_device = None
Expand All @@ -2262,9 +2269,6 @@ def iterator(self) -> Iterator[TensorDictBase]:
self.out_buffer = stack(
[item.cpu() for item in buffers.values()], 0
)
self.out_buffer.set(
("collector", "traj_ids"), torch.stack(traj_ids_list), inplace=True
)
else:
if self._use_buffers is None:
torchrl_logger.warning(
Expand All @@ -2281,9 +2285,6 @@ def iterator(self) -> Iterator[TensorDictBase]:
self.out_buffer = torch.cat(
[item.cpu() for item in buffers.values()], cat_results
)
self.out_buffer.set_(
("collector", "traj_ids"), torch.cat(traj_ids_list, cat_results)
)
except RuntimeError as err:
if (
preempt
Expand Down Expand Up @@ -2792,6 +2793,8 @@ def _main_async_collector(
interruptor=None,
set_truncated: bool = False,
use_buffers: bool | None = None,
replay_buffer: ReplayBuffer | None = None,
traj_pool: _TrajectoryPool = None,
) -> None:
pipe_parent.close()
# init variables that will be cleared when closing
Expand All @@ -2816,6 +2819,8 @@ def _main_async_collector(
interruptor=interruptor,
set_truncated=set_truncated,
use_buffers=use_buffers,
replay_buffer=replay_buffer,
traj_pool=traj_pool,
)
use_buffers = inner_collector._use_buffers
if verbose:
Expand Down Expand Up @@ -2878,6 +2883,21 @@ def _main_async_collector(
# In that case, we skip the collected trajectory and get the message from main. This is faster than
# sending the trajectory in the queue until timeout when it's never going to be received.
continue

if replay_buffer is not None:
try:
queue_out.put((idx, j), timeout=_TIMEOUT)
if verbose:
torchrl_logger.info(f"worker {idx} successfully sent data")
j += 1
has_timed_out = False
continue
except queue.Full:
if verbose:
torchrl_logger.info(f"worker {idx} has timed out")
has_timed_out = True
continue

if j == 0 or not use_buffers:
collected_tensordict = next_data
if (
Expand Down Expand Up @@ -2986,3 +3006,22 @@ def _make_meta_params(param):
if is_param:
pd = nn.Parameter(pd, requires_grad=False)
return pd


class _TrajectoryPool:
def __init__(self, ctx=None, lock: bool = False):
self.ctx = ctx
if ctx is None:
self._traj_id = mp.Value("i", 0)
self.lock = contextlib.nullcontext() if not lock else mp.Lock()
else:
self._traj_id = ctx.Value("i", 0)
self.lock = contextlib.nullcontext() if not lock else ctx.Lock()

def get_traj_and_increment(self, n=1, device=None):
traj_id = []
with self.lock:
for i in range(n):
traj_id.append(int(self._traj_id.value))
self._traj_id.value += 1
return torch.as_tensor(traj_id, device=device)
Loading

0 comments on commit 81ee73b

Please sign in to comment.