Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Oct 12, 2023
1 parent f96c05c commit b25485e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 20 deletions.
9 changes: 2 additions & 7 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1812,13 +1812,8 @@ def as_tensor(self):
and will raise an exception in all other cases.
"""
try:
return self._fast_apply(lambda x: x.as_tensor())
except AttributeError as err:
raise AttributeError(
f"{self.__class__.__name__} does not have an 'as_tensor' method "
f"because at least one of its tensors does not support this method."
) from err
warnings.warn("as_tensor will soon be deprecated.", category=DeprecationWarning)
return self

def update(
self,
Expand Down
26 changes: 13 additions & 13 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2272,19 +2272,19 @@ def test_chunk(self, td_name, device, dim, chunks):
assert sum([_td.shape[dim] for _td in td_chunks]) == td.shape[dim]
assert (torch.cat(td_chunks, dim) == td).all()

def test_as_tensor(self, td_name, device):
td = getattr(self, td_name)(device)
if "memmap" in td_name and device == torch.device("cpu"):
tdt = td.as_tensor()
assert (tdt == td).all()
elif "memmap" in td_name:
with pytest.raises(
RuntimeError, match="can only be called with MemmapTensors stored"
):
td.as_tensor()
else:
with pytest.raises(AttributeError):
td.as_tensor()
# def test_as_tensor(self, td_name, device):
# td = getattr(self, td_name)(device)
# if "memmap" in td_name and device == torch.device("cpu"):
# tdt = td.as_tensor()
# assert (tdt == td).all()
# elif "memmap" in td_name:
# with pytest.raises(
# RuntimeError, match="can only be called with MemmapTensors stored"
# ):
# td.as_tensor()
# else:
# with pytest.raises(AttributeError):
# td.as_tensor()

def test_items_values_keys(self, td_name, device):
torch.manual_seed(1)
Expand Down

0 comments on commit b25485e

Please sign in to comment.