Skip to content

Commit

Permalink
Update (base update)
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Aug 10, 2024
1 parent a6310ae commit 4499d50
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 15 deletions.
47 changes: 47 additions & 0 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2751,6 +2751,53 @@ def test_async(self, use_buffers):
del collector


class TestCollectorRB:
def test_collector_rb_sync(self):
env = SerialEnv(8, lambda cp=CARTPOLE_VERSIONED(): GymEnv(cp))
env.set_seed(0)
rb = ReplayBuffer(storage=LazyTensorStorage(256, ndim=2), batch_size=5)
collector = SyncDataCollector(
env,
RandomPolicy(env.action_spec),
replay_buffer=rb,
total_frames=256,
frames_per_batch=16,
)
torch.manual_seed(0)

for c in collector:
assert c is None
rb.sample()
rbdata0 = rb[:].clone()
collector.shutdown()
if not env.is_closed:
env.close()
del collector, env

env = SerialEnv(8, lambda cp=CARTPOLE_VERSIONED(): GymEnv(cp))
env.set_seed(0)
rb = ReplayBuffer(storage=LazyTensorStorage(256, ndim=2), batch_size=5)
collector = SyncDataCollector(
env, RandomPolicy(env.action_spec), total_frames=256, frames_per_batch=16
)
torch.manual_seed(0)

for i, c in enumerate(collector):
rb.extend(c)
torch.testing.assert_close(
rbdata0[:, : (i + 1) * 2]["observation"], rb[:]["observation"]
)
assert c is not None
rb.sample()

rbdata1 = rb[:].clone()
collector.shutdown()
if not env.is_closed:
env.close()
del collector, env
assert assert_allclose_td(rbdata0, rbdata1)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
56 changes: 43 additions & 13 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
VERBOSE,
)
from torchrl.collectors.utils import split_trajectories
from torchrl.data import ReplayBuffer
from torchrl.data.tensor_specs import TensorSpec
from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
from torchrl.envs.common import _do_nothing, EnvBase
Expand Down Expand Up @@ -357,6 +358,8 @@ class SyncDataCollector(DataCollectorBase):
use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data.
This isn't compatible with environments with dynamic specs. Defaults to ``True``
for envs without dynamic specs, ``False`` for others.
replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordict
but populate the buffer instead. Defaults to ``None``.
Examples:
>>> from torchrl.envs.libs.gym import GymEnv
Expand Down Expand Up @@ -446,6 +449,7 @@ def __init__(
interruptor=None,
set_truncated: bool = False,
use_buffers: bool | None = None,
replay_buffer: ReplayBuffer | None = None,
):
from torchrl.envs.batched_envs import BatchedEnvBase

Expand Down Expand Up @@ -538,9 +542,17 @@ def __init__(

self.env: EnvBase = env
del env
self.replay_buffer = replay_buffer
if self.replay_buffer is not None:
if postproc is not None:
raise TypeError("postproc must be None when a replay buffer is passed.")
if use_buffers:
raise TypeError("replay_buffer is exclusive with use_buffers.")
if use_buffers is None:
use_buffers = not self.env._has_dynamic_specs
use_buffers = not self.env._has_dynamic_specs and self.replay_buffer is None
self._use_buffers = use_buffers
self.replay_buffer = replay_buffer

self.closed = False
if not reset_when_done:
raise ValueError("reset_when_done is deprectated.")
Expand Down Expand Up @@ -871,7 +883,15 @@ def set_seed(self, seed: int, static_seed: bool = False) -> int:
>>> out_seed = collector.set_seed(1) # out_seed = 6
"""
return self.env.set_seed(seed, static_seed=static_seed)
out = self.env.set_seed(seed, static_seed=static_seed)
return out

def _increment_frames(self, numel):
self._frames += numel
completed = self._frames >= self.total_frames
if completed:
self.env.close()
return completed

def iterator(self) -> Iterator[TensorDictBase]:
"""Iterates through the DataCollector.
Expand Down Expand Up @@ -917,14 +937,15 @@ def cuda_check(tensor: torch.Tensor):
for stream in streams:
stack.enter_context(torch.cuda.stream(stream))

total_frames = self.total_frames

while self._frames < self.total_frames:
self._iter += 1
tensordict_out = self.rollout()
self._frames += tensordict_out.numel()
if self._frames >= total_frames:
self.env.close()
if tensordict_out is None:
# if a replay buffer is passed, there is no tensordict_out
# frames are updated within the rollout function
yield
continue
self._increment_frames(tensordict_out.numel())

if self.split_trajs:
tensordict_out = split_trajectories(
Expand Down Expand Up @@ -1053,13 +1074,18 @@ def rollout(self) -> TensorDictBase:
next_data.clear_device_()
self._shuttle.set("next", next_data)

if self.storing_device is not None:
tensordicts.append(
self._shuttle.to(self.storing_device, non_blocking=True)
)
self._sync_storage()
if self.replay_buffer is not None:
self.replay_buffer.add(self._shuttle)
if self._increment_frames(self._shuttle.numel()):
return
else:
tensordicts.append(self._shuttle)
if self.storing_device is not None:
tensordicts.append(
self._shuttle.to(self.storing_device, non_blocking=True)
)
self._sync_storage()
else:
tensordicts.append(self._shuttle)

# carry over collector data without messing up devices
collector_data = self._shuttle.get("collector").copy()
Expand All @@ -1074,6 +1100,8 @@ def rollout(self) -> TensorDictBase:
self.interruptor is not None
and self.interruptor.collection_stopped()
):
if self.replay_buffer is not None:
return
result = self._final_rollout
if self._use_buffers:
try:
Expand Down Expand Up @@ -1109,6 +1137,8 @@ def rollout(self) -> TensorDictBase:
self._final_rollout.ndim - 1,
out=self._final_rollout,
)
elif self.replay_buffer is not None:
return
else:
result = TensorDict.maybe_dense_stack(tensordicts, dim=-1)
result.refine_names(..., "time")
Expand Down
6 changes: 4 additions & 2 deletions torchrl/envs/custom/pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ class PendulumEnv(EnvBase):
"render_fps": 30,
}
batch_locked = False
rng = None

def __init__(self, td_params=None, seed=None, device=None):
if td_params is None:
Expand All @@ -224,7 +225,7 @@ def __init__(self, td_params=None, seed=None, device=None):
super().__init__(device=device)
self._make_spec(td_params)
if seed is None:
seed = torch.empty((), dtype=torch.int64).random_().item()
seed = torch.empty((), dtype=torch.int64).random_(generator=self.rng).item()
self.set_seed(seed)

@classmethod
Expand Down Expand Up @@ -354,7 +355,8 @@ def make_composite_from_td(td):
return composite

def _set_seed(self, seed: int):
rng = torch.manual_seed(seed)
rng = torch.Generator()
rng.manual_seed(seed)
self.rng = rng

@staticmethod
Expand Down

0 comments on commit 4499d50

Please sign in to comment.