diff --git a/.github/unittest/linux/scripts/environment.yml b/.github/unittest/linux/scripts/environment.yml index 38d95eb335..8b512526a0 100644 --- a/.github/unittest/linux/scripts/environment.yml +++ b/.github/unittest/linux/scripts/environment.yml @@ -34,3 +34,5 @@ dependencies: - transformers - ninja - timm + - gymnasium[atari,accept-rom-license] + - mo-gymnasium[mujoco] diff --git a/.github/unittest/linux/scripts/run_all.sh b/.github/unittest/linux/scripts/run_all.sh index 1026b31cf0..dec06bd2d8 100755 --- a/.github/unittest/linux/scripts/run_all.sh +++ b/.github/unittest/linux/scripts/run_all.sh @@ -87,11 +87,6 @@ conda env update --file "${this_dir}/environment.yml" --prune conda deactivate conda activate "${env_dir}" -echo "installing gymnasium" -pip3 install "gymnasium[atari,accept-rom-license]" -pip3 install mo-gymnasium[mujoco] # requires here bc needs mujoco-py -pip3 install "mujoco" -U - # sanity check: remove? python3 -c """ import dm_control diff --git a/torchrl/modules/distributions/discrete.py b/torchrl/modules/distributions/discrete.py index c48d816888..d2ffba3068 100644 --- a/torchrl/modules/distributions/discrete.py +++ b/torchrl/modules/distributions/discrete.py @@ -389,6 +389,17 @@ def sample( ) -> torch.Tensor: ... + @property + def deterministic_sample(self): + return self.mode + + @property + def mode(self) -> torch.Tensor: + if hasattr(self, "logits"): + return (self.logits == self.logits.max(-1, True)[0]).to(torch.long) + else: + return (self.probs == self.probs.max(-1, True)[0]).to(torch.long) + def log_prob(self, value: torch.Tensor) -> torch.Tensor: return super().log_prob(value.argmax(dim=-1))