Skip to content

Commit

Permalink
Update (base update)
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Sep 2, 2024
1 parent d5a62c4 commit 6b1761e
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions benchmarks/test_objectives_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def test_sac_speed(
value, in_keys=["hidden", "action"], out_keys=["state_action_value"]
)
value = Seq(common, value_head)
value(actor(td))
value(actor(td.clone()))

loss = SACLoss(actor, value, action_spec=Unbounded(shape=(n_act,)))

Expand Down Expand Up @@ -404,7 +404,7 @@ def test_redq_speed(
value, in_keys=["hidden", "action"], out_keys=["state_action_value"]
)
value = Seq(common, value_head)
value(actor(td))
value(actor(td.copy()))

loss = REDQLoss(actor, value, action_spec=Unbounded(shape=(n_act,)))

Expand Down Expand Up @@ -489,7 +489,7 @@ def test_redq_deprec_speed(
value, in_keys=["hidden", "action"], out_keys=["state_action_value"]
)
value = Seq(common, value_head)
value(actor(td))
value(actor(td.copy()))

loss = REDQLoss_deprecated(actor, value, action_spec=Unbounded(shape=(n_act,)))

Expand Down

0 comments on commit 6b1761e

Please sign in to comment.