Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 14, 2024
1 parent 86d6406 commit 2ed9a70
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
2 changes: 1 addition & 1 deletion tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1975,7 +1975,7 @@ def memmap_like(
else:
return TensorDictFuture(futures, result)
input = self.apply(
lambda x: torch.empty((), device=x.device, dtype=x.dtype).expand(x.shape)
lambda x: torch.empty_like(x)
)
return input._memmap_(
prefix=prefix,
Expand Down
16 changes: 13 additions & 3 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@
torch.full_like,
torch.zeros_like,
torch.ones_like,
torch.empty_like,
torch.randn_like,
torch.rand_like,
torch.clone,
torch.squeeze,
torch.unsqueeze,
Expand Down Expand Up @@ -1246,6 +1249,13 @@ class NonTensorData:
# and all the overhead falls back on this class.
data: Any

@classmethod
def from_tensor(cls, value: torch.Tensor, batch_size, device=None, names=None):
"""A util to create a NonTensorData containing a tensor."""
out = cls(data=None, batch_size=batch_size, device=device, names=names)
out._non_tensordict["data"] = value
return out

def __post_init__(self):
if isinstance(self.data, NonTensorData):
self.data = self.data.data
Expand Down Expand Up @@ -1304,7 +1314,7 @@ def __or__(self, other):
self.__class__.__or__ = __or__

def empty(self, recurse=False):
return NonTensorData(
return type(self)(
data=self.data,
batch_size=self.batch_size,
names=self.names if self._has_names() else None,
Expand Down Expand Up @@ -1332,7 +1342,7 @@ def _check_equal(a, b):
if all(_check_equal(data.data, first.data) for data in list_of_non_tensor[1:]):
batch_size = list(first.batch_size)
batch_size.insert(dim, len(list_of_non_tensor))
return NonTensorData(
return type(self)(
data=first.data,
batch_size=batch_size,
names=first.names if first._has_names() else None,
Expand All @@ -1358,7 +1368,7 @@ def __torch_function__(
):
return NotImplemented

escape_conversion = func in (torch.stack,)
escape_conversion = func in (torch.stack, torch.ones_like, torch.zeros_like, torch.empty_like, torch.randn_like, torch.rand_like)

if kwargs is None:
kwargs = {}
Expand Down

0 comments on commit 2ed9a70

Please sign in to comment.