From 6d7c5c45554112f8b24122fced99b751927094e5 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 4 Jul 2024 13:19:22 +0200 Subject: [PATCH 01/22] fix norm --- sota-implementations/gail/config.yaml | 50 +++++ sota-implementations/gail/gail.py | 270 ++++++++++++++++++++++++ sota-implementations/gail/gail_utils.py | 69 ++++++ sota-implementations/gail/ppo_utils.py | 150 +++++++++++++ torchrl/objectives/__init__.py | 1 + torchrl/objectives/gail.py | 214 +++++++++++++++++++ 6 files changed, 754 insertions(+) create mode 100644 sota-implementations/gail/config.yaml create mode 100644 sota-implementations/gail/gail.py create mode 100644 sota-implementations/gail/gail_utils.py create mode 100644 sota-implementations/gail/ppo_utils.py create mode 100644 torchrl/objectives/gail.py diff --git a/sota-implementations/gail/config.yaml b/sota-implementations/gail/config.yaml new file mode 100644 index 00000000000..c2a77337526 --- /dev/null +++ b/sota-implementations/gail/config.yaml @@ -0,0 +1,50 @@ +env: + env_name: HalfCheetah-v4 + seed: 42 + backend: gym + + +# logger +logger: + backend: wandb + project_name: gail + group_name: null + exp_name: gail_ppo + test_interval: 5000 + num_test_episodes: 5 + video: False + mode: online + +ppo: + # collector + collector: + frames_per_batch: 2048 + total_frames: 1_000_000 + + # Optim + optim: + lr: 3e-4 + weight_decay: 0.0 + anneal_lr: True + + # loss + loss: + gamma: 0.99 + mini_batch_size: 64 + ppo_epochs: 10 + gae_lambda: 0.95 + clip_epsilon: 0.2 + anneal_clip_epsilon: False + critic_coef: 0.25 + entropy_coef: 0.0 + loss_critic_type: l2 + +gail: + hidden_dim: 128 + lr: 3e-4 + use_grad_penalty: True + gp_lambda: 10.0 + +replay_buffer: + dataset: halfcheetah-expert-v2 + batch_size: 256 diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py new file mode 100644 index 00000000000..c1f60418f24 --- /dev/null +++ b/sota-implementations/gail/gail.py @@ -0,0 +1,270 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""GAIL Example. + +This is a self-contained example of an offline GAIL training script. + +The helper functions for gail are coded in the gail_utils.py and helper functions for ppo in ppo_utils. + +""" +import hydra +import numpy as np +import torch +import tqdm + +from gail_utils import log_metrics, make_gail_discriminator, make_offline_replay_buffer +from ppo_utils import eval_model, make_env, make_ppo_models +from torchrl.collectors import SyncDataCollector +from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement + +from torchrl.envs import set_gym_backend +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.objectives import ClipPPOLoss, GAILLoss +from torchrl.objectives.value.advantages import GAE +from torchrl.record import VideoRecorder +from torchrl.record.loggers import generate_exp_name, get_logger + + +@hydra.main(config_path="", config_name="config") +def main(cfg: "DictConfig"): # noqa: F821 + set_gym_backend(cfg.env.backend).set() + + device = "cpu" if not torch.cuda.device_count() else "cuda" + num_mini_batches = ( + cfg.ppo.collector.frames_per_batch // cfg.ppo.loss.mini_batch_size + ) + total_network_updates = ( + (cfg.ppo.collector.total_frames // cfg.ppo.collector.frames_per_batch) + * cfg.ppo.loss.ppo_epochs + * num_mini_batches + ) + + # Create logger + exp_name = generate_exp_name("Gail-offline", cfg.logger.exp_name) + logger = None + if cfg.logger.backend: + logger = get_logger( + logger_type=cfg.logger.backend, + logger_name="gail_logging", + experiment_name=exp_name, + wandb_kwargs={ + "mode": cfg.logger.mode, + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, + ) + + # Set seeds + torch.manual_seed(cfg.env.seed) + np.random.seed(cfg.env.seed) + + # Create models (check utils_mujoco.py) + actor, critic = make_ppo_models(cfg.env.env_name) + actor, critic = actor.to(device), critic.to(device) + + # Create collector + collector = SyncDataCollector( + create_env_fn=make_env(cfg.env.env_name, device), + policy=actor, + frames_per_batch=cfg.ppo.collector.frames_per_batch, + total_frames=cfg.ppo.collector.total_frames, + device=device, + storing_device=device, + max_frames_per_traj=-1, + ) + + # Create data buffer + data_buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage(cfg.ppo.collector.frames_per_batch), + sampler=SamplerWithoutReplacement(), + batch_size=cfg.ppo.loss.mini_batch_size, + ) + + # Create loss and adv modules + adv_module = GAE( + gamma=cfg.ppo.loss.gamma, + lmbda=cfg.ppo.loss.gae_lambda, + value_network=critic, + average_gae=False, + ) + + loss_module = ClipPPOLoss( + actor_network=actor, + critic_network=critic, + clip_epsilon=cfg.ppo.loss.clip_epsilon, + loss_critic_type=cfg.ppo.loss.loss_critic_type, + entropy_coef=cfg.ppo.loss.entropy_coef, + critic_coef=cfg.ppo.loss.critic_coef, + normalize_advantage=True, + ) + + # Create optimizers + actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5) + critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5) + + # Create replay buffer + replay_buffer = make_offline_replay_buffer(cfg.replay_buffer) + + # Create Discriminator + discriminator = make_gail_discriminator(cfg, collector.env, device) + + # Create loss + discriminator_loss = GAILLoss( + discriminator, + use_grad_penalty=cfg.gail.use_grad_penalty, + gp_lambda=cfg.gail.gp_lambda, + ) + + # Create optimizer + discriminator_optim = torch.optim.Adam( + params=discriminator.parameters(), lr=cfg.gail.lr + ) + + # Create test environment + logger_video = cfg.logger.video + test_env = make_env(cfg.env.env_name, device, from_pixels=logger_video) + if logger_video: + test_env = test_env.append_transform( + VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"]) + ) + test_env.eval() + + # Training loop + collected_frames = 0 + num_network_updates = 0 + pbar = tqdm.tqdm(total=cfg.ppo.collector.total_frames) + + # extract cfg variables + cfg_loss_ppo_epochs = cfg.ppo.loss.ppo_epochs + cfg_optim_anneal_lr = cfg.ppo.optim.anneal_lr + cfg_optim_lr = cfg.ppo.optim.lr + cfg_loss_anneal_clip_eps = cfg.ppo.loss.anneal_clip_epsilon + cfg_loss_clip_epsilon = cfg.ppo.loss.clip_epsilon + cfg_logger_test_interval = cfg.logger.test_interval + cfg_logger_num_test_episodes = cfg.logger.num_test_episodes + + for i, data in enumerate(collector): + + log_info = {} + frames_in_batch = data.numel() + collected_frames += frames_in_batch + pbar.update(data.numel()) + + # Update discriminator + + # Get expert data + expert_data = replay_buffer.sample() + expert_data = expert_data.to(device) + d_loss = discriminator_loss(expert_data, data) + + # Backward pass + discriminator_optim.zero_grad() + d_loss.get("loss").backward() + discriminator_optim.step() + + # Compute discriminator reward + with torch.no_grad(): + data = discriminator(data) + d_rewards = -torch.log(1 - data["d_logits"] + 1e-8) + d_rewards = torch.log(data["d_logits"] + 1e-8) - torch.log( + 1 - data["d_logits"] + 1e-8 + ) + + # set d_rewards to tensordict + data.set(("next", "reward"), d_rewards) + + # Get training rewards and episode lengths + episode_rewards = data["next", "episode_reward"][data["next", "done"]] + if len(episode_rewards) > 0: + episode_length = data["next", "step_count"][data["next", "done"]] + log_info.update( + { + "train/reward": episode_rewards.mean().item(), + "train/episode_length": episode_length.sum().item() + / len(episode_length), + } + ) + # Update PPO + for _ in range(cfg_loss_ppo_epochs): + + # Compute GAE + with torch.no_grad(): + data = adv_module(data) + data_reshape = data.reshape(-1) + + # Update the data buffer + data_buffer.extend(data_reshape) + + for _, batch in enumerate(data_buffer): + + # Get a data batch + batch = batch.to(device) + + # Linearly decrease the learning rate and clip epsilon + alpha = 1.0 + if cfg_optim_anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in actor_optim.param_groups: + group["lr"] = cfg_optim_lr * alpha + for group in critic_optim.param_groups: + group["lr"] = cfg_optim_lr * alpha + if cfg_loss_anneal_clip_eps: + loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha) + num_network_updates += 1 + + # Forward pass PPO loss + loss = loss_module(batch) + critic_loss = loss["loss_critic"] + actor_loss = loss["loss_objective"] + loss["loss_entropy"] + + # Backward pass + actor_loss.backward() + critic_loss.backward() + + # Update the networks + actor_optim.step() + critic_optim.step() + actor_optim.zero_grad() + critic_optim.zero_grad() + + log_info.update( + { + "train/actor_loss": actor_loss.item(), + "train/critic_loss": critic_loss.item(), + "train/discriminator_loss": d_loss["loss"].item(), + "train/lr": alpha * cfg_optim_lr, + "train/clip_epsilon": ( + alpha * cfg_loss_clip_epsilon + if cfg_loss_anneal_clip_eps + else cfg_loss_clip_epsilon + ), + } + ) + + # evaluation + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + if ((i - 1) * frames_in_batch) // cfg_logger_test_interval < ( + i * frames_in_batch + ) // cfg_logger_test_interval: + actor.eval() + test_rewards = eval_model( + actor, test_env, num_episodes=cfg_logger_num_test_episodes + ) + log_info.update( + { + "eval/reward": test_rewards.mean(), + } + ) + actor.train() + if logger is not None: + log_metrics(logger, log_info, i) + + pbar.close() + + +if __name__ == "__main__": + main() diff --git a/sota-implementations/gail/gail_utils.py b/sota-implementations/gail/gail_utils.py new file mode 100644 index 00000000000..067e9c8c927 --- /dev/null +++ b/sota-implementations/gail/gail_utils.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn +import torch.optim + +from torchrl.data.datasets.d4rl import D4RLExperienceReplay +from torchrl.data.replay_buffers import SamplerWithoutReplacement +from torchrl.envs import DoubleToFloat + +from torchrl.modules import SafeModule + + +# ==================================================================== +# Offline Replay buffer +# --------------------------- + + +def make_offline_replay_buffer(rb_cfg): + data = D4RLExperienceReplay( + dataset_id=rb_cfg.dataset, + split_trajs=False, + batch_size=rb_cfg.batch_size, + sampler=SamplerWithoutReplacement(drop_last=False), + prefetch=4, + direct_download=True, + ) + + data.append_transform(DoubleToFloat()) + + return data + + +def make_gail_discriminator(cfg, train_env, device="cpu"): + """Make GAIL discriminator.""" + + state_dim = train_env.observation_spec["observation"].shape[0] + action_dim = train_env.action_spec.shape[0] + + hidden_dim = cfg.gail.hidden_dim + + # Define Discriminator Network + class Discriminator(nn.Module): + def __init__(self, state_dim, action_dim): + super(Discriminator, self).__init__() + self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, hidden_dim) + self.fc3 = nn.Linear(hidden_dim, 1) + + def forward(self, state, action): + x = torch.cat([state, action], dim=1) + x = torch.relu(self.fc1(x)) + x = torch.relu(self.fc2(x)) + return torch.sigmoid(self.fc3(x)) + + d_module = SafeModule( + module=Discriminator(state_dim, action_dim), + in_keys=["observation", "action"], + out_keys=["d_logits"], + ) + return d_module.to(device) + + +def log_metrics(logger, metrics, step): + if logger is not None: + for metric_name, metric_value in metrics.items(): + logger.log_scalar(metric_name, metric_value, step) diff --git a/sota-implementations/gail/ppo_utils.py b/sota-implementations/gail/ppo_utils.py new file mode 100644 index 00000000000..7986738f8e6 --- /dev/null +++ b/sota-implementations/gail/ppo_utils.py @@ -0,0 +1,150 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn +import torch.optim + +from tensordict.nn import AddStateIndependentNormalScale, TensorDictModule +from torchrl.data import CompositeSpec +from torchrl.envs import ( + ClipTransform, + DoubleToFloat, + ExplorationType, + RewardSum, + StepCounter, + TransformedEnv, + VecNorm, +) +from torchrl.envs.libs.gym import GymEnv +from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator +from torchrl.record import VideoRecorder + + +# ==================================================================== +# Environment utils +# -------------------------------------------------------------------- + + +def make_env(env_name="HalfCheetah-v4", device="cpu", from_pixels: bool = False): + env = GymEnv(env_name, device=device, from_pixels=from_pixels, pixels_only=False) + env = TransformedEnv(env) + env.append_transform(VecNorm(in_keys=["observation"], decay=0.99999, eps=1e-2)) + env.append_transform(ClipTransform(in_keys=["observation"], low=-10, high=10)) + env.append_transform(RewardSum()) + env.append_transform(StepCounter()) + env.append_transform(DoubleToFloat(in_keys=["observation"])) + return env + + +# ==================================================================== +# Model utils +# -------------------------------------------------------------------- + + +def make_ppo_models_state(proof_environment): + + # Define input shape + input_shape = proof_environment.observation_spec["observation"].shape + + # Define policy output distribution class + num_outputs = proof_environment.action_spec.shape[-1] + distribution_class = TanhNormal + distribution_kwargs = { + "low": proof_environment.action_spec.space.low, + "high": proof_environment.action_spec.space.high, + "tanh_loc": False, + } + + # Define policy architecture + policy_mlp = MLP( + in_features=input_shape[-1], + activation_class=torch.nn.Tanh, + out_features=num_outputs, # predict only loc + num_cells=[64, 64], + ) + + # Initialize policy weights + for layer in policy_mlp.modules(): + if isinstance(layer, torch.nn.Linear): + torch.nn.init.orthogonal_(layer.weight, 1.0) + layer.bias.data.zero_() + + # Add state-independent normal scale + policy_mlp = torch.nn.Sequential( + policy_mlp, + AddStateIndependentNormalScale( + proof_environment.action_spec.shape[-1], scale_lb=1e-8 + ), + ) + + # Add probabilistic sampling of the actions + policy_module = ProbabilisticActor( + TensorDictModule( + module=policy_mlp, + in_keys=["observation"], + out_keys=["loc", "scale"], + ), + in_keys=["loc", "scale"], + spec=CompositeSpec(action=proof_environment.action_spec), + distribution_class=distribution_class, + distribution_kwargs=distribution_kwargs, + return_log_prob=True, + default_interaction_type=ExplorationType.RANDOM, + ) + + # Define value architecture + value_mlp = MLP( + in_features=input_shape[-1], + activation_class=torch.nn.Tanh, + out_features=1, + num_cells=[64, 64], + ) + + # Initialize value weights + for layer in value_mlp.modules(): + if isinstance(layer, torch.nn.Linear): + torch.nn.init.orthogonal_(layer.weight, 0.01) + layer.bias.data.zero_() + + # Define value module + value_module = ValueOperator( + value_mlp, + in_keys=["observation"], + ) + + return policy_module, value_module + + +def make_ppo_models(env_name): + proof_environment = make_env(env_name, device="cpu") + actor, critic = make_ppo_models_state(proof_environment) + return actor, critic + + +# ==================================================================== +# Evaluation utils +# -------------------------------------------------------------------- + + +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() + + +def eval_model(actor, test_env, num_episodes=3): + test_rewards = [] + for _ in range(num_episodes): + td_test = test_env.rollout( + policy=actor, + auto_reset=True, + auto_cast_to_device=True, + break_when_any_done=True, + max_steps=10_000_000, + ) + reward = td_test["next", "episode_reward"][td_test["next", "done"]] + test_rewards.append(reward.cpu()) + test_env.apply(dump_video) + del td_test + return torch.cat(test_rewards, 0).mean() diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index f8d2bd1d977..f40366cf2e3 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -10,6 +10,7 @@ from .decision_transformer import DTLoss, OnlineDTLoss from .dqn import DistributionalDQNLoss, DQNLoss from .dreamer import DreamerActorLoss, DreamerModelLoss, DreamerValueLoss +from .gail import GAILLoss from .iql import DiscreteIQLLoss, IQLLoss from .multiagent import QMixerLoss from .ppo import ClipPPOLoss, KLPENPPOLoss, PPOLoss diff --git a/torchrl/objectives/gail.py b/torchrl/objectives/gail.py new file mode 100644 index 00000000000..aaaee0b9aa3 --- /dev/null +++ b/torchrl/objectives/gail.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from dataclasses import dataclass + +import torch + +import torch.autograd as autograd +from tensordict import TensorDict, TensorDictBase, TensorDictParams +from tensordict.nn import dispatch, TensorDictModule +from tensordict.utils import NestedKey + +from torchrl.objectives.common import LossModule +from torchrl.objectives.utils import _reduce + + +class GAILLoss(LossModule): + r"""TorchRL implementation of the Generative Adversarial Imitation Learning (GAIL) loss. + + Presented in `"Generative Adversarial Imitation Learning" ` + + Args: + discriminator_network (TensorDictModule): stochastic actor + + Keyword Args: + use_grad_penalty (bool, optional): Whether to use gradient penalty. Default: ``False``. + gp_lambda (float, optional): Gradient penalty lambda. Default: ``10``. + reduction (str, optional): Specifies the reduction to apply to the output: + ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, + ``"mean"``: the sum of the output will be divided by the number of + elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. + """ + + @dataclass + class _AcceptedKeys: + """Maintains default values for all configurable tensordict keys. + + This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their + default values. + + Attributes: + action (NestedKey): The input tensordict key where the action is expected. + Defaults to ``"action"``. + observation (NestedKey): The tensordict key where the observation is expected. + Defaults to ``"observation"``. + """ + + action: NestedKey = "action" + observation: NestedKey = "observation" + discriminator_pred: NestedKey = "d_logits" + + default_keys = _AcceptedKeys() + + discriminator_network: TensorDictModule + discriminator_network_params: TensorDictParams + + def __init__( + self, + discriminator_network: TensorDictModule, + *, + use_grad_penalty: bool = False, + gp_lambda: float = 10, + reduction: str = None, + ) -> None: + self._in_keys = None + self._out_keys = None + if reduction is None: + reduction = "mean" + super().__init__() + + # Discriminator Network + self.convert_to_functional( + discriminator_network, + "discriminator_network", + create_target_params=False, + ) + self.loss_function = torch.nn.BCELoss() + self.use_grad_penalty = use_grad_penalty + self.gp_lambda = gp_lambda + + self.reduction = reduction + + def _set_in_keys(self): + keys = self.discriminator_network.in_keys + keys = set(keys) + self._in_keys = sorted(keys, key=str) + + def _forward_value_estimator_keys(self, **kwargs) -> None: + pass + + @property + def in_keys(self): + if self._in_keys is None: + self._set_in_keys() + return self._in_keys + + @in_keys.setter + def in_keys(self, values): + self._in_keys = values + + @property + def out_keys(self): + if self._out_keys is None: + keys = ["loss"] + self._out_keys = keys + return self._out_keys + + @out_keys.setter + def out_keys(self, values): + self._out_keys = values + + @dispatch + def forward( + self, tensordict: TensorDictBase, collection_tensordict: TensorDictBase + ) -> TensorDictBase: + """Compute the GAIL discriminator loss.""" + expert_tensordict = tensordict.clone(False) + expert_input = expert_tensordict.select(*self.in_keys).detach() + + collection_tensordict = collection_tensordict.clone(False) + collection_input = collection_tensordict.select(*self.in_keys).detach() + + combined_inputs = torch.cat([expert_input, collection_input], dim=0) + + # create labels + collection_bs = collection_tensordict.batch_size[0] + expert_bs = expert_tensordict.batch_size[0] + fake_labels = torch.zeros((collection_bs, 1), dtype=torch.float32).to( + collection_tensordict.device + ) + real_labels = torch.ones((expert_bs, 1), dtype=torch.float32).to( + expert_tensordict.device + ) + + with self.discriminator_network_params.to_module(self.discriminator_network): + d_logits = self.discriminator_network(combined_inputs).get( + self.tensor_keys.discriminator_pred + ) + + expert_preds, collection_preds = torch.split( + d_logits, [expert_bs, collection_bs], dim=0 + ) + + expert_loss = self.loss_function(expert_preds, real_labels) + collection_loss = self.loss_function(collection_preds, fake_labels) + + loss = expert_loss + collection_loss + out = {"loss": loss} + if not self.use_grad_penalty: + obs = collection_tensordict.get(self.tensor_keys.observation) + acts = collection_tensordict.get(self.tensor_keys.action) + obs_e = expert_tensordict.get(self.tensor_keys.observation) + acts_e = expert_tensordict.get(self.tensor_keys.action) + + obs = obs[:expert_bs] + acts = acts[:expert_bs] + + obss_noise = ( + torch.distributions.Uniform(0.0, 1.0) + .sample(obs_e.shape) + .to(tensordict.device) + ) + acts_noise = ( + torch.distributions.Uniform(0.0, 1.0) + .sample(acts_e.shape) + .to(tensordict.device) + ) + obss_mixture = obss_noise * obs + (1 - obss_noise) * obs_e + acts_mixture = acts_noise * acts + (1 - acts_noise) * acts_e + obss_mixture.requires_grad_(True) + acts_mixture.requires_grad_(True) + + pg_input_td = TensorDict( + { + self.tensor_keys.observation: obss_mixture, + self.tensor_keys.action: acts_mixture, + }, + [], + ) + + with self.discriminator_network_params.to_module( + self.discriminator_network + ): + d_logits_mixture = self.discriminator_network(pg_input_td).get( + self.tensor_keys.discriminator_pred + ) + + gradients = torch.cat( + autograd.grad( + outputs=d_logits_mixture, + inputs=(obss_mixture, acts_mixture), + grad_outputs=torch.ones( + d_logits_mixture.size(), device=tensordict.device + ), + create_graph=True, + retain_graph=True, + only_inputs=True, + ), + dim=-1, + ) + + gp_loss = self.gp_lambda * torch.mean( + (torch.linalg.norm(gradients, dim=-1) - 1) ** 2 + ) + + loss += gp_loss + out["gp_loss"] = gp_loss + loss = _reduce(loss, reduction=self.reduction) + + td_out = TensorDict(out, []) + return td_out From 2d31f33c08958a4503b63a4e68a370c99e9d1cca Mon Sep 17 00:00:00 2001 From: BY571 Date: Fri, 5 Jul 2024 09:20:30 +0200 Subject: [PATCH 02/22] update docs --- docs/source/reference/objectives.rst | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index c2f43d8e9b6..993429548eb 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -141,6 +141,15 @@ CQL CQLLoss DiscreteCQLLoss +GAIL +---- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + GAILLoss + DT ---- From 79bda13019b1a7db8d7e2b6ac828f30f8c850d5d Mon Sep 17 00:00:00 2001 From: BY571 Date: Fri, 5 Jul 2024 09:20:44 +0200 Subject: [PATCH 03/22] update comments --- sota-implementations/gail/gail.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py index c1f60418f24..a469848b10f 100644 --- a/sota-implementations/gail/gail.py +++ b/sota-implementations/gail/gail.py @@ -43,7 +43,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create logger - exp_name = generate_exp_name("Gail-offline", cfg.logger.exp_name) + exp_name = generate_exp_name("Gail", cfg.logger.exp_name) logger = None if cfg.logger.backend: logger = get_logger( @@ -170,11 +170,8 @@ def main(cfg: "DictConfig"): # noqa: F821 with torch.no_grad(): data = discriminator(data) d_rewards = -torch.log(1 - data["d_logits"] + 1e-8) - d_rewards = torch.log(data["d_logits"] + 1e-8) - torch.log( - 1 - data["d_logits"] + 1e-8 - ) - # set d_rewards to tensordict + # Set discriminator rewards to tensordict data.set(("next", "reward"), d_rewards) # Get training rewards and episode lengths From 1391b50dfecc645019b72dbdddcdcb48898c1198 Mon Sep 17 00:00:00 2001 From: BY571 Date: Fri, 5 Jul 2024 09:25:37 +0200 Subject: [PATCH 04/22] add sota-example-test --- .github/unittest/linux_examples/scripts/run_test.sh | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index 075489b208d..2fc08908343 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -192,6 +192,12 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq env.train_num_envs=2 \ logger.mode=offline \ logger.backend= + python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/gail/gail.py \ + collector.total_frames=48 \ + loss.mini_batch_size=10 \ + collector.frames_per_batch=16 \ + logger.mode=offline \ + logger.backend= # With single envs python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \ From f444b72029f870a7eee0419555d34e7af30b90f5 Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 8 Jul 2024 14:53:09 +0200 Subject: [PATCH 05/22] update collection data slice --- sota-implementations/gail/gail.py | 12 +++++- torchrl/objectives/gail.py | 63 ++++++++++++++++++------------- 2 files changed, 47 insertions(+), 28 deletions(-) diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py index a469848b10f..0ce2c4ac081 100644 --- a/sota-implementations/gail/gail.py +++ b/sota-implementations/gail/gail.py @@ -155,11 +155,19 @@ def main(cfg: "DictConfig"): # noqa: F821 pbar.update(data.numel()) # Update discriminator - # Get expert data expert_data = replay_buffer.sample() expert_data = expert_data.to(device) - d_loss = discriminator_loss(expert_data, data) + # Add collector data to expert data + expert_data.set( + discriminator_loss.tensor_keys.collector_action, + data["action"][: expert_data.batch_size[0]], + ) + expert_data.set( + discriminator_loss.tensor_keys.collector_observation, + data["observation"][: expert_data.batch_size[0]], + ) + d_loss = discriminator_loss(expert_data) # Backward pass discriminator_optim.zero_grad() diff --git a/torchrl/objectives/gail.py b/torchrl/objectives/gail.py index aaaee0b9aa3..8efd0c736d8 100644 --- a/torchrl/objectives/gail.py +++ b/torchrl/objectives/gail.py @@ -48,8 +48,10 @@ class _AcceptedKeys: Defaults to ``"observation"``. """ - action: NestedKey = "action" - observation: NestedKey = "observation" + expert_action: NestedKey = "action" + expert_observation: NestedKey = "observation" + collector_action: NestedKey = "collector_action" + collector_observation: NestedKey = "collector_observation" discriminator_pred: NestedKey = "d_logits" default_keys = _AcceptedKeys() @@ -86,6 +88,8 @@ def __init__( def _set_in_keys(self): keys = self.discriminator_network.in_keys keys = set(keys) + keys.add(self.tensor_keys.expert_observation) + keys.add(self.tensor_keys.expert_action) self._in_keys = sorted(keys, key=str) def _forward_value_estimator_keys(self, **kwargs) -> None: @@ -114,25 +118,35 @@ def out_keys(self, values): @dispatch def forward( - self, tensordict: TensorDictBase, collection_tensordict: TensorDictBase + self, + tensordict: TensorDictBase, ) -> TensorDictBase: """Compute the GAIL discriminator loss.""" - expert_tensordict = tensordict.clone(False) - expert_input = expert_tensordict.select(*self.in_keys).detach() - - collection_tensordict = collection_tensordict.clone(False) - collection_input = collection_tensordict.select(*self.in_keys).detach() - - combined_inputs = torch.cat([expert_input, collection_input], dim=0) + tensordict = tensordict.clone(False) + batch_size = tensordict.batch_size[0] + collector_obs = tensordict.get(self.tensor_keys.collector_observation) + collector_act = tensordict.get(self.tensor_keys.collector_action) + + expert_obs = tensordict.get(self.tensor_keys.expert_observation) + expert_act = tensordict.get(self.tensor_keys.expert_action) + + combined_obs_inputs = torch.cat([expert_obs, collector_obs], dim=0) + combined_act_inputs = torch.cat([expert_act, collector_act], dim=0) + + combined_inputs = TensorDict( + { + self.tensor_keys.expert_observation: combined_obs_inputs, + self.tensor_keys.expert_action: combined_act_inputs, + }, + batch_size=[2 * batch_size], + ) # create labels - collection_bs = collection_tensordict.batch_size[0] - expert_bs = expert_tensordict.batch_size[0] - fake_labels = torch.zeros((collection_bs, 1), dtype=torch.float32).to( - collection_tensordict.device + fake_labels = torch.zeros((batch_size, 1), dtype=torch.float32).to( + tensordict.device ) - real_labels = torch.ones((expert_bs, 1), dtype=torch.float32).to( - expert_tensordict.device + real_labels = torch.ones((batch_size, 1), dtype=torch.float32).to( + tensordict.device ) with self.discriminator_network_params.to_module(self.discriminator_network): @@ -141,7 +155,7 @@ def forward( ) expert_preds, collection_preds = torch.split( - d_logits, [expert_bs, collection_bs], dim=0 + d_logits, [batch_size, batch_size], dim=0 ) expert_loss = self.loss_function(expert_preds, real_labels) @@ -150,13 +164,10 @@ def forward( loss = expert_loss + collection_loss out = {"loss": loss} if not self.use_grad_penalty: - obs = collection_tensordict.get(self.tensor_keys.observation) - acts = collection_tensordict.get(self.tensor_keys.action) - obs_e = expert_tensordict.get(self.tensor_keys.observation) - acts_e = expert_tensordict.get(self.tensor_keys.action) - - obs = obs[:expert_bs] - acts = acts[:expert_bs] + obs = tensordict.get(self.tensor_keys.collector_observation) + acts = tensordict.get(self.tensor_keys.colecctor_action) + obs_e = tensordict.get(self.tensor_keys.expert_observation) + acts_e = tensordict.get(self.tensor_keys.expert_action) obss_noise = ( torch.distributions.Uniform(0.0, 1.0) @@ -175,8 +186,8 @@ def forward( pg_input_td = TensorDict( { - self.tensor_keys.observation: obss_mixture, - self.tensor_keys.action: acts_mixture, + self.tensor_keys.expert_observation: obss_mixture, + self.tensor_keys.expert_action: acts_mixture, }, [], ) From 244b7ab7536c812557f3eb11c2f6ced3a095e8f6 Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 8 Jul 2024 14:54:41 +0200 Subject: [PATCH 06/22] update docstring --- torchrl/objectives/gail.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torchrl/objectives/gail.py b/torchrl/objectives/gail.py index 8efd0c736d8..d774864459d 100644 --- a/torchrl/objectives/gail.py +++ b/torchrl/objectives/gail.py @@ -42,10 +42,15 @@ class _AcceptedKeys: default values. Attributes: - action (NestedKey): The input tensordict key where the action is expected. + expert_action (NestedKey): The input tensordict key where the action is expected. Defaults to ``"action"``. - observation (NestedKey): The tensordict key where the observation is expected. + expert_observation (NestedKey): The tensordict key where the observation is expected. Defaults to ``"observation"``. + collector_action (NestedKey): The tensordict key where the collector action is expected. + Defaults to ``"collector_action"``. + collector_observation (NestedKey): The tensordict key where the collector observation is expected. + Defaults to ``"collector_observation"``. + discriminator_pred (NestedKey): The tensordict key where the discriminator prediction is expected. """ expert_action: NestedKey = "action" From db635d7830b7c5c4b4d6823b91d3fbfe2fe81cd2 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 9 Jul 2024 10:14:09 +0200 Subject: [PATCH 07/22] update config and objective with gp param --- sota-implementations/gail/config.yaml | 2 +- torchrl/objectives/gail.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sota-implementations/gail/config.yaml b/sota-implementations/gail/config.yaml index c2a77337526..fa9e66dcf15 100644 --- a/sota-implementations/gail/config.yaml +++ b/sota-implementations/gail/config.yaml @@ -42,7 +42,7 @@ ppo: gail: hidden_dim: 128 lr: 3e-4 - use_grad_penalty: True + use_grad_penalty: False gp_lambda: 10.0 replay_buffer: diff --git a/torchrl/objectives/gail.py b/torchrl/objectives/gail.py index d774864459d..9704bb4639c 100644 --- a/torchrl/objectives/gail.py +++ b/torchrl/objectives/gail.py @@ -168,9 +168,9 @@ def forward( loss = expert_loss + collection_loss out = {"loss": loss} - if not self.use_grad_penalty: + if self.use_grad_penalty: obs = tensordict.get(self.tensor_keys.collector_observation) - acts = tensordict.get(self.tensor_keys.colecctor_action) + acts = tensordict.get(self.tensor_keys.collector_action) obs_e = tensordict.get(self.tensor_keys.expert_observation) acts_e = tensordict.get(self.tensor_keys.expert_action) From 434622c15cbbc0bb743613a9dd28a96acdedcdb2 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 9 Jul 2024 10:16:52 +0200 Subject: [PATCH 08/22] init cost tests gail --- test/test_cost.py | 199 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 199 insertions(+) diff --git a/test/test_cost.py b/test/test_cost.py index 76fc4e651f4..028fa8e60f8 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -8976,6 +8976,205 @@ def test_dt_reduction(self, reduction): assert loss["loss"].shape == torch.Size([]) +class TestGAIL(LossModuleTestBase): + seed = 0 + + def _create_mock_discriminator( + self, batch=2, obs_dim=3, action_dim=4, device="cpu" + ): + # Discriminator + body = TensorDictModule( + MLP( + in_features=obs_dim + action_dim, + out_features=32, + depth=1, + num_cells=32, + activation_class=torch.nn.ReLU, + activate_last_layer=True, + ), + in_keys=["observation", "action"], + out_keys="hidden", + ) + head = TensorDictModule( + MLP( + in_features=32, + out_features=1, + depth=0, + num_cells=32, + activation_class=torch.nn.Sigmoid, + activate_last_layer=True, + ), + in_keys="hidden", + out_keys="d_logits", + ) + discriminator = TensorDictSequential(body, head) + + return discriminator.to(device) + + def _create_mock_data_gail(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): + # create a tensordict + obs = torch.randn(batch, obs_dim, device=device) + action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) + td = TensorDict( + batch_size=(batch,), + source={ + "observation": obs, + "action": action, + }, + device=device, + ) + return td + + def _create_seq_mock_data_gail( + self, batch=2, T=4, obs_dim=3, action_dim=4, device="cpu" + ): + # create a tensordict + obs = torch.randn(batch, T, obs_dim, device=device) + action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) + + td = TensorDict( + batch_size=(batch, T), + source={ + "observation": obs, + "action": action, + }, + device=device, + ) + return td + + def test_dt_tensordict_keys(self): + actor = self._create_mock_actor() + loss_fn = DTLoss(actor) + + default_keys = { + "action_target": "action", + "action_pred": "action", + } + + self.tensordict_keys_test( + loss_fn, + default_keys=default_keys, + ) + + @pytest.mark.parametrize("device", get_default_devices()) + def test_dt_notensordict(self, device): + torch.manual_seed(self.seed) + actor = self._create_mock_actor(device=device) + td = self._create_mock_data_dt(device=device) + loss_fn = DTLoss(actor) + + in_keys = self._flatten_in_keys(loss_fn.in_keys) + kwargs = dict(td.flatten_keys("_").select(*in_keys)) + + loss_val_td = loss_fn(td) + loss_val = loss_fn(**kwargs) + torch.testing.assert_close(loss_val_td.get("loss"), loss_val) + # test select + loss_fn.select_out_keys("loss") + if torch.__version__ >= "2.0.0": + loss_actor = loss_fn(**kwargs) + else: + with pytest.raises( + RuntimeError, + match="You are likely using tensordict.nn.dispatch with keyword arguments", + ): + loss_actor = loss_fn(**kwargs) + return + assert loss_actor == loss_val_td["loss"] + + @pytest.mark.parametrize("device", get_available_devices()) + def test_dt(self, device): + torch.manual_seed(self.seed) + td = self._create_mock_data_dt(device=device) + + actor = self._create_mock_actor(device=device) + + loss_fn = DTLoss(actor) + loss = loss_fn(td) + loss_transformer = loss["loss"] + loss_transformer.backward(retain_graph=True) + named_parameters = loss_fn.named_parameters() + + for name, p in named_parameters: + if p.grad is not None and p.grad.norm() > 0.0: + assert "actor" in name + assert "alpha" not in name + if p.grad is None: + assert "actor" not in name + assert "alpha" in name + loss_fn.zero_grad() + + sum([loss_transformer]).backward() + named_parameters = list(loss_fn.named_parameters()) + named_buffers = list(loss_fn.named_buffers()) + + assert len({p for n, p in named_parameters}) == len(list(named_parameters)) + assert len({p for n, p in named_buffers}) == len(list(named_buffers)) + + for name, p in named_parameters: + assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" + + @pytest.mark.parametrize("device", get_available_devices()) + def test_dt_state_dict(self, device): + torch.manual_seed(self.seed) + + actor = self._create_mock_actor(device=device) + + loss_fn = DTLoss(actor) + sd = loss_fn.state_dict() + loss_fn2 = DTLoss(actor) + loss_fn2.load_state_dict(sd) + + @pytest.mark.parametrize("device", get_available_devices()) + def test_seq_dt(self, device): + torch.manual_seed(self.seed) + td = self._create_seq_mock_data_dt(device=device) + + actor = self._create_mock_actor(device=device) + + loss_fn = DTLoss(actor) + loss = loss_fn(td) + loss_transformer = loss["loss"] + loss_transformer.backward(retain_graph=True) + named_parameters = loss_fn.named_parameters() + + for name, p in named_parameters: + if p.grad is not None and p.grad.norm() > 0.0: + assert "actor" in name + assert "alpha" not in name + if p.grad is None: + assert "actor" not in name + assert "alpha" in name + loss_fn.zero_grad() + + sum([loss_transformer]).backward() + named_parameters = list(loss_fn.named_parameters()) + named_buffers = list(loss_fn.named_buffers()) + + assert len({p for n, p in named_parameters}) == len(list(named_parameters)) + assert len({p for n, p in named_buffers}) == len(list(named_buffers)) + + for name, p in named_parameters: + assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" + + @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) + def test_dt_reduction(self, reduction): + torch.manual_seed(self.seed) + device = ( + torch.device("cpu") + if torch.cuda.device_count() == 0 + else torch.device("cuda") + ) + td = self._create_mock_data_dt(device=device) + actor = self._create_mock_actor(device=device) + loss_fn = DTLoss(actor, reduction=reduction) + loss = loss_fn(td) + if reduction == "none": + assert loss["loss"].shape == td["action"].shape + else: + assert loss["loss"].shape == torch.Size([]) + + @pytest.mark.skipif( not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" ) From baca70fb6f858a9a2232a8ff007ac0661c6b5755 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 9 Jul 2024 10:58:26 +0200 Subject: [PATCH 09/22] update cost test --- test/test_cost.py | 38 ++++++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 028fa8e60f8..40872d3a9fd 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -108,6 +108,7 @@ DreamerModelLoss, DreamerValueLoss, DTLoss, + GAILLoss, IQLLoss, KLPENPPOLoss, OnlineDTLoss, @@ -9042,13 +9043,16 @@ def _create_seq_mock_data_gail( ) return td - def test_dt_tensordict_keys(self): - actor = self._create_mock_actor() - loss_fn = DTLoss(actor) + def test_gail_tensordict_keys(self): + discriminator = self._create_mock_discriminator() + loss_fn = GAILLoss(discriminator) default_keys = { - "action_target": "action", - "action_pred": "action", + "expert_action": "action", + "expert_observation": "observation", + "collector_action": "collector_action", + "collector_observation": "collector_observation", + "discriminator_pred": "d_logits", } self.tensordict_keys_test( @@ -9057,30 +9061,36 @@ def test_dt_tensordict_keys(self): ) @pytest.mark.parametrize("device", get_default_devices()) - def test_dt_notensordict(self, device): + def test_gail_notensordict(self, device): torch.manual_seed(self.seed) - actor = self._create_mock_actor(device=device) - td = self._create_mock_data_dt(device=device) - loss_fn = DTLoss(actor) + discriminator = self._create_mock_discriminator(device=device) + loss_fn = DTLoss(discriminator) + + expert_td = self._create_mock_data_gail(device=device) + collector_td = self._create_mock_data_gail(device=device) + expert_td.set( + loss_fn.tensor_keys.collector_observation, collector_td["observation"] + ) + expert_td.set(loss_fn.tensor_keys.collector_action, collector_td["action"]) in_keys = self._flatten_in_keys(loss_fn.in_keys) - kwargs = dict(td.flatten_keys("_").select(*in_keys)) + kwargs = dict(expert_td.flatten_keys("_").select(*in_keys)) - loss_val_td = loss_fn(td) + loss_val_td = loss_fn(expert_td) loss_val = loss_fn(**kwargs) torch.testing.assert_close(loss_val_td.get("loss"), loss_val) # test select loss_fn.select_out_keys("loss") if torch.__version__ >= "2.0.0": - loss_actor = loss_fn(**kwargs) + loss_discriminator = loss_fn(**kwargs) else: with pytest.raises( RuntimeError, match="You are likely using tensordict.nn.dispatch with keyword arguments", ): - loss_actor = loss_fn(**kwargs) + loss_discriminator = loss_fn(**kwargs) return - assert loss_actor == loss_val_td["loss"] + assert loss_discriminator == loss_val_td["loss"] @pytest.mark.parametrize("device", get_available_devices()) def test_dt(self, device): From 8e7713f1d503fa1d29c26ee9c0b5a48ec1048d8e Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 11 Jul 2024 11:39:19 +0200 Subject: [PATCH 10/22] add gail cost tests --- test/test_cost.py | 85 ++++++++++++++++++++++---------------- torchrl/objectives/gail.py | 65 +++++++++++++++++++---------- 2 files changed, 92 insertions(+), 58 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index d76c2db9565..0d7b24f7779 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -10438,6 +10438,8 @@ def _create_mock_data_gail(self, batch=2, obs_dim=3, action_dim=4, device="cpu") source={ "observation": obs, "action": action, + "collector_action": action, + "collector_observation": obs, }, device=device, ) @@ -10455,6 +10457,8 @@ def _create_seq_mock_data_gail( source={ "observation": obs, "action": action, + "collector_action": action, + "collector_observation": obs, }, device=device, ) @@ -10478,23 +10482,26 @@ def test_gail_tensordict_keys(self): ) @pytest.mark.parametrize("device", get_default_devices()) - def test_gail_notensordict(self, device): + @pytest.mark.parametrize("use_grad_penalty", [True, False]) + @pytest.mark.parametrize("gp_lambda", [0.1, 1.0]) + def test_gail_notensordict(self, device, use_grad_penalty, gp_lambda): torch.manual_seed(self.seed) discriminator = self._create_mock_discriminator(device=device) - loss_fn = DTLoss(discriminator) - - expert_td = self._create_mock_data_gail(device=device) - collector_td = self._create_mock_data_gail(device=device) - expert_td.set( - loss_fn.tensor_keys.collector_observation, collector_td["observation"] + loss_fn = GAILLoss( + discriminator, use_grad_penalty=use_grad_penalty, gp_lambda=gp_lambda ) - expert_td.set(loss_fn.tensor_keys.collector_action, collector_td["action"]) + + tensordict = self._create_mock_data_gail(device=device) in_keys = self._flatten_in_keys(loss_fn.in_keys) - kwargs = dict(expert_td.flatten_keys("_").select(*in_keys)) + kwargs = dict(tensordict.flatten_keys("_").select(*in_keys)) + + loss_val_td = loss_fn(tensordict) + if use_grad_penalty: + loss_val, _ = loss_fn(**kwargs) + else: + loss_val = loss_fn(**kwargs) - loss_val_td = loss_fn(expert_td) - loss_val = loss_fn(**kwargs) torch.testing.assert_close(loss_val_td.get("loss"), loss_val) # test select loss_fn.select_out_keys("loss") @@ -10510,13 +10517,17 @@ def test_gail_notensordict(self, device): assert loss_discriminator == loss_val_td["loss"] @pytest.mark.parametrize("device", get_available_devices()) - def test_dt(self, device): + @pytest.mark.parametrize("use_grad_penalty", [True, False]) + @pytest.mark.parametrize("gp_lambda", [0.1, 1.0]) + def test_gail(self, device, use_grad_penalty, gp_lambda): torch.manual_seed(self.seed) - td = self._create_mock_data_dt(device=device) + td = self._create_mock_data_gail(device=device) - actor = self._create_mock_actor(device=device) + discriminator = self._create_mock_discriminator(device=device) - loss_fn = DTLoss(actor) + loss_fn = GAILLoss( + discriminator, use_grad_penalty=use_grad_penalty, gp_lambda=gp_lambda + ) loss = loss_fn(td) loss_transformer = loss["loss"] loss_transformer.backward(retain_graph=True) @@ -10524,11 +10535,9 @@ def test_dt(self, device): for name, p in named_parameters: if p.grad is not None and p.grad.norm() > 0.0: - assert "actor" in name - assert "alpha" not in name + assert "discriminator" in name if p.grad is None: - assert "actor" not in name - assert "alpha" in name + assert "discriminator" not in name loss_fn.zero_grad() sum([loss_transformer]).backward() @@ -10542,24 +10551,28 @@ def test_dt(self, device): assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" @pytest.mark.parametrize("device", get_available_devices()) - def test_dt_state_dict(self, device): + def test_gail_state_dict(self, device): torch.manual_seed(self.seed) - actor = self._create_mock_actor(device=device) + discriminator = self._create_mock_discriminator(device=device) - loss_fn = DTLoss(actor) + loss_fn = GAILLoss(discriminator) sd = loss_fn.state_dict() - loss_fn2 = DTLoss(actor) + loss_fn2 = GAILLoss(discriminator) loss_fn2.load_state_dict(sd) @pytest.mark.parametrize("device", get_available_devices()) - def test_seq_dt(self, device): + @pytest.mark.parametrize("use_grad_penalty", [True, False]) + @pytest.mark.parametrize("gp_lambda", [0.1, 1.0]) + def test_seq_gail(self, device, use_grad_penalty, gp_lambda): torch.manual_seed(self.seed) - td = self._create_seq_mock_data_dt(device=device) + td = self._create_seq_mock_data_gail(device=device) - actor = self._create_mock_actor(device=device) + discriminator = self._create_mock_discriminator(device=device) - loss_fn = DTLoss(actor) + loss_fn = GAILLoss( + discriminator, use_grad_penalty=use_grad_penalty, gp_lambda=gp_lambda + ) loss = loss_fn(td) loss_transformer = loss["loss"] loss_transformer.backward(retain_graph=True) @@ -10567,11 +10580,9 @@ def test_seq_dt(self, device): for name, p in named_parameters: if p.grad is not None and p.grad.norm() > 0.0: - assert "actor" in name - assert "alpha" not in name + assert "discriminator" in name if p.grad is None: - assert "actor" not in name - assert "alpha" in name + assert "discriminator" not in name loss_fn.zero_grad() sum([loss_transformer]).backward() @@ -10585,19 +10596,21 @@ def test_seq_dt(self, device): assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient" @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) - def test_dt_reduction(self, reduction): + @pytest.mark.parametrize("use_grad_penalty", [True, False]) + @pytest.mark.parametrize("gp_lambda", [0.1, 1.0]) + def test_gail_reduction(self, reduction, use_grad_penalty, gp_lambda): torch.manual_seed(self.seed) device = ( torch.device("cpu") if torch.cuda.device_count() == 0 else torch.device("cuda") ) - td = self._create_mock_data_dt(device=device) - actor = self._create_mock_actor(device=device) - loss_fn = DTLoss(actor, reduction=reduction) + td = self._create_mock_data_gail(device=device) + discriminator = self._create_mock_discriminator(device=device) + loss_fn = GAILLoss(discriminator, reduction=reduction) loss = loss_fn(td) if reduction == "none": - assert loss["loss"].shape == td["action"].shape + assert loss["loss"].shape == (td["observation"].shape[0], 1) else: assert loss["loss"].shape == torch.Size([]) diff --git a/torchrl/objectives/gail.py b/torchrl/objectives/gail.py index 9704bb4639c..3c0050fca84 100644 --- a/torchrl/objectives/gail.py +++ b/torchrl/objectives/gail.py @@ -63,6 +63,13 @@ class _AcceptedKeys: discriminator_network: TensorDictModule discriminator_network_params: TensorDictParams + target_discriminator_network: TensorDictModule + target_discriminator_network_params: TensorDictParams + + out_keys = [ + "loss", + "gp_loss", + ] def __init__( self, @@ -84,7 +91,7 @@ def __init__( "discriminator_network", create_target_params=False, ) - self.loss_function = torch.nn.BCELoss() + self.loss_function = torch.nn.BCELoss(reduction="none") self.use_grad_penalty = use_grad_penalty self.gp_lambda = gp_lambda @@ -95,6 +102,8 @@ def _set_in_keys(self): keys = set(keys) keys.add(self.tensor_keys.expert_observation) keys.add(self.tensor_keys.expert_action) + keys.add(self.tensor_keys.collector_observation) + keys.add(self.tensor_keys.collector_action) self._in_keys = sorted(keys, key=str) def _forward_value_estimator_keys(self, **kwargs) -> None: @@ -114,6 +123,8 @@ def in_keys(self, values): def out_keys(self): if self._out_keys is None: keys = ["loss"] + if self.use_grad_penalty: + keys.append("gp_loss") self._out_keys = keys return self._out_keys @@ -126,9 +137,19 @@ def forward( self, tensordict: TensorDictBase, ) -> TensorDictBase: - """Compute the GAIL discriminator loss.""" + """The forward method. + + Computes the discriminator loss and gradient penalty if `use_grad_penalty` is set to True. If `use_grad_penalty` is set to True, the detached gradient penalty loss is also returned for logging purposes. + To see what keys are expected in the input tensordict and what keys are expected as output, check the + class's `"in_keys"` and `"out_keys"` attributes. + """ + device = self.discriminator_network.device tensordict = tensordict.clone(False) - batch_size = tensordict.batch_size[0] + shape = tensordict.shape + if len(shape) > 1: + batch_size, seq_len = shape + else: + batch_size = shape[0] collector_obs = tensordict.get(self.tensor_keys.collector_observation) collector_act = tensordict.get(self.tensor_keys.collector_action) @@ -144,15 +165,20 @@ def forward( self.tensor_keys.expert_action: combined_act_inputs, }, batch_size=[2 * batch_size], + device=device, ) - # create labels - fake_labels = torch.zeros((batch_size, 1), dtype=torch.float32).to( - tensordict.device - ) - real_labels = torch.ones((batch_size, 1), dtype=torch.float32).to( - tensordict.device - ) + # create + if len(shape) > 1: + fake_labels = torch.zeros((batch_size, seq_len, 1), dtype=torch.float32).to( + device + ) + real_labels = torch.ones((batch_size, seq_len, 1), dtype=torch.float32).to( + device + ) + else: + fake_labels = torch.zeros((batch_size, 1), dtype=torch.float32).to(device) + real_labels = torch.ones((batch_size, 1), dtype=torch.float32).to(device) with self.discriminator_network_params.to_module(self.discriminator_network): d_logits = self.discriminator_network(combined_inputs).get( @@ -167,7 +193,7 @@ def forward( collection_loss = self.loss_function(collection_preds, fake_labels) loss = expert_loss + collection_loss - out = {"loss": loss} + out = {} if self.use_grad_penalty: obs = tensordict.get(self.tensor_keys.collector_observation) acts = tensordict.get(self.tensor_keys.collector_action) @@ -175,14 +201,10 @@ def forward( acts_e = tensordict.get(self.tensor_keys.expert_action) obss_noise = ( - torch.distributions.Uniform(0.0, 1.0) - .sample(obs_e.shape) - .to(tensordict.device) + torch.distributions.Uniform(0.0, 1.0).sample(obs_e.shape).to(device) ) acts_noise = ( - torch.distributions.Uniform(0.0, 1.0) - .sample(acts_e.shape) - .to(tensordict.device) + torch.distributions.Uniform(0.0, 1.0).sample(acts_e.shape).to(device) ) obss_mixture = obss_noise * obs + (1 - obss_noise) * obs_e acts_mixture = acts_noise * acts + (1 - acts_noise) * acts_e @@ -195,6 +217,7 @@ def forward( self.tensor_keys.expert_action: acts_mixture, }, [], + device=device, ) with self.discriminator_network_params.to_module( @@ -208,9 +231,7 @@ def forward( autograd.grad( outputs=d_logits_mixture, inputs=(obss_mixture, acts_mixture), - grad_outputs=torch.ones( - d_logits_mixture.size(), device=tensordict.device - ), + grad_outputs=torch.ones(d_logits_mixture.size(), device=device), create_graph=True, retain_graph=True, only_inputs=True, @@ -223,8 +244,8 @@ def forward( ) loss += gp_loss - out["gp_loss"] = gp_loss + out["gp_loss"] = gp_loss.detach() loss = _reduce(loss, reduction=self.reduction) - + out["loss"] = loss td_out = TensorDict(out, []) return td_out From b31da8a8c72812e94ba7ac2917129232ca28f0d4 Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 31 Jul 2024 10:00:01 +0200 Subject: [PATCH 11/22] Update config --- sota-implementations/gail/config.yaml | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/sota-implementations/gail/config.yaml b/sota-implementations/gail/config.yaml index fa9e66dcf15..3ffed9d6591 100644 --- a/sota-implementations/gail/config.yaml +++ b/sota-implementations/gail/config.yaml @@ -3,8 +3,6 @@ env: seed: 42 backend: gym - -# logger logger: backend: wandb project_name: gail @@ -16,18 +14,15 @@ logger: mode: online ppo: - # collector collector: frames_per_batch: 2048 total_frames: 1_000_000 - # Optim optim: lr: 3e-4 weight_decay: 0.0 anneal_lr: True - # loss loss: gamma: 0.99 mini_batch_size: 64 @@ -44,6 +39,7 @@ gail: lr: 3e-4 use_grad_penalty: False gp_lambda: 10.0 + device: null replay_buffer: dataset: halfcheetah-expert-v2 From 63885b0c08b6a28f758891f8378beb7b86211248 Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 31 Jul 2024 10:00:34 +0200 Subject: [PATCH 12/22] update gail device --- sota-implementations/gail/gail.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py index 0ce2c4ac081..4cc313232fc 100644 --- a/sota-implementations/gail/gail.py +++ b/sota-implementations/gail/gail.py @@ -32,7 +32,13 @@ def main(cfg: "DictConfig"): # noqa: F821 set_gym_backend(cfg.env.backend).set() - device = "cpu" if not torch.cuda.device_count() else "cuda" + device = cfg.gail.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) num_mini_batches = ( cfg.ppo.collector.frames_per_batch // cfg.ppo.loss.mini_batch_size ) From 739332cd2619d8b1afd87fb83f48fdd6b4cc3af8 Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 31 Jul 2024 10:03:14 +0200 Subject: [PATCH 13/22] update example tests --- .github/unittest/linux_examples/scripts/run_test.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index 77b21fffe9a..9cb2cdcaa32 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -206,9 +206,9 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq logger.mode=offline \ logger.backend= python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/gail/gail.py \ - collector.total_frames=48 \ - loss.mini_batch_size=10 \ - collector.frames_per_batch=16 \ + ppo.collector.total_frames=48 \ + ppo.loss.mini_batch_size=10 \ + ppo.collector.frames_per_batch=16 \ logger.mode=offline \ logger.backend= From 9455fef40f02fae1f36371364659eb107847bc24 Mon Sep 17 00:00:00 2001 From: BY571 Date: Mon, 5 Aug 2024 15:09:09 +0200 Subject: [PATCH 14/22] gymnasium backend --- sota-implementations/gail/config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/gail/config.yaml b/sota-implementations/gail/config.yaml index 3ffed9d6591..cf6c8053037 100644 --- a/sota-implementations/gail/config.yaml +++ b/sota-implementations/gail/config.yaml @@ -1,7 +1,7 @@ env: env_name: HalfCheetah-v4 seed: 42 - backend: gym + backend: gymnasium logger: backend: wandb From 4926d8038c6f86d1ad4318f593912c3435728799 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 6 Aug 2024 13:45:48 -0400 Subject: [PATCH 15/22] fixes --- .github/unittest/linux_examples/scripts/run_test.sh | 1 + sota-implementations/gail/gail.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index 9cb2cdcaa32..ef0d081f8fd 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -207,6 +207,7 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq logger.backend= python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/gail/gail.py \ ppo.collector.total_frames=48 \ + replay_buffer.batch_size=16 \ ppo.loss.mini_batch_size=10 \ ppo.collector.frames_per_batch=16 \ logger.mode=offline \ diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py index 4cc313232fc..a3c64693fb3 100644 --- a/sota-implementations/gail/gail.py +++ b/sota-implementations/gail/gail.py @@ -257,7 +257,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # evaluation - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): if ((i - 1) * frames_in_batch) // cfg_logger_test_interval < ( i * frames_in_batch ) // cfg_logger_test_interval: From cbd5dfa5f8dda2f9708c635958d1b7bd26c73907 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 6 Aug 2024 14:55:43 -0400 Subject: [PATCH 16/22] init --- .github/unittest/linux/scripts/environment.yml | 1 + .github/unittest/linux_distributed/scripts/environment.yml | 1 + .github/unittest/linux_examples/scripts/environment.yml | 1 + .github/unittest/linux_libs/scripts_envpool/environment.yml | 1 + .github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml | 1 + 5 files changed, 5 insertions(+) diff --git a/.github/unittest/linux/scripts/environment.yml b/.github/unittest/linux/scripts/environment.yml index 30e01cfc4b5..8f1641a43c2 100644 --- a/.github/unittest/linux/scripts/environment.yml +++ b/.github/unittest/linux/scripts/environment.yml @@ -25,6 +25,7 @@ dependencies: - imageio==2.26.0 - wandb - dm_control + - mujoco<3.2.1 - mlflow - av - coverage diff --git a/.github/unittest/linux_distributed/scripts/environment.yml b/.github/unittest/linux_distributed/scripts/environment.yml index 6d27071791b..fbe1d3475b3 100644 --- a/.github/unittest/linux_distributed/scripts/environment.yml +++ b/.github/unittest/linux_distributed/scripts/environment.yml @@ -24,6 +24,7 @@ dependencies: - imageio==2.26.0 - wandb - dm_control + - mujoco<3.2.1 - mlflow - av - coverage diff --git a/.github/unittest/linux_examples/scripts/environment.yml b/.github/unittest/linux_examples/scripts/environment.yml index 688921f826a..03752200576 100644 --- a/.github/unittest/linux_examples/scripts/environment.yml +++ b/.github/unittest/linux_examples/scripts/environment.yml @@ -22,6 +22,7 @@ dependencies: - hydra-core - imageio==2.26.0 - dm_control + - mujoco<3.2.1 - mlflow - av - coverage diff --git a/.github/unittest/linux_libs/scripts_envpool/environment.yml b/.github/unittest/linux_libs/scripts_envpool/environment.yml index 9259a2a4a43..6ef2c9537cf 100644 --- a/.github/unittest/linux_libs/scripts_envpool/environment.yml +++ b/.github/unittest/linux_libs/scripts_envpool/environment.yml @@ -19,4 +19,5 @@ dependencies: - pyyaml - scipy - dm_control + - mujoco<3.2.1 - coverage diff --git a/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml b/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml index d34011e7bdc..ba8567450c9 100644 --- a/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml +++ b/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml @@ -22,6 +22,7 @@ dependencies: - scipy - hydra-core - dm_control -e git+https://github.com/deepmind/dm_control.git@c053360edea6170acfd9c8f65446703307d9d352#egg={dm_control} + - mujoco<3.2.1 - patchelf - pyopengl==3.1.4 - ray From b8ca705fde8a3827b2a0f539693aab0fcb055f70 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 6 Aug 2024 16:19:16 -0400 Subject: [PATCH 17/22] amend --- .github/unittest/linux/scripts/environment.yml | 2 +- .github/unittest/linux_distributed/scripts/environment.yml | 2 +- .github/unittest/linux_examples/scripts/environment.yml | 2 +- .github/unittest/linux_libs/scripts_envpool/environment.yml | 2 +- .github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml | 1 - 5 files changed, 4 insertions(+), 5 deletions(-) diff --git a/.github/unittest/linux/scripts/environment.yml b/.github/unittest/linux/scripts/environment.yml index 8f1641a43c2..0e4a22802c8 100644 --- a/.github/unittest/linux/scripts/environment.yml +++ b/.github/unittest/linux/scripts/environment.yml @@ -24,7 +24,7 @@ dependencies: - tensorboard - imageio==2.26.0 - wandb - - dm_control +# - dm_control - mujoco<3.2.1 - mlflow - av diff --git a/.github/unittest/linux_distributed/scripts/environment.yml b/.github/unittest/linux_distributed/scripts/environment.yml index fbe1d3475b3..39d48a24aee 100644 --- a/.github/unittest/linux_distributed/scripts/environment.yml +++ b/.github/unittest/linux_distributed/scripts/environment.yml @@ -23,7 +23,7 @@ dependencies: - tensorboard - imageio==2.26.0 - wandb - - dm_control +# - dm_control - mujoco<3.2.1 - mlflow - av diff --git a/.github/unittest/linux_examples/scripts/environment.yml b/.github/unittest/linux_examples/scripts/environment.yml index 03752200576..ec58d5b83be 100644 --- a/.github/unittest/linux_examples/scripts/environment.yml +++ b/.github/unittest/linux_examples/scripts/environment.yml @@ -21,7 +21,7 @@ dependencies: - scipy - hydra-core - imageio==2.26.0 - - dm_control +# - dm_control - mujoco<3.2.1 - mlflow - av diff --git a/.github/unittest/linux_libs/scripts_envpool/environment.yml b/.github/unittest/linux_libs/scripts_envpool/environment.yml index 6ef2c9537cf..3860f1a5337 100644 --- a/.github/unittest/linux_libs/scripts_envpool/environment.yml +++ b/.github/unittest/linux_libs/scripts_envpool/environment.yml @@ -18,6 +18,6 @@ dependencies: - expecttest - pyyaml - scipy - - dm_control +# - dm_control - mujoco<3.2.1 - coverage diff --git a/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml b/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml index ba8567450c9..d34011e7bdc 100644 --- a/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml +++ b/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml @@ -22,7 +22,6 @@ dependencies: - scipy - hydra-core - dm_control -e git+https://github.com/deepmind/dm_control.git@c053360edea6170acfd9c8f65446703307d9d352#egg={dm_control} - - mujoco<3.2.1 - patchelf - pyopengl==3.1.4 - ray From f0c225f8697418b249e8e4a59eddd8828d47051c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 6 Aug 2024 16:34:25 -0400 Subject: [PATCH 18/22] amend --- .github/unittest/linux/scripts/environment.yml | 2 +- .github/unittest/linux_distributed/scripts/environment.yml | 2 +- .github/unittest/linux_examples/scripts/environment.yml | 2 +- .github/unittest/linux_libs/scripts_envpool/environment.yml | 2 +- .github/workflows/benchmarks.yml | 4 ++-- .github/workflows/benchmarks_pr.yml | 4 ++-- 6 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/unittest/linux/scripts/environment.yml b/.github/unittest/linux/scripts/environment.yml index 0e4a22802c8..2dca2a6e9ad 100644 --- a/.github/unittest/linux/scripts/environment.yml +++ b/.github/unittest/linux/scripts/environment.yml @@ -24,7 +24,7 @@ dependencies: - tensorboard - imageio==2.26.0 - wandb -# - dm_control + - dm_control<1.0.21 - mujoco<3.2.1 - mlflow - av diff --git a/.github/unittest/linux_distributed/scripts/environment.yml b/.github/unittest/linux_distributed/scripts/environment.yml index 39d48a24aee..d7eabcdea4f 100644 --- a/.github/unittest/linux_distributed/scripts/environment.yml +++ b/.github/unittest/linux_distributed/scripts/environment.yml @@ -23,7 +23,7 @@ dependencies: - tensorboard - imageio==2.26.0 - wandb -# - dm_control + - dm_control<1.0.21 - mujoco<3.2.1 - mlflow - av diff --git a/.github/unittest/linux_examples/scripts/environment.yml b/.github/unittest/linux_examples/scripts/environment.yml index ec58d5b83be..e99d6133963 100644 --- a/.github/unittest/linux_examples/scripts/environment.yml +++ b/.github/unittest/linux_examples/scripts/environment.yml @@ -21,7 +21,7 @@ dependencies: - scipy - hydra-core - imageio==2.26.0 -# - dm_control + - dm_control<1.0.21 - mujoco<3.2.1 - mlflow - av diff --git a/.github/unittest/linux_libs/scripts_envpool/environment.yml b/.github/unittest/linux_libs/scripts_envpool/environment.yml index 3860f1a5337..9ff3396056b 100644 --- a/.github/unittest/linux_libs/scripts_envpool/environment.yml +++ b/.github/unittest/linux_libs/scripts_envpool/environment.yml @@ -18,6 +18,6 @@ dependencies: - expecttest - pyyaml - scipy -# - dm_control + - dm_control<1.0.21 - mujoco<3.2.1 - coverage diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index 8008c8b5bbe..f698f67763f 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -35,7 +35,7 @@ jobs: python3 setup.py develop python3 -m pip install pytest pytest-benchmark python3 -m pip install "gym[accept-rom-license,atari]" - python3 -m pip install dm_control + python3 -m pip install "dm_control<1.0.21" "mujoco<3.2.1" export TD_GET_DEFAULTS_TO_NONE=1 - name: Run benchmarks run: | @@ -97,7 +97,7 @@ jobs: python3 setup.py develop python3 -m pip install pytest pytest-benchmark python3 -m pip install "gym[accept-rom-license,atari]" - python3 -m pip install dm_control + python3 -m pip install "dm_control<1.0.21" "mujoco<3.2.1" export TD_GET_DEFAULTS_TO_NONE=1 - name: check GPU presence run: | diff --git a/.github/workflows/benchmarks_pr.yml b/.github/workflows/benchmarks_pr.yml index e994e860b9c..5bec0f23d1e 100644 --- a/.github/workflows/benchmarks_pr.yml +++ b/.github/workflows/benchmarks_pr.yml @@ -34,7 +34,7 @@ jobs: python3 setup.py develop python3 -m pip install pytest pytest-benchmark python3 -m pip install "gym[accept-rom-license,atari]" - python3 -m pip install dm_control + python3 -m pip install "dm_control<1.0.21" "mujoco<3.2.1" export TD_GET_DEFAULTS_TO_NONE=1 - name: Setup benchmarks run: | @@ -108,7 +108,7 @@ jobs: python3 setup.py develop python3 -m pip install pytest pytest-benchmark python3 -m pip install "gym[accept-rom-license,atari]" - python3 -m pip install dm_control + python3 -m pip install "dm_control<1.0.21" "mujoco<3.2.1" export TD_GET_DEFAULTS_TO_NONE=1 - name: check GPU presence run: | From 511fa959c165aef88c55c40988a985fbc66ca07e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 6 Aug 2024 16:35:13 -0400 Subject: [PATCH 19/22] amend --- .github/unittest/linux/scripts/environment.yml | 2 +- .github/unittest/linux_distributed/scripts/environment.yml | 2 +- .github/unittest/linux_examples/scripts/environment.yml | 2 +- .github/unittest/linux_examples/scripts/run_test.sh | 7 +++++++ .../unittest/linux_libs/scripts_envpool/environment.yml | 2 +- .github/workflows/benchmarks.yml | 4 ++-- .github/workflows/benchmarks_pr.yml | 4 ++-- 7 files changed, 15 insertions(+), 8 deletions(-) diff --git a/.github/unittest/linux/scripts/environment.yml b/.github/unittest/linux/scripts/environment.yml index 0e4a22802c8..2dca2a6e9ad 100644 --- a/.github/unittest/linux/scripts/environment.yml +++ b/.github/unittest/linux/scripts/environment.yml @@ -24,7 +24,7 @@ dependencies: - tensorboard - imageio==2.26.0 - wandb -# - dm_control + - dm_control<1.0.21 - mujoco<3.2.1 - mlflow - av diff --git a/.github/unittest/linux_distributed/scripts/environment.yml b/.github/unittest/linux_distributed/scripts/environment.yml index 39d48a24aee..d7eabcdea4f 100644 --- a/.github/unittest/linux_distributed/scripts/environment.yml +++ b/.github/unittest/linux_distributed/scripts/environment.yml @@ -23,7 +23,7 @@ dependencies: - tensorboard - imageio==2.26.0 - wandb -# - dm_control + - dm_control<1.0.21 - mujoco<3.2.1 - mlflow - av diff --git a/.github/unittest/linux_examples/scripts/environment.yml b/.github/unittest/linux_examples/scripts/environment.yml index ec58d5b83be..e99d6133963 100644 --- a/.github/unittest/linux_examples/scripts/environment.yml +++ b/.github/unittest/linux_examples/scripts/environment.yml @@ -21,7 +21,7 @@ dependencies: - scipy - hydra-core - imageio==2.26.0 -# - dm_control + - dm_control<1.0.21 - mujoco<3.2.1 - mlflow - av diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index f8b700c0410..ef0d081f8fd 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -205,6 +205,13 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq env.train_num_envs=2 \ logger.mode=offline \ logger.backend= + python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/gail/gail.py \ + ppo.collector.total_frames=48 \ + replay_buffer.batch_size=16 \ + ppo.loss.mini_batch_size=10 \ + ppo.collector.frames_per_batch=16 \ + logger.mode=offline \ + logger.backend= # With single envs python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \ diff --git a/.github/unittest/linux_libs/scripts_envpool/environment.yml b/.github/unittest/linux_libs/scripts_envpool/environment.yml index 3860f1a5337..9ff3396056b 100644 --- a/.github/unittest/linux_libs/scripts_envpool/environment.yml +++ b/.github/unittest/linux_libs/scripts_envpool/environment.yml @@ -18,6 +18,6 @@ dependencies: - expecttest - pyyaml - scipy -# - dm_control + - dm_control<1.0.21 - mujoco<3.2.1 - coverage diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index 8008c8b5bbe..f698f67763f 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -35,7 +35,7 @@ jobs: python3 setup.py develop python3 -m pip install pytest pytest-benchmark python3 -m pip install "gym[accept-rom-license,atari]" - python3 -m pip install dm_control + python3 -m pip install "dm_control<1.0.21" "mujoco<3.2.1" export TD_GET_DEFAULTS_TO_NONE=1 - name: Run benchmarks run: | @@ -97,7 +97,7 @@ jobs: python3 setup.py develop python3 -m pip install pytest pytest-benchmark python3 -m pip install "gym[accept-rom-license,atari]" - python3 -m pip install dm_control + python3 -m pip install "dm_control<1.0.21" "mujoco<3.2.1" export TD_GET_DEFAULTS_TO_NONE=1 - name: check GPU presence run: | diff --git a/.github/workflows/benchmarks_pr.yml b/.github/workflows/benchmarks_pr.yml index e994e860b9c..5bec0f23d1e 100644 --- a/.github/workflows/benchmarks_pr.yml +++ b/.github/workflows/benchmarks_pr.yml @@ -34,7 +34,7 @@ jobs: python3 setup.py develop python3 -m pip install pytest pytest-benchmark python3 -m pip install "gym[accept-rom-license,atari]" - python3 -m pip install dm_control + python3 -m pip install "dm_control<1.0.21" "mujoco<3.2.1" export TD_GET_DEFAULTS_TO_NONE=1 - name: Setup benchmarks run: | @@ -108,7 +108,7 @@ jobs: python3 setup.py develop python3 -m pip install pytest pytest-benchmark python3 -m pip install "gym[accept-rom-license,atari]" - python3 -m pip install dm_control + python3 -m pip install "dm_control<1.0.21" "mujoco<3.2.1" export TD_GET_DEFAULTS_TO_NONE=1 - name: check GPU presence run: | From 4bc316bd811e90f846608440aca1d1303f89af08 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 6 Aug 2024 16:47:12 -0400 Subject: [PATCH 20/22] amend --- docs/requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index f6138cac30a..60c94749ee7 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -14,7 +14,8 @@ docutils sphinx_design torchvision -dm_control +dm_control<1.0.21 +mujoco<3.2.1 atari-py ale-py gym[classic_control,accept-rom-license] From 3d43e424e90830d8467da618af49f2b4ddd3447f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 6 Aug 2024 17:05:32 -0400 Subject: [PATCH 21/22] amend --- .github/unittest/linux/scripts/run_all.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/unittest/linux/scripts/run_all.sh b/.github/unittest/linux/scripts/run_all.sh index 38235043d3f..17a53648f8c 100755 --- a/.github/unittest/linux/scripts/run_all.sh +++ b/.github/unittest/linux/scripts/run_all.sh @@ -91,7 +91,7 @@ echo "installing gymnasium" pip3 install "gymnasium" pip3 install ale_py pip3 install mo-gymnasium[mujoco] # requires here bc needs mujoco-py -pip3 install mujoco -U +pip3 install "mujoco<3.2.1" -U # sanity check: remove? python3 -c """ From c488bcd571ada5a17c9fd8d2ff0f21df11a1d1fb Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 6 Aug 2024 17:44:39 -0400 Subject: [PATCH 22/22] amend --- .github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml b/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml index d34011e7bdc..ba8567450c9 100644 --- a/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml +++ b/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml @@ -22,6 +22,7 @@ dependencies: - scipy - hydra-core - dm_control -e git+https://github.com/deepmind/dm_control.git@c053360edea6170acfd9c8f65446703307d9d352#egg={dm_control} + - mujoco<3.2.1 - patchelf - pyopengl==3.1.4 - ray