Skip to content

Commit

Permalink
[CI] Fix CI (#2245)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jun 24, 2024
1 parent 00b7c2e commit b7561b1
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 10 deletions.
1 change: 1 addition & 0 deletions .github/unittest/linux_libs/scripts_brax/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ dependencies:
- pyyaml
- scipy
- hydra-core
- jax[cuda12]
- brax
3 changes: 3 additions & 0 deletions .github/unittest/linux_libs/scripts_gym/batch_scripts.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ do
echo "Testing gym version: ${GYM_VERSION}"
# handling https://github.com/openai/gym/issues/3202
pip3 install wheel==0.38.4
pip3 install "pip<24.1"
pip3 install gym==$GYM_VERSION
$DIR/run_test.sh

Expand All @@ -70,6 +71,7 @@ do

echo "Testing gym version: ${GYM_VERSION}"
pip3 install wheel==0.38.4
pip3 install "pip<24.1"
pip3 install 'gym[atari]'==$GYM_VERSION
pip3 install ale-py==0.7
$DIR/run_test.sh
Expand All @@ -88,6 +90,7 @@ do

echo "Testing gym version: ${GYM_VERSION}"
pip3 install 'gym[atari]'==$GYM_VERSION
pip3 install pip -U
$DIR/run_test.sh

# delete the conda copy
Expand Down
4 changes: 2 additions & 2 deletions .github/unittest/linux_libs/scripts_gym/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ git submodule sync && git submodule update --init --recursive

printf "Installing PyTorch with %s\n" "${CU_VERSION}"
if [ "${CU_VERSION:-}" == cpu ] ; then
conda install pytorch==2.0 torchvision==0.15 cpuonly -c pytorch
conda install pytorch==2.0 torchvision==0.15 cpuonly -c pytorch -y
else
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia -y
fi

# Solving circular import: https://stackoverflow.com/questions/75501048/how-to-fix-attributeerror-partially-initialized-module-charset-normalizer-has
Expand Down
4 changes: 2 additions & 2 deletions .github/unittest/linux_olddeps/scripts_gym_0_13/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ git submodule sync && git submodule update --init --recursive

printf "Installing PyTorch with %s\n" "${CU_VERSION}"
if [ "${CU_VERSION:-}" == cpu ] ; then
conda install pytorch==2.0 torchvision==0.15 cpuonly -c pytorch
conda install pytorch==2.0 torchvision==0.15 cpuonly -c pytorch -y
else
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia -y
fi

# Solving circular import: https://stackoverflow.com/questions/75501048/how-to-fix-attributeerror-partially-initialized-module-charset-normalizer-has
Expand Down
31 changes: 26 additions & 5 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2816,14 +2816,35 @@ def _minari_selected_datasets():

torch.manual_seed(0)

keys = list(minari.list_remote_datasets())
indices = torch.randperm(len(keys))[:20]
keys = [keys[idx] for idx in indices]
# We rely on sorting the keys as v0 < v1 but if the version is greater than 9 this won't work
total_keys = sorted(minari.list_remote_datasets())
assert not any(
key[-2:] == "10" for key in total_keys
), "You should adapt the Minari test scripts as some dataset have a version >= 10 and sorting will fail."
total_keys_splits = [key.split("-") for key in total_keys]
indices = torch.randperm(len(total_keys))[:20]
keys = [total_keys[idx] for idx in indices]
keys = [
key
for key in keys
if "=0.4" in minari.list_remote_datasets()[key]["minari_version"]
]

def _replace_with_max(key):
key_split = key.split("-")
same_entries = (
torch.tensor(
[total_key[:-1] == key_split[:-1] for total_key in total_keys_splits]
)
.nonzero()
.squeeze()
.tolist()
)
last_same_entry = same_entries[-1]
return total_keys[last_same_entry]

keys = [_replace_with_max(key) for key in keys]

assert len(keys) > 5, keys
_MINARI_DATASETS += keys

Expand Down Expand Up @@ -3669,10 +3690,10 @@ def test_collector(self, task, parallel):
seed=0,
use_mask=not parallel,
)
coll = SyncDataCollector(
collector = SyncDataCollector(
create_env_fn=env_fun, frames_per_batch=30, total_frames=60, policy=None
)
for _ in coll:
for _ in collector:
break


Expand Down
2 changes: 1 addition & 1 deletion torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def __init__(
if policy is None:
from torchrl.collectors import RandomPolicy

policy = RandomPolicy(env.action_spec)
policy = RandomPolicy(env.full_action_spec)

##########################
# Setting devices:
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/datasets/openx.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ def _download_and_preproc(self):
streaming=False,
split="train",
cache_dir=cache_dir,
trust_remote_code=True,
)
# iterate over the dataset a first time to count elements
total_frames = 0
Expand Down

0 comments on commit b7561b1

Please sign in to comment.