Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Sep 29, 2024
2 parents 522111a + e8fbd71 commit 82cc9c4
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 38 deletions.
7 changes: 4 additions & 3 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3045,9 +3045,10 @@ def make_and_test_policy(

# If the policy is a CudaGraphModule, we know it's on cuda - no need to warn
if torch.cuda.is_available():
policy = make_policy(original_device)
cudagraph_policy = CudaGraphModule(policy)
make_and_test_policy(cudagraph_policy, policy_device=original_device)
with pytest.warns(UserWarning, match="Tensordict is registered in PyTree"):
policy = make_policy(original_device)
cudagraph_policy = CudaGraphModule(policy)
make_and_test_policy(cudagraph_policy, policy_device=original_device)


if __name__ == "__main__":
Expand Down
21 changes: 3 additions & 18 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,10 @@ def _get_policy_and_device(
env_maker=env_maker,
env_maker_kwargs=env_maker_kwargs,
)
if not policy_device:
if not policy_device or not isinstance(policy, nn.Module):
return policy, None

if isinstance(policy, nn.Module):
param_and_buf = TensorDict.from_module(policy, as_module=True)
else:
param_and_buf = TensorDict()
param_and_buf = TensorDict.from_module(policy, as_module=True)

i = -1
for p in param_and_buf.values(True, True):
Expand All @@ -192,24 +189,12 @@ def _get_policy_and_device(
"The collector will trust that the devices match. To suppress this "
"warning, set `trust_policy=True` when building the collector."
)
else:
# We checked and all params are on the appropriate device
pass
return policy, None

def get_weights_fn(param_and_buf=param_and_buf):
return param_and_buf.data

# create a stateless policy and populate it with params

# TODO: merge these two funcs
has_different_device = False
return policy, None

def map_weight(
weight,
policy_device=policy_device,
):
nonlocal has_different_device

is_param = isinstance(weight, nn.Parameter)
is_buffer = isinstance(weight, nn.Buffer)
Expand Down
23 changes: 10 additions & 13 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,22 +982,20 @@ def _find_start_stop_traj(

# faster
end = trajectory[:-1] != trajectory[1:]
end = torch.cat([end, trajectory[-1:] != trajectory[:1]], 0)
if not at_capacity:
end = torch.cat([end, torch.ones_like(end[:1])], 0)
else:
end = torch.cat([end, trajectory[-1:] != trajectory[:1]], 0)
length = trajectory.shape[0]
else:
# TODO: check that storage is at capacity here, if not we need to assume that the last element of end is True

# We presume that not done at the end means that the traj spans across end and beginning of storage
length = end.shape[0]
if not at_capacity:
end = end.clone()
end[length - 1] = True
ndim = end.ndim

if not at_capacity:
end = torch.index_fill(
end,
index=torch.tensor(-1, device=end.device, dtype=torch.long),
dim=0,
value=1,
)
else:
if at_capacity:
# we must have at least one end by traj to individuate trajectories
# so if no end can be found we set it manually
if cursor is not None:
Expand All @@ -1019,7 +1017,6 @@ def _find_start_stop_traj(
mask = ~end.any(0, True)
mask = torch.cat([torch.zeros_like(end[:-1]), mask])
end = torch.masked_fill(mask, end, 1)
ndim = end.ndim
if ndim == 0:
raise RuntimeError(
"Expected the end-of-trajectory signal to be at least 1-dimensional."
Expand Down Expand Up @@ -1126,7 +1123,7 @@ def _get_stop_and_length(self, storage, fallback=True):
"Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories."
)
vals = self._find_start_stop_traj(
trajectory=trajectory,
trajectory=trajectory.clone(),
at_capacity=storage._is_full,
cursor=getattr(storage, "_last_cursor", None),
)
Expand Down
12 changes: 8 additions & 4 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
TensorDict,
TensorDictBase,
)
from tensordict.base import _NESTED_TENSORS_AS_LISTS
from tensordict.memmap import MemoryMappedTensor
from torch import multiprocessing as mp
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
Expand Down Expand Up @@ -901,9 +902,7 @@ def max_size_along_dim0(data_shape):

if is_tensor_collection(data):
out = data.to(self.device)
out = out.expand(max_size_along_dim0(data.shape))
out = out.clone()
out = out.zero_()
out = torch.empty_like(out.expand(max_size_along_dim0(data.shape)))
else:
# if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype
out = tree_map(
Expand Down Expand Up @@ -1120,7 +1119,12 @@ def max_size_along_dim0(data_shape):
out = out.memmap_like(prefix=self.scratch_dir, existsok=self.existsok)
if torchrl_logger.isEnabledFor(logging.DEBUG):
for key, tensor in sorted(
out.items(include_nested=True, leaves_only=True), key=str
out.items(
include_nested=True,
leaves_only=True,
is_leaf=_NESTED_TENSORS_AS_LISTS,
),
key=str,
):
try:
filesize = os.path.getsize(tensor.filename) / 1024 / 1024
Expand Down

0 comments on commit 82cc9c4

Please sign in to comment.