Skip to content

Commit

Permalink
[Feature] Propagate existsok in memmap* methods
Browse files Browse the repository at this point in the history
ghstack-source-id: 6dcab0ff5e2ae2bb9b8d3bbf18cfb524c51d144d
Pull Request resolved: #990
  • Loading branch information
vmoens committed Sep 16, 2024
1 parent fc323e5 commit 48d52d2
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 17 deletions.
4 changes: 4 additions & 0 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2458,6 +2458,7 @@ def _memmap_(
inplace=True,
like=False,
share_non_tensor,
existsok,
) -> T:
if prefix is not None:
prefix = Path(prefix)
Expand Down Expand Up @@ -2489,6 +2490,7 @@ def save_metadata(prefix=prefix, self=self):
inplace=inplace,
like=like,
share_non_tensor=share_non_tensor,
existsok=existsok,
)
)
if not inplace:
Expand Down Expand Up @@ -3526,6 +3528,7 @@ def _memmap_(
inplace,
like,
share_non_tensor,
existsok,
) -> T:
def save_metadata(data: TensorDictBase, filepath, metadata=None):
if metadata is None:
Expand Down Expand Up @@ -3558,6 +3561,7 @@ def save_metadata(data: TensorDictBase, filepath, metadata=None):
inplace=inplace,
like=like,
share_non_tensor=share_non_tensor,
existsok=existsok,
)
if not inplace:
dest = type(self)(
Expand Down
20 changes: 16 additions & 4 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -2535,6 +2535,7 @@ def _memmap_(
inplace,
like,
share_non_tensor,
existsok,
) -> T:

if prefix is not None:
Expand Down Expand Up @@ -2569,6 +2570,7 @@ def _memmap_(
inplace=inplace,
like=like,
share_non_tensor=share_non_tensor,
existsok=existsok,
)
if prefix is not None:
_update_metadata(
Expand All @@ -2585,6 +2587,7 @@ def _memmap_(
copy_existing=copy_existing,
prefix=prefix,
like=like,
existsok=existsok,
)
else:
futures.append(
Expand All @@ -2596,6 +2599,7 @@ def _memmap_(
copy_existing=copy_existing,
prefix=prefix,
like=like,
existsok=existsok,
)
)
if prefix is not None:
Expand Down Expand Up @@ -2847,7 +2851,12 @@ def make_memmap_from_storage(
return memmap_tensor

def make_memmap_from_tensor(
self, key: NestedKey, tensor: torch.Tensor, *, copy_data: bool = True
self,
key: NestedKey,
tensor: torch.Tensor,
*,
copy_data: bool = True,
existsok: bool = True,
) -> MemoryMappedTensor:
if not self.is_memmap():
raise RuntimeError(
Expand Down Expand Up @@ -2876,6 +2885,7 @@ def make_memmap_from_tensor(
copy_existing=True,
prefix=last_node._memmap_prefix,
like=not copy_data,
existsok=existsok,
)
_update_metadata(
metadata=metadata,
Expand Down Expand Up @@ -3906,6 +3916,7 @@ def _memmap_(
inplace,
like,
share_non_tensor,
existsok,
) -> T:
if prefix is not None:

Expand Down Expand Up @@ -3936,6 +3947,7 @@ def save_metadata(prefix=prefix, self=self):
inplace=inplace,
like=like,
share_non_tensor=share_non_tensor,
existsok=existsok,
)
if not inplace:
result = _SubTensorDict(_source, idx=self.idx)
Expand Down Expand Up @@ -4404,7 +4416,7 @@ def _save_metadata(data: TensorDictBase, prefix: Path, metadata=None):


# user did specify location and memmap is in wrong place, so we copy
def _populate_memmap(*, dest, value, key, copy_existing, prefix, like):
def _populate_memmap(*, dest, value, key, copy_existing, prefix, like, existsok):
filename = None if prefix is None else str(prefix / f"{key}.memmap")
if value.is_nested:
shape = value._nested_tensor_size()
Expand All @@ -4416,7 +4428,7 @@ def _populate_memmap(*, dest, value, key, copy_existing, prefix, like):
shape,
filename=shape_filename,
copy_existing=copy_existing,
existsok=True,
existsok=existsok,
copy_data=True,
)
else:
Expand All @@ -4425,9 +4437,9 @@ def _populate_memmap(*, dest, value, key, copy_existing, prefix, like):
value.data if value.requires_grad else value,
filename=filename,
copy_existing=copy_existing,
existsok=True,
copy_data=not like,
shape=shape,
existsok=existsok,
)
dest._tensordict[key] = memmap_tensor
return memmap_tensor
Expand Down
16 changes: 16 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3211,6 +3211,7 @@ def _memmap_(
inplace,
like,
share_non_tensor,
existsok,
) -> T: ...

def densify(self, layout: torch.layout = torch.strided):
Expand Down Expand Up @@ -3743,6 +3744,7 @@ def memmap_(
num_threads: int = 0,
return_early: bool = False,
share_non_tensor: bool = False,
existsok: bool = True,
) -> T:
"""Writes all tensors onto a corresponding memory-mapped Tensor, in-place.
Expand All @@ -3767,6 +3769,8 @@ def memmap_(
on all other workers. If the number of non-tensor leaves is high (e.g.,
sharing large stacks of non-tensor data) this may result in OOM or similar
errors. Defaults to ``False``.
existsok (bool, optional): if ``False``, an exception will be raised if a tensor already
exists in the same path. Defaults to ``True``.
The TensorDict is then locked, meaning that any writing operations that
isn't in-place will throw an exception (eg, rename, set or remove an
Expand Down Expand Up @@ -3799,6 +3803,7 @@ def memmap_(
inplace=True,
like=False,
share_non_tensor=share_non_tensor,
existsok=existsok,
)
if not return_early:
concurrent.futures.wait(futures)
Expand All @@ -3813,6 +3818,7 @@ def memmap_(
executor=None,
like=False,
share_non_tensor=share_non_tensor,
existsok=existsok,
).lock_()

@abc.abstractmethod
Expand Down Expand Up @@ -3935,6 +3941,7 @@ def memmap(
num_threads: int = 0,
return_early: bool = False,
share_non_tensor: bool = False,
existsok: bool = True,
) -> T:
"""Writes all tensors onto a corresponding memory-mapped Tensor in a new tensordict.
Expand All @@ -3958,6 +3965,8 @@ def memmap(
on all other workers. If the number of non-tensor leaves is high (e.g.,
sharing large stacks of non-tensor data) this may result in OOM or similar
errors. Defaults to ``False``.
existsok (bool, optional): if ``False``, an exception will be raised if a tensor already
exists in the same path. Defaults to ``True``.
The TensorDict is then locked, meaning that any writing operations that
isn't in-place will throw an exception (eg, rename, set or remove an
Expand Down Expand Up @@ -3992,6 +4001,7 @@ def memmap(
inplace=False,
like=False,
share_non_tensor=share_non_tensor,
existsok=existsok,
)
if not return_early:
concurrent.futures.wait(futures)
Expand All @@ -4007,13 +4017,15 @@ def memmap(
like=False,
futures=None,
share_non_tensor=share_non_tensor,
existsok=existsok,
).lock_()

def memmap_like(
self,
prefix: str | None = None,
copy_existing: bool = False,
*,
existsok: bool = True,
num_threads: int = 0,
return_early: bool = False,
share_non_tensor: bool = False,
Expand All @@ -4040,6 +4052,8 @@ def memmap_like(
on all other workers. If the number of non-tensor leaves is high (e.g.,
sharing large stacks of non-tensor data) this may result in OOM or similar
errors. Defaults to ``False``.
existsok (bool, optional): if ``False``, an exception will be raised if a tensor already
exists in the same path. Defaults to ``True``.
The TensorDict is then locked, meaning that any writing operations that
isn't in-place will throw an exception (eg, rename, set or remove an
Expand Down Expand Up @@ -4089,6 +4103,7 @@ def memmap_like(
inplace=False,
like=True,
share_non_tensor=share_non_tensor,
existsok=existsok,
)
if not return_early:
concurrent.futures.wait(futures)
Expand All @@ -4106,6 +4121,7 @@ def memmap_like(
executor=None,
futures=None,
share_non_tensor=share_non_tensor,
existsok=existsok,
).lock_()

@classmethod
Expand Down
26 changes: 14 additions & 12 deletions tensordict/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,12 @@ def from_tensor(
cls,
input,
*,
filename=None,
existsok=False,
copy_existing=False,
copy_data=True,
shape=None,
):
filename: Path | str = None,
existsok: bool = False,
copy_existing: bool = False,
copy_data: bool = True,
shape: torch.Size | None = None,
): # noqa: D417
"""Creates a MemoryMappedTensor with the same content as another tensor.
If the tensor is already a MemoryMappedTensor the original tensor is
Expand All @@ -149,6 +149,8 @@ def from_tensor(
Args:
input (torch.Tensor): the tensor which content must be copied onto
the MemoryMappedTensor.
Keyword Args:
filename (path to a file): the path to the file where the tensor
should be stored. If none is provided, a file handler is used
instead.
Expand Down Expand Up @@ -280,12 +282,12 @@ def from_storage(
cls,
storage,
*,
shape=None,
dtype=None,
device=None,
index=None,
filename=None,
handler=None,
shape: torch.Size | None = None,
dtype: torch.dtype | None = None,
device: torch.device | None = None,
index: IndexType | None = None,
filename: Path | str = None,
handler: _handler = None,
):
if getattr(storage, "filename", None) is not None:
if filename is None:
Expand Down
4 changes: 3 additions & 1 deletion tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,7 @@ def _memmap_(
inplace,
like,
share_non_tensor,
existsok,
) -> T:
if inplace:
raise RuntimeError("Cannot call memmap inplace in a persistent tensordict.")
Expand Down Expand Up @@ -749,6 +750,7 @@ def save_metadata(data: TensorDictBase, filepath, metadata=None):
futures=futures,
inplace=inplace,
share_non_tensor=share_non_tensor,
existsok=existsok,
),
inplace=False,
validated=True,
Expand Down Expand Up @@ -776,7 +778,7 @@ def _populate(
),
copy_data=not like,
copy_existing=copy_existing,
existsok=True,
existsok=existsok,
)
tensordict._set_str(
key,
Expand Down
6 changes: 6 additions & 0 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,7 @@ def _memmap_(
like=False,
memmaped: bool = False,
share_non_tensor: bool = False,
existsok: bool = True,
):
_non_tensordict = dict(self._non_tensordict)
cls = type(self)
Expand Down Expand Up @@ -997,6 +998,7 @@ def save_metadata(cls=cls, _non_tensordict=_non_tensordict, prefix=prefix):
like=like,
copy_existing=copy_existing,
share_non_tensor=share_non_tensor,
existsok=existsok,
)
if new_futures:
futures += new_futures
Expand Down Expand Up @@ -2816,6 +2818,7 @@ def _memmap_(
like=False,
memmaped: bool = False,
share_non_tensor: bool = False,
existsok: bool = True,
):
# For efficiency, we can avoid doing this saving
# if the data is already there.
Expand All @@ -2842,6 +2845,7 @@ def _memmap_(
like=like,
memmaped=memmaped,
share_non_tensor=share_non_tensor,
existsok=existsok,
)
_metadata["_share_non_tensor"] = share_non_tensor
out._non_tensordict["_metadata"] = _metadata
Expand Down Expand Up @@ -2967,6 +2971,7 @@ def _memmap_(
like=False,
memmaped: bool = False,
share_non_tensor: bool = False,
existsok: bool = True,
) -> T:

memmaped_leaves = memmaped
Expand Down Expand Up @@ -3013,6 +3018,7 @@ def save_metadata(prefix=prefix, self=self):
# no memmapping should be executed
memmaped=memmaped_leaves,
share_non_tensor=share_non_tensor,
existsok=existsok,
)
)
if not inplace:
Expand Down

0 comments on commit 48d52d2

Please sign in to comment.