From 58c4737e5353e0196d4807776f1aab4483e5173d Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Fri, 30 Aug 2024 16:58:51 +0100 Subject: [PATCH] Added recursive to_dict support to AttrDict Fixes #1520 --- elasticsearch_dsl/utils.py | 17 ++++++++++++++-- tests/test_integration/_async/test_search.py | 21 ++++++++++++++++++++ tests/test_integration/_sync/test_search.py | 21 ++++++++++++++++++++ 3 files changed, 57 insertions(+), 2 deletions(-) diff --git a/elasticsearch_dsl/utils.py b/elasticsearch_dsl/utils.py index 021afc99..d9727f1e 100644 --- a/elasticsearch_dsl/utils.py +++ b/elasticsearch_dsl/utils.py @@ -86,6 +86,17 @@ def _wrap(val: Any, obj_wrapper: Optional[Callable[[Any], Any]] = None) -> Any: return val +def _recursive_to_dict(value: Any) -> Any: + if hasattr(value, "to_dict"): + return value.to_dict() + elif isinstance(value, dict) or isinstance(value, AttrDict): + return {k: _recursive_to_dict(v) for k, v in value.items()} + elif isinstance(value, list) or isinstance(value, AttrList): + return [recursive_to_dict(elem) for elem in value] + else: + return value + + class AttrList(Generic[_ValT]): def __init__( self, l: List[_ValT], obj_wrapper: Optional[Callable[[_ValT], Any]] = None @@ -228,8 +239,10 @@ def __setattr__(self, name: str, value: _ValT) -> None: def __iter__(self) -> Iterator[str]: return iter(self._d_) - def to_dict(self) -> Dict[str, _ValT]: - return self._d_ + def to_dict(self, recursive: bool = False) -> Dict[str, _ValT]: + return cast( + Dict[str, _ValT], _recursive_to_dict(self._d_) if recursive else self._d_ + ) def keys(self) -> Iterable[str]: return self._d_.keys() diff --git a/tests/test_integration/_async/test_search.py b/tests/test_integration/_async/test_search.py index 11bc8c72..3dfde51b 100644 --- a/tests/test_integration/_async/test_search.py +++ b/tests/test_integration/_async/test_search.py @@ -112,6 +112,27 @@ async def test_inner_hits_are_wrapped_in_response( ) +@pytest.mark.asyncio +async def test_inner_hits_are_serialized_to_dict( + async_data_client: AsyncElasticsearch, +) -> None: + s = AsyncSearch(index="git")[0:1].query( + "has_parent", parent_type="repo", inner_hits={}, query=Q("match_all") + ) + response = await s.execute() + d = response.to_dict(recursive=True) + assert isinstance(d, dict) + assert isinstance(d["hits"]["hits"][0]["inner_hits"]["repo"], dict) + + # iterating over the results changes the format of the internal AttrDict + for hit in response: + pass + + d = response.to_dict(recursive=True) + assert isinstance(d, dict) + assert isinstance(d["hits"]["hits"][0]["inner_hits"]["repo"], dict) + + @pytest.mark.asyncio async def test_scan_respects_doc_types(async_data_client: AsyncElasticsearch) -> None: repos = [repo async for repo in Repository.search().scan()] diff --git a/tests/test_integration/_sync/test_search.py b/tests/test_integration/_sync/test_search.py index 18ed8566..d4e62016 100644 --- a/tests/test_integration/_sync/test_search.py +++ b/tests/test_integration/_sync/test_search.py @@ -104,6 +104,27 @@ def test_inner_hits_are_wrapped_in_response( ) +@pytest.mark.sync +def test_inner_hits_are_serialized_to_dict( + data_client: Elasticsearch, +) -> None: + s = Search(index="git")[0:1].query( + "has_parent", parent_type="repo", inner_hits={}, query=Q("match_all") + ) + response = s.execute() + d = response.to_dict(recursive=True) + assert isinstance(d, dict) + assert isinstance(d["hits"]["hits"][0]["inner_hits"]["repo"], dict) + + # iterating over the results changes the format of the internal AttrDict + for hit in response: + pass + + d = response.to_dict(recursive=True) + assert isinstance(d, dict) + assert isinstance(d["hits"]["hits"][0]["inner_hits"]["repo"], dict) + + @pytest.mark.sync def test_scan_respects_doc_types(data_client: Elasticsearch) -> None: repos = [repo for repo in Repository.search().scan()]