Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jun 25, 2024
1 parent 66f86cf commit a0d3d20
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
4 changes: 3 additions & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11059,8 +11059,10 @@ def test_transform_compose(self):
],
)
def test_transform_env(self, envname, interval_as_tensor, categorical, sampling):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_env = GymEnv(
HALFCHEETAH_VERSIONED() if envname == "cheetah" else PENDULUM_VERSIONED()
HALFCHEETAH_VERSIONED() if envname == "cheetah" else PENDULUM_VERSIONED(),
device=device,
)
if interval_as_tensor:
num_intervals = torch.arange(5, 11 if envname == "cheetah" else 6)
Expand Down
7 changes: 5 additions & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8283,7 +8283,10 @@ def __init__(
super().__init__(in_keys_inv=[action_key], out_keys_inv=[out_action_key])
self.action_key = action_key
self.out_action_key = out_action_key
self.num_intervals = num_intervals
if not isinstance(num_intervals, torch.Tensor):
self.num_intervals = num_intervals
else:
self.register_buffer("num_intervals", num_intervals)
if sampling is None:
sampling = self.SamplingStrategy.MEDIAN
self.sampling = sampling
Expand Down Expand Up @@ -8397,7 +8400,7 @@ def custom_arange(nint):
del input_spec["full_action_spec", self.in_keys_inv[0]]
return input_spec
except Exception as err:
print("here!")
# To avoid silent AttributeErrors
raise RuntimeError(str(err))

def _init(self):
Expand Down

0 comments on commit a0d3d20

Please sign in to comment.