Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Faster SliceSampler._tensor_slices_from_startend #2423

Merged
merged 1 commit into from
Sep 10, 2024

Conversation

kurtamohler
Copy link
Collaborator

@kurtamohler kurtamohler commented Sep 6, 2024

Description

Speeds up the SliceSampler._tensor_slices_from_startend method by about 8x in the case where seq_length is an int.

Running the performance measurement script from #2422 (comment) in my machine gives:

Without SliceSampler: 0.00019115543303390345 s
With SliceSampler: 0.002860310899753434 s
Slowdown factor: 14.963272842190824x

whereas the output before the change was:

Without SliceSampler: 0.0001870056662786131 s
With SliceSampler: 0.0046725646670286855 s
Slowdown factor: 24.98621972270721x

So this change provides a speedup of about (0.00467 / 0.00286) = 1.632 to the ReplayBuffer.sample method for the particular case in that script.

I also took a performance profile of the script with cProfile, like so:

pthon -m cProfile <script> | grep _tensor_slices_from_startend

I increased the timeit iterations from 30 to 3000 for better precision. Before the change, the cumulative time spent in the _tensor_slices_from_startend function was 7.571. After the change, it was 0.871, so the speedup for _tensor_slices_from_startend alone was about 8x.

Motivation and Context

close #2422

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

@kurtamohler kurtamohler added the performance Performance issue or suggestion for improvement label Sep 6, 2024
Copy link

pytorch-bot bot commented Sep 6, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2423

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 7 Unrelated Failures

As of commit fec4f40 with merge base 57f0580 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 6, 2024
@@ -1076,9 +1076,24 @@ def _tensor_slices_from_startend(self, seq_length, start, storage_length):
# seq_length is a 1d tensor indicating the desired length of each sequence

if isinstance(seq_length, int):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also looked into the possibility of speeding up the case where seq_length is a tensor. It seems a lot less straightforward, and I'm not entirely sure if we can get a speedup comparable to the int case. Since the sequence lengths are all different, it inherently requires doing something that is equivalent to calling torch.arange multiple times and torch.cating the results together.

I can continue to investigate if you'd like--I just didn't want to invest too much time in it without discussing it first

Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!
Just a nit wrt device

IIRC this
https://github.com/kurtamohler/torchrl/blob/db5f5cff8a67f3854759ab78215046bc65019046/torchrl/data/replay_buffers/samplers.py#L1881
used to be the most expensive thing when the buffer is full (eg 1M elements). Can you reproduce that in your benchmark?
If so a follow-up should be to speed up that one too!

torchrl/data/replay_buffers/samplers.py Outdated Show resolved Hide resolved
@kurtamohler
Copy link
Collaborator Author

IIRC this https://github.com/kurtamohler/torchrl/blob/db5f5cff8a67f3854759ab78215046bc65019046/torchrl/data/replay_buffers/samplers.py#L1881 used to be the most expensive thing when the buffer is full (eg 1M elements). Can you reproduce that in your benchmark? If so a follow-up should be to speed up that one too!

I see, I will look into that

@vmoens vmoens merged commit 6aa4b53 into pytorch:main Sep 10, 2024
65 of 70 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. performance Performance issue or suggestion for improvement
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] SliceSampler is slow
3 participants