Skip to content

Commit

Permalink
[Refactor] _set_str and _set_tuple (#480)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jul 7, 2023
1 parent 200302e commit 938b166
Show file tree
Hide file tree
Showing 5 changed files with 472 additions and 419 deletions.
107 changes: 45 additions & 62 deletions tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
TensorDictBase,
)
from tensordict.utils import (
_shape,
cache,
DeviceType,
expand_right,
Expand Down Expand Up @@ -659,7 +658,7 @@ def to(
f"instance, {dest} not allowed"
)

def _validate_value(self, value, check_shape=True):
def _to_numpy(self, value):
if hasattr(value, "requires_grad") and value.requires_grad:
raise RuntimeError("Cannot set a tensor that has requires_grad=True.")
if isinstance(value, torch.Tensor):
Expand All @@ -676,28 +675,24 @@ def _validate_value(self, value, check_shape=True):
raise NotImplementedError(
f"Cannot set values of type {value} in a PersistentTensorDict."
)
if check_shape and _shape(out)[: self.batch_dims] != self.batch_size:
# if TensorDict, let's try to map it to the desired shape
if is_tensor_collection(out):
out = out.clone(recurse=False)
out.batch_size = self.batch_size
else:
raise RuntimeError(
f"batch dimension mismatch, got self.batch_size"
f"={self.batch_size} and value.shape[:self.batch_dims]"
f"={_shape(out)[: self.batch_dims]} with value {out}"
)
return out

def _set(
self, key: str, value, inplace: bool = False, idx=None, check_shape=True
self,
key: str,
value,
inplace: bool = False,
idx=None,
validated=False,
) -> PersistentTensorDict:
# although it is expected that _set will run as few tests as possible,
# we must do the value transformation here as _set can be called by other
# methods from TensorDictBase.
value = self._validate_value(value, check_shape=check_shape)
if not inplace and idx is not None:
raise RuntimeError("Cannot pass an index to _set when inplace=False.")
if not validated:
value = self._validate_value(value, check_shape=idx is None)
value = self._to_numpy(value)
if not inplace:
if idx is not None:
raise RuntimeError("Cannot pass an index to _set when inplace=False.")
elif self.is_locked:
raise RuntimeError(self.LOCK_ERROR)
# shortcut set if we're placing a tensordict
if is_tensor_collection(value):
if isinstance(key, tuple):
Expand Down Expand Up @@ -735,7 +730,6 @@ def _set(
idx = ()
else:
idx = self._process_index(idx, array)

try:
array[idx] = value
except TypeError as err:
Expand All @@ -760,6 +754,7 @@ def _set(
raise err

else:
key = self._process_key(key)
try:
self.file.create_dataset(key, data=value, **self.kwargs)
except (ValueError, OSError) as err:
Expand All @@ -772,50 +767,38 @@ def _set(
self.file.create_dataset(key, data=value, **self.kwargs)
return self

def set(
self,
key: NestedKey,
value: dict[str, CompatibleType] | CompatibleType,
inplace: bool = False,
) -> PersistentTensorDict:

def _convert_inplace(self, inplace, key):
key = self._process_key(key)
if inplace is not False:
has_key = key in self.file
if inplace is True and not has_key: # inplace could be None
raise KeyError(
TensorDictBase.KEY_ERROR.format(
key, self.__class__.__name__, sorted(self.keys())
)
)
inplace = has_key
return inplace

def _set_str(self, key, value, *, inplace, validated):
inplace = self._convert_inplace(inplace, key)
return self._set(key, value, inplace=inplace, validated=validated)

def _set_tuple(self, key, value, *, inplace, validated):
if len(key) == 1:
return self._set_str(key[0], value, inplace=inplace, validated=validated)
elif key[0] in self.keys():
return self._get_str(key[0])._set_tuple(
key[1:], value, inplace=inplace, validated=validated
)
inplace = self._convert_inplace(inplace, key)
return self._set(key, value, inplace=inplace, validated=validated)

visitor = _Visitor()
self.file.visit(visitor)
inplace = inplace and key in visitor
if self.is_locked and not inplace:
raise RuntimeError(TensorDictBase.LOCK_ERROR)

# not calling set_ to avoid re-validate key
return self._set(key, value, inplace=inplace)

def set_(
self, key: str, value: dict[str, CompatibleType] | CompatibleType
) -> PersistentTensorDict:
visitor = _Visitor()
self.file.visit(visitor)
key = self._process_key(key)
if key not in visitor:
raise KeyError(f'key "{key}" not found in h5.')
# we don't need to check shape as the modification will be done
# in-place and an error will be thrown anyway if shapes don't match
return self._set(key, value, inplace=True, check_shape=False)
def _set_at_str(self, key, value, idx, *, validated):
return self._set(key, value, inplace=True, idx=idx, validated=validated)

def set_at_(
self,
key: NestedKey,
value: dict[str, CompatibleType] | CompatibleType,
idx: IndexType,
) -> PersistentTensorDict:
visitor = _Visitor()
self.file.visit(visitor)
key = self._process_key(key)
if key not in visitor:
raise KeyError(f'key "{key}" not found in h5.')
# we don't need to check shape as the modification will be done
# in-place and an error will be thrown anyway if shapes don't match
return self._set(key, value, inplace=True, idx=idx, check_shape=False)
def _set_at_tuple(self, key, value, idx, *, validated):
return self._set(key, value, inplace=True, idx=idx, validated=validated)

def _set_metadata(self, orig_metadata_container: PersistentTensorDict):
for key, td in orig_metadata_container._nested_tensordicts.items():
Expand Down
7 changes: 5 additions & 2 deletions tensordict/prototype/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from itertools import filterfalse, tee
from typing import Any, Callable, Iterable

from tensordict._tensordict import _unravel_key_to_tuple
from tensordict.nn import TensorDictModule, TensorDictSequential
from tensordict.tensordict import TensorDictBase
from tensordict.utils import NestedKey
Expand All @@ -26,7 +27,7 @@ def __init__(
out_keys: list[NestedKey],
) -> None:
super().__init__()
self.out_keys = out_keys
self.out_keys = [_unravel_key_to_tuple(ok) for ok in out_keys]
self._gm = graph_module

def forward(
Expand All @@ -42,7 +43,9 @@ def forward(

for out_key, output in zip(self.out_keys, outputs):
if out_key != "_":
tensordict_out._set(out_key, output)
tensordict_out._set_tuple(
out_key, output, inplace=False, validated=True
)

return tensordict_out

Expand Down
Loading

0 comments on commit 938b166

Please sign in to comment.