Skip to content

Commit

Permalink
Merge pull request #691 from materialsproject/enhancement/pymongo_tim…
Browse files Browse the repository at this point in the history
…eout

Add pymongo timeout context to queries
  • Loading branch information
munrojm committed Jul 26, 2022
2 parents 92e7722 + 6c26a28 commit e821b03
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 40 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pymongo==4.0.2
pymongo==4.2.0
monty==2022.4.26
mongomock==4.0.0
pydash==5.1.0
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
include_package_data=True,
install_requires=[
"setuptools",
"pymongo>=4.0",
"pymongo>=4.2.0",
"monty>=1.0.2",
"mongomock>=3.10.0",
"pydash>=4.1.0",
Expand Down
24 changes: 18 additions & 6 deletions src/maggma/api/resource/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from fastapi import HTTPException, Response, Request
from pydantic import BaseModel
from pymongo import timeout as query_timeout
from pymongo.errors import NetworkTimeout, PyMongoError

from maggma.api.models import Meta
from maggma.api.models import Response as ResponseModel
Expand All @@ -22,6 +24,7 @@ def __init__(
store: Store,
model: Type[BaseModel],
pipeline_query_operator: QueryOperator,
timeout: Optional[int] = None,
tags: Optional[List[str]] = None,
include_in_schema: Optional[bool] = True,
sub_path: Optional[str] = "/",
Expand All @@ -33,6 +36,8 @@ def __init__(
model: The pydantic model this Resource represents
tags: List of tags for the Endpoint
pipeline_query_operator: Operator for the aggregation pipeline
timeout: Time in seconds Pymongo should wait when querying MongoDB
before raising a timeout error
include_in_schema: Whether the endpoint should be shown in the documented schema.
sub_path: sub-URL path for the resource.
"""
Expand All @@ -45,6 +50,7 @@ def __init__(

self.pipeline_query_operator = pipeline_query_operator
self.header_processor = header_processor
self.timeout = timeout

super().__init__(model)

Expand All @@ -69,12 +75,18 @@ async def search(**queries: Dict[str, STORE_PARAMS]) -> Dict:
self.store.connect()

try:
data = list(self.store._collection.aggregate(query["pipeline"]))
except Exception:
raise HTTPException(
status_code=400,
detail="Problem with provided aggregation pipeline.",
)
with query_timeout(self.timeout):
data = list(self.store._collection.aggregate(query["pipeline"]))
except (NetworkTimeout, PyMongoError) as e:
if e.timeout:
raise HTTPException(
status_code=504,
detail="Server timed out trying to obtain data. Try again with a smaller request.",
)
else:
raise HTTPException(
status_code=500,
)

count = len(data)

Expand Down
21 changes: 19 additions & 2 deletions src/maggma/api/resource/post_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from fastapi import HTTPException, Request
from pydantic import BaseModel
from pymongo import timeout as query_timeout
from pymongo.errors import NetworkTimeout, PyMongoError

from maggma.api.models import Meta, Response
from maggma.api.query_operator import PaginationQuery, QueryOperator, SparseFieldsQuery
Expand All @@ -25,6 +27,7 @@ def __init__(
query_operators: Optional[List[QueryOperator]] = None,
key_fields: Optional[List[str]] = None,
query: Optional[Dict] = None,
timeout: Optional[int] = None,
include_in_schema: Optional[bool] = True,
sub_path: Optional[str] = "/",
):
Expand All @@ -36,6 +39,8 @@ def __init__(
query_operators: Operators for the query language
key_fields: List of fields to always project. Default uses SparseFieldsQuery
to allow user to define these on-the-fly.
timeout: Time in seconds Pymongo should wait when querying MongoDB
before raising a timeout error
include_in_schema: Whether the endpoint should be shown in the documented schema.
sub_path: sub-URL path for the resource.
"""
Expand All @@ -44,6 +49,7 @@ def __init__(
self.query = query or {}
self.key_fields = key_fields
self.versioned = False
self.timeout = timeout

self.include_in_schema = include_in_schema
self.sub_path = sub_path
Expand Down Expand Up @@ -101,8 +107,19 @@ async def search(**queries: Dict[str, STORE_PARAMS]) -> Dict:

self.store.connect()

count = self.store.count(query["criteria"])
data = list(self.store.query(**query))
try:
with query_timeout(self.timeout):
count = self.store.count(query["criteria"])
data = list(self.store.query(**query))
except (NetworkTimeout, PyMongoError) as e:
if e.timeout:
raise HTTPException(
status_code=504,
detail="Server timed out trying to obtain data. Try again with a smaller request.",
)
else:
raise HTTPException(status_code=500)

operator_meta = {}

for operator in self.query_operators:
Expand Down
57 changes: 44 additions & 13 deletions src/maggma/api/resource/read_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from fastapi import Depends, HTTPException, Path, Request
from fastapi import Response
from pydantic import BaseModel
from pymongo import timeout as query_timeout
from pymongo.errors import NetworkTimeout, PyMongoError

from maggma.api.models import Meta
from maggma.api.models import Response as ResponseModel
Expand Down Expand Up @@ -33,6 +35,7 @@ def __init__(
key_fields: Optional[List[str]] = None,
hint_scheme: Optional[HintScheme] = None,
header_processor: Optional[HeaderProcessor] = None,
timeout: Optional[int] = None,
enable_get_by_key: bool = True,
enable_default_search: bool = True,
disable_validation: bool = False,
Expand All @@ -47,6 +50,8 @@ def __init__(
query_operators: Operators for the query language
hint_scheme: The hint scheme to use for this resource
header_processor: The header processor to use for this resource
timeout: Time in seconds Pymongo should wait when querying MongoDB
before raising a timeout error
key_fields: List of fields to always project. Default uses SparseFieldsQuery
to allow user to define these on-the-fly.
enable_get_by_key: Enable default key route for endpoint.
Expand All @@ -65,6 +70,7 @@ def __init__(
self.versioned = False
self.enable_get_by_key = enable_get_by_key
self.enable_default_search = enable_default_search
self.timeout = timeout
self.disable_validation = disable_validation
self.include_in_schema = include_in_schema
self.sub_path = sub_path
Expand Down Expand Up @@ -132,11 +138,24 @@ async def get_by_key(
"""
self.store.connect()

item = [
self.store.query_one(
criteria={self.store.key: key}, properties=_fields["properties"],
)
]
try:
with query_timeout(self.timeout):
item = [
self.store.query_one(
criteria={self.store.key: key},
properties=_fields["properties"],
)
]
except (NetworkTimeout, PyMongoError) as e:
if e.timeout:
raise HTTPException(
status_code=504,
detail="Server timed out trying to obtain data. Try again with a smaller request.",
)
else:
raise HTTPException(
status_code=500,
)

if item == [None]:
raise HTTPException(
Expand Down Expand Up @@ -209,16 +228,28 @@ async def search(**queries: Dict[str, STORE_PARAMS]) -> Union[Dict, Response]:
query.update(hints)

self.store.connect()
try:
with query_timeout(self.timeout):
count = self.store.count(
**{
field: query[field]
for field in query
if field in ["criteria", "hint"]
}
)

count = self.store.count(
**{
field: query[field]
for field in query
if field in ["criteria", "hint"]
}
)
data = list(self.store.query(**query))
except (NetworkTimeout, PyMongoError) as e:
if e.timeout:
raise HTTPException(
status_code=504,
detail="Server timed out trying to obtain data. Try again with a smaller request.",
)
else:
raise HTTPException(
status_code=500,
)

data = list(self.store.query(**query))
operator_meta = {}

for operator in self.query_operators:
Expand Down
34 changes: 30 additions & 4 deletions src/maggma/api/resource/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from fastapi import HTTPException, Path, Request
from pydantic import BaseModel, Field, create_model
from pymongo import timeout as query_timeout
from pymongo.errors import NetworkTimeout, PyMongoError

from maggma.api.models import Meta, Response
from maggma.api.query_operator import QueryOperator, SubmissionQuery
Expand All @@ -28,6 +30,7 @@ def __init__(
post_query_operators: List[QueryOperator],
get_query_operators: List[QueryOperator],
tags: Optional[List[str]] = None,
timeout: Optional[int] = None,
include_in_schema: Optional[bool] = True,
duplicate_fields_check: Optional[List[str]] = None,
enable_default_search: Optional[bool] = True,
Expand All @@ -42,6 +45,8 @@ def __init__(
store: The Maggma Store to get data from
model: The pydantic model this resource represents
tags: List of tags for the Endpoint
timeout: Time in seconds Pymongo should wait when querying MongoDB
before raising a timeout error
post_query_operators: Operators for the query language for post data
get_query_operators: Operators for the query language for get data
include_in_schema: Whether to include the submission resource in the documented schema
Expand All @@ -66,6 +71,7 @@ def __init__(
self.default_state = default_state
self.store = store
self.tags = tags or []
self.timeout = timeout
self.post_query_operators = post_query_operators
self.get_query_operators = (
[op for op in get_query_operators if op is not None] # type: ignore
Expand Down Expand Up @@ -143,8 +149,17 @@ async def get_by_key(
self.store.connect()

crit = {key_name: key}

item = [self.store.query_one(criteria=crit)]
try:
with query_timeout(self.timeout):
item = [self.store.query_one(criteria=crit)]
except (NetworkTimeout, PyMongoError) as e:
if e.timeout:
raise HTTPException(
status_code=504,
detail="Server timed out trying to obtain data. Try again with a smaller request.",
)
else:
raise HTTPException(status_code=500)

if item == [None]:
raise HTTPException(
Expand Down Expand Up @@ -198,8 +213,19 @@ async def search(**queries: STORE_PARAMS):

self.store.connect(force_reset=True)

count = self.store.count(query["criteria"])
data = list(self.store.query(**query)) # type: ignore
try:
with query_timeout(self.timeout):
count = self.store.count(query["criteria"])
data = list(self.store.query(**query)) # type: ignore
except (NetworkTimeout, PyMongoError) as e:
if e.timeout:
raise HTTPException(
status_code=504,
detail="Server timed out trying to obtain data. Try again with a smaller request.",
)
else:
raise HTTPException(status_code=500,)

meta = Meta(total_doc=count)

for operator in self.get_query_operators: # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion src/maggma/stores/compound_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def connect(self, force_reset: bool = False):
Args:
force_reset: whether to forcibly reset the connection
"""
conn = (
conn: MongoClient = (
MongoClient(
host=self.host,
port=self.port,
Expand Down
9 changes: 3 additions & 6 deletions src/maggma/stores/gridfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def connect(self, force_reset: bool = False):
"""
Connect to the source data
"""
conn = (
conn: MongoClient = (
MongoClient(
host=self.host,
port=self.port,
Expand Down Expand Up @@ -238,10 +238,7 @@ def query(
metadata = doc.get("metadata", {})

data = self._collection.find_one(
filter={"_id": doc["_id"]},
skip=skip,
limit=limit,
sort=sort,
filter={"_id": doc["_id"]}, skip=skip, limit=limit, sort=sort,
).read()

if metadata.get("compression", "") == "zlib":
Expand Down Expand Up @@ -511,7 +508,7 @@ def connect(self, force_reset: bool = False):
Connect to the source data
"""
if not self._coll or force_reset: # pragma: no cover
conn = MongoClient(self.uri, **self.mongoclient_kwargs)
conn: MongoClient = MongoClient(self.uri, **self.mongoclient_kwargs)
db = conn[self.database]
self._coll = gridfs.GridFS(db, self.collection_name)
self._files_collection = db["{}.files".format(self.collection_name)]
Expand Down
9 changes: 3 additions & 6 deletions src/maggma/stores/mongolike.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def connect(self, force_reset: bool = False):
self.ssh_tunnel.start()
host, port = self.ssh_tunnel.local_address

conn = (
conn: MongoClient = (
MongoClient(
host=host,
port=port,
Expand Down Expand Up @@ -569,7 +569,7 @@ def connect(self, force_reset: bool = False):
Connect to the source data
"""
if self._coll is None or force_reset: # pragma: no cover
conn = MongoClient(self.uri, **self.mongoclient_kwargs)
conn: MongoClient = MongoClient(self.uri, **self.mongoclient_kwargs)
db = conn[self.database]
self._coll = db[self.collection_name]

Expand Down Expand Up @@ -677,10 +677,7 @@ class JSONStore(MemoryStore):
"""

def __init__(
self,
paths: Union[str, List[str]],
read_only: bool = True,
**kwargs,
self, paths: Union[str, List[str]], read_only: bool = True, **kwargs,
):
"""
Args:
Expand Down

0 comments on commit e821b03

Please sign in to comment.