Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Rename specs to simpler names #2365

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 6 additions & 14 deletions benchmarks/test_objectives_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
TensorDictSequential as Seq,
)
from torch.nn import functional as F
from torchrl.data.tensor_specs import BoundedTensorSpec, UnboundedContinuousTensorSpec
from torchrl.data.tensor_specs import Bounded, Unbounded
from torchrl.modules import MLP, QValueActor, TanhNormal
from torchrl.objectives import (
A2CLoss,
Expand Down Expand Up @@ -253,9 +253,7 @@ def test_sac_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
value = Seq(common, value_head)
value(actor(td))

loss = SACLoss(
actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,))
)
loss = SACLoss(actor, value, action_spec=Unbounded(shape=(n_act,)))

loss(td)
benchmark(loss, td)
Expand Down Expand Up @@ -312,9 +310,7 @@ def test_redq_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden
value = Seq(common, value_head)
value(actor(td))

loss = REDQLoss(
actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,))
)
loss = REDQLoss(actor, value, action_spec=Unbounded(shape=(n_act,)))

loss(td)
benchmark(loss, td)
Expand Down Expand Up @@ -373,9 +369,7 @@ def test_redq_deprec_speed(
value = Seq(common, value_head)
value(actor(td))

loss = REDQLoss_deprecated(
actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,))
)
loss = REDQLoss_deprecated(actor, value, action_spec=Unbounded(shape=(n_act,)))

loss(td)
benchmark(loss, td)
Expand Down Expand Up @@ -435,7 +429,7 @@ def test_td3_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
loss = TD3Loss(
actor,
value,
action_spec=BoundedTensorSpec(shape=(n_act,), low=-1, high=1),
action_spec=Bounded(shape=(n_act,), low=-1, high=1),
)

loss(td)
Expand Down Expand Up @@ -490,9 +484,7 @@ def test_cql_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
value = Seq(common, value_head)
value(actor(td))

loss = CQLLoss(
actor, value, action_spec=UnboundedContinuousTensorSpec(shape=(n_act,))
)
loss = CQLLoss(actor, value, action_spec=Unbounded(shape=(n_act,)))

loss(td)
benchmark(loss, td)
Expand Down
27 changes: 24 additions & 3 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -890,19 +890,40 @@ Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check.
:template: rl_template.rst

TensorSpec
Binary
Bounded
Categorical
Composite
MultiCategorical
MultiOneHot
NonTensor
OneHotDiscrete
Stacked
StackedComposite
Unbounded
UnboundedContinuous
UnboundedDiscrete

The following classes are deprecated and just point to the classes above:

.. currentmodule:: torchrl.data

.. autosummary::
:toctree: generated/
:template: rl_template.rst

BinaryDiscreteTensorSpec
BoundedTensorSpec
CompositeSpec
DiscreteTensorSpec
LazyStackedCompositeSpec
LazyStackedTensorSpec
MultiDiscreteTensorSpec
MultiOneHotDiscreteTensorSpec
NonTensorSpec
OneHotDiscreteTensorSpec
UnboundedContinuousTensorSpec
UnboundedDiscreteTensorSpec
LazyStackedTensorSpec
LazyStackedCompositeSpec
NonTensorSpec

Reinforcement Learning From Human Feedback (RLHF)
-------------------------------------------------
Expand Down
Loading
Loading