Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Reinitialize vmap callers after reset of vmap randomness #2314

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ def vmap_randomness(self):

def set_vmap_randomness(self, value):
self._vmap_randomness = value
self._make_vmap()

@staticmethod
def _make_meta_params(param):
Expand All @@ -570,6 +571,11 @@ def _make_meta_params(param):
pd = nn.Parameter(pd, requires_grad=False)
return pd

def _make_vmap(self):
raise NotImplementedError(
f"_make_vmap has been called but is not implemented for loss of type {type(self).__name__}."
)


class _make_target_param:
def __init__(self, clone):
Expand Down
4 changes: 3 additions & 1 deletion torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,13 +374,15 @@ def __init__(
torch.nn.Parameter(torch.tensor(math.log(1.0), device=device)),
)

self.reduction = reduction

def _make_vmap(self):
self._vmap_qvalue_networkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
)
self._vmap_qvalue_network00 = _vmap_func(
self.qvalue_network, randomness=self.vmap_randomness
)
self.reduction = reduction

@property
def target_entropy(self):
Expand Down
5 changes: 4 additions & 1 deletion torchrl/objectives/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,13 @@ def __init__(

self._target_entropy = target_entropy
self._action_spec = action_spec
self._make_vmap()
self.reduction = reduction

def _make_vmap(self):
self._vmap_qnetworkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
)
self.reduction = reduction

@property
def target_entropy_buffer(self):
Expand Down
5 changes: 3 additions & 2 deletions torchrl/objectives/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,12 @@ def __init__(
self.gSDE = gSDE
self.reduction = reduction

self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0))

if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)

def _make_vmap(self):
self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0))

@property
def target_entropy(self):
target_entropy = self.target_entropy_buffer
Expand Down
5 changes: 4 additions & 1 deletion torchrl/objectives/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,10 +318,13 @@ def __init__(
self.loss_function = loss_function
if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self._make_vmap()
self.reduction = reduction

def _make_vmap(self):
self._vmap_qvalue_networkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
)
self.reduction = reduction

@property
def device(self) -> torch.device:
Expand Down
2 changes: 2 additions & 0 deletions torchrl/objectives/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,9 @@ def __init__(
self.gSDE = gSDE
if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self._make_vmap()

def _make_vmap(self):
self._vmap_qvalue_network00 = _vmap_func(
self.qvalue_network, randomness=self.vmap_randomness
)
Expand Down
10 changes: 8 additions & 2 deletions torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,14 +403,17 @@ def __init__(
)
if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self._make_vmap()
self.reduction = reduction

def _make_vmap(self):
self._vmap_qnetworkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
)
if self._version == 1:
self._vmap_qnetwork00 = _vmap_func(
qvalue_network, randomness=self.vmap_randomness
)
self.reduction = reduction

@property
def target_entropy_buffer(self):
Expand Down Expand Up @@ -1101,10 +1104,13 @@ def __init__(
self.register_buffer(
"target_entropy", torch.tensor(target_entropy, device=device)
)
self._make_vmap()
self.reduction = reduction

def _make_vmap(self):
self._vmap_qnetworkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
)
self.reduction = reduction

def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
Expand Down
5 changes: 4 additions & 1 deletion torchrl/objectives/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,13 +317,16 @@ def __init__(
self.register_buffer("min_action", low)
if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self._make_vmap()
self.reduction = reduction

def _make_vmap(self):
self._vmap_qvalue_network00 = _vmap_func(
self.qvalue_network, randomness=self.vmap_randomness
)
self._vmap_actor_network00 = _vmap_func(
self.actor_network, randomness=self.vmap_randomness
)
self.reduction = reduction

def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
Expand Down
5 changes: 4 additions & 1 deletion torchrl/objectives/td3_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,13 +331,16 @@ def __init__(
high = high.to(device)
self.register_buffer("max_action", high)
self.register_buffer("min_action", low)
self._make_vmap()
self.reduction = reduction

def _make_vmap(self):
self._vmap_qvalue_network00 = _vmap_func(
self.qvalue_network, randomness=self.vmap_randomness
)
self._vmap_actor_network00 = _vmap_func(
self.actor_network, randomness=self.vmap_randomness
)
self.reduction = reduction

def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
Expand Down
Loading