diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index 42ef4301c4d..fcc8c321a90 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -87,6 +87,8 @@ def main(cfg: "DictConfig"): # noqa: F821 weight_decay=cfg.optim.weight_decay, eps=cfg.optim.eps, ) + if cfg.loss.compile: + loss_module = torch.compile(loss_module) # Create logger logger = None diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index 2b390d39d2a..7793a4c558f 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -69,6 +69,8 @@ def main(cfg: "DictConfig"): # noqa: F821 entropy_coef=cfg.loss.entropy_coef, critic_coef=cfg.loss.critic_coef, ) + if cfg.loss.compile: + loss_module = torch.compile(loss_module) # Create optimizers actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.optim.lr) diff --git a/sota-implementations/a2c/config_atari.yaml b/sota-implementations/a2c/config_atari.yaml index dd0f43b52cb..b223b3f9dfe 100644 --- a/sota-implementations/a2c/config_atari.yaml +++ b/sota-implementations/a2c/config_atari.yaml @@ -34,3 +34,4 @@ loss: critic_coef: 0.25 entropy_coef: 0.01 loss_critic_type: l2 + compile: True diff --git a/sota-implementations/a2c/config_mujoco.yaml b/sota-implementations/a2c/config_mujoco.yaml index 03a0bde32c5..314127ec2df 100644 --- a/sota-implementations/a2c/config_mujoco.yaml +++ b/sota-implementations/a2c/config_mujoco.yaml @@ -31,3 +31,4 @@ loss: critic_coef: 0.25 entropy_coef: 0.0 loss_critic_type: l2 + compile: True diff --git a/sota-implementations/cql/cql_offline.py b/sota-implementations/cql/cql_offline.py index 73155d9fa1a..dfd40e1d97e 100644 --- a/sota-implementations/cql/cql_offline.py +++ b/sota-implementations/cql/cql_offline.py @@ -15,6 +15,8 @@ import numpy as np import torch import tqdm +from tensordict import TensorDict + from torchrl._utils import logger as torchrl_logger from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.record.loggers import generate_exp_name, get_logger @@ -81,66 +83,68 @@ def main(cfg: "DictConfig"): # noqa: F821 alpha_prime_optim, ) = make_continuous_cql_optimizer(cfg, loss_module) - pbar = tqdm.tqdm(total=cfg.optim.gradient_steps) gradient_steps = cfg.optim.gradient_steps policy_eval_start = cfg.optim.policy_eval_start evaluation_interval = cfg.logger.eval_iter eval_steps = cfg.logger.eval_steps - # Training loop - start_time = time.time() - for i in range(gradient_steps): - pbar.update(1) - # sample data - data = replay_buffer.sample() - # compute loss - loss_vals = loss_module(data.clone().to(device)) + def update(data, i): + critic_optim.zero_grad() + q_loss, metadata = loss_module.q_loss(data) + cql_loss, cql_metadata = loss_module.cql_loss(data) + q_loss = q_loss + cql_loss + q_loss.backward() + critic_optim.step() + metadata.update(cql_metadata) - # official cql implementation uses behavior cloning loss for first few updating steps as it helps for some tasks + policy_optim.zero_grad() if i >= policy_eval_start: - actor_loss = loss_vals["loss_actor"] + actor_loss, actor_metadata = loss_module.actor_loss(data) else: - actor_loss = loss_vals["loss_actor_bc"] - q_loss = loss_vals["loss_qvalue"] - cql_loss = loss_vals["loss_cql"] - - q_loss = q_loss + cql_loss - - # update model - alpha_loss = loss_vals["loss_alpha"] - alpha_prime_loss = loss_vals["loss_alpha_prime"] + actor_loss, actor_metadata = loss_module.actor_bc_loss(data) + actor_loss.backward() + policy_optim.step() + metadata.update(actor_metadata) alpha_optim.zero_grad() + alpha_loss, alpha_metadata = loss_module.alpha_loss(actor_metadata) alpha_loss.backward() alpha_optim.step() - - policy_optim.zero_grad() - actor_loss.backward() - policy_optim.step() + metadata.update(alpha_metadata) if alpha_prime_optim is not None: alpha_prime_optim.zero_grad() - alpha_prime_loss.backward(retain_graph=True) + alpha_prime_loss, alpha_prime_metadata = loss_module.alpha_prime_loss(data) + alpha_prime_loss.backward() alpha_prime_optim.step() + metadata.update(alpha_prime_metadata) - critic_optim.zero_grad() - # TODO: we have the option to compute losses independently retain is not needed? - q_loss.backward(retain_graph=False) - critic_optim.step() + loss_vals = TensorDict(metadata) + loss_vals["loss_qvalue"] = q_loss + loss_vals["loss_cql"] = cql_loss + loss_vals["loss_alpha"] = alpha_loss + loss = actor_loss + q_loss + alpha_loss + if alpha_prime_optim is not None: + loss_vals["loss_alpha_prime"] = alpha_prime_loss + loss = loss + alpha_prime_loss + loss_vals["loss"] = loss + + return loss_vals.detach() + + if cfg.loss.compile: + update = torch.compile(update, mode=cfg.loss.compile_mode) - loss = actor_loss + q_loss + alpha_loss + alpha_prime_loss + # Training loop + start_time = time.time() + pbar = tqdm.tqdm(range(gradient_steps)) + for i in pbar: + # sample data + data = replay_buffer.sample().to(device) + loss_vals = update(data, i) # log metrics - to_log = { - "loss": loss.item(), - "loss_actor_bc": loss_vals["loss_actor_bc"].item(), - "loss_actor": loss_vals["loss_actor"].item(), - "loss_qvalue": q_loss.item(), - "loss_cql": cql_loss.item(), - "loss_alpha": alpha_loss.item(), - "loss_alpha_prime": alpha_prime_loss.item(), - } + to_log = loss_vals.mean().to_dict() # update qnet_target params target_net_updater.step() diff --git a/sota-implementations/cql/cql_online.py b/sota-implementations/cql/cql_online.py index cf629ed0733..94aba0d7433 100644 --- a/sota-implementations/cql/cql_online.py +++ b/sota-implementations/cql/cql_online.py @@ -18,7 +18,9 @@ import torch import tqdm from tensordict import TensorDict -from torchrl._utils import logger as torchrl_logger +from tensordict.nn import CudaGraphModule + +from torchrl._utils import logger as torchrl_logger, timeit from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.record.loggers import generate_exp_name, get_logger @@ -111,17 +113,77 @@ def main(cfg: "DictConfig"): # noqa: F821 evaluation_interval = cfg.logger.log_interval eval_rollout_steps = cfg.logger.eval_steps + def update(sampled_tensordict): + + critic_optim.zero_grad() + q_loss, metadata = loss_module.q_loss(sampled_tensordict) + cql_loss, metadata_cql = loss_module.cql_loss(sampled_tensordict) + metadata.update(metadata) + q_loss = q_loss + cql_loss + q_loss.backward() + critic_optim.step() + + if loss_module.with_lagrange: + alpha_prime_optim.zero_grad() + alpha_prime_loss, metadata_aprime = loss_module.alpha_prime_loss( + sampled_tensordict + ) + metadata.update(metadata_aprime) + alpha_prime_loss.backward() + alpha_prime_optim.step() + + policy_optim.zero_grad() + # loss_actor_bc, _ = loss_module.actor_bc_loss(sampled_tensordict) + actor_loss, actor_metadata = loss_module.actor_loss(sampled_tensordict) + metadata.update(actor_metadata) + actor_loss.backward() + policy_optim.step() + + alpha_optim.zero_grad() + alpha_loss, metadata_actor = loss_module.alpha_loss(actor_metadata) + metadata.update(metadata_actor) + alpha_loss.backward() + alpha_optim.step() + loss_td = TensorDict(metadata) + + loss_td["loss_actor"] = actor_loss + loss_td["loss_qvalue"] = q_loss + loss_td["loss_cql"] = cql_loss + loss_td["loss_alpha"] = alpha_loss + if alpha_prime_optim: + alpha_prime_loss = loss_td["loss_alpha_prime"] + + loss = actor_loss + alpha_loss + q_loss + if alpha_prime_optim is not None: + loss = loss + alpha_prime_loss + + loss_td["loss"] = loss + return loss_td.detach() + + if cfg.loss.compile: + update = torch.compile(update, mode=cfg.loss.compile_mode) + + if cfg.loss.cudagraphs: + update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5) + sampling_start = time.time() - for i, tensordict in enumerate(collector): + collector_iter = iter(collector) + for i in range(cfg.collector.total_frames): + timeit.print() + timeit.erase() + with timeit("collection"): + tensordict = next(collector_iter) sampling_time = time.time() - sampling_start pbar.update(tensordict.numel()) - # update weights of the inference policy - collector.update_policy_weights_() + with timeit("update policies"): + # update weights of the inference policy + collector.update_policy_weights_() - tensordict = tensordict.view(-1) + tensordict = tensordict.reshape(-1) current_frames = tensordict.numel() # add to replay buffer - replay_buffer.extend(tensordict.cpu()) + with timeit("extend"): + replay_buffer.extend(tensordict.cpu()) collected_frames += current_frames # optimization steps @@ -130,7 +192,8 @@ def main(cfg: "DictConfig"): # noqa: F821 log_loss_td = TensorDict({}, [num_updates]) for j in range(num_updates): # sample from replay buffer - sampled_tensordict = replay_buffer.sample() + with timeit("sample"): + sampled_tensordict = replay_buffer.sample() if sampled_tensordict.device != device: sampled_tensordict = sampled_tensordict.to( device, non_blocking=True @@ -138,36 +201,13 @@ def main(cfg: "DictConfig"): # noqa: F821 else: sampled_tensordict = sampled_tensordict.clone() - loss_td = loss_module(sampled_tensordict) - - actor_loss = loss_td["loss_actor"] - q_loss = loss_td["loss_qvalue"] - cql_loss = loss_td["loss_cql"] - q_loss = q_loss + cql_loss - alpha_loss = loss_td["loss_alpha"] - alpha_prime_loss = loss_td["loss_alpha_prime"] - - alpha_optim.zero_grad() - alpha_loss.backward() - alpha_optim.step() - - policy_optim.zero_grad() - actor_loss.backward() - policy_optim.step() - - if alpha_prime_optim is not None: - alpha_prime_optim.zero_grad() - alpha_prime_loss.backward(retain_graph=True) - alpha_prime_optim.step() - - critic_optim.zero_grad() - q_loss.backward(retain_graph=False) - critic_optim.step() - - log_loss_td[j] = loss_td.detach() + with timeit("update"): + loss_td = update(sampled_tensordict) + log_loss_td[j] = loss_td - # update qnet_target params - target_net_updater.step() + with timeit("target net"): + # update qnet_target params + target_net_updater.step() # update priority if prb: @@ -191,10 +231,11 @@ def main(cfg: "DictConfig"): # noqa: F821 metrics_to_log["train/loss_actor"] = log_loss_td.get("loss_actor").mean() metrics_to_log["train/loss_qvalue"] = log_loss_td.get("loss_qvalue").mean() metrics_to_log["train/loss_alpha"] = log_loss_td.get("loss_alpha").mean() - metrics_to_log["train/loss_alpha_prime"] = log_loss_td.get( - "loss_alpha_prime" - ).mean() - metrics_to_log["train/entropy"] = log_loss_td.get("entropy").mean() + if alpha_prime_optim is not None: + metrics_to_log["train/loss_alpha_prime"] = log_loss_td.get( + "loss_alpha_prime" + ).mean() + # metrics_to_log["train/entropy"] = log_loss_td.get("entropy").mean() metrics_to_log["train/sampling_time"] = sampling_time metrics_to_log["train/training_time"] = training_time @@ -204,7 +245,7 @@ def main(cfg: "DictConfig"): # noqa: F821 cur_test_frame = (i * frames_per_batch) // evaluation_interval final = current_frames >= collector.total_frames if (i >= 1 and (prev_test_frame < cur_test_frame)) or final: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(), timeit("eval"): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/cql/discrete_cql_config.yaml b/sota-implementations/cql/discrete_cql_config.yaml index 644b8ec624e..d09469b9c23 100644 --- a/sota-implementations/cql/discrete_cql_config.yaml +++ b/sota-implementations/cql/discrete_cql_config.yaml @@ -57,3 +57,4 @@ loss: loss_function: l2 gamma: 0.99 tau: 0.005 + compile: True diff --git a/sota-implementations/cql/offline_config.yaml b/sota-implementations/cql/offline_config.yaml index bf213d4e3c5..429ea64bc1e 100644 --- a/sota-implementations/cql/offline_config.yaml +++ b/sota-implementations/cql/offline_config.yaml @@ -52,5 +52,7 @@ loss: max_q_backup: False deterministic_backup: False num_random: 10 - with_lagrange: True + with_lagrange: False lagrange_thresh: 5.0 # tau + compile: False + compile_mode: reduce-overhead diff --git a/sota-implementations/cql/online_config.yaml b/sota-implementations/cql/online_config.yaml index 00db1d6bb62..75feb4b695b 100644 --- a/sota-implementations/cql/online_config.yaml +++ b/sota-implementations/cql/online_config.yaml @@ -64,5 +64,8 @@ loss: max_q_backup: False deterministic_backup: False num_random: 10 - with_lagrange: True + with_lagrange: False lagrange_thresh: 10.0 + compile: False + compile_mode: reduce-overhead + cudagraphs: False diff --git a/sota-implementations/cql/utils.py b/sota-implementations/cql/utils.py index c1d6fb52024..ea371a4011b 100644 --- a/sota-implementations/cql/utils.py +++ b/sota-implementations/cql/utils.py @@ -202,17 +202,21 @@ def make_cql_model(cfg, train_env, eval_env, device="cpu"): # We use a ProbabilisticActor to make sure that we map the # network output to the right space using a TanhDelta # distribution. + high = action_spec.space.high + low = action_spec.space.low + if train_env.batch_size: + high = high[(0,) * len(train_env.batch_size)] + low = low[(0,) * len(train_env.batch_size)] actor = ProbabilisticActor( module=actor_module, in_keys=["loc", "scale"], spec=action_spec, distribution_class=TanhNormal, distribution_kwargs={ - "low": action_spec.space.low[len(train_env.batch_size) :], - "high": action_spec.space.high[ - len(train_env.batch_size) : - ], # remove batch-size + "low": low, + "high": high, "tanh_loc": False, + "safe_tanh": not cfg.loss.compile, }, default_interaction_type=ExplorationType.RANDOM, ) @@ -334,6 +338,8 @@ def make_discrete_loss(loss_cfg, model): ) loss_module.make_value_estimator(gamma=loss_cfg.gamma) target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau) + if loss_cfg.compile: + loss_module = torch.compile(loss_module) return loss_module, target_net_updater diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index 71fee70d5b8..fc5c8506c48 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -14,9 +14,9 @@ from torch import distributions as D, nn try: - from torch.compiler import assume_constant_result + from torch.compiler import assume_constant_result, is_dynamo_compiling except ImportError: - from torch._dynamo import assume_constant_result + from torch._dynamo import assume_constant_result, is_compiling as is_dynamo_compiling from torch.distributions import constraints from torch.distributions.transforms import _InverseTransform @@ -465,8 +465,8 @@ def __init__( t = SafeTanhTransform() else: t = D.TanhTransform() - # t = D.TanhTransform() - if torch.compiler.is_dynamo_compiling() or ( + t = D.TanhTransform() + if is_dynamo_compiling() or ( self.non_trivial_max or self.non_trivial_min ): t = _PatchedComposeTransform( @@ -495,7 +495,7 @@ def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None: if self.tanh_loc: loc = (loc / self.upscale).tanh() * self.upscale # loc must be rescaled if tanh_loc - if torch.compiler.is_dynamo_compiling() or ( + if is_dynamo_compiling() or ( self.non_trivial_max or self.non_trivial_min ): loc = loc + (self.high - self.low) / 2 + self.low @@ -820,7 +820,7 @@ def _uniform_sample_delta(dist: Delta, size=None) -> torch.Tensor: def _err_compile_safetanh(): raise RuntimeError( - "safe_tanh=True in TanhNormal is not compatible with torch.compile. To deactivate it, pass" + "safe_tanh=True in TanhNormal is not compatible with torch.compile. To deactivate it, pass " "safe_tanh=False. " "If you are using a ProbabilisticTensorDictModule, this can be done via " "`distribution_kwargs={'safe_tanh': False}`. " diff --git a/torchrl/modules/distributions/utils.py b/torchrl/modules/distributions/utils.py index 546d93cb228..2038cd0b08c 100644 --- a/torchrl/modules/distributions/utils.py +++ b/torchrl/modules/distributions/utils.py @@ -8,7 +8,10 @@ import torch from torch import autograd, distributions as d from torch.distributions import Independent, Transform, TransformedDistribution - +try: + from torch.compiler import is_dynamo_compiling +except ImportError: + from torch._dynamo import is_compiling as is_dynamo_compiling def _cast_device(elt: Union[torch.Tensor, float], device) -> Union[torch.Tensor, float]: if isinstance(elt, torch.Tensor): diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index c823788b4c2..1af767b2e83 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -316,10 +316,11 @@ def __init__( self.entropy_bonus = entropy_bonus and entropy_coef self.reduction = reduction - try: - device = next(self.parameters()).device - except AttributeError: - device = torch.device("cpu") + p = next(self.parameters()) + if hasattr(p, "device"): + device = p.device + else: + device = torch.get_default_device() self.register_buffer( "entropy_coef", torch.as_tensor(entropy_coef, device=device) diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index fb8fbff2ccf..2d39d2d028b 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -320,10 +320,11 @@ def __init__( ) self.loss_function = loss_function - try: - device = next(self.parameters()).device - except AttributeError: - device = torch.device("cpu") + p = next(self.parameters()) + if hasattr(p, "device"): + device = p.device + else: + device = torch.get_default_device() self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) if bool(min_alpha) ^ bool(max_alpha): min_alpha = min_alpha if min_alpha else 0.0 @@ -540,6 +541,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ) if shape: tensordict.update(tensordict_reshape.view(shape)) + entropy = -actor_metadata.get(self.tensor_keys.log_prob).mean().detach() + out = { "loss_actor": loss_actor, "loss_actor_bc": loss_actor_bc, @@ -547,7 +550,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "loss_cql": cql_loss, "loss_alpha": loss_alpha, "alpha": self._alpha, - "entropy": -actor_metadata.get(self.tensor_keys.log_prob).mean().detach(), + "entropy": entropy, } if self.with_lagrange: out["loss_alpha_prime"] = alpha_prime_loss.mean() @@ -579,7 +582,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: ExplorationType.RANDOM ), self.actor_network_params.to_module(self.actor_network): dist = self.actor_network.get_dist( - tensordict, + tensordict ) a_reparm = dist.rsample() log_prob = dist.log_prob(a_reparm) @@ -740,12 +743,12 @@ def cql_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: ) random_actions_tensor = ( - torch.FloatTensor( + torch.empty( tensordict.shape[0] * self.num_random, tensordict[self.tensor_keys.action].shape[-1], + device=tensordict.device, ) .uniform_(-1, 1) - .to(tensordict.device) ) curr_actions_td, curr_log_pis = self._get_policy_actions( tensordict.copy(), @@ -884,15 +887,18 @@ def alpha_prime_loss(self, tensordict: TensorDictBase) -> Tensor: ) alpha_prime = torch.clamp_max(self.log_alpha_prime.exp(), max=1000000.0) - min_qf1_loss = alpha_prime * (cql_q1_loss.mean() - self.target_action_gap) - min_qf2_loss = alpha_prime * (cql_q2_loss.mean() - self.target_action_gap) + with torch.no_grad(): + min_qf1_loss = (cql_q1_loss.mean() - self.target_action_gap) + min_qf2_loss = (cql_q2_loss.mean() - self.target_action_gap) - alpha_prime_loss = (-min_qf1_loss - min_qf2_loss) * 0.5 + alpha_prime_loss = alpha_prime * (-min_qf1_loss - min_qf2_loss) * 0.5 alpha_prime_loss = _reduce(alpha_prime_loss, reduction=self.reduction) return alpha_prime_loss, {} def alpha_loss(self, tensordict: TensorDictBase) -> Tensor: log_pi = tensordict.get(self.tensor_keys.log_prob) + if log_pi is None: + log_pi = tensordict.get("bc_log_prob") if self.target_entropy is not None: # we can compute this loss even if log_alpha is not a parameter alpha_loss = -self.log_alpha * (log_pi.detach() + self.target_entropy) @@ -1080,9 +1086,9 @@ def __init__( self.loss_function = loss_function if action_space is None: # infer from value net - try: + if hasattr(value_network, "spec"): action_space = value_network.spec - except AttributeError: + else: # let's try with action_space then try: action_space = value_network.action_space @@ -1205,8 +1211,6 @@ def value_loss( with torch.no_grad(): td_error = (pred_val_index - target_value).pow(2) td_error = td_error.unsqueeze(-1) - if tensordict.device is not None: - td_error = td_error.to(tensordict.device) tensordict.set( self.tensor_keys.priority, diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index d86442fca12..2179cb7d821 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -303,10 +303,11 @@ def __init__( ) self.loss_function = loss_function - try: - device = next(self.parameters()).device - except AttributeError: - device = torch.device("cpu") + p = next(self.parameters()) + if hasattr(p, "device"): + device = p.device + else: + device = torch.get_default_device() self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) if bool(min_alpha) ^ bool(max_alpha): min_alpha = min_alpha if min_alpha else 0.0 diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index eb34b021484..95d2f7ecd48 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -100,10 +100,11 @@ def __init__( "actor_network", create_target_params=False, ) - try: - device = next(self.parameters()).device - except AttributeError: - device = torch.device("cpu") + p = next(self.parameters()) + if hasattr(p, "device"): + device = p.device + else: + device = torch.get_default_device() self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) if bool(min_alpha) ^ bool(max_alpha): diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 32394942600..d46fc93a5c2 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -192,10 +192,11 @@ def __init__( self.sub_sample_len = max(1, min(sub_sample_len, num_qvalue_nets - 1)) self.loss_function = loss_function - try: - device = next(self.parameters()).device - except AttributeError: - device = torch.device("cpu") + p = next(self.parameters()) + if hasattr(p, "device"): + device = p.device + else: + device = torch.get_default_device() self.register_buffer("alpha_init", torch.as_tensor(alpha_init, device=device)) self.register_buffer( diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index 6cbb8b02426..63fcf1f145c 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -210,9 +210,9 @@ def __init__( self.loss_function = loss_function if action_space is None: # infer from value net - try: + if hasattr(value_network, "spec"): action_space = value_network.spec - except AttributeError: + else: # let's try with action_space then try: action_space = value_network.action_space diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py index 39777c59e26..3d588aaae73 100644 --- a/torchrl/objectives/multiagent/qmixer.py +++ b/torchrl/objectives/multiagent/qmixer.py @@ -243,9 +243,9 @@ def __init__( self.loss_function = loss_function if action_space is None: # infer from value net - try: + if hasattr(local_value_network, "spec"): action_space = local_value_network.spec - except AttributeError: + else: # let's try with action_space then try: action_space = local_value_network.action_space diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index efc951b3999..680255d4b98 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -365,10 +365,11 @@ def __init__( self.separate_losses = separate_losses self.reduction = reduction - try: - device = next(self.parameters()).device - except AttributeError: - device = torch.device("cpu") + p = next(self.parameters()) + if hasattr(p, "device"): + device = p.device + else: + device = torch.get_default_device() self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device)) if critic_coef is not None: diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index 271f233bae8..46feeac36d4 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -306,10 +306,11 @@ def __init__( self.sub_sample_len = max(1, min(sub_sample_len, num_qvalue_nets - 1)) self.loss_function = loss_function - try: - device = next(self.parameters()).device - except AttributeError: - device = torch.device("cpu") + p = next(self.parameters()) + if hasattr(p, "device"): + device = p.device + else: + device = torch.get_default_device() self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) self.register_buffer( diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index bd21e33c30d..f73e2755a77 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -362,10 +362,11 @@ def __init__( ) self.loss_function = loss_function - try: - device = next(self.parameters()).device - except AttributeError: - device = torch.device("cpu") + p = next(self.parameters()) + if hasattr(p, "device"): + device = p.device + else: + device = torch.get_default_device() self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) if bool(min_alpha) ^ bool(max_alpha): min_alpha = min_alpha if min_alpha else 0.0 @@ -1051,10 +1052,11 @@ def __init__( self.num_qvalue_nets = num_qvalue_nets self.loss_function = loss_function - try: - device = next(self.parameters()).device - except AttributeError: - device = torch.device("cpu") + p = next(self.parameters()) + if hasattr(p, "device"): + device = p.device + else: + device = torch.get_default_device() self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) if bool(min_alpha) ^ bool(max_alpha): diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index 89ff581991f..1635f721521 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -391,7 +391,7 @@ def actor_loss(self, tensordict) -> Tuple[torch.Tensor, dict]: .get(self.tensor_keys.state_action_value) .squeeze(-1) ) - loss_actor = -(state_action_value_actor[0]) + loss_actor = -state_action_value_actor[0] metadata = { "state_action_value_actor": state_action_value_actor.detach(), }