Skip to content

Commit

Permalink
Merge pull request #911 from kbuma/enhancement/open_data_format_update
Browse files Browse the repository at this point in the history
updating for open data format change
  • Loading branch information
munrojm committed Feb 2, 2024
2 parents a72d2d1 + 6baf8bc commit efb4b2b
Show file tree
Hide file tree
Showing 3 changed files with 229 additions and 248 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
"montydb": ["montydb>=2.3.12"],
"notebook_runner": ["IPython>=8.11", "nbformat>=5.0", "regex>=2020.6"],
"azure": ["azure-storage-blob>=12.16.0", "azure-identity>=1.12.0"],
"open_data": ["pandas>=2.1.4"],
"open_data": ["pandas>=2.1.4", "jsonlines>=4.0.0"],
"testing": [
"pytest",
"pytest-cov",
Expand Down
232 changes: 177 additions & 55 deletions src/maggma/stores/open_data.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import gzip
from datetime import datetime
from io import BytesIO
from io import BytesIO, StringIO
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union

import orjson
import jsonlines
import pandas as pd
from boto3 import Session
from botocore import UNSIGNED
Expand Down Expand Up @@ -31,7 +31,7 @@ def __init__(
@property
def _collection(self):
"""
Returns a handle to the pymongo collection object
Returns a handle to the pymongo collection object.
Raises:
NotImplementedError: always as this concept does not make sense for this type of store
Expand All @@ -41,7 +41,7 @@ def _collection(self):
@property
def name(self) -> str:
"""
Return a string representing this data source
Return a string representing this data source.
"""
return "imem://"

Expand Down Expand Up @@ -151,6 +151,15 @@ def query(
ret = self._query(criteria=criteria, properties=properties, sort=sort, skip=skip, limit=limit)
return (row.to_dict() for _, row in ret.iterrows())

@staticmethod
def add_missing_items(to_dt: pd.DataFrame, from_dt: pd.DataFrame, on: List[str]) -> pd.DataFrame:
orig_columns = to_dt.columns
merged = to_dt.merge(from_dt, on=on, how="left", suffixes=("", "_B"))
for column in from_dt.columns:
if column not in on:
merged[column].update(merged.pop(column + "_B"))
return merged[orig_columns]

def update(self, docs: Union[List[Dict], Dict], key: Union[List, str, None] = None, clear_first: bool = False):
"""
Update documents into the Store
Expand All @@ -164,20 +173,17 @@ def update(self, docs: Union[List[Dict], Dict], key: Union[List, str, None] = No
clear_first: if True clears the underlying data first, fully replacing the data with docs;
if False performs an upsert based on the parameters
"""
if key is not None:
raise NotImplementedError("updating store based on a key different than the store key is not supported")

df = pd.DataFrame(docs)
if self._data is None or clear_first:
if not df.empty:
self._data = df
return
key = [self.key]

if key is None:
key = self.key

merged = self._data.merge(df, on=key, how="left", suffixes=("", "_B"))
for column in df.columns:
if column not in key:
merged[column] = merged[column + "_B"].combine_first(merged[column])
merged = merged[self._data.columns]
merged = PandasMemoryStore.add_missing_items(to_dt=self._data, from_dt=df, on=key)
non_matching = df[~df.set_index(key).index.isin(self._data.set_index(key).index)]
self._data = pd.concat([merged, non_matching], ignore_index=True)

Expand Down Expand Up @@ -372,7 +378,15 @@ def _retrieve_manifest(self) -> pd.DataFrame:
"""
try:
response = self.client.get_object(Bucket=self.bucket, Key=self._get_full_key_path())
return pd.read_json(response["Body"], orient="records")
df = pd.read_json(response["Body"], orient="records", lines=True)
if self.last_updated_field in df.columns:
df[self.last_updated_field] = df[self.last_updated_field].apply(
lambda x: datetime.fromisoformat(x["$date"].rstrip("Z"))
if isinstance(x, dict) and "$date" in x
else x
)
return df

except ClientError as ex:
if ex.response["Error"]["Code"] == "NoSuchKey":
return []
Expand All @@ -386,17 +400,22 @@ def _load_index(self, force_reset: bool = False) -> None:
"""
super().update(self._retrieve_manifest(), clear_first=True)

def store_manifest(self, data: List[Dict]) -> None:
def store_manifest(self, data: pd.DataFrame) -> None:
"""Stores the provided data into the index stored in S3.
This overwrites and fully replaces all of the contents of the previous index stored in S3.
It also rewrites the memory index with the provided data.
Args:
data (List[Dict]): The data to store in the index.
"""
string_io = StringIO()
with jsonlines.Writer(string_io, dumps=json_util.dumps) as writer:
for _, row in data.iterrows():
writer.write(row.to_dict())

self.client.put_object(
Bucket=self.bucket,
Body=orjson.dumps(data, default=json_util.default),
Body=BytesIO(string_io.getvalue().encode("utf-8")),
Key=self._get_full_key_path(),
)
super().update(data, clear_first=True)
Expand Down Expand Up @@ -458,8 +477,9 @@ def __init__(
sub_dir: Optional[str] = None,
key: str = "fs_id",
searchable_fields: Optional[List[str]] = None,
object_file_extension: str = ".json.gz",
object_file_extension: str = ".jsonl.gz",
access_as_public_bucket: bool = False,
object_grouping: Optional[List[str]] = None,
**kwargs,
):
"""Initializes an OpenDataStore
Expand All @@ -481,10 +501,11 @@ def __init__(
self.searchable_fields = searchable_fields if searchable_fields is not None else []
self.object_file_extension = object_file_extension
self.access_as_public_bucket = access_as_public_bucket
self.object_grouping = object_grouping if object_grouping is not None else ["nelements", "symmetry_number"]

if access_as_public_bucket:
kwargs["s3_resource_kwargs"] = kwargs["s3_resource_kwargs"] if "s3_resource_kwargs" in kwargs else {}
kwargs["s3_resource_kwargs"]["config"] = Config(signature_version=UNSIGNED)

kwargs["index"] = index
kwargs["bucket"] = bucket
kwargs["compress"] = compress
Expand All @@ -495,16 +516,71 @@ def __init__(
kwargs["unpack_data"] = True
self.kwargs = kwargs
super().__init__(**kwargs)
self.searchable_fields = list(
set(self.object_grouping) | set(self.searchable_fields) | {self.key, self.last_updated_field}
)

def query(
self,
criteria: Optional[Dict] = None,
properties: Union[Dict, List, None] = None,
sort: Optional[Dict[str, Union[Sort, int]]] = None,
skip: int = 0,
limit: int = 0,
) -> Iterator[Dict]:
"""
Queries the Store for a set of documents.
Args:
criteria: PyMongo filter for documents to search in.
properties: properties to return in grouped documents.
sort: Dictionary of sort order for fields. Keys are field names and values
are 1 for ascending or -1 for descending.
skip: number documents to skip.
limit: limit on total number of documents returned.
"""
prop_keys = set()
if isinstance(properties, dict):
prop_keys = set(properties.keys())
elif isinstance(properties, list):
prop_keys = set(properties)

for _, docs in self.index.groupby(
keys=self.object_grouping, criteria=criteria, sort=sort, limit=limit, skip=skip
):
group_doc = None # S3 backed group doc
for _, doc in docs.iterrows():
data = doc
if properties is None or not prop_keys.issubset(set(doc.keys())):
if not group_doc:
group_doc = self._read_doc_from_s3(self._get_full_key_path(docs))
if group_doc.empty:
continue
data = group_doc.query(f"{self.key} == '{doc[self.key]}'")
data = data.to_dict(orient="index")[0]
if properties is None:
yield data
else:
yield {p: data[p] for p in prop_keys if p in data}

def _read_doc_from_s3(self, file_id: str) -> pd.DataFrame:
try:
response = self.s3_bucket.Object(file_id).get()
return pd.read_json(response["Body"], orient="records", lines=True, compression={"method": "gzip"})
except ClientError as ex:
if ex.response["Error"]["Code"] == "NoSuchKey":
return pd.DataFrame()
raise

def _get_full_key_path(self, id: str) -> str:
if self.index.collection_name == "thermo" and self.key == "thermo_id":
material_id, thermo_type = id.split("_", 1)
return f"{self.sub_dir}{thermo_type}/{material_id}{self.object_file_extension}"
if self.index.collection_name == "xas" and self.key == "spectrum_id":
material_id, spectrum_type, absorbing_element, edge = id.rsplit("-", 3)
return f"{self.sub_dir}{edge}/{spectrum_type}/{absorbing_element}/{material_id}{self.object_file_extension}"
if self.index.collection_name == "synth_descriptions" and self.key == "doi":
return f"{self.sub_dir}{id.replace('/', '_')}{self.object_file_extension}"
raise NotImplementedError("Not implemented for this store")

def _get_full_key_path(self, index: pd.DataFrame) -> str:
id = ""
for group in self.object_grouping:
id = f"{id}{group}={index[group].iloc[0]}/"
id = id.rstrip("/")
return f"{self.sub_dir}{id}{self.object_file_extension}"

def _get_compression_function(self) -> Callable:
Expand All @@ -513,36 +589,83 @@ def _get_compression_function(self) -> Callable:
def _get_decompression_function(self) -> Callable:
return gzip.decompress

def _read_data(self, data: bytes, compress_header: str = "gzip") -> List[Dict]:
if compress_header is not None:
data = self._get_decompression_function()(data)
return orjson.loads(data)
def _gather_indexable_data(self, df: pd.DataFrame, search_keys: List[str]) -> pd.DataFrame:
return df[search_keys]

def _gather_indexable_data(self, doc: Dict, search_keys: List[str]) -> Dict:
index_doc = {k: doc[k] for k in search_keys}
index_doc[self.key] = doc[self.key] # Ensure key is in metadata
# Ensure last updated field is in metada if it's present in the data
if self.last_updated_field in doc:
index_doc[self.last_updated_field] = doc[self.last_updated_field]
return index_doc
def update(
self,
docs: Union[List[Dict], Dict],
key: Union[List, str, None] = None,
additional_metadata: Union[str, List[str], None] = None,
):
if additional_metadata is not None:
raise NotImplementedError("updating store with additional metadata is not supported")
super().update(docs=docs, key=key)

def _write_to_s3_and_index(self, docs: List[Dict], search_keys: List[str]):
"""Implements updating of the provided documents in S3 and the index.
Args:
docs (List[Dict]): The documents to update
search_keys (List[str]): The keys of the information to be updated in the index
"""
# group docs to update by object grouping
og = list(set(self.object_grouping) | set(search_keys))
df = pd.json_normalize(docs, sep="_")[og]
df_grouped = df.groupby(self.object_grouping)
existing = self.index._data
docs_df = pd.DataFrame(docs)
for group, _ in df_grouped:
query_str = " and ".join([f"{col} == {val!r}" for col, val in zip(self.object_grouping, group)])
sub_df = df.query(query_str)
sub_docs_df = docs_df[docs_df[self.key].isin(sub_df[self.key].unique())]
merged_df = sub_df
if existing is not None:
# fetch subsection of existing and docs_df and do outer merge with indicator=True
sub_existing = existing.query(query_str)
merged_df = sub_existing.merge(sub_df, on=og, how="outer", indicator=True)
# if there's any rows in existing only
if not merged_df[merged_df["_merge"] == "left_only"].empty:
## fetch the S3 data and populate those rows in sub_docs_df
s3_df = self._read_doc_from_s3(self._get_full_key_path(sub_existing))
# sub_docs
sub_docs_df = sub_docs_df.merge(merged_df[[self.key, "_merge"]], on=self.key, how="right")
sub_docs_df.update(s3_df, overwrite=False)
sub_docs_df = sub_docs_df.drop("_merge", axis=1)

merged_df = merged_df.drop("_merge", axis=1)
# write doc based on subsection
self._write_doc_and_update_index(sub_docs_df, merged_df)

def _write_doc_and_update_index(self, items: pd.DataFrame, index: pd.DataFrame) -> None:
self.write_doc_to_s3(items, index)
self.index.update(index)

def write_doc_to_s3(self, doc, search_keys):
if not isinstance(doc, pd.DataFrame):
raise NotImplementedError("doc parameter must be a Pandas DataFrame for the implementation for this store")
if not isinstance(search_keys, pd.DataFrame):
raise NotImplementedError(
"search_keys parameter must be a Pandas DataFrame for the implementation for this store"
)
# def write_doc_to_s3(self, items: pd.DataFrame, index: pd.DataFrame) -> None:
string_io = StringIO()
with jsonlines.Writer(string_io, dumps=json_util.dumps) as writer:
for _, row in doc.iterrows():
writer.write(row.to_dict())

def write_doc_to_s3(self, doc: Dict, search_keys: List[str]) -> Dict:
search_doc = self._gather_indexable_data(doc, search_keys)
data = self._get_compression_function()(string_io.getvalue().encode("utf-8"))

data = orjson.dumps(doc, default=json_util.default)
data = self._get_compression_function()(data)
self._get_bucket().upload_fileobj(
Fileobj=BytesIO(data),
Key=self._get_full_key_path(str(doc[self.key])),
Key=self._get_full_key_path(search_keys),
)
return search_doc

def _index_for_doc_from_s3(self, bucket, key: str) -> Dict:
response = bucket.Object(key).get()
doc = self._read_data(response["Body"].read())
def _index_for_doc_from_s3(self, key: str) -> pd.DataFrame:
doc = self._read_doc_from_s3(key)
return self._gather_indexable_data(doc, self.searchable_fields)

def rebuild_index_from_s3_data(self) -> List[Dict]:
def rebuild_index_from_s3_data(self) -> pd.DataFrame:
"""
Rebuilds the index Store from the data in S3
Stores only the key, last_updated_field and searchable_fields in the index.
Expand All @@ -561,12 +684,12 @@ def rebuild_index_from_s3_data(self) -> List[Dict]:
for file in page["Contents"]:
key = file["Key"]
if key != self.index._get_full_key_path():
index_doc = self._index_for_doc_from_s3(bucket, key)
all_index_docs.append(index_doc)
self.index.store_manifest(all_index_docs)
return all_index_docs
all_index_docs.append(self._index_for_doc_from_s3(key))
ret = pd.concat(all_index_docs, ignore_index=True)
self.index.store_manifest(ret)
return ret

def rebuild_index_from_data(self, docs: List[Dict]) -> List[Dict]:
def rebuild_index_from_data(self, docs: pd.DataFrame) -> pd.DataFrame:
"""
Rebuilds the index Store from the provided data.
The provided data needs to include all of the documents in this data set.
Expand All @@ -578,10 +701,7 @@ def rebuild_index_from_data(self, docs: List[Dict]) -> List[Dict]:
Returns:
List[Dict]: The set of docs representing the index data.
"""
all_index_docs = []
for doc in docs:
index_doc = self._gather_indexable_data(doc, self.searchable_fields)
all_index_docs.append(index_doc)
all_index_docs = self._gather_indexable_data(docs, self.searchable_fields)
self.index.store_manifest(all_index_docs)
return all_index_docs

Expand All @@ -593,6 +713,7 @@ def __hash__(self):
self.endpoint_url,
self.key,
self.sub_dir,
tuple(self.object_grouping),
)
)

Expand All @@ -614,5 +735,6 @@ def __eq__(self, other: object) -> bool:
"searchable_fields",
"sub_dir",
"last_updated_field",
"object_grouping",
]
return all(getattr(self, f) == getattr(other, f) for f in fields)
Loading

0 comments on commit efb4b2b

Please sign in to comment.