Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Allowing for auto-nested tensordict #119

Closed
wants to merge 7 commits into from
Closed

Conversation

Zooll
Copy link

@Zooll Zooll commented Dec 16, 2022

Description

Blocked self nested behaviour during repr or any other recursive calls.

Motivation and Context

Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
[BUG] Auto-nested tensordict bugs #106

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

Bug fix (non-breaking change which fixes an issue)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 16, 2022
Copy link
Contributor

@tcbegley tcbegley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking good, thanks for working on it!

I've left a few comments inline.

The other main thing that we should add are some tests. We can add a couple of tests for specific functionality such as checking for the presence of "Auto-nested" in the repr, but as @vmoens suggested in #106 it would be nice to run all tests on a tensordict with autonesting and check that we haven't inadvertently broken something.

Check this class that gets used in the tests, you can add a method for auto_nested_td or similar, and then add it to the tests here. In principle it should work for anything that nested_td works for, but in any case it will be interesting to see what breaks!

tensordict/tensordict.py Outdated Show resolved Hide resolved
tensordict/tensordict.py Outdated Show resolved Hide resolved
tensordict/tensordict.py Outdated Show resolved Hide resolved
tensordict/tensordict.py Outdated Show resolved Hide resolved
tensordict/tensordict.py Show resolved Hide resolved
@vmoens vmoens changed the title fix [BUG] Auto-nested tensordict bugs #106 [BugFix] Auto-nested tensordict bugs #106 Dec 22, 2022
@vmoens vmoens changed the title [BugFix] Auto-nested tensordict bugs #106 [BugFix] Allowing for auto-nested tensordict Dec 22, 2022
@Zooll
Copy link
Author

Zooll commented Dec 22, 2022

This is looking good, thanks for working on it!

I've left a few comments inline.

The other main thing that we should add are some tests. We can add a couple of tests for specific functionality such as checking for the presence of "Auto-nested" in the repr, but as @vmoens suggested in #106 it would be nice to run all tests on a tensordict with autonesting and check that we haven't inadvertently broken something.

Check this class that gets used in the tests, you can add a method for auto_nested_td or similar, and then add it to the tests here. In principle it should work for anything that nested_td works for, but in any case it will be interesting to see what breaks!

I added "Auto-nested" as you recommended, but because of the nature of auto-nesting I think we can't test it using existing test cases. For instance we have this test:

 def test_items_values_keys(self, td_name, device):
        torch.manual_seed(1)
        td = getattr(self, td_name)(device)
        keys = list(td.keys())
        values = list(td.values())
        items = list(td.items())
    
        # Test td.items()
        constructed_td1 = TensorDict({}, batch_size=td.shape)
        for key, value in items:
            constructed_td1.set(key, value)
    
>       assert (td == constructed_td1).all()

if we use this td:

 def auto_nested_td(self, device):
        tensordict = TensorDict({
            "a": torch.randn(4, 3, 2, 1, 5),
            "b": torch.randn(4, 3, 2, 1, 10),
            }, 
            batch_size=[4, 3, 2, 1],
            device=device,)
        tensordict["self"] = tensordict
        return tensordict

when we can't construct constructed_td1 properly, because out self property contains our "main" td. And when we copy it to constructed_td1.self property it's not the same as we put constructed_td1['self'] = constructed_td1. More over when we try to compare them, our TensorDictBase __eq__ function will try to do deep comparation. It means we are not able to compare td with self nested key.
I don't know what to do in this case we can discuss it or omit it in this iteration because we need to have a strategy about working with auto(or self) nesting.

@Zooll
Copy link
Author

Zooll commented Dec 30, 2022

After #121, It became easier to solve this auto-nesting bug. I updated pull request. But it still has a problem to add tests which I mentioned here #119 (comment)

@Zooll Zooll requested a review from tcbegley December 30, 2022 15:34
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the tests, I'm happy with tests being independent of the others although I wonder if the fact that a call to __eq__ is prohibited should not be addressed by this PR.
TBH I did not quite get what the issue was. Can you elaborate a bit more?
You can also code it and push the code with the tests breaking, that would give me some visibility on what's happening.

tensordict/tensordict.py Show resolved Hide resolved
tensordict/tensordict.py Outdated Show resolved Hide resolved
tensordict/tensordict.py Show resolved Hide resolved
tensordict/tensordict.py Outdated Show resolved Hide resolved
@@ -1830,10 +1844,14 @@ def flatten_keys(
inner_tensordict = self.get(key).flatten_keys(
separator=separator, inplace=inplace
)
for inner_key, inner_item in inner_tensordict.items():
tensordict_out.set(separator.join([key, inner_key]), inner_item)
if inner_tensordict is not self.get(key):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this is compatible with lazy tensordicts (eg LazyStackedTensorDict) where two calls to get(key) return items which ids are different (eg because they are the results of a call to stack)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's interesting point. I didn't research 'LazyStackedTensorDict' 's behaviour. But if 'get(key)' return new ids every time how can we detect self loop?

@Zooll
Copy link
Author

Zooll commented Jan 4, 2023

Regarding the tests, I'm happy with tests being independent of the others although I wonder if the fact that a call to __eq__ is prohibited should not be addressed by this PR. TBH I did not quite get what the issue was. Can you elaborate a bit more? You can also code it and push the code with the tests breaking, that would give me some visibility on what's happening.

This issue was about infinite recursion during print, flatten_keys, and list(keys()). Issue # #106. This is example of code how to verify this PR:

Code
tensordict = TensorDict({
     "key 1": torch.ones(3, 4, 5),
     "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
}, batch_size=[3, 4])

tensordict2 = TensorDict({
    "super key 1": torch.ones(3, 4, 6),
    "super key 2": torch.zeros(3, 4, 7, dtype=torch.bool),
}, batch_size=[3, 4])

tensordict["innerTD"] = tensordict2
tensordict2["self 2"] = tensordict2
tensordict["self"] = tensordict
tensordict["key 3"] = torch.zeros(3, 4, 7, dtype=torch.bool)

f_keys = tensordict.flatten_keys()  # Ok
print(f_keys)
td_list = list(tensordict.keys(include_nested=True))  # Ok
print(td_list)
print(tensordict)  # Ok

@vmoens
Copy link
Contributor

vmoens commented Jan 4, 2023

Thanks
If you could write some tests it would be easier to comment by using the fit review process

@vmoens vmoens added the bug Something isn't working label Jan 5, 2023
@Zooll
Copy link
Author

Zooll commented Jan 5, 2023

@vmoens I've added tests for keys() and repr(). Unfortunately, I didn't handle to add tests for flatten_keys in current test_flatten_keys for several reasons. 1) We can't just add

@pytest.mark.parametrize(
    "td_name",
    [...,
    auto_nested_td,
    ]
)

because there are many other tests which will be broken.
2) I tried to add test_flatten_keys(self, td_name, device, inplace, separator) with auto nested td only, but faced with infinite recursive unlocking of self nested tensordict which doesn't related to flatten_keys problem https://github.com/pytorch-labs/tensordict/blob/455b64adc16833567f45d917d40e4abf3b0a898c/test/test_tensordict.py#L1716

@vmoens
Copy link
Contributor

vmoens commented Jan 5, 2023

We can break the tests, I'd rather have that and solve the bug. Otherwise it's hard to see what the problem is.

@Zooll
Copy link
Author

Zooll commented Jan 6, 2023

We can break the tests, I'd rather have that and solve the bug. Otherwise it's hard to see what the problem is.

I've added auto_nested_td into tests

@Zooll Zooll requested review from vmoens and removed request for tcbegley January 11, 2023 11:10
@vmoens
Copy link
Contributor

vmoens commented Jan 11, 2023

Some benchmarking first:
On this branch, common_ops:

TensorDict({}, [3, 4]) 0.054020772999999966
TensorDict({'a': a, 'b': b}, [3, 4]) 0.10693320900000014
TensorDict({'a': a, ('b', 'b1'): b}, [3, 4]) 0.26566142599999987
TensorDict({'a': a, 'b': {'b1': b}}, [3, 4]) 0.19251463799999957
tdc = td.clone() 0.23632501300000008
tdc = td.clone();tdc['c'] = c 0.2863006590000001
tdc = td.clone();tdc.set('c', c) 0.2764215059999984
tdc = td.clone();tdc.share_memory_() 1.781234584
tdc = td.clone();tdc.update(td2) 0.41446814300000057
tdc = td.clone();tdc.update(td2) 0.6884915679999999
tdc = td.clone();tdc['b', 'b1'] = b 0.3841283730000029
tdc = td.clone();tdc['c', 'c', 'c'] = c 0.6221515459999978

On main

TensorDict({}, [3, 4]) 0.06562997599999987
TensorDict({'a': a, 'b': b}, [3, 4]) 0.1289677100000004
TensorDict({'a': a, ('b', 'b1'): b}, [3, 4]) 0.29692919799999995
TensorDict({'a': a, 'b': {'b1': b}}, [3, 4]) 0.24163629199999992
tdc = td.clone() 0.26948069800000063
tdc = td.clone();tdc['c'] = c 0.3178999400000002
tdc = td.clone();tdc.set('c', c) 0.3091379649999997
tdc = td.clone();tdc.share_memory_() 1.806307208999998
tdc = td.clone();tdc.update(td2) 0.42458679399999966
tdc = td.clone();tdc.update(td2) 0.6898641609999991
tdc = td.clone();tdc['b', 'b1'] = b 0.4234232099999993
tdc = td.clone();tdc['c', 'c', 'c'] = c 0.6865369579999978

So we're good on that side.

Functional:

instantiation, functorch: 2.3080958469999997
instantiation, tensordict: 1.3075678029999995
exec, functorch: 4.814458004
exec, tensordict: 5.19893972

Other benchmarks seem to be roughly similar.

@vmoens
Copy link
Contributor

vmoens commented Jan 11, 2023

@Zooll I'm on it. Thanks for your hard work!

I tried to solve the to_tensordict() but it's a hard one. I need to sleep on this because all the solutions I can think of are either completely wrong or way to complex.

I'll keep you posted!

@vmoens
Copy link
Contributor

vmoens commented Feb 15, 2023

Closing as this has been addressed in #201

@vmoens vmoens closed this Feb 15, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants