diff --git a/.github/unittest/linux_examples/scripts/run_all.sh b/.github/unittest/linux_examples/scripts/run_all.sh index 37719e51074..18e6075baae 100755 --- a/.github/unittest/linux_examples/scripts/run_all.sh +++ b/.github/unittest/linux_examples/scripts/run_all.sh @@ -150,15 +150,15 @@ git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with %s\n" "${CU_VERSION}" if [[ "$TORCH_VERSION" == "nightly" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U + pip3 install --pre torch torchvision numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/nightly/cpu -U else - pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION + pip3 install --pre torch torchvision numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION fi elif [[ "$TORCH_VERSION" == "stable" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu + pip3 install torch torchvision numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/cpu else - pip3 install torch torchvision --index-url https://download.pytorch.org/whl/$CU_VERSION + pip3 install torch torchvision numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/$CU_VERSION fi else printf "Failed to install pytorch" diff --git a/.github/unittest/linux_libs/scripts_rlhf/install.sh b/.github/unittest/linux_libs/scripts_rlhf/install.sh index d0363186c1a..4c769ba9bd6 100755 --- a/.github/unittest/linux_libs/scripts_rlhf/install.sh +++ b/.github/unittest/linux_libs/scripts_rlhf/install.sh @@ -31,15 +31,15 @@ git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with cu121" if [[ "$TORCH_VERSION" == "nightly" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + pip3 install --pre torch numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/nightly/cpu -U else - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U + pip3 install --pre torch numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/nightly/cu121 -U fi elif [[ "$TORCH_VERSION" == "stable" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install torch --index-url https://download.pytorch.org/whl/cpu + pip3 install torch numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/cpu else - pip3 install torch --index-url https://download.pytorch.org/whl/cu121 + pip3 install torch numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/cu121 fi else printf "Failed to install pytorch" diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 766e674eb77..27b2b86b7cf 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -12,7 +12,6 @@ from dataclasses import dataclass from typing import Iterator, List, Optional, Tuple -import torch.compiler from tensordict import is_tensor_collection, TensorDict, TensorDictBase from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams @@ -25,12 +24,17 @@ from torchrl.objectives.utils import RANDOM_MODULE_LIST, ValueEstimators from torchrl.objectives.value import ValueEstimatorBase +try: + from torch.compiler import is_dynamo_compiling +except ImportError: + from torch._dynamo import is_compiling as is_dynamo_compiling + def _updater_check_forward_prehook(module, *args, **kwargs): if ( not all(module._has_update_associated.values()) and RL_WARNINGS - and not torch.compiler.is_dynamo_compiling() + and not is_dynamo_compiling() ): warnings.warn( module.TARGET_NET_WARNING, @@ -425,7 +429,7 @@ def __getattr__(self, item): elif ( not self._has_update_associated[item[7:-7]] and RL_WARNINGS - and not torch.compiler.is_dynamo_compiling() + and not is_dynamo_compiling() ): # no updater associated warnings.warn( diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 5758b1ed7d8..83da183148e 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -804,7 +804,12 @@ def __init__( clip_value=clip_value, **kwargs, ) - self.register_buffer("clip_epsilon", torch.tensor(clip_epsilon)) + for p in self.parameters(): + device = p.device + break + else: + device = None + self.register_buffer("clip_epsilon", torch.tensor(clip_epsilon, device=device)) @property def _clip_bounds(self):