Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Algorithm] GAIL #2273

Merged
merged 35 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
6d7c5c4
fix norm
BY571 Jul 4, 2024
2d31f33
update docs
BY571 Jul 5, 2024
79bda13
update comments
BY571 Jul 5, 2024
1391b50
add sota-example-test
BY571 Jul 5, 2024
f444b72
update collection data slice
BY571 Jul 8, 2024
244b7ab
update docstring
BY571 Jul 8, 2024
db635d7
update config and objective with gp param
BY571 Jul 9, 2024
434622c
init cost tests gail
BY571 Jul 9, 2024
baca70f
update cost test
BY571 Jul 9, 2024
956567f
Merge branch 'main' into gail
BY571 Jul 11, 2024
8e7713f
add gail cost tests
BY571 Jul 11, 2024
714c35c
Merge remote-tracking branch 'origin/main' into gail
vmoens Jul 30, 2024
a05bef3
Merge branch 'main' into gail
BY571 Jul 31, 2024
b31da8a
Update config
BY571 Jul 31, 2024
63885b0
update gail device
BY571 Jul 31, 2024
739332c
update example tests
BY571 Jul 31, 2024
7a9919c
Merge branch 'gail' of github.com:BY571/rl into gail
BY571 Jul 31, 2024
3fd3c32
Merge remote-tracking branch 'origin/main' into gail
vmoens Aug 2, 2024
9455fef
gymnasium backend
BY571 Aug 5, 2024
fba43d2
Merge branch 'gail' of https://github.com/BY571/rl into gail
vmoens Aug 5, 2024
5e41d89
Merge remote-tracking branch 'origin/main' into gail
vmoens Aug 5, 2024
6c3f7d2
Merge remote-tracking branch 'origin/main' into gail
vmoens Aug 5, 2024
415443b
Merge remote-tracking branch 'origin/main' into gail
vmoens Aug 6, 2024
4926d80
fixes
vmoens Aug 6, 2024
cbd5dfa
init
vmoens Aug 6, 2024
70e1f49
Merge branch 'pin-mujoco' into gail
vmoens Aug 6, 2024
b8ca705
amend
vmoens Aug 6, 2024
6a00bda
Merge branch 'pin-mujoco' into gail
vmoens Aug 6, 2024
f0c225f
amend
vmoens Aug 6, 2024
511fa95
amend
vmoens Aug 6, 2024
4bc316b
amend
vmoens Aug 6, 2024
2f7e64c
Merge branch 'pin-mujoco' into gail
vmoens Aug 6, 2024
3d43e42
amend
vmoens Aug 6, 2024
63398d1
Merge branch 'pin-mujoco' into gail
vmoens Aug 6, 2024
c488bcd
amend
vmoens Aug 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/unittest/linux/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ dependencies:
- tensorboard
- imageio==2.26.0
- wandb
- dm_control
- dm_control<1.0.21
- mujoco<3.2.1
- mlflow
- av
- coverage
Expand Down
2 changes: 1 addition & 1 deletion .github/unittest/linux/scripts/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ echo "installing gymnasium"
pip3 install "gymnasium"
pip3 install ale_py
pip3 install mo-gymnasium[mujoco] # requires here bc needs mujoco-py
pip3 install mujoco -U
pip3 install "mujoco<3.2.1" -U

