Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into gail
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jul 30, 2024
2 parents 8e7713f + bf91ff6 commit 714c35c
Show file tree
Hide file tree
Showing 76 changed files with 3,166 additions and 660 deletions.
3 changes: 2 additions & 1 deletion .github/unittest/linux/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies:
- mlflow
- av
- coverage
- ray<2.8.0
- ray
- transformers
- ninja
- timm
3 changes: 2 additions & 1 deletion .github/unittest/linux/scripts/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ conda deactivate
conda activate "${env_dir}"

echo "installing gymnasium"
pip3 install "gymnasium[atari,ale-py,accept-rom-license]"
pip3 install "gymnasium"
pip3 install ale_py
pip3 install mo-gymnasium[mujoco] # requires here bc needs mujoco-py
pip3 install mujoco -U

Expand Down
2 changes: 1 addition & 1 deletion .github/unittest/linux_distributed/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@ dependencies:
- mlflow
- av
- coverage
- ray<2.8.0
- ray
- virtualenv
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ dependencies:
- dm_control -e git+https://github.com/deepmind/dm_control.git@c053360edea6170acfd9c8f65446703307d9d352#egg={dm_control}
- patchelf
- pyopengl==3.1.4
- ray<2.8.0
- ray
- av
2 changes: 1 addition & 1 deletion .github/unittest/linux_optdeps/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ dependencies:
- pyyaml
- scipy
- coverage
- ray<2.8.0
- ray
1 change: 1 addition & 0 deletions .github/unittest/windows_optdepts/scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ fi

# submodules
git submodule sync && git submodule update --init --recursive
python -m pip install "numpy<2.0"

printf "Installing PyTorch with %s\n" "${cudatoolkit}"
if [[ "$TORCH_VERSION" == "nightly" ]]; then
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build-wheels-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ jobs:
package-name: ${{ matrix.package-name }}
smoke-test-script: ${{ matrix.smoke-test-script }}
trigger-event: ${{ github.event_name }}
env-var-script: .github/scripts/td_script.sh
pre-script: .github/scripts/td_script.sh
4 changes: 2 additions & 2 deletions .github/workflows/test-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
tests-cpu:
strategy:
matrix:
python_version: ["3.8", "3.9", "3.10", "3.11"]
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
fail-fast: false
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
Expand Down Expand Up @@ -51,7 +51,7 @@ jobs:
tests-gpu:
strategy:
matrix:
python_version: ["3.10"]
python_version: ["3.11"]
cuda_arch_version: ["12.1"]
fail-fast: false
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
Expand Down
14 changes: 14 additions & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,19 @@ etc.), but one can not use an arbitrary TorchRL environment, as it is possible w
ParallelEnv
EnvCreator


Custom native TorchRL environments
----------------------------------

TorchRL offers a series of custom built-in environments.

.. autosummary::
:toctree: generated/
:template: rl_template.rst

PendulumEnv
TicTacToeEnv

Multi-agent environments
------------------------

Expand Down Expand Up @@ -780,6 +793,7 @@ to be able to create this other composition:
CenterCrop
ClipTransform
Compose
Crop
DTypeCastTransform
DeviceCastTransform
DiscreteActionProjection
Expand Down
11 changes: 9 additions & 2 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,11 @@ other cases, the action written in the tensordict is simply the network output.
:toctree: generated/
:template: rl_template_noinherit.rst

AdditiveGaussianModule
AdditiveGaussianWrapper
EGreedyModule
EGreedyWrapper
OrnsteinUhlenbeckProcessModule
OrnsteinUhlenbeckProcessWrapper

Probabilistic actors
Expand Down Expand Up @@ -353,13 +355,18 @@ algorithms, such as DQN, DDPG or Dreamer.
Multi-agent-specific modules
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

These networks implement models that can be used in
multi-agent contexts.
These networks implement models that can be used in multi-agent contexts.
They use :func:`~torch.vmap` to execute multiple networks all at once on the
network inputs. Because the parameters are batched, initialization may differ
from what is usually done with other PyTorch modules, see
:meth:`~torchrl.modules.MultiAgentNetBase.get_stateful_net`
for more information.

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

MultiAgentNetBase
MultiAgentMLP
MultiAgentConvNet
QMixer
Expand Down
29 changes: 29 additions & 0 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,35 @@ The main characteristics of TorchRL losses are:

>>> loss_val = sum(loss for key, loss in loss_vals.items() if key.startswith("loss_"))

.. note::
Initializing parameters in losses can be done via a query to :meth:`~torchrl.objectives.LossModule.get_stateful_net`
which will return a stateful version of the network that can be initialized like any other module.
If the modification is done in-place, it will be downstreamed to any other module that uses the same parameter
set (within and outside of the loss): for instance, modifying the ``actor_network`` parameters from the loss
will also modify the actor in the collector.
If the parameters are modified out-of-place, :meth:`~torchrl.objectives.LossModule.from_stateful_net` can be
used to reset the parameters in the loss to the new value.

torch.vmap and randomness
-------------------------

TorchRL loss modules have plenty of calls to :func:`~torch.vmap` to amortize the cost of calling multiple similar models
in a loop, and instead vectorize these operations. `vmap` needs to be told explicitly what to do when random numbers
need to be generated within the call. To do this, a randomness mode need to be set and must be one of `"error"` (default,
errors when dealing with pseudo-random functions), `"same"` (replicates the results across the batch) or `"different"`
(each element of the batch is treated separately).
Relying on the default will typically result in an error such as this one:

>>> RuntimeError: vmap: called random operation while in randomness error mode.

Since the calls to `vmap` are buried down the loss modules, TorchRL
provides an interface to set that vmap mode from the outside through `loss.vmap_randomness = str_value`, see
:meth:`~torchrl.objectives.LossModule.vmap_randomness` for more information.

``LossModule.vmap_randomness`` defaults to `"error"` if no random module is detected, and to `"different"` in
other cases. By default, only a limited number of modules are listed as random, but the list can be extended
using the :func:`~torchrl.objectives.common.add_random_module` function.

Training value functions
------------------------

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def _main(argv):
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Development Status :: 4 - Beta",
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def main(cfg: "DictConfig"): # noqa: F821
for _, tensordict in enumerate(collector):
sampling_time = time.time() - sampling_start
# Update exploration policy
exploration_policy.step(tensordict.numel())
exploration_policy[1].step(tensordict.numel())

# Update weights of the inference policy
collector.update_policy_weights_()
Expand Down
30 changes: 19 additions & 11 deletions sota-implementations/ddpg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import torch

from tensordict.nn import TensorDictSequential

from torch import nn, optim
from torchrl.collectors import SyncDataCollector
from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer
Expand All @@ -25,9 +27,9 @@
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import (
AdditiveGaussianWrapper,
AdditiveGaussianModule,
MLP,
OrnsteinUhlenbeckProcessWrapper,
OrnsteinUhlenbeckProcessModule,
SafeModule,
SafeSequential,
TanhModule,
Expand Down Expand Up @@ -227,18 +229,24 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):

# Exploration wrappers:
if cfg.network.noise_type == "ou":
actor_model_explore = OrnsteinUhlenbeckProcessWrapper(
actor_model_explore = TensorDictSequential(
model[0],
annealing_num_steps=1_000_000,
).to(device)
OrnsteinUhlenbeckProcessModule(
spec=action_spec,
annealing_num_steps=1_000_000,
).to(device),
)
elif cfg.network.noise_type == "gaussian":
actor_model_explore = AdditiveGaussianWrapper(
actor_model_explore = TensorDictSequential(
model[0],
sigma_end=1.0,
sigma_init=1.0,
mean=0.0,
std=0.1,
).to(device)
AdditiveGaussianModule(
spec=action_spec,
sigma_end=1.0,
sigma_init=1.0,
mean=0.0,
std=0.1,
).to(device),
)
else:
raise NotImplementedError

Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/dreamer/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def compile_rssms(module):
if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)

policy.step(current_frames)
policy[1].step(current_frames)
collector.update_policy_weights_()
# Evaluation
if (i % eval_iter) == 0:
Expand Down
17 changes: 10 additions & 7 deletions sota-implementations/dreamer/dreamer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
)
from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type
from torchrl.modules import (
AdditiveGaussianWrapper,
AdditiveGaussianModule,
DreamerActor,
IndependentNormal,
MLP,
Expand Down Expand Up @@ -266,13 +266,16 @@ def make_dreamer(
test_env=test_env,
)
# Exploration noise to be added to the actor_realworld
actor_realworld = AdditiveGaussianWrapper(
actor_realworld = TensorDictSequential(
actor_realworld,
sigma_init=1.0,
sigma_end=1.0,
annealing_num_steps=1,
mean=0.0,
std=cfg.networks.exploration_noise,
AdditiveGaussianModule(
spec=test_env.action_spec,
sigma_init=1.0,
sigma_end=1.0,
annealing_num_steps=1,
mean=0.0,
std=cfg.networks.exploration_noise,
),
)

# Make Critic
Expand Down
15 changes: 9 additions & 6 deletions sota-implementations/multiagent/maddpg_iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import hydra
import torch

from tensordict.nn import TensorDictModule
from tensordict.nn import TensorDictModule, TensorDictSequential
from torch import nn
from torchrl._utils import logger as torchrl_logger
from torchrl.collectors import SyncDataCollector
Expand All @@ -18,7 +18,7 @@
from torchrl.envs.libs.vmas import VmasEnv
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import (
AdditiveGaussianWrapper,
AdditiveGaussianModule,
ProbabilisticActor,
TanhDelta,
ValueOperator,
Expand Down Expand Up @@ -102,10 +102,13 @@ def train(cfg: "DictConfig"): # noqa: F821
return_log_prob=False,
)

policy_explore = AdditiveGaussianWrapper(
policy_explore = TensorDictSequential(
policy,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
AdditiveGaussianModule(
spec=env.unbatched_action_spec,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
),
)

# Critic
Expand Down Expand Up @@ -200,7 +203,7 @@ def train(cfg: "DictConfig"): # noqa: F821
optim.zero_grad()
target_net_updater.step()

policy_explore.step(frames=current_frames) # Update exploration annealing
policy_explore[1].step(frames=current_frames) # Update exploration annealing
collector.update_policy_weights_()

training_time = time.time() - training_start
Expand Down
16 changes: 10 additions & 6 deletions sota-implementations/redq/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

import hydra
import torch.cuda
from tensordict.nn import TensorDictSequential
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.transforms import RewardScaling, TransformedEnv
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import OrnsteinUhlenbeckProcessWrapper
from torchrl.modules import OrnsteinUhlenbeckProcessModule
from torchrl.record import VideoRecorder
from torchrl.record.loggers import get_logger
from utils import (
Expand Down Expand Up @@ -111,12 +112,15 @@ def main(cfg: "DictConfig"): # noqa: F821
if cfg.exploration.ou_exploration:
if cfg.exploration.gSDE:
raise RuntimeError("gSDE and ou_exploration are incompatible")
actor_model_explore = OrnsteinUhlenbeckProcessWrapper(
actor_model_explore = TensorDictSequential(
actor_model_explore,
annealing_num_steps=cfg.exploration.annealing_frames,
sigma=cfg.exploration.ou_sigma,
theta=cfg.exploration.ou_theta,
).to(device)
OrnsteinUhlenbeckProcessModule(
spec=actor_model_explore.spec,
annealing_num_steps=cfg.exploration.annealing_frames,
sigma=cfg.exploration.ou_sigma,
theta=cfg.exploration.ou_theta,
).to(device),
)
if device == torch.device("cpu"):
# mostly for debugging
actor_model_explore.share_memory()
Expand Down
10 changes: 6 additions & 4 deletions sota-implementations/redq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
ActorCriticOperator,
ActorValueOperator,
NoisyLinear,
NormalParamWrapper,
NormalParamExtractor,
SafeModule,
SafeSequential,
)
Expand Down Expand Up @@ -483,10 +483,12 @@ def make_redq_model(
}

if not gSDE:
actor_net = NormalParamWrapper(
actor_net = nn.Sequential(
actor_net,
scale_mapping=f"biased_softplus_{default_policy_scale}",
scale_lb=cfg.network.scale_lb,
NormalParamExtractor(
scale_mapping=f"biased_softplus_{default_policy_scale}",
scale_lb=cfg.network.scale_lb,
),
)
actor_module = SafeModule(
actor_net,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def main(cfg: "DictConfig"): # noqa: F821
sampling_start = time.time()
for tensordict in collector:
sampling_time = time.time() - sampling_start
exploration_policy.step(tensordict.numel())
exploration_policy[1].step(tensordict.numel())

# Update weights of the inference policy
collector.update_policy_weights_()
Expand Down
Loading

0 comments on commit 714c35c

Please sign in to comment.