Skip to content

Commit

Permalink
[Refactor] Rename specs to simpler names
Browse files Browse the repository at this point in the history
ghstack-source-id: 13ec2536bf19bf521acdde0cf244ac6a3a197e17
Pull Request resolved: #2365
  • Loading branch information
vmoens committed Aug 6, 2024
1 parent 788710f commit daf2a26
Show file tree
Hide file tree
Showing 87 changed files with 55,061 additions and 2,387 deletions.
20 changes: 6 additions & 14 deletions benchmarks/test_objectives_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
TensorDictSequential as Seq,
)
from torch.nn import functional as F
from torchrl.data.tensor_specs import BoundedTensorSpec, UnboundedContinuousTensorSpec
from torchrl.data.tensor_specs import Bounded, Unbounded
from torchrl.modules import MLP, QValueActor, TanhNormal
from torchrl.objectives import (
A2CLoss,
Expand Down Expand Up @@ -253,9 +253,7 @@ def test_sac_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
value = Seq(common, value_head)
value(actor(td))

loss = SACLoss(
actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,))
)
loss = SACLoss(actor, value, action_spec=Unbounded(shape=(n_act,)))

loss(td)
benchmark(loss, td)
Expand Down Expand Up @@ -312,9 +310,7 @@ def test_redq_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden
value = Seq(common, value_head)
value(actor(td))

loss = REDQLoss(
actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,))
)
loss = REDQLoss(actor, value, action_spec=Unbounded(shape=(n_act,)))

loss(td)
benchmark(loss, td)
Expand Down Expand Up @@ -373,9 +369,7 @@ def test_redq_deprec_speed(
value = Seq(common, value_head)
value(actor(td))

loss = REDQLoss_deprecated(
actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,))
)
loss = REDQLoss_deprecated(actor, value, action_spec=Unbounded(shape=(n_act,)))

loss(td)
benchmark(loss, td)
Expand Down Expand Up @@ -435,7 +429,7 @@ def test_td3_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
loss = TD3Loss(
actor,
value,
action_spec=BoundedTensorSpec(shape=(n_act,), low=-1, high=1),
action_spec=Bounded(shape=(n_act,), low=-1, high=1),
)

loss(td)
Expand Down Expand Up @@ -490,9 +484,7 @@ def test_cql_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
value = Seq(common, value_head)
value(actor(td))

loss = CQLLoss(
actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,))
)
loss = CQLLoss(actor, value, action_spec=Unbounded(shape=(n_act,)))

loss(td)
benchmark(loss, td)
Expand Down
27 changes: 24 additions & 3 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -890,19 +890,40 @@ Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check.
:template: rl_template.rst

TensorSpec
Binary
Bounded
Categorical
Composite
MultiCategorical
MultiOneHot
NonTensor
OneHotDiscrete
Stacked
StackedComposite
Unbounded
UnboundedContinuous
UnboundedDiscrete

The following classes are deprecated and just point to the classes above:

.. currentmodule:: torchrl.data

.. autosummary::
:toctree: generated/
:template: rl_template.rst

BinaryDiscreteTensorSpec
BoundedTensorSpec
CompositeSpec
DiscreteTensorSpec
LazyStackedCompositeSpec
LazyStackedTensorSpec
MultiDiscreteTensorSpec
MultiOneHotDiscreteTensorSpec
NonTensorSpec
OneHotDiscreteTensorSpec
UnboundedContinuousTensorSpec
UnboundedDiscreteTensorSpec
LazyStackedTensorSpec
LazyStackedCompositeSpec
NonTensorSpec

Reinforcement Learning From Human Feedback (RLHF)
-------------------------------------------------
Expand Down
Loading

0 comments on commit daf2a26

Please sign in to comment.