# sanity check: remove?
python3 -c """
Expand Down
3 changes: 2 additions & 1 deletion .github/unittest/linux_distributed/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ dependencies:
- tensorboard
- imageio==2.26.0
- wandb
- dm_control
- dm_control<1.0.21
- mujoco<3.2.1
- mlflow
- av
- coverage
Expand Down
3 changes: 2 additions & 1 deletion .github/unittest/linux_examples/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ dependencies:
- scipy
- hydra-core
- imageio==2.26.0
- dm_control
- dm_control<1.0.21
- mujoco<3.2.1
- mlflow
- av
- coverage
Expand Down
7 changes: 7 additions & 0 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,13 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq
env.train_num_envs=2 \
logger.mode=offline \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/gail/gail.py \
ppo.collector.total_frames=48 \
replay_buffer.batch_size=16 \
ppo.loss.mini_batch_size=10 \
ppo.collector.frames_per_batch=16 \
logger.mode=offline \
logger.backend=

# With single envs
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \
Expand Down
3 changes: 2 additions & 1 deletion .github/unittest/linux_libs/scripts_envpool/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ dependencies:
- expecttest
- pyyaml
- scipy
- dm_control
- dm_control<1.0.21
- mujoco<3.2.1
- coverage
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies:
- scipy
- hydra-core
- dm_control -e git+https://github.com/deepmind/dm_control.git@c053360edea6170acfd9c8f65446703307d9d352#egg={dm_control}
- mujoco<3.2.1
- patchelf
- pyopengl==3.1.4
- ray
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
python3 setup.py develop
python3 -m pip install pytest pytest-benchmark
python3 -m pip install "gym[accept-rom-license,atari]"
python3 -m pip install dm_control
python3 -m pip install "dm_control<1.0.21" "mujoco<3.2.1"
export TD_GET_DEFAULTS_TO_NONE=1
- name: Run benchmarks
run: |
Expand Down Expand Up @@ -97,7 +97,7 @@ jobs:
python3 setup.py develop
python3 -m pip install pytest pytest-benchmark
python3 -m pip install "gym[accept-rom-license,atari]"
python3 -m pip install dm_control
python3 -m pip install "dm_control<1.0.21" "mujoco<3.2.1"
export TD_GET_DEFAULTS_TO_NONE=1
- name: check GPU presence
run: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/benchmarks_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
python3 setup.py develop
python3 -m pip install pytest pytest-benchmark
python3 -m pip install "gym[accept-rom-license,atari]"
python3 -m pip install dm_control
python3 -m pip install "dm_control<1.0.21" "mujoco<3.2.1"
export TD_GET_DEFAULTS_TO_NONE=1
- name: Setup benchmarks
run: |
Expand Down Expand Up @@ -108,7 +108,7 @@ jobs:
python3 setup.py develop
python3 -m pip install pytest pytest-benchmark
python3 -m pip install "gym[accept-rom-license,atari]"
python3 -m pip install dm_control
python3 -m pip install "dm_control<1.0.21" "mujoco<3.2.1"
export TD_GET_DEFAULTS_TO_NONE=1
- name: check GPU presence
run: |
Expand Down
3 changes: 2 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ docutils
sphinx_design

torchvision
dm_control
dm_control<1.0.21
mujoco<3.2.1
atari-py
ale-py
gym[classic_control,accept-rom-license]
Expand Down
9 changes: 9 additions & 0 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,15 @@ CQL
CQLLoss
DiscreteCQLLoss

GAIL
----

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

GAILLoss

DT
----

Expand Down
46 changes: 46 additions & 0 deletions sota-implementations/gail/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
env:
env_name: HalfCheetah-v4
seed: 42
backend: gymnasium

logger:
backend: wandb
project_name: gail
group_name: null
exp_name: gail_ppo
test_interval: 5000
num_test_episodes: 5
video: False
mode: online

ppo:
collector:
frames_per_batch: 2048
total_frames: 1_000_000

optim:
lr: 3e-4
weight_decay: 0.0
anneal_lr: True

loss:
gamma: 0.99
mini_batch_size: 64
ppo_epochs: 10
gae_lambda: 0.95
clip_epsilon: 0.2
anneal_clip_epsilon: False
critic_coef: 0.25
entropy_coef: 0.0
loss_critic_type: l2

gail:
hidden_dim: 128
lr: 3e-4
use_grad_penalty: False
gp_lambda: 10.0
device: null

replay_buffer:
dataset: halfcheetah-expert-v2
batch_size: 256
Loading
Loading