Skip to content

Commit

Permalink
Merge pull request #500 from materialsproject/bump_release
Browse files Browse the repository at this point in the history
Update sorting query operator to take multiple fields
  • Loading branch information
munrojm committed Oct 11, 2021
2 parents 25ad553 + aa019ba commit dae3a8f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 22 deletions.
19 changes: 9 additions & 10 deletions src/maggma/api/query_operator/sorting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, List

from fastapi import HTTPException, Query

Expand All @@ -13,18 +13,17 @@ class SortQuery(QueryOperator):

def query(
self,
sort_field: Optional[str] = Query(None, description="Field to sort with"),
ascending: Optional[bool] = Query(None, description="Whether the sorting should be ascending"),
sort_fields: Optional[List[str]] = Query(None, description="Fields to sort with. Prefixing '-' to a field will\
force a sort in descending order."),
) -> STORE_PARAMS:

sort = {}

if sort_field:
if ascending is not None:
sort.update({sort_field: 1 if ascending else -1})
else:
raise HTTPException(
status_code=400, detail="Must specify both a field and order for sorting.",
)
if sort_fields:
for sort_field in sort_fields:
if sort_field[0] == "-":
sort.update({sort_field[1:]: -1})
else:
sort.update({sort_field: 1})

return {"sort": sort}
15 changes: 3 additions & 12 deletions tests/api/test_query_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_numeric_query_functionality():

assert op.meta() == {}
assert op.query(age_max=10, age_min=1, age_not_eq=[2, 3], weight_min=120) == {
"criteria": {"age": {"$lte": 10, "$gte": 1, "$ne": [2, 3]}, "weight": {"$gte": 120},}
"criteria": {"age": {"$lte": 10, "$gte": 1, "$ne": [2, 3]}, "weight": {"$gte": 120}}
}


Expand All @@ -87,16 +87,7 @@ def test_sort_query_functionality():

op = SortQuery()

assert op.query(sort_field="volume", ascending=True) == {"sort": {"volume": 1}}
assert op.query(sort_field="density", ascending=False) == {"sort": {"density": -1}}


@pytest.mark.xfail
def test_sort_error():

op = SortQuery()

op.query(sort_field="volume", ascending=None)
assert op.query(sort_fields=["volume", "-density"]) == {"sort": {"volume": 1, "density": -1}}


def test_sort_serialization():
Expand All @@ -106,7 +97,7 @@ def test_sort_serialization():
with ScratchDir("."):
dumpfn(op, "temp.json")
new_op = loadfn("temp.json")
assert new_op.query(sort_field="volume", ascending=True) == {"sort": {"volume": 1}}
assert new_op.query(sort_fields=["volume", "-density"]) == {"sort": {"volume": 1, "density": -1}}


@pytest.fixture
Expand Down

0 comments on commit dae3a8f

Please sign in to comment.