Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jul 23, 2024
1 parent e3a67c9 commit a4cdd5e
Show file tree
Hide file tree
Showing 15 changed files with 276 additions and 89 deletions.
119 changes: 119 additions & 0 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -6854,6 +6854,71 @@ def test_cql(
p.grad is None or p.grad.norm() == 0.0
), f"target parameter {name} (shape: {p.shape}) has a non-null gradient"

@pytest.mark.parametrize("delay_actor", (True,))
@pytest.mark.parametrize("delay_qvalue", (True,))
@pytest.mark.parametrize(
"max_q_backup",
[
True,
],
)
@pytest.mark.parametrize(
"deterministic_backup",
[
True,
],
)
@pytest.mark.parametrize(
"with_lagrange",
[
True,
],
)
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("td_est", [None])
def test_cql_qvalfromlist(
self,
delay_actor,
delay_qvalue,
max_q_backup,
deterministic_backup,
with_lagrange,
device,
td_est,
):
torch.manual_seed(self.seed)
td = self._create_mock_data_cql(device=device)

actor = self._create_mock_actor(device=device)
qvalue0 = self._create_mock_qvalue(device=device)
qvalue1 = self._create_mock_qvalue(device=device)

loss_fn_single = CQLLoss(
actor_network=actor,
qvalue_network=qvalue0,
loss_function="l2",
max_q_backup=max_q_backup,
deterministic_backup=deterministic_backup,
with_lagrange=with_lagrange,
delay_actor=delay_actor,
delay_qvalue=delay_qvalue,
)
loss_fn_mult = CQLLoss(
actor_network=actor,
qvalue_network=[qvalue0, qvalue1],
loss_function="l2",
max_q_backup=max_q_backup,
deterministic_backup=deterministic_backup,
with_lagrange=with_lagrange,
delay_actor=delay_actor,
delay_qvalue=delay_qvalue,
)
# Check that all params have the same shape
p2 = dict(loss_fn_mult.named_parameters())
for key, val in loss_fn_single.named_parameters():
assert val.shape == p2[key].shape
assert len(dict(loss_fn_single.named_parameters())) == len(p2)

@pytest.mark.parametrize("delay_actor", (True, False))
@pytest.mark.parametrize("delay_qvalue", (True, False))
@pytest.mark.parametrize("max_q_backup", [True])
Expand Down Expand Up @@ -14605,6 +14670,60 @@ def init(mod):
loss.from_stateful_net("module_a", module_a)
assert (loss.module_a_params == 1).all()

def test_from_module_list(self):
class MyLoss(LossModule):
module_a: TensorDictModule
module_b: TensorDictModule
module_a_params: TensorDict
module_b_params: TensorDict
target_module_a_params: TensorDict
target_module_b_params: TensorDict

def __init__(self, module_a, module_b0, module_b1, expand_dim=2):
super().__init__()
self.convert_to_functional(module_a, "module_a")
self.convert_to_functional(
[module_b0, module_b1],
"module_b",
# This will be ignored
compare_against=module_a.parameters(),
expand_dim=expand_dim,
)

module1 = nn.Linear(3, 4)
module2 = nn.Linear(3, 4)
module3a = nn.Linear(3, 4)
module3b = nn.Linear(3, 4)

module_a = TensorDictModule(
nn.Sequential(module1, module2), in_keys=["a"], out_keys=["c"]
)

module_b0 = TensorDictModule(
nn.Sequential(module1, module3a), in_keys=["b"], out_keys=["c"]
)
module_b1 = TensorDictModule(
nn.Sequential(module1, module3b), in_keys=["b"], out_keys=["c"]
)

loss = MyLoss(module_a, module_b0, module_b1)

# This should be extended
assert not isinstance(
loss.module_b_params["module", "0", "weight"], nn.Parameter
)
assert loss.module_b_params["module", "0", "weight"].shape[0] == 2
assert (
loss.module_b_params["module", "0", "weight"].data.data_ptr()
== loss.module_a_params["module", "0", "weight"].data.data_ptr()
)
assert isinstance(loss.module_b_params["module", "1", "weight"], nn.Parameter)
assert loss.module_b_params["module", "1", "weight"].shape[0] == 2
assert (
loss.module_b_params["module", "1", "weight"].data.data_ptr()
!= loss.module_a_params["module", "1", "weight"].data.data_ptr()
)

def test_tensordict_keys(self):
"""Test configurable tensordict key behavior with derived classes."""

Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class A2CLoss(LossModule):
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, ie. gradients are propagated to shared
Defaults to ``False``, i.e., gradients are propagated to shared
parameters for both policy and critic losses.
advantage_key (str): [Deprecated, use set_keys(advantage_key=advantage_key) instead]
The input tensordict key where the advantage is expected to be written. default: "advantage"
Expand Down
102 changes: 56 additions & 46 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,57 +317,67 @@ def convert_to_functional(
# Otherwise, casting the module to a device will keep old references
# to uncast tensors
sep = self.SEP
params = TensorDict.from_module(module, as_module=True)

for key in params.keys(True):
if sep in key:
raise KeyError(
f"The key {key} contains the '_sep_' pattern which is prohibited. Consider renaming the parameter / buffer."
if isinstance(module, (list, tuple)):
if len(module) != expand_dim:
raise RuntimeError(
"The ``expand_dim`` value must match the length of the module list/tuple "
"if a single module isn't provided."
)
if compare_against is not None:
compare_against = set(compare_against)
params = TensorDict.from_modules(
*module, as_module=True, expand_identical=True
)
else:
compare_against = set()
if expand_dim:
# Expands the dims of params and buffers.
# If the param already exist in the module, we return a simple expansion of the
# original one. Otherwise, we expand and resample it.
# For buffers, a cloned expansion (or equivalently a repeat) is returned.

def _compare_and_expand(param):
if is_tensor_collection(param):
return param._apply_nest(
params = TensorDict.from_module(module, as_module=True)

for key in params.keys(True):
if sep in key:
raise KeyError(
f"The key {key} contains the '_sep_' pattern which is prohibited. Consider renaming the parameter / buffer."
)
if compare_against is not None:
compare_against = set(compare_against)
else:
compare_against = set()
if expand_dim:
# Expands the dims of params and buffers.
# If the param already exist in the module, we return a simple expansion of the
# original one. Otherwise, we expand and resample it.
# For buffers, a cloned expansion (or equivalently a repeat) is returned.

def _compare_and_expand(param):
if is_tensor_collection(param):
return param._apply_nest(
_compare_and_expand,
batch_size=[expand_dim, *param.shape],
filter_empty=False,
call_on_nested=True,
)
if not isinstance(param, nn.Parameter):
buffer = param.expand(expand_dim, *param.shape).clone()
return buffer
if param in compare_against:
expanded_param = param.data.expand(expand_dim, *param.shape)
# the expanded parameter must be sent to device when to()
# is called:
return expanded_param
else:
p_out = param.expand(expand_dim, *param.shape).clone()
p_out = nn.Parameter(
p_out.uniform_(
p_out.min().item(), p_out.max().item()
).requires_grad_()
)
return p_out

params = TensorDictParams(
params.apply(
_compare_and_expand,
batch_size=[expand_dim, *param.shape],
batch_size=[expand_dim, *params.shape],
filter_empty=False,
call_on_nested=True,
)
if not isinstance(param, nn.Parameter):
buffer = param.expand(expand_dim, *param.shape).clone()
return buffer
if param in compare_against:
expanded_param = param.data.expand(expand_dim, *param.shape)
# the expanded parameter must be sent to device when to()
# is called:
return expanded_param
else:
p_out = param.expand(expand_dim, *param.shape).clone()
p_out = nn.Parameter(
p_out.uniform_(
p_out.min().item(), p_out.max().item()
).requires_grad_()
)
return p_out

params = TensorDictParams(
params.apply(
_compare_and_expand,
batch_size=[expand_dim, *params.shape],
filter_empty=False,
call_on_nested=True,
),
no_convert=True,
)
),
no_convert=True,
)

param_name = module_name + "_params"

Expand Down
13 changes: 10 additions & 3 deletions torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from copy import deepcopy
from dataclasses import dataclass

from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -46,8 +46,15 @@ class CQLLoss(LossModule):
Args:
actor_network (ProbabilisticActor): stochastic actor
qvalue_network (TensorDictModule): Q(s, a) parametric model.
qvalue_network (TensorDictModule or list of TensorDictModule): Q(s, a) parametric model.
This module typically outputs a ``"state_action_value"`` entry.
If a single instance of `qvalue_network` is provided, it will be duplicated ``N``
times (where ``N=2`` for this loss). If a list of modules is passed, their
parameters will be stacked unless they share the same identity (in which case
the original parameter will be expanded).
.. warning:: When a list of parameters if passed, it will __not__ be compared against the policy parameters
and all the parameters will be considered as untied.
Keyword args:
loss_function (str, optional): loss function to be used with
Expand Down Expand Up @@ -266,7 +273,7 @@ class _AcceptedKeys:
def __init__(
self,
actor_network: ProbabilisticActor,
qvalue_network: TensorDictModule,
qvalue_network: TensorDictModule | List[TensorDictModule],
*,
loss_function: str = "smooth_l1",
alpha_init: float = 1.0,
Expand Down
13 changes: 10 additions & 3 deletions torchrl/objectives/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import math
from dataclasses import dataclass
from functools import wraps
from typing import Dict, Tuple, Union
from typing import Dict, List, Tuple, Union

import torch
from tensordict import TensorDict, TensorDictBase, TensorDictParams
Expand Down Expand Up @@ -54,6 +54,13 @@ class CrossQLoss(LossModule):
actor_network (ProbabilisticActor): stochastic actor
qvalue_network (TensorDictModule): Q(s, a) parametric model.
This module typically outputs a ``"state_action_value"`` entry.
If a single instance of `qvalue_network` is provided, it will be duplicated ``num_qvalue_nets``
times. If a list of modules is passed, their
parameters will be stacked unless they share the same identity (in which case
the original parameter will be expanded).
.. warning:: When a list of parameters if passed, it will __not__ be compared against the policy parameters
and all the parameters will be considered as untied.
Keyword Args:
num_qvalue_nets (integer, optional): number of Q-Value networks used.
Expand Down Expand Up @@ -81,7 +88,7 @@ class CrossQLoss(LossModule):
priority (for prioritized replay buffer usage). Defaults to ``"td_error"``.
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, ie. gradients are propagated to shared
Defaults to ``False``, i.e., gradients are propagated to shared
parameters for both policy and critic losses.
reduction (str, optional): Specifies the reduction to apply to the output:
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
Expand Down Expand Up @@ -248,7 +255,7 @@ class _AcceptedKeys:
def __init__(
self,
actor_network: ProbabilisticActor,
qvalue_network: TensorDictModule,
qvalue_network: TensorDictModule | List[TensorDictModule],
*,
num_qvalue_nets: int = 2,
loss_function: str = "smooth_l1",
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class DDPGLoss(LossModule):
data collection. Default is ``True``.
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, ie. gradients are propagated to shared
Defaults to ``False``, i.e., gradients are propagated to shared
parameters for both policy and critic losses.
reduction (str, optional): Specifies the reduction to apply to the output:
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
Expand Down
13 changes: 10 additions & 3 deletions torchrl/objectives/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import math
from dataclasses import dataclass
from numbers import Number
from typing import Tuple, Union
from typing import List, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -41,6 +41,13 @@ class REDQLoss_deprecated(LossModule):
actor_network (TensorDictModule): the actor to be trained
qvalue_network (TensorDictModule): a single Q-value network that will
be multiplied as many times as needed.
If a single instance of `qvalue_network` is provided, it will be duplicated ``num_qvalue_nets``
times. If a list of modules is passed, their
parameters will be stacked unless they share the same identity (in which case
the original parameter will be expanded).
.. warning:: When a list of parameters if passed, it will __not__ be compared against the policy parameters
and all the parameters will be considered as untied.
Keyword Args:
num_qvalue_nets (int, optional): Number of Q-value networks to be trained.
Expand Down Expand Up @@ -75,7 +82,7 @@ class REDQLoss_deprecated(LossModule):
``"td_error"``.
separate_losses (bool, optional): if ``True``, shared parameters between
policy and critic will only be trained on the policy loss.
Defaults to ``False``, ie. gradients are propagated to shared
Defaults to ``False``, i.e., gradients are propagated to shared
parameters for both policy and critic losses.
reduction (str, optional): Specifies the reduction to apply to the output:
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
Expand Down Expand Up @@ -134,7 +141,7 @@ class _AcceptedKeys:
def __init__(
self,
actor_network: TensorDictModule,
qvalue_network: TensorDictModule,
qvalue_network: TensorDictModule | List[TensorDictModule],
*,
num_qvalue_nets: int = 10,
sub_sample_len: int = 2,
Expand Down
Loading

0 comments on commit a4cdd5e

Please sign in to comment.