diff --git a/test/test_rb.py b/test/test_rb.py index 07c38d1c66..985f1a561e 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -58,6 +58,11 @@ SliceSampler, SliceSamplerWithoutReplacement, ) +from torchrl.data.replay_buffers.scheduler import ( + LinearScheduler, + SchedulerList, + StepScheduler, +) from torchrl.data.replay_buffers.storages import ( LazyMemmapStorage, @@ -98,11 +103,6 @@ UnsqueezeTransform, VecNorm, ) -from torchrl.data.replay_buffers.scheduler import ( - LinearScheduler, - StepScheduler, - SchedulerList -) OLD_TORCH = parse(torch.__version__) < parse("2.0.0") @@ -3040,33 +3040,39 @@ def test_prioritized_parameter_scheduler(): LINEAR_STEPS = 100 TOTAL_STEPS = 200 rb = TensorDictPrioritizedReplayBuffer( - alpha=INIT_ALPHA, - beta=INIT_BETA, - storage=ListStorage(max_size=2000) - ) - data = TensorDict( - { - "data": torch.randn(1000, 5) - }, - batch_size=1000 + alpha=INIT_ALPHA, beta=INIT_BETA, storage=ListStorage(max_size=2000) ) + data = TensorDict({"data": torch.randn(1000, 5)}, batch_size=1000) rb.extend(data) alpha_scheduler = LinearScheduler( rb, param_name="alpha", final_value=0.0, num_steps=LINEAR_STEPS ) beta_scheduler = StepScheduler( - rb, param_name="beta", gamma=GAMMA, n_steps=EVERY_N_STEPS, max_value=1.0, mode="additive" + rb, + param_name="beta", + gamma=GAMMA, + n_steps=EVERY_N_STEPS, + max_value=1.0, + mode="additive", + ) + scheduler = SchedulerList(schedulers=(alpha_scheduler, beta_scheduler)) + expected_alpha_vals = np.linspace(INIT_ALPHA, 0.0, num=LINEAR_STEPS + 1) + expected_alpha_vals = np.pad( + expected_alpha_vals, (0, TOTAL_STEPS - LINEAR_STEPS), constant_values=0.0 ) - scheduler = SchedulerList(scheduler=(alpha_scheduler, beta_scheduler)) - expected_alpha_vals = np.linspace(INIT_ALPHA, 0.0, num=LINEAR_STEPS+1) - expected_alpha_vals = np.pad(expected_alpha_vals, (0, TOTAL_STEPS-LINEAR_STEPS), constant_values=0.0) expected_beta_vals = [INIT_BETA] - for _ in range((TOTAL_STEPS // EVERY_N_STEPS -1)): + for _ in range((TOTAL_STEPS // EVERY_N_STEPS - 1)): expected_beta_vals.append(expected_beta_vals[-1] + GAMMA) - expected_beta_vals = np.atleast_2d(expected_beta_vals).repeat(EVERY_N_STEPS).clip(None, 1.0) + expected_beta_vals = ( + np.atleast_2d(expected_beta_vals).repeat(EVERY_N_STEPS).clip(None, 1.0) + ) for i in range(TOTAL_STEPS): - assert np.isclose(rb.sampler.alpha, expected_alpha_vals[i]), f"expected {expected_alpha_vals[i]}, got {rb.sampler.alpha}" - assert np.isclose(rb.sampler.beta, expected_beta_vals[i]), f"expected {expected_beta_vals[i]}, got {rb.sampler.beta}" + assert np.isclose( + rb.sampler.alpha, expected_alpha_vals[i] + ), f"expected {expected_alpha_vals[i]}, got {rb.sampler.alpha}" + assert np.isclose( + rb.sampler.beta, expected_beta_vals[i] + ), f"expected {expected_beta_vals[i]}, got {rb.sampler.beta}" rb.sample(20) scheduler.step() diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 77cb5e3607..95b552bec5 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -394,11 +394,11 @@ def __repr__(self): @property def max_size(self): return self._max_capacity - + @property def alpha(self): return self._alpha - + @alpha.setter def alpha(self, value): self._alpha = value @@ -406,12 +406,11 @@ def alpha(self, value): @property def beta(self): return self._beta - + @beta.setter def beta(self, value): self._beta = value - def __getstate__(self): if get_spawning_popen() is not None: raise RuntimeError( diff --git a/torchrl/data/replay_buffers/scheduler.py b/torchrl/data/replay_buffers/scheduler.py index a8364f2bd2..f72c3e13a5 100644 --- a/torchrl/data/replay_buffers/scheduler.py +++ b/torchrl/data/replay_buffers/scheduler.py @@ -1,14 +1,15 @@ +from typing import Any, Callable, Dict + import numpy as np -from typing import Callable, Dict, Any from .replay_buffers import ReplayBuffer from .samplers import Sampler class ParameterScheduler: + """Scheduler to adjust the value of a given parameter of a replay buffer's sampler. - """Scheduler to adjust the value of a given parameter of a replay buffer's sampler, e.g. the - alpha and beta values in the PrioritizedSampler. + Scheduler can for example be used to alter the alpha and beta values in the PrioritizedSampler. Args: rb (ReplayBuffer): the replay buffer whose sampler to adjust @@ -21,11 +22,11 @@ class ParameterScheduler: """ def __init__( - self, - obj: ReplayBuffer | Sampler, + self, + obj: ReplayBuffer | Sampler, param_name: str, - min_value: int | float = None, - max_value: int | float = None + min_value: int | float = None, + max_value: int | float = None, ): if not isinstance(obj, ReplayBuffer) and not isinstance(obj, Sampler): raise TypeError( @@ -36,7 +37,9 @@ def __init__( self._min_val = min_value self._max_val = max_value if not hasattr(self.sampler, self.param_name): - raise ValueError(f"Provided class {obj.__name__} does not have an attribute {param_name}") + raise ValueError( + f"Provided class {obj.__name__} does not have an attribute {param_name}" + ) self.initial_val = getattr(self.sampler, self.param_name) self._step_cnt = 0 @@ -46,9 +49,7 @@ def state_dict(self): It contains an entry for every variable in self.__dict__ which is not the optimizer. """ - return { - key: value for key, value in self.__dict__.items() if key != "sampler" - } + return {key: value for key, value in self.__dict__.items() if key != "sampler"} def load_state_dict(self, state_dict: Dict[str, Any]): """Load the scheduler's state. @@ -73,29 +74,30 @@ def _step(self): class LambdaScheduler(ParameterScheduler): - """Similar to torch.optim.LambdaLR, this class sets a parameter to its initial value - times a given function. + """Sets a parameter to its initial value times a given function. + + Similar to torch.optim.LambdaLR. Args: obj (ReplayBuffer | Sampler): the replay buffer whose sampler to adjust (or the sampler itself) - param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the + param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the beta parameter - lambda_fn (function): A function which computes a multiplicative factor given an integer + lambda_fn (function): A function which computes a multiplicative factor given an integer parameter step_count min_value (Union[int, float], optional): a lower bound for the parameter to be adjusted Defaults to None. max_value (Union[int, float], optional): an upper bound for the parameter to be adjusted Defaults to None - + """ def __init__( - self, - obj: ReplayBuffer | Sampler, - param_name: str, + self, + obj: ReplayBuffer | Sampler, + param_name: str, lambda_fn: Callable[[int], float], - min_value: int | float = None, - max_value: int | float = None + min_value: int | float = None, + max_value: int | float = None, ): super().__init__(obj, param_name, min_value, max_value) self.lambda_fn = lambda_fn @@ -104,20 +106,20 @@ def _step(self): return self.initial_val * self.lambda_fn(self._step_cnt) - class LinearScheduler(ParameterScheduler): """A linear scheduler for gradually altering a parameter in an object over a given number of steps. + This scheduler linearly interpolates between the initial value of the parameter and a final target value. - + Args: obj (ReplayBuffer | Sampler): the replay buffer whose sampler to adjust (or the sampler itself) - param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the + param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the beta parameter - final_value (Union[int, float]): The final value that the parameter will reach after the + final_value (Union[int, float]): The final value that the parameter will reach after the specified number of steps. - num_steps (Union[int, float], optional): The total number of steps over which the parameter - will be linearly altered. - + num_steps (Union[int, float], optional): The total number of steps over which the parameter + will be linearly altered. + Example: >>> # xdoctest: +SKIP >>> # Assuming sampler uses initial beta = 0.6 @@ -131,12 +133,13 @@ class LinearScheduler(ParameterScheduler): >>> validate(...) >>> scheduler.step() """ + def __init__( - self, - obj: ReplayBuffer | Sampler, - param_name: str, + self, + obj: ReplayBuffer | Sampler, + param_name: str, final_value: int | float, - num_steps: int + num_steps: int, ): super().__init__(obj, param_name) self.final_val = final_value @@ -150,21 +153,19 @@ def _step(self): return self.final_val - class StepScheduler(ParameterScheduler): - """ - A step scheduler that alters a parameter after every n steps using either multiplicative or additive changes. - + """A step scheduler that alters a parameter after every n steps using either multiplicative or additive changes. + The scheduler can apply: 1. Multiplicative changes: `new_val = curr_val * gamma` 2. Additive changes: `new_val = curr_val + gamma` - + Args: obj (ReplayBuffer | Sampler): the replay buffer whose sampler to adjust (or the sampler itself) - param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the + param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the beta parameter - gamma (int | float, optional): The value by which to adjust the parameter, - either multiplicatively or additive + gamma (int | float, optional): The value by which to adjust the parameter, + either in a multiplicative or additive way n_steps (int, optional): The number of steps after which the parameter should be altered. Defaults to 1 mode (str, optional): The mode of scheduling. Can be either 'multiplicative' or 'additive'. @@ -172,16 +173,16 @@ class StepScheduler(ParameterScheduler): min_value (int | float, optional): a lower bound for the parameter to be adjusted Defaults to None. max_value (int | float, optional): an upper bound for the parameter to be adjusted - Defaults to None + Defaults to None Example: >>> # xdoctest: +SKIP >>> # Assuming sampler uses initial beta = 0.6 - >>> # beta = 0.6 if step < 10 - >>> # beta = 0.7 if step == 10 - >>> # beta = 0.8 if step == 20 - >>> # beta = 0.9 if step == 30 - >>> # beta = 1.0 if step >= 40 + >>> # beta = 0.6 if 0 <= step < 10 + >>> # beta = 0.7 if 10 <= step < 20 + >>> # beta = 0.8 if 20 <= step < 30 + >>> # beta = 0.9 if 30 <= step < 40 + >>> # beta = 1.0 if 40 <= step >>> scheduler = StepScheduler(sampler, param_name='beta', gamma=0.1, mode='additive', max_value=1.0) >>> for epoch in range(100): >>> train(...) @@ -191,24 +192,26 @@ class StepScheduler(ParameterScheduler): def __init__( self, - obj: ReplayBuffer | Sampler, - param_name: str, - gamma: int | float = 0.9, - n_steps: int = 1, + obj: ReplayBuffer | Sampler, + param_name: str, + gamma: int | float = 0.9, + n_steps: int = 1, mode: str = "multiplicative", - min_value: int | float = None, - max_value: int | float = None + min_value: int | float = None, + max_value: int | float = None, ): super().__init__(obj, param_name, min_value, max_value) self.gamma = gamma self.n_steps = n_steps - if mode == "additive": + if mode == "additive": operator = np.add elif mode == "multiplicative": operator = np.multiply else: - raise ValueError(f"Invalid mode: {self.mode}. Choose 'multiplicative' or 'additive'.") + raise ValueError( + f"Invalid mode: {self.mode}. Choose 'multiplicative' or 'additive'." + ) self.operator = operator def _step(self): @@ -222,14 +225,16 @@ def _step(self): class SchedulerList: - def __init__(self, scheduler: list[ParameterScheduler]) -> None: - if isinstance(scheduler, ParameterScheduler): - scheduler = [scheduler] - self.scheduler = scheduler + """Simple container abstracting a list of schedulers.""" + + def __init__(self, schedulers: list[ParameterScheduler]) -> None: + if isinstance(schedulers, ParameterScheduler): + schedulers = [schedulers] + self.schedulers = schedulers def append(self, scheduler: ParameterScheduler): - self.scheduler.append(scheduler) + self.schedulers.append(scheduler) def step(self): - for scheduler in self.scheduler: - scheduler.step() \ No newline at end of file + for scheduler in self.schedulers: + scheduler.step()