Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Store MARL parameters in module #2351

Merged
merged 2 commits into from
Aug 3, 2024
Merged

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Aug 2, 2024

Description

We currently store the parameters in MARL modules in self.params in a TensorDictParams.
During a call to forward, we call vmap and to_module to put the batched parameters in place within the module.

This PR proposes to optionally make self.params a regular TensorDict (ie, self.parameters() will not see them because self.params is not within the self.modules() anymore), and place them in the self._empty_net instead. With that in place, the module has two copies of the parameters, but one is not accessible via self.parameters() (so things don't change from the user perspective).

We test that these two scenarios are identical and that sending the module to device does not create multiple distinct copies of the params.

cc @matteobettini

Copy link

pytorch-bot bot commented Aug 2, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2351

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 4 Unrelated Failures

As of commit 99f5dcc with merge base 59d2ae1 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 2, 2024
Copy link

github-actions bot commented Aug 2, 2024

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 91. Improved: $\large\color{#35bf28}7$. Worsened: $\large\color{#d91a1a}1$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_single 60.8594ms 58.0787ms 17.2180 Ops/s 16.9606 Ops/s $\color{#35bf28}+1.52\%$
test_sync 33.4960ms 31.6156ms 31.6300 Ops/s 30.9966 Ops/s $\color{#35bf28}+2.04\%$
test_async 53.9812ms 30.1618ms 33.1545 Ops/s 33.1318 Ops/s $\color{#35bf28}+0.07\%$
test_simple 0.4764s 0.4100s 2.4393 Ops/s 2.3884 Ops/s $\color{#35bf28}+2.13\%$
test_transformed 0.6281s 0.5667s 1.7647 Ops/s 1.7258 Ops/s $\color{#35bf28}+2.25\%$
test_serial 1.3178s 1.2514s 0.7991 Ops/s 0.7892 Ops/s $\color{#35bf28}+1.25\%$
test_parallel 1.1867s 1.1220s 0.8913 Ops/s 0.8981 Ops/s $\color{#d91a1a}-0.76\%$
test_step_mdp_speed[True-True-True-True-True] 77.6570μs 23.8122μs 41.9952 KOps/s 41.0001 KOps/s $\color{#35bf28}+2.43\%$
test_step_mdp_speed[True-True-True-True-False] 40.1650μs 13.9823μs 71.5188 KOps/s 69.3805 KOps/s $\color{#35bf28}+3.08\%$
test_step_mdp_speed[True-True-True-False-True] 49.6730μs 13.7194μs 72.8893 KOps/s 70.5843 KOps/s $\color{#35bf28}+3.27\%$
test_step_mdp_speed[True-True-True-False-False] 56.1460μs 8.1598μs 122.5516 KOps/s 120.6058 KOps/s $\color{#35bf28}+1.61\%$
test_step_mdp_speed[True-True-False-True-True] 80.8810μs 25.4408μs 39.3070 KOps/s 38.7792 KOps/s $\color{#35bf28}+1.36\%$
test_step_mdp_speed[True-True-False-True-False] 60.3230μs 15.4593μs 64.6861 KOps/s 63.6353 KOps/s $\color{#35bf28}+1.65\%$
test_step_mdp_speed[True-True-False-False-True] 48.6910μs 15.2752μs 65.4657 KOps/s 63.9940 KOps/s $\color{#35bf28}+2.30\%$
test_step_mdp_speed[True-True-False-False-False] 45.5450μs 9.5504μs 104.7072 KOps/s 103.2983 KOps/s $\color{#35bf28}+1.36\%$
test_step_mdp_speed[True-False-True-True-True] 60.5730μs 27.3040μs 36.6246 KOps/s 36.4190 KOps/s $\color{#35bf28}+0.56\%$
test_step_mdp_speed[True-False-True-True-False] 55.2530μs 17.0829μs 58.5379 KOps/s 58.0164 KOps/s $\color{#35bf28}+0.90\%$
test_step_mdp_speed[True-False-True-False-True] 41.0870μs 15.3451μs 65.1673 KOps/s 64.4278 KOps/s $\color{#35bf28}+1.15\%$
test_step_mdp_speed[True-False-True-False-False] 36.3380μs 9.5931μs 104.2420 KOps/s 103.3176 KOps/s $\color{#35bf28}+0.89\%$
test_step_mdp_speed[True-False-False-True-True] 57.4970μs 28.6329μs 34.9248 KOps/s 35.0743 KOps/s $\color{#d91a1a}-0.43\%$
test_step_mdp_speed[True-False-False-True-False] 53.9300μs 18.4106μs 54.3166 KOps/s 53.9439 KOps/s $\color{#35bf28}+0.69\%$
test_step_mdp_speed[True-False-False-False-True] 62.1160μs 16.5997μs 60.2419 KOps/s 59.7351 KOps/s $\color{#35bf28}+0.85\%$
test_step_mdp_speed[True-False-False-False-False] 41.6880μs 10.8472μs 92.1901 KOps/s 89.4092 KOps/s $\color{#35bf28}+3.11\%$
test_step_mdp_speed[False-True-True-True-True] 93.3140μs 26.8498μs 37.2443 KOps/s 36.4847 KOps/s $\color{#35bf28}+2.08\%$
test_step_mdp_speed[False-True-True-True-False] 44.5730μs 16.9294μs 59.0689 KOps/s 57.6231 KOps/s $\color{#35bf28}+2.51\%$
test_step_mdp_speed[False-True-True-False-True] 49.7330μs 17.5997μs 56.8190 KOps/s 54.6237 KOps/s $\color{#35bf28}+4.02\%$
test_step_mdp_speed[False-True-True-False-False] 33.6140μs 10.7646μs 92.8974 KOps/s 90.8619 KOps/s $\color{#35bf28}+2.24\%$
test_step_mdp_speed[False-True-False-True-True] 55.6250μs 28.6919μs 34.8531 KOps/s 34.6712 KOps/s $\color{#35bf28}+0.52\%$
test_step_mdp_speed[False-True-False-True-False] 44.8740μs 18.6361μs 53.6594 KOps/s 53.4642 KOps/s $\color{#35bf28}+0.37\%$
test_step_mdp_speed[False-True-False-False-True] 44.2230μs 19.0267μs 52.5576 KOps/s 51.1201 KOps/s $\color{#35bf28}+2.81\%$
test_step_mdp_speed[False-True-False-False-False] 64.9710μs 12.0870μs 82.7335 KOps/s 81.0897 KOps/s $\color{#35bf28}+2.03\%$
test_step_mdp_speed[False-False-True-True-True] 4.9703ms 30.9542μs 32.3058 KOps/s 32.7407 KOps/s $\color{#d91a1a}-1.33\%$
test_step_mdp_speed[False-False-True-True-False] 54.4520μs 19.6727μs 50.8317 KOps/s 49.6940 KOps/s $\color{#35bf28}+2.29\%$
test_step_mdp_speed[False-False-True-False-True] 49.0020μs 18.9595μs 52.7439 KOps/s 51.7012 KOps/s $\color{#35bf28}+2.02\%$
test_step_mdp_speed[False-False-True-False-False] 38.5020μs 12.0400μs 83.0561 KOps/s 81.1532 KOps/s $\color{#35bf28}+2.34\%$
test_step_mdp_speed[False-False-False-True-True] 80.1400μs 31.2660μs 31.9836 KOps/s 31.9307 KOps/s $\color{#35bf28}+0.17\%$
test_step_mdp_speed[False-False-False-True-False] 44.5740μs 20.9597μs 47.7107 KOps/s 46.5263 KOps/s $\color{#35bf28}+2.55\%$
test_step_mdp_speed[False-False-False-False-True] 51.4460μs 20.0099μs 49.9754 KOps/s 48.7567 KOps/s $\color{#35bf28}+2.50\%$
test_step_mdp_speed[False-False-False-False-False] 43.6820μs 13.2395μs 75.5316 KOps/s 73.3613 KOps/s $\color{#35bf28}+2.96\%$
test_values[generalized_advantage_estimate-True-True] 11.5159ms 9.8473ms 101.5509 Ops/s 105.3125 Ops/s $\color{#d91a1a}-3.57\%$
test_values[vec_generalized_advantage_estimate-True-True] 37.5409ms 33.6259ms 29.7390 Ops/s 27.7375 Ops/s $\textbf{\color{#35bf28}+7.22\%}$
test_values[td0_return_estimate-False-False] 0.2253ms 0.1632ms 6.1288 KOps/s 5.9978 KOps/s $\color{#35bf28}+2.18\%$
test_values[td1_return_estimate-False-False] 25.7453ms 24.1383ms 41.4279 Ops/s 41.7063 Ops/s $\color{#d91a1a}-0.67\%$
test_values[vec_td1_return_estimate-False-False] 35.6314ms 33.6180ms 29.7460 Ops/s 27.6331 Ops/s $\textbf{\color{#35bf28}+7.65\%}$
test_values[td_lambda_return_estimate-True-False] 35.6141ms 34.7343ms 28.7900 Ops/s 28.8062 Ops/s $\color{#d91a1a}-0.06\%$
test_values[vec_td_lambda_return_estimate-True-False] 35.3133ms 33.6304ms 29.7350 Ops/s 27.5942 Ops/s $\textbf{\color{#35bf28}+7.76\%}$
test_gae_speed[generalized_advantage_estimate-False-1-512] 8.6173ms 8.5050ms 117.5776 Ops/s 120.1911 Ops/s $\color{#d91a1a}-2.17\%$
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] 1.8700ms 1.7815ms 561.3352 Ops/s 441.4320 Ops/s $\textbf{\color{#35bf28}+27.16\%}$
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] 0.5924ms 0.3565ms 2.8054 KOps/s 2.7594 KOps/s $\color{#35bf28}+1.67\%$
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] 41.2907ms 38.7780ms 25.7878 Ops/s 22.7053 Ops/s $\textbf{\color{#35bf28}+13.58\%}$
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 3.9704ms 3.0346ms 329.5341 Ops/s 328.7155 Ops/s $\color{#35bf28}+0.25\%$
test_dqn_speed 6.3492ms 1.2916ms 774.2185 Ops/s 763.7180 Ops/s $\color{#35bf28}+1.37\%$
test_ddpg_speed 2.9762ms 2.6971ms 370.7623 Ops/s 366.8248 Ops/s $\color{#35bf28}+1.07\%$
test_sac_speed 8.9942ms 7.8987ms 126.6024 Ops/s 124.6071 Ops/s $\color{#35bf28}+1.60\%$
test_redq_speed 13.5127ms 12.5030ms 79.9806 Ops/s 79.6195 Ops/s $\color{#35bf28}+0.45\%$
test_redq_deprec_speed 13.8810ms 12.5049ms 79.9688 Ops/s 78.5457 Ops/s $\color{#35bf28}+1.81\%$
test_td3_speed 8.3465ms 7.8654ms 127.1398 Ops/s 126.3317 Ops/s $\color{#35bf28}+0.64\%$
test_cql_speed 35.9901ms 35.1285ms 28.4669 Ops/s 28.2031 Ops/s $\color{#35bf28}+0.94\%$
test_a2c_speed 7.8295ms 7.2417ms 138.0886 Ops/s 138.1495 Ops/s $\color{#d91a1a}-0.04\%$
test_ppo_speed 8.9631ms 7.4977ms 133.3734 Ops/s 132.8584 Ops/s $\color{#35bf28}+0.39\%$
test_reinforce_speed 7.6140ms 6.4028ms 156.1813 Ops/s 156.1572 Ops/s $\color{#35bf28}+0.02\%$
test_iql_speed 33.3337ms 31.7687ms 31.4775 Ops/s 28.8420 Ops/s $\textbf{\color{#35bf28}+9.14\%}$
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 6.6381ms 4.7600ms 210.0833 Ops/s 208.6635 Ops/s $\color{#35bf28}+0.68\%$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 0.7460ms 0.4744ms 2.1078 KOps/s 2.0964 KOps/s $\color{#35bf28}+0.54\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 0.6694ms 0.4510ms 2.2172 KOps/s 2.1870 KOps/s $\color{#35bf28}+1.38\%$
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 6.6852ms 4.7144ms 212.1183 Ops/s 205.5682 Ops/s $\color{#35bf28}+3.19\%$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 0.7908ms 0.4939ms 2.0247 KOps/s 2.0859 KOps/s $\color{#d91a1a}-2.94\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 0.6686ms 0.4568ms 2.1890 KOps/s 2.1842 KOps/s $\color{#35bf28}+0.22\%$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] 2.4522ms 1.6914ms 591.2126 Ops/s 590.0084 Ops/s $\color{#35bf28}+0.20\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] 2.1586ms 1.6032ms 623.7356 Ops/s 623.7380 Ops/s $-0.00\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 6.8665ms 5.0988ms 196.1264 Ops/s 204.7254 Ops/s $\color{#d91a1a}-4.20\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 1.0157ms 0.6216ms 1.6089 KOps/s 1.6261 KOps/s $\color{#d91a1a}-1.06\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 0.7771ms 0.5967ms 1.6759 KOps/s 1.7160 KOps/s $\color{#d91a1a}-2.34\%$
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 5.2075ms 4.9130ms 203.5424 Ops/s 210.4294 Ops/s $\color{#d91a1a}-3.27\%$
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 1.2371ms 0.4866ms 2.0550 KOps/s 2.0949 KOps/s $\color{#d91a1a}-1.91\%$
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 0.6645ms 0.4617ms 2.1659 KOps/s 2.1789 KOps/s $\color{#d91a1a}-0.60\%$
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 5.9631ms 4.8569ms 205.8944 Ops/s 212.9380 Ops/s $\color{#d91a1a}-3.31\%$
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 0.7190ms 0.4780ms 2.0923 KOps/s 2.1093 KOps/s $\color{#d91a1a}-0.81\%$
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 7.1771ms 0.4672ms 2.1405 KOps/s 2.2368 KOps/s $\color{#d91a1a}-4.31\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 5.2972ms 4.9849ms 200.6069 Ops/s 205.8461 Ops/s $\color{#d91a1a}-2.55\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 0.9501ms 0.6104ms 1.6384 KOps/s 1.6238 KOps/s $\color{#35bf28}+0.90\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 0.7855ms 0.5871ms 1.7033 KOps/s 1.6627 KOps/s $\color{#35bf28}+2.44\%$
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 0.1221s 8.2732ms 120.8728 Ops/s 167.5156 Ops/s $\textbf{\color{#d91a1a}-27.84\%}$
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] 16.8086ms 12.9537ms 77.1981 Ops/s 77.2236 Ops/s $\color{#d91a1a}-0.03\%$
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] 1.8267ms 1.1390ms 877.9705 Ops/s 903.3859 Ops/s $\color{#d91a1a}-2.81\%$
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 0.1099s 5.8167ms 171.9196 Ops/s 171.0075 Ops/s $\color{#35bf28}+0.53\%$
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] 17.0531ms 12.9615ms 77.1517 Ops/s 77.3634 Ops/s $\color{#d91a1a}-0.27\%$
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] 1.8565ms 1.1393ms 877.7313 Ops/s 877.4761 Ops/s $\color{#35bf28}+0.03\%$
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] 0.1100s 8.1448ms 122.7777 Ops/s 123.2273 Ops/s $\color{#d91a1a}-0.36\%$
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] 17.5152ms 13.0745ms 76.4848 Ops/s 76.7339 Ops/s $\color{#d91a1a}-0.32\%$
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] 2.0098ms 1.2830ms 779.4535 Ops/s 728.1783 Ops/s $\textbf{\color{#35bf28}+7.04\%}$

@vmoens vmoens added the enhancement New feature or request label Aug 2, 2024
Copy link

github-actions bot commented Aug 2, 2024

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of GPU Benchmark Tests

Total Benchmarks: 94. Improved: $\large\color{#35bf28}3$. Worsened: $\large\color{#d91a1a}3$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_single 0.1083s 0.1081s 9.2478 Ops/s 9.1035 Ops/s $\color{#35bf28}+1.59\%$
test_sync 96.2380ms 95.8970ms 10.4279 Ops/s 10.4743 Ops/s $\color{#d91a1a}-0.44\%$
test_async 0.1790s 89.8904ms 11.1247 Ops/s 11.2981 Ops/s $\color{#d91a1a}-1.54\%$
test_single_pixels 0.1201s 0.1191s 8.3950 Ops/s 8.3807 Ops/s $\color{#35bf28}+0.17\%$
test_sync_pixels 77.4057ms 75.0139ms 13.3309 Ops/s 13.3348 Ops/s $\color{#d91a1a}-0.03\%$
test_async_pixels 0.1398s 69.3810ms 14.4132 Ops/s 14.4059 Ops/s $\color{#35bf28}+0.05\%$
test_simple 0.7820s 0.7813s 1.2799 Ops/s 1.2466 Ops/s $\color{#35bf28}+2.67\%$
test_transformed 1.1076s 1.0335s 0.9676 Ops/s 0.9782 Ops/s $\color{#d91a1a}-1.09\%$
test_serial 2.3195s 2.2473s 0.4450 Ops/s 0.4416 Ops/s $\color{#35bf28}+0.76\%$
test_parallel 2.0396s 1.9841s 0.5040 Ops/s 0.5066 Ops/s $\color{#d91a1a}-0.51\%$
test_step_mdp_speed[True-True-True-True-True] 0.1040ms 36.0206μs 27.7619 KOps/s 26.7325 KOps/s $\color{#35bf28}+3.85\%$
test_step_mdp_speed[True-True-True-True-False] 39.3010μs 20.7227μs 48.2563 KOps/s 46.5723 KOps/s $\color{#35bf28}+3.62\%$
test_step_mdp_speed[True-True-True-False-True] 49.5510μs 20.6293μs 48.4748 KOps/s 47.4369 KOps/s $\color{#35bf28}+2.19\%$
test_step_mdp_speed[True-True-True-False-False] 30.9610μs 11.9492μs 83.6878 KOps/s 84.2822 KOps/s $\color{#d91a1a}-0.71\%$
test_step_mdp_speed[True-True-False-True-True] 59.5810μs 38.4785μs 25.9885 KOps/s 25.3073 KOps/s $\color{#35bf28}+2.69\%$
test_step_mdp_speed[True-True-False-True-False] 42.5110μs 23.2164μs 43.0730 KOps/s 42.0332 KOps/s $\color{#35bf28}+2.47\%$
test_step_mdp_speed[True-True-False-False-True] 43.1620μs 23.0089μs 43.4614 KOps/s 41.9601 KOps/s $\color{#35bf28}+3.58\%$
test_step_mdp_speed[True-True-False-False-False] 35.4310μs 14.2273μs 70.2876 KOps/s 69.2938 KOps/s $\color{#35bf28}+1.43\%$
test_step_mdp_speed[True-False-True-True-True] 71.2520μs 40.7142μs 24.5614 KOps/s 23.7132 KOps/s $\color{#35bf28}+3.58\%$
test_step_mdp_speed[True-False-True-True-False] 53.2810μs 25.5470μs 39.1435 KOps/s 38.0121 KOps/s $\color{#35bf28}+2.98\%$
test_step_mdp_speed[True-False-True-False-True] 57.1310μs 22.6112μs 44.2259 KOps/s 42.8372 KOps/s $\color{#35bf28}+3.24\%$
test_step_mdp_speed[True-False-True-False-False] 32.5900μs 14.1789μs 70.5274 KOps/s 69.9104 KOps/s $\color{#35bf28}+0.88\%$
test_step_mdp_speed[True-False-False-True-True] 72.2510μs 42.4975μs 23.5308 KOps/s 22.5736 KOps/s $\color{#35bf28}+4.24\%$
test_step_mdp_speed[True-False-False-True-False] 60.1110μs 27.4178μs 36.4726 KOps/s 35.3670 KOps/s $\color{#35bf28}+3.13\%$
test_step_mdp_speed[True-False-False-False-True] 44.1910μs 24.8130μs 40.3015 KOps/s 38.6465 KOps/s $\color{#35bf28}+4.28\%$
test_step_mdp_speed[True-False-False-False-False] 36.1010μs 16.3322μs 61.2289 KOps/s 61.6458 KOps/s $\color{#d91a1a}-0.68\%$
test_step_mdp_speed[False-True-True-True-True] 60.9210μs 40.4013μs 24.7517 KOps/s 23.8815 KOps/s $\color{#35bf28}+3.64\%$
test_step_mdp_speed[False-True-True-True-False] 50.5810μs 25.3500μs 39.4478 KOps/s 38.5502 KOps/s $\color{#35bf28}+2.33\%$
test_step_mdp_speed[False-True-True-False-True] 52.4410μs 26.9628μs 37.0882 KOps/s 35.9671 KOps/s $\color{#35bf28}+3.12\%$
test_step_mdp_speed[False-True-True-False-False] 37.3200μs 16.0720μs 62.2200 KOps/s 60.9309 KOps/s $\color{#35bf28}+2.12\%$
test_step_mdp_speed[False-True-False-True-True] 64.3600μs 42.7812μs 23.3747 KOps/s 22.6899 KOps/s $\color{#35bf28}+3.02\%$
test_step_mdp_speed[False-True-False-True-False] 45.7710μs 27.5551μs 36.2909 KOps/s 35.5163 KOps/s $\color{#35bf28}+2.18\%$
test_step_mdp_speed[False-True-False-False-True] 73.3910μs 28.8504μs 34.6615 KOps/s 33.0408 KOps/s $\color{#35bf28}+4.91\%$
test_step_mdp_speed[False-True-False-False-False] 41.6300μs 18.1248μs 55.1730 KOps/s 53.3433 KOps/s $\color{#35bf28}+3.43\%$
test_step_mdp_speed[False-False-True-True-True] 4.0032ms 45.0151μs 22.2148 KOps/s 21.4147 KOps/s $\color{#35bf28}+3.74\%$
test_step_mdp_speed[False-False-True-True-False] 47.9010μs 29.7437μs 33.6205 KOps/s 32.7864 KOps/s $\color{#35bf28}+2.54\%$
test_step_mdp_speed[False-False-True-False-True] 52.3610μs 29.2682μs 34.1668 KOps/s 33.1629 KOps/s $\color{#35bf28}+3.03\%$
test_step_mdp_speed[False-False-True-False-False] 49.0700μs 18.1907μs 54.9732 KOps/s 53.8108 KOps/s $\color{#35bf28}+2.16\%$
test_step_mdp_speed[False-False-False-True-True] 72.0910μs 46.9558μs 21.2966 KOps/s 20.8103 KOps/s $\color{#35bf28}+2.34\%$
test_step_mdp_speed[False-False-False-True-False] 56.4510μs 31.8132μs 31.4334 KOps/s 30.4653 KOps/s $\color{#35bf28}+3.18\%$
test_step_mdp_speed[False-False-False-False-True] 59.9610μs 30.7492μs 32.5212 KOps/s 31.3747 KOps/s $\color{#35bf28}+3.65\%$
test_step_mdp_speed[False-False-False-False-False] 42.3610μs 20.2891μs 49.2875 KOps/s 48.6610 KOps/s $\color{#35bf28}+1.29\%$
test_values[generalized_advantage_estimate-True-True] 24.6960ms 24.3358ms 41.0917 Ops/s 41.0758 Ops/s $\color{#35bf28}+0.04\%$
test_values[vec_generalized_advantage_estimate-True-True] 96.6017ms 2.8266ms 353.7806 Ops/s 373.1887 Ops/s $\textbf{\color{#d91a1a}-5.20\%}$
test_values[td0_return_estimate-False-False] 92.3010μs 64.8003μs 15.4320 KOps/s 15.3579 KOps/s $\color{#35bf28}+0.48\%$
test_values[td1_return_estimate-False-False] 54.9978ms 54.4949ms 18.3504 Ops/s 18.3411 Ops/s $\color{#35bf28}+0.05\%$
test_values[vec_td1_return_estimate-False-False] 1.3319ms 1.0804ms 925.5448 Ops/s 926.7632 Ops/s $\color{#d91a1a}-0.13\%$
test_values[td_lambda_return_estimate-True-False] 87.4524ms 86.6773ms 11.5370 Ops/s 11.4541 Ops/s $\color{#35bf28}+0.72\%$
test_values[vec_td_lambda_return_estimate-True-False] 1.2641ms 1.0758ms 929.5132 Ops/s 924.3211 Ops/s $\color{#35bf28}+0.56\%$
test_gae_speed[generalized_advantage_estimate-False-1-512] 24.7570ms 24.4916ms 40.8303 Ops/s 40.7605 Ops/s $\color{#35bf28}+0.17\%$
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] 0.9360ms 0.7141ms 1.4003 KOps/s 1.4078 KOps/s $\color{#d91a1a}-0.53\%$
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] 0.7471ms 0.6654ms 1.5029 KOps/s 1.5003 KOps/s $\color{#35bf28}+0.18\%$
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] 1.6187ms 1.4656ms 682.3087 Ops/s 684.6305 Ops/s $\color{#d91a1a}-0.34\%$
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 0.7069ms 0.6803ms 1.4700 KOps/s 1.4683 KOps/s $\color{#35bf28}+0.12\%$
test_dqn_speed 7.2455ms 1.3561ms 737.4004 Ops/s 713.1605 Ops/s $\color{#35bf28}+3.40\%$
test_ddpg_speed 2.9393ms 2.7271ms 366.6902 Ops/s 362.8527 Ops/s $\color{#35bf28}+1.06\%$
test_sac_speed 8.1415ms 7.8818ms 126.8739 Ops/s 125.9786 Ops/s $\color{#35bf28}+0.71\%$
test_redq_speed 12.0670ms 10.1998ms 98.0416 Ops/s 99.6837 Ops/s $\color{#d91a1a}-1.65\%$
test_redq_deprec_speed 11.2003ms 10.8145ms 92.4685 Ops/s 92.5732 Ops/s $\color{#d91a1a}-0.11\%$
test_td3_speed 7.9883ms 7.8344ms 127.6424 Ops/s 127.1104 Ops/s $\color{#35bf28}+0.42\%$
test_cql_speed 25.6982ms 25.1240ms 39.8026 Ops/s 39.5185 Ops/s $\color{#35bf28}+0.72\%$
test_a2c_speed 5.7846ms 5.5162ms 181.2830 Ops/s 177.4163 Ops/s $\color{#35bf28}+2.18\%$
test_ppo_speed 6.0877ms 5.8146ms 171.9818 Ops/s 167.5109 Ops/s $\color{#35bf28}+2.67\%$
test_reinforce_speed 5.2625ms 4.4583ms 224.3018 Ops/s 219.7956 Ops/s $\color{#35bf28}+2.05\%$
test_iql_speed 20.0205ms 19.3390ms 51.7090 Ops/s 50.7365 Ops/s $\color{#35bf28}+1.92\%$
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 6.8335ms 6.6412ms 150.5745 Ops/s 147.2232 Ops/s $\color{#35bf28}+2.28\%$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 0.6549ms 0.5261ms 1.9007 KOps/s 1.9035 KOps/s $\color{#d91a1a}-0.14\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 0.6929ms 0.5049ms 1.9804 KOps/s 2.0010 KOps/s $\color{#d91a1a}-1.03\%$
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 6.8353ms 6.5292ms 153.1573 Ops/s 151.2057 Ops/s $\color{#35bf28}+1.29\%$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 1.4001ms 0.5122ms 1.9522 KOps/s 1.9352 KOps/s $\color{#35bf28}+0.88\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 0.6761ms 0.4945ms 2.0221 KOps/s 2.0211 KOps/s $\color{#35bf28}+0.05\%$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] 2.1617ms 1.9770ms 505.8102 Ops/s 501.6605 Ops/s $\color{#35bf28}+0.83\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] 2.0618ms 1.8767ms 532.8376 Ops/s 529.1907 Ops/s $\color{#35bf28}+0.69\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 6.9180ms 6.7587ms 147.9576 Ops/s 146.7965 Ops/s $\color{#35bf28}+0.79\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 0.9750ms 0.6724ms 1.4872 KOps/s 1.4719 KOps/s $\color{#35bf28}+1.04\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 0.8180ms 0.6511ms 1.5359 KOps/s 1.5323 KOps/s $\color{#35bf28}+0.23\%$
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 6.7845ms 6.6517ms 150.3382 Ops/s 146.4199 Ops/s $\color{#35bf28}+2.68\%$
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 0.9348ms 0.5272ms 1.8967 KOps/s 1.9114 KOps/s $\color{#d91a1a}-0.77\%$
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 0.6551ms 0.5056ms 1.9780 KOps/s 2.0061 KOps/s $\color{#d91a1a}-1.40\%$
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 6.9231ms 6.5829ms 151.9097 Ops/s 149.5321 Ops/s $\color{#35bf28}+1.59\%$
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 1.7462ms 0.5389ms 1.8557 KOps/s 1.9251 KOps/s $\color{#d91a1a}-3.61\%$
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 1.4684ms 0.5761ms 1.7359 KOps/s 2.0065 KOps/s $\textbf{\color{#d91a1a}-13.49\%}$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 6.9649ms 6.8082ms 146.8828 Ops/s 144.8607 Ops/s $\color{#35bf28}+1.40\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 1.9442ms 0.6847ms 1.4605 KOps/s 1.4693 KOps/s $\color{#d91a1a}-0.60\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 0.7870ms 0.6504ms 1.5375 KOps/s 1.5177 KOps/s $\color{#35bf28}+1.31\%$
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 0.1318s 7.6483ms 130.7474 Ops/s 100.2972 Ops/s $\textbf{\color{#35bf28}+30.36\%}$
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] 19.6142ms 15.8788ms 62.9771 Ops/s 60.9847 Ops/s $\color{#35bf28}+3.27\%$
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] 2.4495ms 1.2261ms 815.6235 Ops/s 732.7017 Ops/s $\textbf{\color{#35bf28}+11.32\%}$
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 0.1241s 7.4996ms 133.3407 Ops/s 132.1132 Ops/s $\color{#35bf28}+0.93\%$
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] 18.5292ms 15.8524ms 63.0821 Ops/s 61.3859 Ops/s $\color{#35bf28}+2.76\%$
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] 2.4271ms 1.1931ms 838.1335 Ops/s 849.0529 Ops/s $\color{#d91a1a}-1.29\%$
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] 0.1275s 10.1190ms 98.8238 Ops/s 129.7365 Ops/s $\textbf{\color{#d91a1a}-23.83\%}$
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] 18.7779ms 16.2217ms 61.6456 Ops/s 61.0035 Ops/s $\color{#35bf28}+1.05\%$
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] 2.1934ms 1.3140ms 761.0181 Ops/s 670.2804 Ops/s $\textbf{\color{#35bf28}+13.54\%}$

@vmoens vmoens merged commit 3267533 into main Aug 3, 2024
70 of 71 checks passed
@vmoens vmoens deleted the make-marl-params-apparent branch August 3, 2024 00:23
@matteobettini
Copy link
Contributor

Could you explain why we need this?

also, having 2 copies of the parameters is not error prone?

for example in methods like https://github.com/facebookresearch/BenchMARL/blob/d260eea5d4ef2ff5f0bea8ae36f68638ecb14865/benchmarl/models/common.py#L165 or in any general case where users access self.parameters() won’t things break?

@vmoens
Copy link
Contributor Author

vmoens commented Aug 3, 2024

We test that nothing breaks. I don't thing it's error prone, you never see two copies (for instance parameters() just returns one).

We need this because it makes initialization of the params more natural, mainly.

@matteobettini
Copy link
Contributor

matteobettini commented Aug 3, 2024

So if a user modifies the content of one copy of the parameters, the change is reflected in the other copy? As in the function I sent.

But apart from being more natural, what use cases is it used for/ envisioned for?

@matteobettini
Copy link
Contributor

Maybe I am misreading the PR description: when you say 2 copies you mean:

  1. 2 different sets of parameters?
  2. 2 objects referring by pointer to the same parameter tensors?

@@ -898,6 +898,51 @@ def one_outofplace(mod):
mlp.from_stateful_net(snet)
assert (mlp.params == 1).all()

@retry(AssertionError, 5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the effect of this?

:class:`~tensordict.nn.TensorDictParams` object (which inherits both from `TensorDict` and `nn.Module`).
If ``False``, parameters are contained in `self._empty_net`. All things considered, these two approaches
should be roughly identical but not interchangeable: for instance, a ``state_dict`` created with
``use_td_params=True`` cannot be used when ``use_td_params=False``.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defaults to True missing from docs

@vmoens
Copy link
Contributor Author

vmoens commented Aug 3, 2024

So if a user modifies the content of one copy of the parameters, the change is reflected in the other copy? As in the function I sent.

They are exactly the same objects, just one is in self.params and not seen by self.modules() or self.parameters() and the other is in self._empty_net.

But apart from being more natural, what use cases is it used for/ envisioned for?

Many people are used to do

def init(module):
    if isinstance(module, nn.Linear):
        self.weight.data.zero_()

self.apply(init)

which you can only do if the params are in the module, not in the TDParams. Moreover TDParams carries some overhead. The new version should be faster.
On top of that it's totally optional and 100% non-bc breaking

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants