Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix MARL-DDPG tutorial and other MODE usages #2373

Merged
merged 12 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/unittest/linux/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ dependencies:
- tensorboard
- imageio==2.26.0
- wandb
- dm_control
- dm_control<1.0.21
- mujoco<3.2.1
- mlflow
- av
- coverage
Expand Down
2 changes: 1 addition & 1 deletion .github/unittest/linux/scripts/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ echo "installing gymnasium"
pip3 install "gymnasium"
pip3 install ale_py
pip3 install mo-gymnasium[mujoco] # requires here bc needs mujoco-py
pip3 install mujoco -U
pip3 install "mujoco<3.2.1" -U

# sanity check: remove?
python3 -c """
Expand Down
3 changes: 2 additions & 1 deletion .github/unittest/linux_distributed/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ dependencies:
- tensorboard
- imageio==2.26.0
- wandb
- dm_control
- dm_control<1.0.21
- mujoco<3.2.1
- mlflow
- av
- coverage
Expand Down
3 changes: 2 additions & 1 deletion .github/unittest/linux_examples/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ dependencies:
- scipy
- hydra-core
- imageio==2.26.0
- dm_control
- dm_control<1.0.21
- mujoco<3.2.1
- mlflow
- av
- coverage
Expand Down
3 changes: 2 additions & 1 deletion .github/unittest/linux_libs/scripts_envpool/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ dependencies:
- expecttest
- pyyaml
- scipy
- dm_control
- dm_control<1.0.21
- mujoco<3.2.1
- coverage
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies:
- scipy
- hydra-core
- dm_control -e git+https://github.com/deepmind/dm_control.git@c053360edea6170acfd9c8f65446703307d9d352#egg={dm_control}
- mujoco<3.2.1
- patchelf
- pyopengl==3.1.4
- ray
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
python3 setup.py develop
python3 -m pip install pytest pytest-benchmark
python3 -m pip install "gym[accept-rom-license,atari]"
python3 -m pip install dm_control
python3 -m pip install "dm_control<1.0.21" "mujoco<3.2.1"
export TD_GET_DEFAULTS_TO_NONE=1
- name: Run benchmarks
run: |
Expand Down Expand Up @@ -97,7 +97,7 @@ jobs:
python3 setup.py develop
python3 -m pip install pytest pytest-benchmark
python3 -m pip install "gym[accept-rom-license,atari]"
python3 -m pip install dm_control
python3 -m pip install "dm_control<1.0.21" "mujoco<3.2.1"
export TD_GET_DEFAULTS_TO_NONE=1
- name: check GPU presence
run: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/benchmarks_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
python3 setup.py develop
python3 -m pip install pytest pytest-benchmark
python3 -m pip install "gym[accept-rom-license,atari]"
python3 -m pip install dm_control
python3 -m pip install "dm_control<1.0.21" "mujoco<3.2.1"
export TD_GET_DEFAULTS_TO_NONE=1
- name: Setup benchmarks
run: |
Expand Down Expand Up @@ -108,7 +108,7 @@ jobs:
python3 setup.py develop
python3 -m pip install pytest pytest-benchmark
python3 -m pip install "gym[accept-rom-license,atari]"
python3 -m pip install dm_control
python3 -m pip install "dm_control<1.0.21" "mujoco<3.2.1"
export TD_GET_DEFAULTS_TO_NONE=1
- name: check GPU presence
run: |
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ And it is `functorch` and `torch.compile` compatible!
policy_explore = EGreedyWrapper(policy)
with set_exploration_type(ExplorationType.RANDOM):
tensordict = policy_explore(tensordict) # will use eps-greedy
with set_exploration_type(ExplorationType.MODE):
with set_exploration_type(ExplorationType.DETERMINISTIC):
tensordict = policy_explore(tensordict) # will not use eps-greedy
```
</details>
Expand Down
3 changes: 2 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ docutils
sphinx_design

torchvision
dm_control
dm_control<1.0.21
mujoco<3.2.1
atari-py
ale-py
gym[classic_control,accept-rom-license]
Expand Down
2 changes: 1 addition & 1 deletion docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ Regular modules
Conv3dNet
SqueezeLayer
Squeeze2dLayer
BatchRenorm
BatchRenorm1d

Algorithm-specific modules
~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ CrossQ
:toctree: generated/
:template: rl_template_noinherit.rst

CrossQ
CrossQLoss

IQL
----
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def main(cfg: "DictConfig"): # noqa: F821

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/td3_bc/td3_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def main(cfg: "DictConfig"): # noqa: F821

# evaluation
if i % evaluation_interval == 0:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_td = eval_env.rollout(
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
)
Expand Down
7 changes: 5 additions & 2 deletions test/test_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ def test_no_spec_error(self, device):
@pytest.mark.parametrize("safe", [True, False])
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize(
"exploration_type", [InteractionType.RANDOM, InteractionType.MODE]
"exploration_type", [InteractionType.RANDOM, InteractionType.DETERMINISTIC]
)
def test_gsde(
state_dim, action_dim, gSDE, device, safe, exploration_type, batch=16, bound=0.1
Expand Down Expand Up @@ -708,7 +708,10 @@ def test_gsde(
with set_exploration_type(exploration_type):
action1 = module(td).get("action")
action2 = actor(td.exclude("action")).get("action")
if gSDE or exploration_type == InteractionType.MODE:
if gSDE or exploration_type in (
InteractionType.DETERMINISTIC,
InteractionType.MODE,
):
torch.testing.assert_close(action1, action2)
else:
with pytest.raises(AssertionError):
Expand Down
2 changes: 1 addition & 1 deletion test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def test_stateful(self, safe, spec_type, lazy):
@pytest.mark.parametrize("out_keys", [["loc", "scale"], ["loc_1", "scale_1"]])
@pytest.mark.parametrize("lazy", [True, False])
@pytest.mark.parametrize(
"exp_mode", [InteractionType.MODE, InteractionType.RANDOM, None]
"exp_mode", [InteractionType.DETERMINISTIC, InteractionType.RANDOM, None]
)
def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys):
torch.manual_seed(0)
Expand Down
1 change: 1 addition & 0 deletions torchrl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
TruncatedNormal,
)
from .models import (
BatchRenorm1d,
Conv3dNet,
ConvNet,
DdpgCnnActor,
Expand Down
2 changes: 0 additions & 2 deletions torchrl/objectives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,3 @@
SoftUpdate,
ValueEstimators,
)

# from .value import bellman_max, c_val, dv_val, vtrace, GAE, TDLambdaEstimate, TDEstimate
2 changes: 1 addition & 1 deletion tutorials/sphinx-tutorials/coding_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ def get_loss_module(actor, gamma):
frame_skip=1,
policy_exploration=actor_explore,
environment=test_env,
exploration_type=ExplorationType.MODE,
exploration_type=ExplorationType.DETERMINISTIC,
log_keys=[("next", "reward")],
out_keys={("next", "reward"): "rewards"},
log_pbar=True,
Expand Down
2 changes: 1 addition & 1 deletion tutorials/sphinx-tutorials/dqn_with_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@
exploration_module.step(data.numel())
updater.step()

with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
rollout = env.rollout(10000, stoch_policy)
traj_lens.append(rollout.get(("next", "step_count")).max().item())

Expand Down
4 changes: 2 additions & 2 deletions tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ def process_batch(batch: TensorDictBase) -> TensorDictBase:
target_updaters[group].step()

# Exploration sigma anneal update
exploration_policies[group].step(current_frames)
exploration_policies[group][-1].step(current_frames)

# Stop training a certain group when a condition is met (e.g., number of training iterations)
if iteration == iteration_when_stop_training_evaders:
Expand Down Expand Up @@ -903,7 +903,7 @@ def process_batch(batch: TensorDictBase) -> TensorDictBase:
env_with_render = env_with_render.append_transform(
VideoRecorder(logger=video_logger, tag="vmas_rendered")
)
with set_exploration_type(ExplorationType.MODE):
with set_exploration_type(ExplorationType.DETERMINISTIC):
print("Rendering rollout...")
env_with_render.rollout(100, policy=agents_exploration_policy)
print("Saving the video...")
Expand Down
2 changes: 1 addition & 1 deletion tutorials/sphinx-tutorials/torchrl_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ def exec_sequence(params, data):
td_module(td)
print("random:", td["action"])

with set_exploration_type(ExplorationType.MODE):
with set_exploration_type(ExplorationType.DETERMINISTIC):
td_module(td)
print("mode:", td["action"])

Expand Down
Loading