From 4b3b09e80584aba02b49ccb4055a2e562f2fd52d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 26 Jul 2024 15:57:58 +0100 Subject: [PATCH] init --- torchrl/modules/tensordict_module/rnn.py | 84 ++++++++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index f0290d6a42f..68af5f44972 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -381,6 +381,8 @@ class LSTMModule(ModuleBase): Methods: set_recurrent_mode: controls whether the module should be executed in recurrent mode. + make_tensordict_primer: creates the TensorDictPrimer transforms for the environment to be aware of the + recurrent states of the RNN. .. note:: This module relies on specific ``recurrent_state`` keys being present in the input TensorDicts. To generate a :class:`~torchrl.envs.transforms.TensorDictPrimer` transform that will automatically @@ -521,6 +523,46 @@ def __init__( self._recurrent_mode = False def make_tensordict_primer(self): + """Makes a tensordict primer for the environment. + + A :class:`~torchrl.envs.TensorDictPrimer` object will ensure that the policy is aware of the supplementary + inputs and outputs (recurrent states) during rollout execution. That way, the data can be shared across + processes and dealt with properly. + + Not including a ``TensorDictPrimer`` in the environment may result in poorly defined behaviours, for instance + in parallel settings where a step involves copying the new recurrent state from ``"next"`` to the root + tensordict, which the meth:`~torchrl.EnvBase.step_mdp` method will not be able to do as the recurrent states + are not registered within the environment specs. + + Examples: + >>> from torchrl.collectors import SyncDataCollector + >>> from torchrl.envs import TransformedEnv, InitTracker + >>> from torchrl.envs import GymEnv + >>> from torchrl.modules import MLP, LSTMModule + >>> from torch import nn + >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod + >>> + >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) + >>> assert env.base_env.batch_locked + >>> lstm_module = LSTMModule( + ... input_size=env.observation_spec["observation"].shape[-1], + ... hidden_size=64, + ... in_keys=["observation", "rs_h", "rs_c"], + ... out_keys=["intermediate", ("next", "rs_h"), ("next", "rs_c")]) + >>> mlp = MLP(num_cells=[64], out_features=1) + >>> policy = Seq(lstm_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) + >>> policy(env.reset()) + >>> env = env.append_transform(lstm_module.make_tensordict_primer()) + >>> data_collector = SyncDataCollector( + ... env, + ... policy, + ... frames_per_batch=10 + ... ) + >>> for data in data_collector: + ... print(data) + ... break + + """ from torchrl.envs.transforms.transforms import TensorDictPrimer def make_tuple(key): @@ -1065,6 +1107,8 @@ class GRUModule(ModuleBase): Methods: set_recurrent_mode: controls whether the module should be executed in recurrent mode. + make_tensordict_primer: creates the TensorDictPrimer transforms for the environment to be aware of the + recurrent states of the RNN. .. note:: This module relies on specific ``recurrent_state`` keys being present in the input TensorDicts. To generate a :class:`~torchrl.envs.transforms.TensorDictPrimer` transform that will automatically @@ -1230,6 +1274,46 @@ def __init__( self._recurrent_mode = False def make_tensordict_primer(self): + """Makes a tensordict primer for the environment. + + A :class:`~torchrl.envs.TensorDictPrimer` object will ensure that the policy is aware of the supplementary + inputs and outputs (recurrent states) during rollout execution. That way, the data can be shared across + processes and dealt with properly. + + Not including a ``TensorDictPrimer`` in the environment may result in poorly defined behaviours, for instance + in parallel settings where a step involves copying the new recurrent state from ``"next"`` to the root + tensordict, which the meth:`~torchrl.EnvBase.step_mdp` method will not be able to do as the recurrent states + are not registered within the environment specs. + + Examples: + >>> from torchrl.collectors import SyncDataCollector + >>> from torchrl.envs import TransformedEnv, InitTracker + >>> from torchrl.envs import GymEnv + >>> from torchrl.modules import MLP, LSTMModule + >>> from torch import nn + >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod + >>> + >>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker()) + >>> assert env.base_env.batch_locked + >>> gru_module = GRUModule( + ... input_size=env.observation_spec["observation"].shape[-1], + ... hidden_size=64, + ... in_keys=["observation", "rs"], + ... out_keys=["intermediate", ("next", "rs")]) + >>> mlp = MLP(num_cells=[64], out_features=1) + >>> policy = Seq(gru_module, Mod(mlp, in_keys=["intermediate"], out_keys=["action"])) + >>> policy(env.reset()) + >>> env = env.append_transform(gru_module.make_tensordict_primer()) + >>> data_collector = SyncDataCollector( + ... env, + ... policy, + ... frames_per_batch=10 + ... ) + >>> for data in data_collector: + ... print(data) + ... break + + """ from torchrl.envs import TensorDictPrimer def make_tuple(key):