Skip to content

Commit

Permalink
[BugFix] Fix MARL-DDPG tutorial and other MODE usages (#2373)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Aug 6, 2024
1 parent 4348c84 commit a41da21
Show file tree
Hide file tree
Showing 13 changed files with 17 additions and 15 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ And it is `functorch` and `torch.compile` compatible!
policy_explore = EGreedyWrapper(policy)
with set_exploration_type(ExplorationType.RANDOM):
tensordict = policy_explore(tensordict) # will use eps-greedy
with set_exploration_type(ExplorationType.MODE):
with set_exploration_type(ExplorationType.DETERMINISTIC):
tensordict = policy_explore(tensordict) # will not use eps-greedy
```
</details>
Expand Down
2 changes: 1 addition & 1 deletion docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ Regular modules
Conv3dNet
SqueezeLayer
Squeeze2dLayer
BatchRenorm
BatchRenorm1d

Algorithm-specific modules
~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ CrossQ
:toctree: generated/
:template: rl_template_noinherit.rst

CrossQ
CrossQLoss

IQL
----
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def main(cfg: "DictConfig"): # noqa: F821

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/td3_bc/td3_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def main(cfg: "DictConfig"): # noqa: F821

# evaluation
if i % evaluation_interval == 0:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_td = eval_env.rollout(
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
)
Expand Down
7 changes: 5 additions & 2 deletions test/test_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ def test_no_spec_error(self, device):
@pytest.mark.parametrize("safe", [True, False])
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize(
"exploration_type", [InteractionType.RANDOM, InteractionType.MODE]
"exploration_type", [InteractionType.RANDOM, InteractionType.DETERMINISTIC]
)
def test_gsde(
state_dim, action_dim, gSDE, device, safe, exploration_type, batch=16, bound=0.1
Expand Down Expand Up @@ -708,7 +708,10 @@ def test_gsde(
with set_exploration_type(exploration_type):
action1 = module(td).get("action")
action2 = actor(td.exclude("action")).get("action")
if gSDE or exploration_type == InteractionType.MODE:
if gSDE or exploration_type in (
InteractionType.DETERMINISTIC,
InteractionType.MODE,
):
torch.testing.assert_close(action1, action2)
else:
with pytest.raises(AssertionError):
Expand Down
2 changes: 1 addition & 1 deletion test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def test_stateful(self, safe, spec_type, lazy):
@pytest.mark.parametrize("out_keys", [["loc", "scale"], ["loc_1", "scale_1"]])
@pytest.mark.parametrize("lazy", [True, False])
@pytest.mark.parametrize(
"exp_mode", [InteractionType.MODE, InteractionType.RANDOM, None]
"exp_mode", [InteractionType.DETERMINISTIC, InteractionType.RANDOM, None]
)
def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys):
torch.manual_seed(0)
Expand Down
1 change: 1 addition & 0 deletions torchrl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
TruncatedNormal,
)
from .models import (
BatchRenorm1d,
Conv3dNet,
ConvNet,
DdpgCnnActor,
Expand Down
2 changes: 0 additions & 2 deletions torchrl/objectives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,3 @@
SoftUpdate,
ValueEstimators,
)

# from .value import bellman_max, c_val, dv_val, vtrace, GAE, TDLambdaEstimate, TDEstimate
2 changes: 1 addition & 1 deletion tutorials/sphinx-tutorials/coding_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ def get_loss_module(actor, gamma):
frame_skip=1,
policy_exploration=actor_explore,
environment=test_env,
exploration_type=ExplorationType.MODE,
exploration_type=ExplorationType.DETERMINISTIC,
log_keys=[("next", "reward")],
out_keys={("next", "reward"): "rewards"},
log_pbar=True,
Expand Down
2 changes: 1 addition & 1 deletion tutorials/sphinx-tutorials/dqn_with_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@
exploration_module.step(data.numel())
updater.step()

with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
rollout = env.rollout(10000, stoch_policy)
traj_lens.append(rollout.get(("next", "step_count")).max().item())

Expand Down
4 changes: 2 additions & 2 deletions tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ def process_batch(batch: TensorDictBase) -> TensorDictBase:
target_updaters[group].step()

# Exploration sigma anneal update
exploration_policies[group].step(current_frames)
exploration_policies[group][-1].step(current_frames)

# Stop training a certain group when a condition is met (e.g., number of training iterations)
if iteration == iteration_when_stop_training_evaders:
Expand Down Expand Up @@ -903,7 +903,7 @@ def process_batch(batch: TensorDictBase) -> TensorDictBase:
env_with_render = env_with_render.append_transform(
VideoRecorder(logger=video_logger, tag="vmas_rendered")
)
with set_exploration_type(ExplorationType.MODE):
with set_exploration_type(ExplorationType.DETERMINISTIC):
print("Rendering rollout...")
env_with_render.rollout(100, policy=agents_exploration_policy)
print("Saving the video...")
Expand Down
2 changes: 1 addition & 1 deletion tutorials/sphinx-tutorials/torchrl_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ def exec_sequence(params, data):
td_module(td)
print("random:", td["action"])

with set_exploration_type(ExplorationType.MODE):
with set_exploration_type(ExplorationType.DETERMINISTIC):
td_module(td)
print("mode:", td["action"])

Expand Down

0 comments on commit a41da21

Please sign in to comment.