Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jul 26, 2024
1 parent c6ef080 commit 4b3b09e
Showing 1 changed file with 84 additions and 0 deletions.
84 changes: 84 additions & 0 deletions torchrl/modules/tensordict_module/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,8 @@ class LSTMModule(ModuleBase):
Methods:
set_recurrent_mode: controls whether the module should be executed in
recurrent mode.
make_tensordict_primer: creates the TensorDictPrimer transforms for the environment to be aware of the
recurrent states of the RNN.
.. note:: This module relies on specific ``recurrent_state`` keys being present in the input
TensorDicts. To generate a :class:`~torchrl.envs.transforms.TensorDictPrimer` transform that will automatically
Expand Down Expand Up @@ -521,6 +523,46 @@ def __init__(
self._recurrent_mode = False

def make_tensordict_primer(self):
"""Makes a tensordict primer for the environment.
A :class:`~torchrl.envs.TensorDictPrimer` object will ensure that the policy is aware of the supplementary
inputs and outputs (recurrent states) during rollout execution. That way, the data can be shared across
processes and dealt with properly.
Not including a ``TensorDictPrimer`` in the environment may result in poorly defined behaviours, for instance
in parallel settings where a step involves copying the new recurrent state from ``"next"`` to the root
tensordict, which the meth:`~torchrl.EnvBase.step_mdp` method will not be able to do as the recurrent states
are not registered within the environment specs.
Examples:
>>> from torchrl.collectors import SyncDataCollector
>>> from torchrl.envs import TransformedEnv, InitTracker
>>> from torchrl.envs import GymEnv
>>> from torchrl.modules import MLP, LSTMModule
>>> from torch import nn
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>>
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
>>> assert env.base_env.batch_locked
>>> lstm_module = LSTMModule(
... input_size=env.observation_spec["observation"].shape[-1],
... hidden_size=64,
... in_keys=["observation", "rs_h", "rs_c"],
... out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")])
>>> mlp = MLP(num_cells=[64], out_features=1)
>>> policy = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> policy(env.reset())
>>> env = env.append_transform(lstm_module.make_tensordict_primer())
>>> data_collector = SyncDataCollector(
... env,
... policy,
... frames_per_batch=10
... )
>>> for data in data_collector:
... print(data)
... break
"""
from torchrl.envs.transforms.transforms import TensorDictPrimer

def make_tuple(key):
Expand Down Expand Up @@ -1065,6 +1107,8 @@ class GRUModule(ModuleBase):
Methods:
set_recurrent_mode: controls whether the module should be executed in
recurrent mode.
make_tensordict_primer: creates the TensorDictPrimer transforms for the environment to be aware of the
recurrent states of the RNN.
.. note:: This module relies on specific ``recurrent_state`` keys being present in the input
TensorDicts. To generate a :class:`~torchrl.envs.transforms.TensorDictPrimer` transform that will automatically
Expand Down Expand Up @@ -1230,6 +1274,46 @@ def __init__(
self._recurrent_mode = False

def make_tensordict_primer(self):
"""Makes a tensordict primer for the environment.
A :class:`~torchrl.envs.TensorDictPrimer` object will ensure that the policy is aware of the supplementary
inputs and outputs (recurrent states) during rollout execution. That way, the data can be shared across
processes and dealt with properly.
Not including a ``TensorDictPrimer`` in the environment may result in poorly defined behaviours, for instance
in parallel settings where a step involves copying the new recurrent state from ``"next"`` to the root
tensordict, which the meth:`~torchrl.EnvBase.step_mdp` method will not be able to do as the recurrent states
are not registered within the environment specs.
Examples:
>>> from torchrl.collectors import SyncDataCollector
>>> from torchrl.envs import TransformedEnv, InitTracker
>>> from torchrl.envs import GymEnv
>>> from torchrl.modules import MLP, LSTMModule
>>> from torch import nn
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>>
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
>>> assert env.base_env.batch_locked
>>> gru_module = GRUModule(
... input_size=env.observation_spec["observation"].shape[-1],
... hidden_size=64,
... in_keys=["observation", "rs"],
... out_keys=["intermediate", ("next", "rs")])
>>> mlp = MLP(num_cells=[64], out_features=1)
>>> policy = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"]))
>>> policy(env.reset())
>>> env = env.append_transform(gru_module.make_tensordict_primer())
>>> data_collector = SyncDataCollector(
... env,
... policy,
... frames_per_batch=10
... )
>>> for data in data_collector:
... print(data)
... break
"""
from torchrl.envs import TensorDictPrimer

def make_tuple(key):
Expand Down

0 comments on commit 4b3b09e

Please sign in to comment.