From 1d43d8b2bf16a0ca437afe7e31b043e2a4b94f23 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 11 Jul 2024 17:56:29 +0200 Subject: [PATCH] init --- torchrl/envs/transforms/transforms.py | 154 ++++++++++++++++++++++++++ 1 file changed, 154 insertions(+) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 70aef03e041..6caa114065c 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -8557,3 +8557,157 @@ def _inv_call(self, tensordict): if self.sampling == self.SamplingStrategy.RANDOM: action = action + self.jitters * torch.rand_like(self.jitters) return tensordict.set(self.in_keys_inv[0], action) + + +class AbsorbingStateTransform(ObservationTransform): + """Adds an absorbing state to the observation space. + + A transform that introduces an absorbing state to the environment. This absorbing state is typically used + in reinforcement learning to handle terminal states by creating an additional state that signifies + the end of an episode but allows for additional steps in the transition to better handle + learning algorithms. + + Args: + max_episode_length (int): Maximum length of an episode. + in_keys (Sequence[NestedKey], optional): Keys to use for input observation. Defaults to ``"observation"``. + out_keys (Sequence[NestedKey], optional): Keys to use for output observation. Defaults to ``in_keys``. + done_key (Optional[NestedKey]): Key indicating if the episode is done. Defaults to ``"done"``. + terminate_key (Optional[NestedKey]): Key indicating if the episode is terminated. Defaults to ``"terminated"``. + + Examples: + >>> from torchrl.envs import GymEnv + >>> t = AbsorbingStateTransform(max_episode_length=1000) + >>> base_env = GymEnv("HalfCheetah-v4") + >>> env = TransformedEnv(base_env, t) + """ + + def __init__( + self, + max_episode_length: int, + in_keys: Sequence[NestedKey] | None = None, + out_keys: Sequence[NestedKey] | None = None, + done_key: Optional[NestedKey] = "done", + terminate_key: Optional[NestedKey] = "terminated", + ): + if in_keys is None: + in_keys = "observation" # default + if out_keys is None: + out_keys = copy(in_keys) + super().__init__(in_keys=in_keys, out_keys=out_keys) + self.max_episode_length = max_episode_length + self.done_key = done_key + self.terminate_key = terminate_key + self._done = None + self._curr_timestep = 0 + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + raise RuntimeError(FORWARD_NOT_IMPLEMENTED.format(type(self))) + + def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: + # Check if the observation is batched or not + if observation.dim() == 1: + # Single observation + if self._done: + # Return absorbing state which is [0, ..., 0, 1] + return torch.eye(observation.size(0) + 1)[-1] + return torch.cat((observation, torch.zeros(1)), dim=-1) + + elif observation.dim() == 2: + # Batched observations + batch_size = observation.size(0) + if self._done: + # Create absorbing states for the batched observations + absorbing_state = torch.eye(observation.size(1) + 1)[-1] + return absorbing_state.expand(batch_size, -1) + zeros = torch.zeros(batch_size, 1) + return torch.cat((observation, zeros), dim=-1) + + else: + raise ValueError( + "Unsupported observation dimension: {}".format(observation.dim()) + ) + + def _reset( + self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase + ) -> TensorDictBase: + self._curr_timestep = 0 + self._done = False + with _set_missing_tolerance(self, True): + return self._call(tensordict_reset) + + def _call(self, tensordict: TensorDictBase) -> TensorDictBase: + parent = self.parent + if parent is None: + raise RuntimeError( + f"{type(self)}.parent cannot be None: make sure this transform is executed within an environment." + ) + if self._done: + for in_key, out_key in zip(self.in_keys, self.out_keys): + value = tensordict.get(in_key, default=None) + if value is not None: + observation = self._apply_transform(value) + tensordict.set( + out_key, + observation, + ) + elif not self.missing_tolerance: + raise KeyError( + f"{self}: '{in_key}' not found in tensordict {tensordict}" + ) + tensordict.set( + self.done_key, torch.ones_like(tensordict.get(self.done_key)).bool() + ) + tensordict.set( + self.terminate_key, + torch.ones_like(tensordict.get(self.terminate_key)).bool(), + ) + return tensordict + done = tensordict.get(self.done_key) + self._done = done.any() + # set dones to be true + for in_key, out_key in zip(self.in_keys, self.out_keys): + value = tensordict.get(in_key, default=None) + if value is not None: + observation = self._apply_transform(value) + tensordict.set( + out_key, + observation, + ) + elif not self.missing_tolerance: + raise KeyError( + f"{self}: '{in_key}' not found in tensordict {tensordict}" + ) + + tensordict.set( + self.done_key, torch.zeros_like(tensordict.get(self.done_key)).bool() + ) + tensordict.set( + self.terminate_key, + torch.zeros_like(tensordict.get(self.terminate_key)).bool(), + ) + return tensordict + + @property + def is_done(self) -> bool: + return self._done + + @_apply_to_composite + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + space = observation_spec.space + + if isinstance(space, ContinuousBox): + space.low = self._apply_transform(space.low) + space.high = self._apply_transform(space.high) + observation_spec.shape = space.low.shape + else: + observation_spec.shape = self._apply_transform( + torch.zeros(observation_spec.shape) + ).shape + return observation_spec + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"max_episode_length={self.max_episode_length}, " + f"keys={self.in_keys})" + )