Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Jul 30, 2024
1 parent 90ad18b commit 07e4e1a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 37 deletions.
2 changes: 1 addition & 1 deletion docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -790,10 +790,10 @@ to be able to create this other composition:
BurnInTransform
CatFrames
CatTensors
Crop
CenterCrop
ClipTransform
Compose
Crop
DTypeCastTransform
DeviceCastTransform
DiscreteActionProjection
Expand Down
42 changes: 6 additions & 36 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2178,18 +2178,8 @@ def test_transform_no_env(self, keys, h, nchannels, batch, device):
assert observation_spec[key].shape == torch.Size([nchannels, 20, h])

@pytest.mark.parametrize("nchannels", [3])
@pytest.mark.parametrize(
"batch",
[
[2]
]
)
@pytest.mark.parametrize(
"h",
[
None
]
)
@pytest.mark.parametrize("batch", [[2]])
@pytest.mark.parametrize("h", [None])
@pytest.mark.parametrize("keys", [["observation_pixels"]])
@pytest.mark.parametrize("device", get_default_devices())
def test_transform_model(self, keys, h, nchannels, batch, device):
Expand All @@ -2214,18 +2204,8 @@ def test_transform_model(self, keys, h, nchannels, batch, device):
assert (td.get("dont touch") == dont_touch).all()

@pytest.mark.parametrize("nchannels", [3])
@pytest.mark.parametrize(
"batch",
[
[2]
]
)
@pytest.mark.parametrize(
"h",
[
None
]
)
@pytest.mark.parametrize("batch", [[2]])
@pytest.mark.parametrize("h", [None])
@pytest.mark.parametrize("keys", [["observation_pixels"]])
@pytest.mark.parametrize("device", get_default_devices())
def test_transform_compose(self, keys, h, nchannels, batch, device):
Expand Down Expand Up @@ -2254,18 +2234,8 @@ def test_transform_compose(self, keys, h, nchannels, batch, device):
assert (tdc.get("dont touch") == dont_touch).all()

@pytest.mark.parametrize("nchannels", [3])
@pytest.mark.parametrize(
"batch",
[
[2]
]
)
@pytest.mark.parametrize(
"h",
[
None
]
)
@pytest.mark.parametrize("batch", [[2]])
@pytest.mark.parametrize("h", [None])
@pytest.mark.parametrize("keys", [["observation_pixels"]])
@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
def test_transform_rb(
Expand Down

0 comments on commit 07e4e1a

Please sign in to comment.