diff --git a/src/maggma/api/query_operator/sorting.py b/src/maggma/api/query_operator/sorting.py index 050f4e7ba..619530303 100644 --- a/src/maggma/api/query_operator/sorting.py +++ b/src/maggma/api/query_operator/sorting.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, List from fastapi import HTTPException, Query @@ -13,18 +13,17 @@ class SortQuery(QueryOperator): def query( self, - sort_field: Optional[str] = Query(None, description="Field to sort with"), - ascending: Optional[bool] = Query(None, description="Whether the sorting should be ascending"), + sort_fields: Optional[List[str]] = Query(None, description="Fields to sort with. Prefixing '-' to a field will\ + force a sort in descending order."), ) -> STORE_PARAMS: sort = {} - if sort_field: - if ascending is not None: - sort.update({sort_field: 1 if ascending else -1}) - else: - raise HTTPException( - status_code=400, detail="Must specify both a field and order for sorting.", - ) + if sort_fields: + for sort_field in sort_fields: + if sort_field[0] == "-": + sort.update({sort_field[1:]: -1}) + else: + sort.update({sort_field: 1}) return {"sort": sort} diff --git a/tests/api/test_query_operators.py b/tests/api/test_query_operators.py index 332f333c6..4b8a6da8a 100644 --- a/tests/api/test_query_operators.py +++ b/tests/api/test_query_operators.py @@ -69,7 +69,7 @@ def test_numeric_query_functionality(): assert op.meta() == {} assert op.query(age_max=10, age_min=1, age_not_eq=[2, 3], weight_min=120) == { - "criteria": {"age": {"$lte": 10, "$gte": 1, "$ne": [2, 3]}, "weight": {"$gte": 120},} + "criteria": {"age": {"$lte": 10, "$gte": 1, "$ne": [2, 3]}, "weight": {"$gte": 120}} } @@ -87,16 +87,7 @@ def test_sort_query_functionality(): op = SortQuery() - assert op.query(sort_field="volume", ascending=True) == {"sort": {"volume": 1}} - assert op.query(sort_field="density", ascending=False) == {"sort": {"density": -1}} - - -@pytest.mark.xfail -def test_sort_error(): - - op = SortQuery() - - op.query(sort_field="volume", ascending=None) + assert op.query(sort_fields=["volume", "-density"]) == {"sort": {"volume": 1, "density": -1}} def test_sort_serialization(): @@ -106,7 +97,7 @@ def test_sort_serialization(): with ScratchDir("."): dumpfn(op, "temp.json") new_op = loadfn("temp.json") - assert new_op.query(sort_field="volume", ascending=True) == {"sort": {"volume": 1}} + assert new_op.query(sort_fields=["volume", "-density"]) == {"sort": {"volume": 1, "density": -1}} @pytest.fixture