Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Sep 25, 2024
1 parent 1ea0601 commit 2df6393
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build-wheels-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
post-script: "python packaging/wheel/relocate.py"
smoke-test-script: test/smoke_test.py
package-name: torchrl
name: pytorch/rl
name: ${{ matrix.repository }}
uses: pytorch/test-infra/.github/workflows/build_wheels_windows.yml@main
with:
repository: ${{ matrix.repository }}
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/wheels-legacy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
runs-on: windows-latest
strategy:
matrix:
python_version: [["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"], ["3.12", "3.12"]]
python_version: [["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"], ["3.12", "3.12"]]
steps:
- name: Setup Python
uses: actions/setup-python@v2
Expand All @@ -37,12 +37,12 @@ jobs:
python3 -mpip install wheel
TORCHRL_BUILD_VERSION=0.5.0 python3 setup.py bdist_wheel
- name: Upload wheel for the test-wheel job
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v3
with:
name: torchrl-win-${{ matrix.python_version[0] }}.whl
path: dist/torchrl-*.whl
- name: Upload wheel for download
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v3
with:
name: torchrl-batch.whl
path: dist/*.whl
Expand Down Expand Up @@ -77,7 +77,7 @@ jobs:
run: |
python3 -mpip install numpy pytest pytest-cov codecov unittest-xml-reporting pillow>=4.1.1 scipy av networkx expecttest pyyaml
- name: Download built wheels
uses: actions/download-artifact@v2
uses: actions/download-artifact@v3
with:
name: torchrl-win-${{ matrix.python_version }}.whl
path: wheels
Expand Down
15 changes: 15 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
assert_allclose_td,
is_tensor_collection,
is_tensorclass,
LazyStackedTensorDict,
tensorclass,
TensorDict,
TensorDictBase,
Expand Down Expand Up @@ -715,6 +716,20 @@ def test_storage_state_dict(self, storage_in, storage_out, init_out, backend):
s = new_replay_buffer.sample()
assert (s.exclude("index") == 1).all()

@pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage])
def test_extend_lazystack(self, storage_type):

rb = ReplayBuffer(
storage=storage_type(6),
batch_size=2,
)
td1 = TensorDict(a=torch.rand(5, 4, 8), batch_size=5)
td2 = TensorDict(a=torch.rand(5, 3, 8), batch_size=5)
ltd = LazyStackedTensorDict(td1, td2, stack_dim=1)
rb.extend(ltd)
rb.sample(3)
assert len(rb) == 5

@pytest.mark.parametrize("device_data", get_default_devices())
@pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage])
@pytest.mark.parametrize("data_type", ["tensor", "tc", "td", "pytree"])
Expand Down
23 changes: 13 additions & 10 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from __future__ import annotations

import abc

import logging
import os
import textwrap
import warnings
Expand Down Expand Up @@ -1116,16 +1118,17 @@ def max_size_along_dim0(data_shape):
out = data.clone().to(self.device)
out = out.expand(max_size_along_dim0(data.shape))
out = out.memmap_like(prefix=self.scratch_dir, existsok=self.existsok)
for key, tensor in sorted(
out.items(include_nested=True, leaves_only=True), key=str
):
try:
filesize = os.path.getsize(tensor.filename) / 1024 / 1024
torchrl_logger.debug(
f"\t{key}: {tensor.filename}, {filesize} Mb of storage (size: {tensor.shape})."
)
except (AttributeError, RuntimeError):
pass
if torchrl_logger.getEffectiveLevel() == logging.DEBUG:
for key, tensor in sorted(
out.items(include_nested=True, leaves_only=True), key=str
):
try:
filesize = os.path.getsize(tensor.filename) / 1024 / 1024
torchrl_logger.debug(
f"\t{key}: {tensor.filename}, {filesize} Mb of storage (size: {tensor.shape})."
)
except (AttributeError, RuntimeError):
pass
else:
out = _init_pytree(self.scratch_dir, max_size_along_dim0, data)
self._storage = out
Expand Down

0 comments on commit 2df6393

Please sign in to comment.