diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 6a40bf9880c..11a5bb041a6 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -790,10 +790,10 @@ to be able to create this other composition: BurnInTransform CatFrames CatTensors - Crop CenterCrop ClipTransform Compose + Crop DTypeCastTransform DeviceCastTransform DiscreteActionProjection diff --git a/test/test_transforms.py b/test/test_transforms.py index 6ab17f628ef..34c40ef3f1e 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -2178,18 +2178,8 @@ def test_transform_no_env(self, keys, h, nchannels, batch, device): assert observation_spec[key].shape == torch.Size([nchannels, 20, h]) @pytest.mark.parametrize("nchannels", [3]) - @pytest.mark.parametrize( - "batch", - [ - [2] - ] - ) - @pytest.mark.parametrize( - "h", - [ - None - ] - ) + @pytest.mark.parametrize("batch", [[2]]) + @pytest.mark.parametrize("h", [None]) @pytest.mark.parametrize("keys", [["observation_pixels"]]) @pytest.mark.parametrize("device", get_default_devices()) def test_transform_model(self, keys, h, nchannels, batch, device): @@ -2214,18 +2204,8 @@ def test_transform_model(self, keys, h, nchannels, batch, device): assert (td.get("dont touch") == dont_touch).all() @pytest.mark.parametrize("nchannels", [3]) - @pytest.mark.parametrize( - "batch", - [ - [2] - ] - ) - @pytest.mark.parametrize( - "h", - [ - None - ] - ) + @pytest.mark.parametrize("batch", [[2]]) + @pytest.mark.parametrize("h", [None]) @pytest.mark.parametrize("keys", [["observation_pixels"]]) @pytest.mark.parametrize("device", get_default_devices()) def test_transform_compose(self, keys, h, nchannels, batch, device): @@ -2254,18 +2234,8 @@ def test_transform_compose(self, keys, h, nchannels, batch, device): assert (tdc.get("dont touch") == dont_touch).all() @pytest.mark.parametrize("nchannels", [3]) - @pytest.mark.parametrize( - "batch", - [ - [2] - ] - ) - @pytest.mark.parametrize( - "h", - [ - None - ] - ) + @pytest.mark.parametrize("batch", [[2]]) + @pytest.mark.parametrize("h", [None]) @pytest.mark.parametrize("keys", [["observation_pixels"]]) @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) def test_transform_rb(