diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index f2b02825005..c80633f4580 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -559,6 +559,7 @@ def vmap_randomness(self): def set_vmap_randomness(self, value): self._vmap_randomness = value + self._make_vmap() @staticmethod def _make_meta_params(param): @@ -570,6 +571,11 @@ def _make_meta_params(param): pd = nn.Parameter(pd, requires_grad=False) return pd + def _make_vmap(self): + raise NotImplementedError( + f"_make_vmap has been called but is not implemented for loss of type {type(self).__name__}." + ) + class _make_target_param: def __init__(self, clone): diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 0d2d869d1e1..96f37225fd8 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -374,13 +374,15 @@ def __init__( torch.nn.Parameter(torch.tensor(math.log(1.0), device=device)), ) + self.reduction = reduction + + def _make_vmap(self): self._vmap_qvalue_networkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) self._vmap_qvalue_network00 = _vmap_func( self.qvalue_network, randomness=self.vmap_randomness ) - self.reduction = reduction @property def target_entropy(self): diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 355a33a4682..05499cb227d 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -338,10 +338,13 @@ def __init__( self._target_entropy = target_entropy self._action_spec = action_spec + self._make_vmap() + self.reduction = reduction + + def _make_vmap(self): self._vmap_qnetworkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) - self.reduction = reduction @property def target_entropy_buffer(self): diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 9e7115ac601..b54e96eb32f 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -223,11 +223,12 @@ def __init__( self.gSDE = gSDE self.reduction = reduction - self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0)) - if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) + def _make_vmap(self): + self._vmap_qvalue_networkN0 = _vmap_func(self.qvalue_network, (None, 0)) + @property def target_entropy(self): target_entropy = self.target_entropy_buffer diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 7fab95a95ed..013435c9079 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -318,10 +318,13 @@ def __init__( self.loss_function = loss_function if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) + self._make_vmap() + self.reduction = reduction + + def _make_vmap(self): self._vmap_qvalue_networkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) - self.reduction = reduction @property def device(self) -> torch.device: diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index a0aaa96f7c5..db05063535a 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -336,7 +336,9 @@ def __init__( self.gSDE = gSDE if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) + self._make_vmap() + def _make_vmap(self): self._vmap_qvalue_network00 = _vmap_func( self.qvalue_network, randomness=self.vmap_randomness ) diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 67ab7d7d8ce..d03014da7bd 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -403,6 +403,10 @@ def __init__( ) if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) + self._make_vmap() + self.reduction = reduction + + def _make_vmap(self): self._vmap_qnetworkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) @@ -410,7 +414,6 @@ def __init__( self._vmap_qnetwork00 = _vmap_func( qvalue_network, randomness=self.vmap_randomness ) - self.reduction = reduction @property def target_entropy_buffer(self): @@ -1101,10 +1104,13 @@ def __init__( self.register_buffer( "target_entropy", torch.tensor(target_entropy, device=device) ) + self._make_vmap() + self.reduction = reduction + + def _make_vmap(self): self._vmap_qnetworkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) - self.reduction = reduction def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index db99237d39e..b0026b0158d 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -317,13 +317,16 @@ def __init__( self.register_buffer("min_action", low) if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) + self._make_vmap() + self.reduction = reduction + + def _make_vmap(self): self._vmap_qvalue_network00 = _vmap_func( self.qvalue_network, randomness=self.vmap_randomness ) self._vmap_actor_network00 = _vmap_func( self.actor_network, randomness=self.vmap_randomness ) - self.reduction = reduction def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: diff --git a/torchrl/objectives/td3_bc.py b/torchrl/objectives/td3_bc.py index d5529e0b859..bea101f4038 100644 --- a/torchrl/objectives/td3_bc.py +++ b/torchrl/objectives/td3_bc.py @@ -331,13 +331,16 @@ def __init__( high = high.to(device) self.register_buffer("max_action", high) self.register_buffer("min_action", low) + self._make_vmap() + self.reduction = reduction + + def _make_vmap(self): self._vmap_qvalue_network00 = _vmap_func( self.qvalue_network, randomness=self.vmap_randomness ) self._vmap_actor_network00 = _vmap_func( self.actor_network, randomness=self.vmap_randomness ) - self.reduction = reduction def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: