Skip to content

Commit

Permalink
[Doc] Improve docstrings of samplers
Browse files Browse the repository at this point in the history
  • Loading branch information
LTluttmann committed Sep 24, 2024
1 parent e2337ef commit b4dca1b
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 87 deletions.
50 changes: 28 additions & 22 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@
SliceSampler,
SliceSamplerWithoutReplacement,
)
from torchrl.data.replay_buffers.scheduler import (
LinearScheduler,
SchedulerList,
StepScheduler,
)

from torchrl.data.replay_buffers.storages import (
LazyMemmapStorage,
Expand Down Expand Up @@ -98,11 +103,6 @@
UnsqueezeTransform,
VecNorm,
)
from torchrl.data.replay_buffers.scheduler import (
LinearScheduler,
StepScheduler,
SchedulerList
)


OLD_TORCH = parse(torch.__version__) < parse("2.0.0")
Expand Down Expand Up @@ -3040,33 +3040,39 @@ def test_prioritized_parameter_scheduler():
LINEAR_STEPS = 100
TOTAL_STEPS = 200
rb = TensorDictPrioritizedReplayBuffer(
alpha=INIT_ALPHA,
beta=INIT_BETA,
storage=ListStorage(max_size=2000)
)
data = TensorDict(
{
"data": torch.randn(1000, 5)
},
batch_size=1000
alpha=INIT_ALPHA, beta=INIT_BETA, storage=ListStorage(max_size=2000)
)
data = TensorDict({"data": torch.randn(1000, 5)}, batch_size=1000)
rb.extend(data)
alpha_scheduler = LinearScheduler(
rb, param_name="alpha", final_value=0.0, num_steps=LINEAR_STEPS
)
beta_scheduler = StepScheduler(
rb, param_name="beta", gamma=GAMMA, n_steps=EVERY_N_STEPS, max_value=1.0, mode="additive"
rb,
param_name="beta",
gamma=GAMMA,
n_steps=EVERY_N_STEPS,
max_value=1.0,
mode="additive",
)
scheduler = SchedulerList(schedulers=(alpha_scheduler, beta_scheduler))
expected_alpha_vals = np.linspace(INIT_ALPHA, 0.0, num=LINEAR_STEPS + 1)
expected_alpha_vals = np.pad(
expected_alpha_vals, (0, TOTAL_STEPS - LINEAR_STEPS), constant_values=0.0
)
scheduler = SchedulerList(scheduler=(alpha_scheduler, beta_scheduler))
expected_alpha_vals = np.linspace(INIT_ALPHA, 0.0, num=LINEAR_STEPS+1)
expected_alpha_vals = np.pad(expected_alpha_vals, (0, TOTAL_STEPS-LINEAR_STEPS), constant_values=0.0)
expected_beta_vals = [INIT_BETA]
for _ in range((TOTAL_STEPS // EVERY_N_STEPS -1)):
for _ in range((TOTAL_STEPS // EVERY_N_STEPS - 1)):
expected_beta_vals.append(expected_beta_vals[-1] + GAMMA)
expected_beta_vals = np.atleast_2d(expected_beta_vals).repeat(EVERY_N_STEPS).clip(None, 1.0)
expected_beta_vals = (
np.atleast_2d(expected_beta_vals).repeat(EVERY_N_STEPS).clip(None, 1.0)
)
for i in range(TOTAL_STEPS):
assert np.isclose(rb.sampler.alpha, expected_alpha_vals[i]), f"expected {expected_alpha_vals[i]}, got {rb.sampler.alpha}"
assert np.isclose(rb.sampler.beta, expected_beta_vals[i]), f"expected {expected_beta_vals[i]}, got {rb.sampler.beta}"
assert np.isclose(
rb.sampler.alpha, expected_alpha_vals[i]
), f"expected {expected_alpha_vals[i]}, got {rb.sampler.alpha}"
assert np.isclose(
rb.sampler.beta, expected_beta_vals[i]
), f"expected {expected_beta_vals[i]}, got {rb.sampler.beta}"
rb.sample(20)
scheduler.step()

Expand Down
7 changes: 3 additions & 4 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,24 +394,23 @@ def __repr__(self):
@property
def max_size(self):
return self._max_capacity

@property
def alpha(self):
return self._alpha

@alpha.setter
def alpha(self, value):
self._alpha = value

@property
def beta(self):
return self._beta

@beta.setter
def beta(self, value):
self._beta = value


def __getstate__(self):
if get_spawning_popen() is not None:
raise RuntimeError(
Expand Down
127 changes: 66 additions & 61 deletions torchrl/data/replay_buffers/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from typing import Any, Callable, Dict

import numpy as np
from typing import Callable, Dict, Any

from .replay_buffers import ReplayBuffer
from .samplers import Sampler


class ParameterScheduler:
"""Scheduler to adjust the value of a given parameter of a replay buffer's sampler.
"""Scheduler to adjust the value of a given parameter of a replay buffer's sampler, e.g. the
alpha and beta values in the PrioritizedSampler.
Scheduler can for example be used to alter the alpha and beta values in the PrioritizedSampler.
Args:
rb (ReplayBuffer): the replay buffer whose sampler to adjust
Expand All @@ -21,11 +22,11 @@ class ParameterScheduler:
"""

def __init__(
self,
obj: ReplayBuffer | Sampler,
self,
obj: ReplayBuffer | Sampler,
param_name: str,
min_value: int | float = None,
max_value: int | float = None
min_value: int | float = None,
max_value: int | float = None,
):
if not isinstance(obj, ReplayBuffer) and not isinstance(obj, Sampler):
raise TypeError(
Expand All @@ -36,7 +37,9 @@ def __init__(
self._min_val = min_value
self._max_val = max_value
if not hasattr(self.sampler, self.param_name):
raise ValueError(f"Provided class {obj.__name__} does not have an attribute {param_name}")
raise ValueError(
f"Provided class {obj.__name__} does not have an attribute {param_name}"
)
self.initial_val = getattr(self.sampler, self.param_name)
self._step_cnt = 0

Expand All @@ -46,9 +49,7 @@ def state_dict(self):
It contains an entry for every variable in self.__dict__ which
is not the optimizer.
"""
return {
key: value for key, value in self.__dict__.items() if key != "sampler"
}
return {key: value for key, value in self.__dict__.items() if key != "sampler"}

def load_state_dict(self, state_dict: Dict[str, Any]):
"""Load the scheduler's state.
Expand All @@ -73,29 +74,30 @@ def _step(self):


class LambdaScheduler(ParameterScheduler):
"""Similar to torch.optim.LambdaLR, this class sets a parameter to its initial value
times a given function.
"""Sets a parameter to its initial value times a given function.
Similar to torch.optim.LambdaLR.
Args:
obj (ReplayBuffer | Sampler): the replay buffer whose sampler to adjust (or the sampler itself)
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the
beta parameter
lambda_fn (function): A function which computes a multiplicative factor given an integer
lambda_fn (function): A function which computes a multiplicative factor given an integer
parameter step_count
min_value (Union[int, float], optional): a lower bound for the parameter to be adjusted
Defaults to None.
max_value (Union[int, float], optional): an upper bound for the parameter to be adjusted
Defaults to None
"""

def __init__(
self,
obj: ReplayBuffer | Sampler,
param_name: str,
self,
obj: ReplayBuffer | Sampler,
param_name: str,
lambda_fn: Callable[[int], float],
min_value: int | float = None,
max_value: int | float = None
min_value: int | float = None,
max_value: int | float = None,
):
super().__init__(obj, param_name, min_value, max_value)
self.lambda_fn = lambda_fn
Expand All @@ -104,20 +106,20 @@ def _step(self):
return self.initial_val * self.lambda_fn(self._step_cnt)



class LinearScheduler(ParameterScheduler):
"""A linear scheduler for gradually altering a parameter in an object over a given number of steps.
This scheduler linearly interpolates between the initial value of the parameter and a final target value.
Args:
obj (ReplayBuffer | Sampler): the replay buffer whose sampler to adjust (or the sampler itself)
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the
beta parameter
final_value (Union[int, float]): The final value that the parameter will reach after the
final_value (Union[int, float]): The final value that the parameter will reach after the
specified number of steps.
num_steps (Union[int, float], optional): The total number of steps over which the parameter
will be linearly altered.
num_steps (Union[int, float], optional): The total number of steps over which the parameter
will be linearly altered.
Example:
>>> # xdoctest: +SKIP
>>> # Assuming sampler uses initial beta = 0.6
Expand All @@ -131,12 +133,13 @@ class LinearScheduler(ParameterScheduler):
>>> validate(...)
>>> scheduler.step()
"""

def __init__(
self,
obj: ReplayBuffer | Sampler,
param_name: str,
self,
obj: ReplayBuffer | Sampler,
param_name: str,
final_value: int | float,
num_steps: int
num_steps: int,
):
super().__init__(obj, param_name)
self.final_val = final_value
Expand All @@ -150,38 +153,36 @@ def _step(self):
return self.final_val



class StepScheduler(ParameterScheduler):
"""
A step scheduler that alters a parameter after every n steps using either multiplicative or additive changes.
"""A step scheduler that alters a parameter after every n steps using either multiplicative or additive changes.
The scheduler can apply:
1. Multiplicative changes: `new_val = curr_val * gamma`
2. Additive changes: `new_val = curr_val + gamma`
Args:
obj (ReplayBuffer | Sampler): the replay buffer whose sampler to adjust (or the sampler itself)
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the
beta parameter
gamma (int | float, optional): The value by which to adjust the parameter,
either multiplicatively or additive
gamma (int | float, optional): The value by which to adjust the parameter,
either in a multiplicative or additive way
n_steps (int, optional): The number of steps after which the parameter should be altered.
Defaults to 1
mode (str, optional): The mode of scheduling. Can be either 'multiplicative' or 'additive'.
Defaults to 'multiplicative'
min_value (int | float, optional): a lower bound for the parameter to be adjusted
Defaults to None.
max_value (int | float, optional): an upper bound for the parameter to be adjusted
Defaults to None
Defaults to None
Example:
>>> # xdoctest: +SKIP
>>> # Assuming sampler uses initial beta = 0.6
>>> # beta = 0.6 if step < 10
>>> # beta = 0.7 if step == 10
>>> # beta = 0.8 if step == 20
>>> # beta = 0.9 if step == 30
>>> # beta = 1.0 if step >= 40
>>> # beta = 0.6 if 0 <= step < 10
>>> # beta = 0.7 if 10 <= step < 20
>>> # beta = 0.8 if 20 <= step < 30
>>> # beta = 0.9 if 30 <= step < 40
>>> # beta = 1.0 if 40 <= step
>>> scheduler = StepScheduler(sampler, param_name='beta', gamma=0.1, mode='additive', max_value=1.0)
>>> for epoch in range(100):
>>> train(...)
Expand All @@ -191,24 +192,26 @@ class StepScheduler(ParameterScheduler):

def __init__(
self,
obj: ReplayBuffer | Sampler,
param_name: str,
gamma: int | float = 0.9,
n_steps: int = 1,
obj: ReplayBuffer | Sampler,
param_name: str,
gamma: int | float = 0.9,
n_steps: int = 1,
mode: str = "multiplicative",
min_value: int | float = None,
max_value: int | float = None
min_value: int | float = None,
max_value: int | float = None,
):

super().__init__(obj, param_name, min_value, max_value)
self.gamma = gamma
self.n_steps = n_steps
if mode == "additive":
if mode == "additive":
operator = np.add
elif mode == "multiplicative":
operator = np.multiply
else:
raise ValueError(f"Invalid mode: {self.mode}. Choose 'multiplicative' or 'additive'.")
raise ValueError(
f"Invalid mode: {self.mode}. Choose 'multiplicative' or 'additive'."
)
self.operator = operator

def _step(self):
Expand All @@ -222,14 +225,16 @@ def _step(self):


class SchedulerList:
def __init__(self, scheduler: list[ParameterScheduler]) -> None:
if isinstance(scheduler, ParameterScheduler):
scheduler = [scheduler]
self.scheduler = scheduler
"""Simple container abstracting a list of schedulers."""

def __init__(self, schedulers: list[ParameterScheduler]) -> None:
if isinstance(schedulers, ParameterScheduler):
schedulers = [schedulers]
self.schedulers = schedulers

def append(self, scheduler: ParameterScheduler):
self.scheduler.append(scheduler)
self.schedulers.append(scheduler)

def step(self):
for scheduler in self.scheduler:
scheduler.step()
for scheduler in self.schedulers:
scheduler.step()

0 comments on commit b4dca1b

Please sign in to comment.