Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Consistent Dropout #2399

Merged
merged 12 commits into from
Sep 10, 2024
Merged

[Feature] Consistent Dropout #2399

merged 12 commits into from
Sep 10, 2024

Conversation

N00bcak
Copy link
Contributor

@N00bcak N00bcak commented Aug 15, 2024

Description

Introduces Consistent Dropout (Hausknecht & Wagener, 2022) to TorchRL.

Consists of the following changes:

  • ConsistentDropout PyTorch module
  • Companion TensorDictModule: ConsistentDropoutModule
  • Tests for ConsistentDropoutModule

Motivation and Context

Addresses a feature request.
Fleshes out a PR draft #1587.

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

Copy link

pytorch-bot bot commented Aug 15, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2399

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 8 Unrelated Failures

As of commit 3231780 with merge base 6aa4b53 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 15, 2024
@vmoens vmoens changed the title [New Feature] Consistent Dropout [Feature] Consistent Dropout Aug 30, 2024
@vmoens vmoens added the enhancement New feature or request label Aug 30, 2024
@vmoens
Copy link
Contributor

vmoens commented Aug 30, 2024

I need to make a second pass over this but here's what I got running:

from torchrl.modules import ConsistentDropoutModule, get_primers_from_module

from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
import torch
from torchrl.envs import GymEnv
from tensordict import TensorDict

m = Seq(
    Mod(torch.nn.Linear(3, 4), in_keys=["observation"], out_keys=["intermediate"]),
    ConsistentDropoutModule(p=0.5, input_shape=(4,), in_keys="intermediate"),
    Mod(torch.nn.Linear(4, 1), in_keys=["intermediate"], out_keys=["action"]),
)
primer = get_primers_from_module(m)
env = GymEnv("Pendulum-v1").append_transform(primer)
r = env.reset()
env.rollout(32, m)["next", "mask_13266468624"]

With this you generate your mask during reset, such that it stays the same during execution.

This works, but as you can see we need to change the size of the primer, which means that we won't be able to use the same primer with different envs:

from torchrl.modules import ConsistentDropoutModule, get_primers_from_module

from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
import torch
from torchrl.envs import GymEnv, SerialEnv, StepCounter
from tensordict import TensorDict

m = Seq(
    Mod(torch.nn.Linear(3, 4), in_keys=["observation"], out_keys=["intermediate"]),
    ConsistentDropoutModule(p=0.5, input_shape=(2, 4,), in_keys="intermediate"),
    Mod(torch.nn.Linear(4, 1), in_keys=["intermediate"], out_keys=["action"]),
)
primer = get_primers_from_module(m)
env0 = GymEnv("Pendulum-v1").append_transform(StepCounter(5))
env1 = GymEnv("Pendulum-v1").append_transform(StepCounter(6))
env = SerialEnv(2, [lambda env=env0: env, lambda env=env1: env])
env = env.append_transform(primer)
r = env.reset()

There should be a way to solve this though!

@vmoens vmoens merged commit 0ad8e59 into pytorch:main Sep 10, 2024
63 of 69 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants