Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Oct 10, 2023
1 parent db918a7 commit 86ac8fa
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1764,7 +1764,7 @@ def map(
@cache # noqa: B019
def _add_batch_dim(self, *, in_dim, vmap_level):
if self.is_memmap():
td = self.cpu().as_tensor()
td = self.cpu()
else:
td = self
out = TensorDict(
Expand Down
14 changes: 7 additions & 7 deletions test/test_memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,15 +519,15 @@ def dummy_memmap():
@pytest.mark.parametrize("device", get_available_devices())
class TestOps:
def test_eq(self, device, dummy_memmap):
memmap = dummy_memmap.to(device)
assert (memmap == memmap.clone()).all()
assert (memmap.clone() == memmap).all()
dummy_memmap.device = device
assert (dummy_memmap == dummy_memmap.clone()).all()
assert (dummy_memmap.clone() == dummy_memmap).all()
if device.type == "cpu":
assert (memmap == memmap.as_tensor()).all()
assert (memmap.as_tensor() == memmap).all()
assert (dummy_memmap == dummy_memmap.as_tensor()).all()
assert (dummy_memmap.as_tensor() == dummy_memmap).all()
else:
assert (memmap == memmap._tensor).all()
assert (memmap._tensor == memmap).all()
assert (dummy_memmap == dummy_memmap._tensor).all()
assert (dummy_memmap._tensor == dummy_memmap).all()

def test_fill_(self, device, dummy_memmap):
memmap = dummy_memmap.to(device)
Expand Down

0 comments on commit 86ac8fa

Please sign in to comment.