diff --git a/setup.py b/setup.py index 0dd2f9eca..c12fd347d 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,7 @@ "montydb": ["montydb>=2.3.12"], "notebook_runner": ["IPython>=8.11", "nbformat>=5.0", "regex>=2020.6"], "azure": ["azure-storage-blob>=12.16.0", "azure-identity>=1.12.0"], - "open_data": ["pandas>=2.1.4"], + "open_data": ["pandas>=2.1.4", "jsonlines>=4.0.0"], "testing": [ "pytest", "pytest-cov", diff --git a/src/maggma/stores/open_data.py b/src/maggma/stores/open_data.py index 5104424d0..476a178bc 100644 --- a/src/maggma/stores/open_data.py +++ b/src/maggma/stores/open_data.py @@ -1,9 +1,9 @@ import gzip from datetime import datetime -from io import BytesIO +from io import BytesIO, StringIO from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union -import orjson +import jsonlines import pandas as pd from boto3 import Session from botocore import UNSIGNED @@ -31,7 +31,7 @@ def __init__( @property def _collection(self): """ - Returns a handle to the pymongo collection object + Returns a handle to the pymongo collection object. Raises: NotImplementedError: always as this concept does not make sense for this type of store @@ -41,7 +41,7 @@ def _collection(self): @property def name(self) -> str: """ - Return a string representing this data source + Return a string representing this data source. """ return "imem://" @@ -151,6 +151,15 @@ def query( ret = self._query(criteria=criteria, properties=properties, sort=sort, skip=skip, limit=limit) return (row.to_dict() for _, row in ret.iterrows()) + @staticmethod + def add_missing_items(to_dt: pd.DataFrame, from_dt: pd.DataFrame, on: List[str]) -> pd.DataFrame: + orig_columns = to_dt.columns + merged = to_dt.merge(from_dt, on=on, how="left", suffixes=("", "_B")) + for column in from_dt.columns: + if column not in on: + merged[column].update(merged.pop(column + "_B")) + return merged[orig_columns] + def update(self, docs: Union[List[Dict], Dict], key: Union[List, str, None] = None, clear_first: bool = False): """ Update documents into the Store @@ -164,20 +173,17 @@ def update(self, docs: Union[List[Dict], Dict], key: Union[List, str, None] = No clear_first: if True clears the underlying data first, fully replacing the data with docs; if False performs an upsert based on the parameters """ + if key is not None: + raise NotImplementedError("updating store based on a key different than the store key is not supported") + df = pd.DataFrame(docs) if self._data is None or clear_first: if not df.empty: self._data = df return + key = [self.key] - if key is None: - key = self.key - - merged = self._data.merge(df, on=key, how="left", suffixes=("", "_B")) - for column in df.columns: - if column not in key: - merged[column] = merged[column + "_B"].combine_first(merged[column]) - merged = merged[self._data.columns] + merged = PandasMemoryStore.add_missing_items(to_dt=self._data, from_dt=df, on=key) non_matching = df[~df.set_index(key).index.isin(self._data.set_index(key).index)] self._data = pd.concat([merged, non_matching], ignore_index=True) @@ -372,7 +378,15 @@ def _retrieve_manifest(self) -> pd.DataFrame: """ try: response = self.client.get_object(Bucket=self.bucket, Key=self._get_full_key_path()) - return pd.read_json(response["Body"], orient="records") + df = pd.read_json(response["Body"], orient="records", lines=True) + if self.last_updated_field in df.columns: + df[self.last_updated_field] = df[self.last_updated_field].apply( + lambda x: datetime.fromisoformat(x["$date"].rstrip("Z")) + if isinstance(x, dict) and "$date" in x + else x + ) + return df + except ClientError as ex: if ex.response["Error"]["Code"] == "NoSuchKey": return [] @@ -386,7 +400,7 @@ def _load_index(self, force_reset: bool = False) -> None: """ super().update(self._retrieve_manifest(), clear_first=True) - def store_manifest(self, data: List[Dict]) -> None: + def store_manifest(self, data: pd.DataFrame) -> None: """Stores the provided data into the index stored in S3. This overwrites and fully replaces all of the contents of the previous index stored in S3. It also rewrites the memory index with the provided data. @@ -394,9 +408,14 @@ def store_manifest(self, data: List[Dict]) -> None: Args: data (List[Dict]): The data to store in the index. """ + string_io = StringIO() + with jsonlines.Writer(string_io, dumps=json_util.dumps) as writer: + for _, row in data.iterrows(): + writer.write(row.to_dict()) + self.client.put_object( Bucket=self.bucket, - Body=orjson.dumps(data, default=json_util.default), + Body=BytesIO(string_io.getvalue().encode("utf-8")), Key=self._get_full_key_path(), ) super().update(data, clear_first=True) @@ -458,8 +477,9 @@ def __init__( sub_dir: Optional[str] = None, key: str = "fs_id", searchable_fields: Optional[List[str]] = None, - object_file_extension: str = ".json.gz", + object_file_extension: str = ".jsonl.gz", access_as_public_bucket: bool = False, + object_grouping: Optional[List[str]] = None, **kwargs, ): """Initializes an OpenDataStore @@ -481,10 +501,11 @@ def __init__( self.searchable_fields = searchable_fields if searchable_fields is not None else [] self.object_file_extension = object_file_extension self.access_as_public_bucket = access_as_public_bucket + self.object_grouping = object_grouping if object_grouping is not None else ["nelements", "symmetry_number"] + if access_as_public_bucket: kwargs["s3_resource_kwargs"] = kwargs["s3_resource_kwargs"] if "s3_resource_kwargs" in kwargs else {} kwargs["s3_resource_kwargs"]["config"] = Config(signature_version=UNSIGNED) - kwargs["index"] = index kwargs["bucket"] = bucket kwargs["compress"] = compress @@ -495,16 +516,71 @@ def __init__( kwargs["unpack_data"] = True self.kwargs = kwargs super().__init__(**kwargs) + self.searchable_fields = list( + set(self.object_grouping) | set(self.searchable_fields) | {self.key, self.last_updated_field} + ) + + def query( + self, + criteria: Optional[Dict] = None, + properties: Union[Dict, List, None] = None, + sort: Optional[Dict[str, Union[Sort, int]]] = None, + skip: int = 0, + limit: int = 0, + ) -> Iterator[Dict]: + """ + Queries the Store for a set of documents. + + Args: + criteria: PyMongo filter for documents to search in. + properties: properties to return in grouped documents. + sort: Dictionary of sort order for fields. Keys are field names and values + are 1 for ascending or -1 for descending. + skip: number documents to skip. + limit: limit on total number of documents returned. + + """ + prop_keys = set() + if isinstance(properties, dict): + prop_keys = set(properties.keys()) + elif isinstance(properties, list): + prop_keys = set(properties) + + for _, docs in self.index.groupby( + keys=self.object_grouping, criteria=criteria, sort=sort, limit=limit, skip=skip + ): + group_doc = None # S3 backed group doc + for _, doc in docs.iterrows(): + data = doc + if properties is None or not prop_keys.issubset(set(doc.keys())): + if not group_doc: + group_doc = self._read_doc_from_s3(self._get_full_key_path(docs)) + if group_doc.empty: + continue + data = group_doc.query(f"{self.key} == '{doc[self.key]}'") + data = data.to_dict(orient="index")[0] + if properties is None: + yield data + else: + yield {p: data[p] for p in prop_keys if p in data} + + def _read_doc_from_s3(self, file_id: str) -> pd.DataFrame: + try: + response = self.s3_bucket.Object(file_id).get() + return pd.read_json(response["Body"], orient="records", lines=True, compression={"method": "gzip"}) + except ClientError as ex: + if ex.response["Error"]["Code"] == "NoSuchKey": + return pd.DataFrame() + raise def _get_full_key_path(self, id: str) -> str: - if self.index.collection_name == "thermo" and self.key == "thermo_id": - material_id, thermo_type = id.split("_", 1) - return f"{self.sub_dir}{thermo_type}/{material_id}{self.object_file_extension}" - if self.index.collection_name == "xas" and self.key == "spectrum_id": - material_id, spectrum_type, absorbing_element, edge = id.rsplit("-", 3) - return f"{self.sub_dir}{edge}/{spectrum_type}/{absorbing_element}/{material_id}{self.object_file_extension}" - if self.index.collection_name == "synth_descriptions" and self.key == "doi": - return f"{self.sub_dir}{id.replace('/', '_')}{self.object_file_extension}" + raise NotImplementedError("Not implemented for this store") + + def _get_full_key_path(self, index: pd.DataFrame) -> str: + id = "" + for group in self.object_grouping: + id = f"{id}{group}={index[group].iloc[0]}/" + id = id.rstrip("/") return f"{self.sub_dir}{id}{self.object_file_extension}" def _get_compression_function(self) -> Callable: @@ -513,36 +589,83 @@ def _get_compression_function(self) -> Callable: def _get_decompression_function(self) -> Callable: return gzip.decompress - def _read_data(self, data: bytes, compress_header: str = "gzip") -> List[Dict]: - if compress_header is not None: - data = self._get_decompression_function()(data) - return orjson.loads(data) + def _gather_indexable_data(self, df: pd.DataFrame, search_keys: List[str]) -> pd.DataFrame: + return df[search_keys] - def _gather_indexable_data(self, doc: Dict, search_keys: List[str]) -> Dict: - index_doc = {k: doc[k] for k in search_keys} - index_doc[self.key] = doc[self.key] # Ensure key is in metadata - # Ensure last updated field is in metada if it's present in the data - if self.last_updated_field in doc: - index_doc[self.last_updated_field] = doc[self.last_updated_field] - return index_doc + def update( + self, + docs: Union[List[Dict], Dict], + key: Union[List, str, None] = None, + additional_metadata: Union[str, List[str], None] = None, + ): + if additional_metadata is not None: + raise NotImplementedError("updating store with additional metadata is not supported") + super().update(docs=docs, key=key) + + def _write_to_s3_and_index(self, docs: List[Dict], search_keys: List[str]): + """Implements updating of the provided documents in S3 and the index. + + Args: + docs (List[Dict]): The documents to update + search_keys (List[str]): The keys of the information to be updated in the index + """ + # group docs to update by object grouping + og = list(set(self.object_grouping) | set(search_keys)) + df = pd.json_normalize(docs, sep="_")[og] + df_grouped = df.groupby(self.object_grouping) + existing = self.index._data + docs_df = pd.DataFrame(docs) + for group, _ in df_grouped: + query_str = " and ".join([f"{col} == {val!r}" for col, val in zip(self.object_grouping, group)]) + sub_df = df.query(query_str) + sub_docs_df = docs_df[docs_df[self.key].isin(sub_df[self.key].unique())] + merged_df = sub_df + if existing is not None: + # fetch subsection of existing and docs_df and do outer merge with indicator=True + sub_existing = existing.query(query_str) + merged_df = sub_existing.merge(sub_df, on=og, how="outer", indicator=True) + # if there's any rows in existing only + if not merged_df[merged_df["_merge"] == "left_only"].empty: + ## fetch the S3 data and populate those rows in sub_docs_df + s3_df = self._read_doc_from_s3(self._get_full_key_path(sub_existing)) + # sub_docs + sub_docs_df = sub_docs_df.merge(merged_df[[self.key, "_merge"]], on=self.key, how="right") + sub_docs_df.update(s3_df, overwrite=False) + sub_docs_df = sub_docs_df.drop("_merge", axis=1) + + merged_df = merged_df.drop("_merge", axis=1) + # write doc based on subsection + self._write_doc_and_update_index(sub_docs_df, merged_df) + + def _write_doc_and_update_index(self, items: pd.DataFrame, index: pd.DataFrame) -> None: + self.write_doc_to_s3(items, index) + self.index.update(index) + + def write_doc_to_s3(self, doc, search_keys): + if not isinstance(doc, pd.DataFrame): + raise NotImplementedError("doc parameter must be a Pandas DataFrame for the implementation for this store") + if not isinstance(search_keys, pd.DataFrame): + raise NotImplementedError( + "search_keys parameter must be a Pandas DataFrame for the implementation for this store" + ) + # def write_doc_to_s3(self, items: pd.DataFrame, index: pd.DataFrame) -> None: + string_io = StringIO() + with jsonlines.Writer(string_io, dumps=json_util.dumps) as writer: + for _, row in doc.iterrows(): + writer.write(row.to_dict()) - def write_doc_to_s3(self, doc: Dict, search_keys: List[str]) -> Dict: - search_doc = self._gather_indexable_data(doc, search_keys) + data = self._get_compression_function()(string_io.getvalue().encode("utf-8")) - data = orjson.dumps(doc, default=json_util.default) - data = self._get_compression_function()(data) self._get_bucket().upload_fileobj( Fileobj=BytesIO(data), - Key=self._get_full_key_path(str(doc[self.key])), + Key=self._get_full_key_path(search_keys), ) - return search_doc - def _index_for_doc_from_s3(self, bucket, key: str) -> Dict: - response = bucket.Object(key).get() - doc = self._read_data(response["Body"].read()) + def _index_for_doc_from_s3(self, key: str) -> pd.DataFrame: + doc = self._read_doc_from_s3(key) return self._gather_indexable_data(doc, self.searchable_fields) - def rebuild_index_from_s3_data(self) -> List[Dict]: + def rebuild_index_from_s3_data(self) -> pd.DataFrame: """ Rebuilds the index Store from the data in S3 Stores only the key, last_updated_field and searchable_fields in the index. @@ -561,12 +684,12 @@ def rebuild_index_from_s3_data(self) -> List[Dict]: for file in page["Contents"]: key = file["Key"] if key != self.index._get_full_key_path(): - index_doc = self._index_for_doc_from_s3(bucket, key) - all_index_docs.append(index_doc) - self.index.store_manifest(all_index_docs) - return all_index_docs + all_index_docs.append(self._index_for_doc_from_s3(key)) + ret = pd.concat(all_index_docs, ignore_index=True) + self.index.store_manifest(ret) + return ret - def rebuild_index_from_data(self, docs: List[Dict]) -> List[Dict]: + def rebuild_index_from_data(self, docs: pd.DataFrame) -> pd.DataFrame: """ Rebuilds the index Store from the provided data. The provided data needs to include all of the documents in this data set. @@ -578,10 +701,7 @@ def rebuild_index_from_data(self, docs: List[Dict]) -> List[Dict]: Returns: List[Dict]: The set of docs representing the index data. """ - all_index_docs = [] - for doc in docs: - index_doc = self._gather_indexable_data(doc, self.searchable_fields) - all_index_docs.append(index_doc) + all_index_docs = self._gather_indexable_data(docs, self.searchable_fields) self.index.store_manifest(all_index_docs) return all_index_docs @@ -593,6 +713,7 @@ def __hash__(self): self.endpoint_url, self.key, self.sub_dir, + tuple(self.object_grouping), ) ) @@ -614,5 +735,6 @@ def __eq__(self, other: object) -> bool: "searchable_fields", "sub_dir", "last_updated_field", + "object_grouping", ] return all(getattr(self, f) == getattr(other, f) for f in fields) diff --git a/tests/stores/test_open_data.py b/tests/stores/test_open_data.py index 383c51926..176484bce 100644 --- a/tests/stores/test_open_data.py +++ b/tests/stores/test_open_data.py @@ -1,9 +1,10 @@ import pickle -import time from datetime import datetime +from io import BytesIO, StringIO import boto3 -import orjson +import jsonlines +import pandas as pd import pytest from botocore.exceptions import ClientError from bson import json_util @@ -26,7 +27,7 @@ def s3store(): conn.create_bucket(Bucket="bucket1") index = S3IndexStore(collection_name="index", bucket="bucket1", key="task_id") - store = OpenDataStore(index=index, bucket="bucket1", key="task_id") + store = OpenDataStore(index=index, bucket="bucket1", key="task_id", object_grouping=["task_id"]) store.connect() store.update( @@ -58,20 +59,9 @@ def s3store_w_subdir(): conn.create_bucket(Bucket="bucket1") index = S3IndexStore(collection_name="index", bucket="bucket1", key="task_id") - store = OpenDataStore(index=index, bucket="bucket1", key="task_id", sub_dir="subdir1", s3_workers=1) - store.connect() - - yield store - - -@pytest.fixture() -def s3store_multi(): - with mock_s3(): - conn = boto3.resource("s3", region_name="us-east-1") - conn.create_bucket(Bucket="bucket1") - - index = S3IndexStore(collection_name="index", bucket="bucket1", key="task_id") - store = OpenDataStore(index=index, bucket="bucket1", key="task_id", s3_workers=4) + store = OpenDataStore( + index=index, bucket="bucket1", key="task_id", sub_dir="subdir1", s3_workers=1, object_grouping=["task_id"] + ) store.connect() yield store @@ -84,9 +74,13 @@ def s3indexstore(): conn = boto3.resource("s3", region_name="us-east-1") conn.create_bucket(Bucket="bucket1") client = boto3.client("s3", region_name="us-east-1") + string_io = StringIO() + with jsonlines.Writer(string_io, dumps=json_util.dumps) as writer: + for _, row in pd.DataFrame(data).iterrows(): + writer.write(row.to_dict()) client.put_object( Bucket="bucket1", - Body=orjson.dumps(data, default=json_util.default), + Body=BytesIO(string_io.getvalue().encode("utf-8")), Key="manifest.json", ) @@ -123,7 +117,7 @@ def test_index_load_manifest(s3indexstore): def test_index_store_manifest(s3indexstore): - data = [{"task_id": "mp-2", "last_updated": datetime.utcnow()}] + data = pd.DataFrame([{"task_id": "mp-2", "last_updated": datetime.utcnow()}]) s3indexstore.store_manifest(data) assert s3indexstore.count() == 1 assert s3indexstore.query_one({"query": "task_id == 'mp-1'"}) is None @@ -139,52 +133,25 @@ def test_keys(): store = OpenDataStore(index=index, bucket="bucket1", s3_workers=4, key=1) index = S3IndexStore(collection_name="test", bucket="bucket1", key="key1") with pytest.warns(UserWarning, match=r"The desired S3Store.*$"): - store = OpenDataStore(index=index, bucket="bucket1", s3_workers=4, key="key2") + store = OpenDataStore(index=index, bucket="bucket1", s3_workers=4, key="key2", object_grouping=["key1"]) store.connect() - store.update({"key1": "mp-1", "data": "1234"}) + store.update({"key1": "mp-1", "data": "1234", store.last_updated_field: datetime.utcnow()}) with pytest.raises(KeyError): store.update({"key2": "mp-2", "data": "1234"}) assert store.key == store.index.key == "key1" -def test_multi_update(s3store, s3store_multi): - data = [ - { - "task_id": str(j), - "data": "DATA", - s3store_multi.last_updated_field: datetime.utcnow(), - } - for j in range(32) - ] - - def fake_writing(doc, search_keys): - time.sleep(0.20) - return {k: doc[k] for k in search_keys} - - s3store.write_doc_to_s3 = fake_writing - s3store_multi.write_doc_to_s3 = fake_writing - - start = time.time() - s3store_multi.update(data, key=["task_id"]) - end = time.time() - time_multi = end - start - - start = time.time() - s3store.update(data, key=["task_id"]) - end = time.time() - time_single = end - start - assert time_single > time_multi * (s3store_multi.s3_workers - 1) / (s3store.s3_workers) - - def test_count(s3store): assert s3store.count() == 2 assert s3store.count({"query": "task_id == 'mp-3'"}) == 1 -def test_qeuery(s3store): +def test_query(s3store): assert s3store.query_one(criteria={"query": "task_id == 'mp-2'"}) is None assert s3store.query_one(criteria={"query": "task_id == 'mp-1'"})["data"] == "asd" assert s3store.query_one(criteria={"query": "task_id == 'mp-3'"})["data"] == "sdf" + assert s3store.query_one(criteria={"query": "task_id == 'mp-1'"}, properties=["task_id"])["task_id"] == "mp-1" + assert s3store.query_one(criteria={"query": "task_id == 'mp-1'"}, properties=["task_id", "data"])["data"] == "asd" assert len(list(s3store.query())) == 2 @@ -201,32 +168,31 @@ def test_update(s3store): ) assert s3store.query_one({"query": "task_id == 'mp-199999'"}) is not None - s3store.update([{"task_id": "mp-4", "data": "asd"}]) + mp4 = [{"task_id": "mp-4", "data": "asd", s3store.last_updated_field: datetime.utcnow()}] + s3store.update(mp4) assert s3store.query_one({"query": "task_id == 'mp-4'"})["data"] == "asd" - assert s3store.s3_bucket.Object(s3store._get_full_key_path("mp-4")).key == "mp-4.json.gz" + assert s3store.s3_bucket.Object(s3store._get_full_key_path(pd.DataFrame(mp4))).key == "task_id=mp-4.jsonl.gz" def test_rebuild_index_from_s3_data(s3store): - s3store.update([{"task_id": "mp-2", "data": "asd"}]) + s3store.update([{"task_id": "mp-2", "data": "asd", s3store.last_updated_field: datetime.utcnow()}]) index_docs = s3store.rebuild_index_from_s3_data() assert len(index_docs) == 3 - for doc in index_docs: - for key in doc: - assert key == "task_id" or key == "last_updated" + for key in index_docs.columns: + assert key == "task_id" or key == "last_updated" def test_rebuild_index_from_data(s3store): - data = [{"task_id": "mp-2", "data": "asd", "last_updated": datetime.utcnow()}] - index_docs = s3store.rebuild_index_from_data(data) + data = [{"task_id": "mp-2", "data": "asd", s3store.last_updated_field: datetime.utcnow()}] + index_docs = s3store.rebuild_index_from_data(pd.DataFrame(data)) assert len(index_docs) == 1 - for doc in index_docs: - for key in doc: - assert key == "task_id" or key == "last_updated" + for key in index_docs.columns: + assert key == "task_id" or key == "last_updated" def tests_msonable_read_write(s3store, memstore): dd = memstore.as_dict() - s3store.update([{"task_id": "mp-2", "data": dd}]) + s3store.update([{"task_id": "mp-2", "data": dd, s3store.last_updated_field: datetime.utcnow()}]) res = s3store.query_one({"query": "task_id == 'mp-2'"}) assert res["data"]["@module"] == "maggma.stores.open_data" @@ -269,8 +235,12 @@ def test_eq(memstore, s3store): def test_count_subdir(s3store_w_subdir): - s3store_w_subdir.update([{"task_id": "mp-1", "data": "asd"}]) - s3store_w_subdir.update([{"task_id": "mp-2", "data": "asd"}]) + s3store_w_subdir.update( + [{"task_id": "mp-1", "data": "asd", s3store_w_subdir.last_updated_field: datetime.utcnow()}] + ) + s3store_w_subdir.update( + [{"task_id": "mp-2", "data": "asd", s3store_w_subdir.last_updated_field: datetime.utcnow()}] + ) assert s3store_w_subdir.count() == 2 assert s3store_w_subdir.count({"query": "task_id == 'mp-2'"}) == 1 @@ -281,23 +251,15 @@ def objects_in_bucket(key): objs = list(s3store_w_subdir.s3_bucket.objects.filter(Prefix=key)) return key in [o.key for o in objs] - s3store_w_subdir.update([{"task_id": "mp-1", "data": "asd"}]) - s3store_w_subdir.update([{"task_id": "mp-2", "data": "asd"}]) - - assert objects_in_bucket("subdir1/mp-1.json.gz") - assert objects_in_bucket("subdir1/mp-2.json.gz") - - -def test_searchable_fields(s3store): - tic = datetime(2018, 4, 12, 16) - - data = [{"task_id": f"mp-{i}", "a": i, s3store.last_updated_field: tic} for i in range(4)] - - s3store.searchable_fields = ["task_id"] - s3store.update(data, key="a") + s3store_w_subdir.update( + [{"task_id": "mp-1", "data": "asd", s3store_w_subdir.last_updated_field: datetime.utcnow()}] + ) + s3store_w_subdir.update( + [{"task_id": "mp-2", "data": "asd", s3store_w_subdir.last_updated_field: datetime.utcnow()}] + ) - # This should only work if the searchable field was put into the index store - assert set(s3store.distinct("task_id")) == {"mp-0", "mp-1", "mp-2", "mp-3"} + assert objects_in_bucket("subdir1/task_id=mp-1.jsonl.gz") + assert objects_in_bucket("subdir1/task_id=mp-2.jsonl.gz") def test_newer_in(s3store): @@ -309,13 +271,15 @@ def test_newer_in(s3store): conn.create_bucket(Bucket="bucket") index_old = S3IndexStore(collection_name="index_old", bucket="bucket", key="task_id") - old_store = OpenDataStore(index=index_old, bucket="bucket", key="task_id") + old_store = OpenDataStore(index=index_old, bucket="bucket", key="task_id", object_grouping=["task_id"]) old_store.connect() old_store.update([{"task_id": "mp-1", "last_updated": tic}]) old_store.update([{"task_id": "mp-2", "last_updated": tic}]) index_new = S3IndexStore(collection_name="index_new", bucket="bucket", prefix="new", key="task_id") - new_store = OpenDataStore(index=index_new, bucket="bucket", sub_dir="new", key="task_id") + new_store = OpenDataStore( + index=index_new, bucket="bucket", sub_dir="new", key="task_id", object_grouping=["task_id"] + ) new_store.connect() new_store.update([{"task_id": "mp-1", "last_updated": tic2}]) new_store.update([{"task_id": "mp-2", "last_updated": tic2}]) @@ -338,10 +302,8 @@ def test_additional_metadata(s3store): data = [{"task_id": f"mp-{i}", "a": i, s3store.last_updated_field: tic} for i in range(4)] - s3store.update(data, key="a", additional_metadata="task_id") - - # This should only work if the searchable field was put into the index store - assert set(s3store.distinct("task_id")) == {"mp-0", "mp-1", "mp-2", "mp-3"} + with pytest.raises(NotImplementedError): + s3store.update(data, key="a", additional_metadata="task_id") def test_get_session(s3store): @@ -354,6 +316,7 @@ def test_get_session(s3store): "aws_access_key_id": "ACCESS_KEY", "aws_secret_access_key": "SECRET_KEY", }, + object_grouping=["task_id"], ) assert store._get_session().get_credentials().access_key == "ACCESS_KEY" assert store._get_session().get_credentials().secret_key == "SECRET_KEY" @@ -365,7 +328,7 @@ def test_no_bucket(): conn.create_bucket(Bucket="bucket1") index = PandasMemoryStore(key="task_id") - store = OpenDataStore(index=index, bucket="bucket2", key="task_id") + store = OpenDataStore(index=index, bucket="bucket2", key="task_id", object_grouping=["task_id"]) with pytest.raises(RuntimeError, match=r".*Bucket not present.*"): store.connect() @@ -379,107 +342,3 @@ def test_pickle(s3store_w_subdir): dobj = pickle.loads(sobj) assert hash(dobj) == hash(s3store_w_subdir) assert dobj == s3store_w_subdir - - -@pytest.fixture() -def thermo_store(): - with mock_s3(): - conn = boto3.resource("s3", region_name="us-east-1") - conn.create_bucket(Bucket="bucket1") - - index = S3IndexStore(collection_name="thermo", bucket="bucket1", key="thermo_id") - store = OpenDataStore(index=index, bucket="bucket1", key="thermo_id") - store.connect() - - store.update( - [ - { - "thermo_id": "mp-1_R2SCAN", - "data": "asd", - store.last_updated_field: datetime.utcnow(), - } - ] - ) - - yield store - - -def test_thermo_collection_special_handling(thermo_store): - assert thermo_store.s3_bucket.Object(thermo_store._get_full_key_path("mp-1_R2SCAN")).key == "R2SCAN/mp-1.json.gz" - thermo_store.update([{"thermo_id": "mp-2_RSCAN", "data": "asd"}]) - index_docs = thermo_store.rebuild_index_from_s3_data() - assert len(index_docs) == 2 - for doc in index_docs: - for key in doc: - assert key == "thermo_id" or key == "last_updated" - - -@pytest.fixture() -def xas_store(): - with mock_s3(): - conn = boto3.resource("s3", region_name="us-east-1") - conn.create_bucket(Bucket="bucket1") - - index = S3IndexStore(collection_name="xas", bucket="bucket1", key="spectrum_id") - store = OpenDataStore(index=index, bucket="bucket1", key="spectrum_id") - store.connect() - - store.update( - [ - { - "spectrum_id": "mp-1-XAFS-Cr-K", - "data": "asd", - store.last_updated_field: datetime.utcnow(), - } - ] - ) - - yield store - - -def test_xas_collection_special_handling(xas_store): - assert xas_store.s3_bucket.Object(xas_store._get_full_key_path("mp-1-XAFS-Cr-K")).key == "K/XAFS/Cr/mp-1.json.gz" - xas_store.update([{"spectrum_id": "mp-2-XAFS-Li-K", "data": "asd"}]) - index_docs = xas_store.rebuild_index_from_s3_data() - assert len(index_docs) == 2 - for doc in index_docs: - for key in doc: - assert key == "spectrum_id" or key == "last_updated" - - -@pytest.fixture() -def synth_descriptions_store(): - with mock_s3(): - conn = boto3.resource("s3", region_name="us-east-1") - conn.create_bucket(Bucket="bucket1") - - index = S3IndexStore(collection_name="synth_descriptions", bucket="bucket1", key="doi") - store = OpenDataStore(index=index, bucket="bucket1", key="doi") - store.connect() - - store.update( - [ - { - "doi": "10.1149/2.051201jes", - "data": "asd", - store.last_updated_field: datetime.utcnow(), - } - ] - ) - - yield store - - -def test_synth_descriptions_collection_special_handling(synth_descriptions_store): - assert ( - synth_descriptions_store.s3_bucket.Object( - synth_descriptions_store._get_full_key_path("10.1149/2.051201jes") - ).key - == "10.1149_2.051201jes.json.gz" - ) - synth_descriptions_store.update([{"doi": "10.1039/C5CP01095K", "data": "asd"}]) - index_docs = synth_descriptions_store.rebuild_index_from_s3_data() - assert len(index_docs) == 2 - for doc in index_docs: - for key in doc: - assert key == "doi" or key == "last_updated"