From 7d20f3e32a24a4268d35b2d14e058c5a468324d1 Mon Sep 17 00:00:00 2001 From: Ryan Kingsbury Date: Tue, 27 Jun 2023 18:13:08 -0400 Subject: [PATCH 01/11] setup.cfg -> pyproject.toml --- pyproject.toml | 42 ++++++++++++++++++++++++++++++++++++++++++ setup.cfg | 23 ----------------------- 2 files changed, 42 insertions(+), 23 deletions(-) create mode 100644 pyproject.toml delete mode 100644 setup.cfg diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..a49d485b9 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,42 @@ +[project] +name = "maggma" +dynamic = ["version", "readme"] +description="Framework to develop datapipelines from files on disk to full dissemenation API" +authors =[ + {name = "The Materials Project", email = "feedback@materialsproject.org"} +] + +[build-system] +requires = ["setuptools>=61.0.0", "setuptools_scm[toml]>=5"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.setuptools_scm] +version_scheme = "no-guess-dev" + +[tool.mypy] +ignore_missing_imports = "True" + +[tool.ruff] +# Never enforce `E471` +ignore = ["E741"] +# Set max line length +line-length = 120 +# exclude some files +exclude = [ + ".git", + "__pycache__", + "docs", + "__init__.py" +] + +[tool.pytest.ini_options] +addopts = "--durations=30" +testpaths = [ + "tests", +] + +[tool.pydocstyle] +ignore = ["D105","D2","D4"] diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 140b52323..000000000 --- a/setup.cfg +++ /dev/null @@ -1,23 +0,0 @@ -[tool:pytest] -addopts = --durations=30 - -[pycodestyle] -count = True -ignore = E121,E123,E126,E133,E226,E241,E242,E704,W503,W504,W505,E741,W605,W293 -max-line-length = 120 -statistics = True - -[flake8] -exclude = .git,__pycache__,docs_rst/conf.py,tests,pymatgen/io/abinit,__init__.py -# max-complexity = 10 -extend-ignore = E741 -max-line-length = 120 - -[isort] -profile=black - -[pydocstyle] -ignore = D105,D2,D4 - -[mypy] -ignore_missing_imports = True From 1297e9c5e1cfc2273775b5b7d12ef48039fdd4f0 Mon Sep 17 00:00:00 2001 From: Ryan Kingsbury Date: Mon, 31 Jul 2023 09:37:46 -0400 Subject: [PATCH 02/11] linting: sync ruff opts with custodian and pymatgen --- .flake8 | 5 --- .github/workflows/testing.yml | 30 ++++++------- .pre-commit-config.yaml | 4 +- pyproject.toml | 85 +++++++++++++++++++++++++++++------ 4 files changed, 88 insertions(+), 36 deletions(-) delete mode 100644 .flake8 diff --git a/.flake8 b/.flake8 deleted file mode 100644 index d9ad0b409..000000000 --- a/.flake8 +++ /dev/null @@ -1,5 +0,0 @@ -[flake8] -ignore = E203, E266, E501, W503, F403, F401 -max-line-length = 79 -max-complexity = 18 -select = B,C,E,F,W,T4,B9 diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 223fed105..591a31420 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -14,22 +14,22 @@ on: jobs: lint: runs-on: ubuntu-latest - + strategy: + max-parallel: 1 steps: - - uses: actions/checkout@v3 - - - name: Set up Python 3.8 - uses: actions/setup-python@v4 - with: - python-version: "3.10" - - - name: Install dependencies - run: | - pip install pre-commit - - - name: Run pre-commit - run: | - pre-commit run --all-files --show-diff-on-failure + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: 3.11 + cache: pip + - name: Run pre-commit + run: | + pip install pre-commit + pre-commit run + test: needs: lint diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 34d9e38b2..987d9550a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,8 +3,8 @@ default_stages: [commit] default_install_hook_types: [pre-commit, commit-msg] repos: - - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.261 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.0.275 hooks: - id: ruff args: [--fix, --ignore, "D,E501"] diff --git a/pyproject.toml b/pyproject.toml index a49d485b9..f28dabd03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,27 +16,84 @@ where = ["src"] [tool.setuptools_scm] version_scheme = "no-guess-dev" -[tool.mypy] -ignore_missing_imports = "True" +[tool.black] +line-length = 120 [tool.ruff] -# Never enforce `E471` -ignore = ["E741"] -# Set max line length +target-version = "py38" line-length = 120 -# exclude some files -exclude = [ - ".git", - "__pycache__", - "docs", - "__init__.py" +select = [ + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "D", # pydocstyle + "E", # pycodestyle error + "EXE", # flake8-executable + "F", # pyflakes + "FA", # flake8-future-annotations + "FLY", # flynt + "I", # isort + "ICN", # flake8-import-conventions + "ISC", # flake8-implicit-str-concat + "PD", # pandas-vet + "PERF", # perflint + "PIE", # flake8-pie + "PL", # pylint + "PT", # flake8-pytest-style + "PYI", # flakes8-pyi + "Q", # flake8-quotes + "RET", # flake8-return + "RSE", # flake8-raise + "RUF", # Ruff-specific rules + "SIM", # flake8-simplify + "SLOT", # flake8-slots + "TCH", # flake8-type-checking + "TID", # tidy imports + "TID", # flake8-tidy-imports + "UP", # pyupgrade + "W", # pycodestyle warning + "YTT", # flake8-2020 +] +ignore = [ + "B023", # Function definition does not bind loop variable + "B028", # No explicit stacklevel keyword argument found + "B904", # Within an except clause, raise exceptions with ... + "C408", # unnecessary-collection-call + "D105", # Missing docstring in magic method + "D205", # 1 blank line required between summary line and description + "D212", # Multi-line docstring summary should start at the first line + "FA100", # Missing `from __future__ import annotations`, but uses `typing.XXX` + "PD011", # pandas-use-of-dot-values + "PD901", # pandas-df-variable-name + "PT011", # `pytest.raises(XXXError)` is too broad, set the `match` parameter... + "PERF203", # try-except-in-loop + "PERF401", # manual-list-comprehension (TODO fix these or wait for autofix) + "PLR", # pylint refactor + "PLW2901", # Outer for loop variable overwritten by inner assignment target + "PT013", # pytest-incorrect-pytest-import + "RUF012", # Disable checks for mutable class args. This is a non-problem. + "SIM105", # Use contextlib.suppress(OSError) instead of try-except-pass ] +pydocstyle.convention = "google" +isort.split-on-trailing-comma = false + +[tool.ruff.per-file-ignores] +"__init__.py" = ["F401"] +"tasks.py" = ["D"] +"tests/*" = ["D"] +"src/maggma/api/*" = ["B008", "B021", "RET505", "RET506"] +"tests/api/*" = ["B017", "B018"] [tool.pytest.ini_options] -addopts = "--durations=30" +addopts = "--color=yes -p no:warnings --import-mode=importlib --durations=30" testpaths = [ "tests", ] -[tool.pydocstyle] -ignore = ["D105","D2","D4"] +[tool.mypy] +ignore_missing_imports = true +namespace_packages = true +explicit_package_bases = true +no_implicit_optional = false + +[tool.codespell] +ignore-words-list = "ot" From f5f06d9abaeda057926aed73c497799fe69a429f Mon Sep 17 00:00:00 2001 From: Ryan Kingsbury Date: Tue, 27 Jun 2023 16:37:04 -0400 Subject: [PATCH 03/11] replace flake8, isort, pycodestyle with ruff --- .github/workflows/testing.yml | 3 +-- .pre-commit-config.yaml | 2 +- pyproject.toml | 8 +++++--- requirements-testing.txt | 5 +---- src/maggma/api/resource/__init__.py | 3 --- src/maggma/api/utils.py | 11 ++++++++--- 6 files changed, 16 insertions(+), 16 deletions(-) diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 591a31420..de150217e 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -29,8 +29,7 @@ jobs: run: | pip install pre-commit pre-commit run - - + test: needs: lint services: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 987d9550a..8c9a33980 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ repos: rev: v0.0.275 hooks: - id: ruff - args: [--fix, --ignore, "D,E501"] + args: [--fix, --ignore, D] - repo: https://github.com/psf/black rev: 23.3.0 diff --git a/pyproject.toml b/pyproject.toml index f28dabd03..4287753e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ version_scheme = "no-guess-dev" line-length = 120 [tool.ruff] -target-version = "py38" +target-version = "py37" line-length = 120 select = [ "B", # flake8-bugbear @@ -61,10 +61,10 @@ ignore = [ "D105", # Missing docstring in magic method "D205", # 1 blank line required between summary line and description "D212", # Multi-line docstring summary should start at the first line - "FA100", # Missing `from __future__ import annotations`, but uses `typing.XXX` + "FA100", # Missing `from __future__ import annotations`, but uses `typing.XXX` TODO "PD011", # pandas-use-of-dot-values "PD901", # pandas-df-variable-name - "PT011", # `pytest.raises(XXXError)` is too broad, set the `match` parameter... + "PT011", # `pytest.raises(XXXError)` is too broad, set the `match` parameter... TODO "PERF203", # try-except-in-loop "PERF401", # manual-list-comprehension (TODO fix these or wait for autofix) "PLR", # pylint refactor @@ -82,6 +82,8 @@ isort.split-on-trailing-comma = false "tests/*" = ["D"] "src/maggma/api/*" = ["B008", "B021", "RET505", "RET506"] "tests/api/*" = ["B017", "B018"] +"src/maggma/cli/*" = ["EXE001"] # triggered by ! at top of file +"src/maggma/api/utils.py" = ["I001"] # to allow unsorted import block [tool.pytest.ini_options] addopts = "--color=yes -p no:warnings --import-mode=importlib --durations=30" diff --git a/requirements-testing.txt b/requirements-testing.txt index 6569e943f..038f33d25 100644 --- a/requirements-testing.txt +++ b/requirements-testing.txt @@ -5,10 +5,7 @@ pytest-cov==3.0.0 pytest-mock==3.10.0 pytest-xdist==2.5.0 moto==3.1.17 -pydocstyle==6.1.1 -flake8==4.0.1 -mypy==0.971 -mypy-extensions==0.4.3 +ruff==0.0.280 responses<0.22.0 types-PyYAML==6.0.11 types-setuptools==65.4.0.0 diff --git a/src/maggma/api/resource/__init__.py b/src/maggma/api/resource/__init__.py index 3726acd89..aafd982b1 100644 --- a/src/maggma/api/resource/__init__.py +++ b/src/maggma/api/resource/__init__.py @@ -1,10 +1,7 @@ -# isort: off from maggma.api.resource.core import Resource from maggma.api.resource.core import HintScheme from maggma.api.resource.core import HeaderProcessor -# isort: on - from maggma.api.resource.aggregation import AggregationResource from maggma.api.resource.post_resource import PostOnlyResource from maggma.api.resource.read_resource import ReadOnlyResource, attach_query_ops diff --git a/src/maggma/api/utils.py b/src/maggma/api/utils.py index e48c98a1e..aae16155a 100644 --- a/src/maggma/api/utils.py +++ b/src/maggma/api/utils.py @@ -1,8 +1,14 @@ import base64 import inspect import sys -from typing import Any, Callable, Dict, List, Optional, Type - +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Type, +) from bson.objectid import ObjectId from monty.json import MSONable from pydantic import BaseModel @@ -15,7 +21,6 @@ else: from typing_extensions import get_args # pragma: no cover - QUERY_PARAMS = ["criteria", "properties", "skip", "limit"] STORE_PARAMS = Dict[ Literal[ From 6690d0a822e052470942c475491f69744090b951 Mon Sep 17 00:00:00 2001 From: Ryan Kingsbury Date: Tue, 27 Jun 2023 17:03:35 -0400 Subject: [PATCH 04/11] typing fixes --- requirements.txt | 1 + src/maggma/core/store.py | 6 +++--- src/maggma/stores/aws.py | 5 +++-- src/maggma/stores/azure.py | 6 +++--- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/requirements.txt b/requirements.txt index 4bece3047..da4d5d506 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,6 +11,7 @@ fastapi==0.79.0 numpy==1.18.0;python_version<"3.8" numpy==1.23.0;python_version>="3.8" typing_extensions;python_version<"3.8" +typing_compat;python_version<"3.8" # required for get_args() pyzmq==24.0.1 dnspython==2.2.1 uvicorn==0.18.3 diff --git a/src/maggma/core/store.py b/src/maggma/core/store.py index 315cd9c02..9822ff435 100644 --- a/src/maggma/core/store.py +++ b/src/maggma/core/store.py @@ -7,7 +7,7 @@ from abc import ABCMeta, abstractmethod, abstractproperty from datetime import datetime from enum import Enum -from typing import Dict, Iterator, List, Optional, Tuple, Union +from typing import Dict, Iterator, List, Optional, Tuple, Union, Callable from monty.dev import deprecated from monty.json import MontyDecoder, MSONable @@ -55,11 +55,11 @@ def __init__( self.key = key self.last_updated_field = last_updated_field self.last_updated_type = last_updated_type - self._lu_func = ( + self._lu_func: Tuple[Callable, Callable] = ( LU_KEY_ISOFORMAT if DateTimeFormat(last_updated_type) == DateTimeFormat.IsoFormat else (identity, identity) - ) # type: Tuple[Callable, Callable] + ) self.validator = validator self.logger = logging.getLogger(type(self).__name__) self.logger.addHandler(logging.NullHandler()) diff --git a/src/maggma/stores/aws.py b/src/maggma/stores/aws.py index 5fc8149f0..ba4808e3b 100644 --- a/src/maggma/stores/aws.py +++ b/src/maggma/stores/aws.py @@ -8,6 +8,7 @@ from concurrent.futures import wait from concurrent.futures.thread import ThreadPoolExecutor from hashlib import sha1 +from typing import Dict, Iterator, List, Optional, Tuple, Union, Any from json import dumps from typing import Dict, Iterator, List, Optional, Tuple, Union @@ -77,8 +78,8 @@ def __init__( self.compress = compress self.endpoint_url = endpoint_url self.sub_dir = sub_dir.strip("/") + "/" if sub_dir else "" - self.s3 = None # type: Any - self.s3_bucket = None # type: Any + self.s3: Any = None + self.s3_bucket: Any = None self.s3_workers = s3_workers self.s3_resource_kwargs = ( s3_resource_kwargs if s3_resource_kwargs is not None else {} diff --git a/src/maggma/stores/azure.py b/src/maggma/stores/azure.py index 20fb18625..bef362a53 100644 --- a/src/maggma/stores/azure.py +++ b/src/maggma/stores/azure.py @@ -26,7 +26,7 @@ from azure.storage.blob import BlobServiceClient, ContainerClient except (ImportError, ModuleNotFoundError): azure_blob = None # type: ignore - ContainerClient = None + # ContainerClient = None AZURE_KEY_SANITIZE = {"-": "_", ".": "_"} @@ -94,8 +94,8 @@ def __init__( self.azure_client_info = azure_client_info self.compress = compress self.sub_dir = sub_dir.rstrip("/") + "/" if sub_dir else "" - self.service = None # type: Optional[BlobServiceClient] - self.container = None # type: Optional[ContainerClient] + self.service: Optional[BlobServiceClient] = None + self.container: Optional[ContainerClient] = None self.workers = workers self.azure_resource_kwargs = ( azure_resource_kwargs if azure_resource_kwargs is not None else {} From 4c3456fbd2e05f5a3117e74eb5e3b6307dfe3fd3 Mon Sep 17 00:00:00 2001 From: Ryan Kingsbury Date: Mon, 31 Jul 2023 09:28:06 -0400 Subject: [PATCH 05/11] add pre-commit ci and codespell; bump hook vers --- .pre-commit-config.yaml | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8c9a33980..fa9df33a0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,17 +2,31 @@ default_stages: [commit] default_install_hook_types: [pre-commit, commit-msg] +ci: + autoupdate_schedule: monthly + # skip: [mypy] + autofix_commit_msg: pre-commit auto-fixes + autoupdate_commit_msg: pre-commit autoupdate + repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.275 + rev: v0.0.280 hooks: - id: ruff args: [--fix, --ignore, D] - repo: https://github.com/psf/black - rev: 23.3.0 + rev: 23.7.0 + hooks: + - id: black + + - repo: https://github.com/codespell-project/codespell + rev: v2.2.5 hooks: - - id: black-jupyter + - id: codespell + stages: [commit, commit-msg] + exclude_types: [html] + additional_dependencies: [tomli] # needed to read pyproject.toml below py3.11 - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 From 2cac39597cf0b38fdb50ecb24d9f95f5d8782421 Mon Sep 17 00:00:00 2001 From: Ryan Kingsbury Date: Mon, 31 Jul 2023 10:17:40 -0400 Subject: [PATCH 06/11] add upgrade dependencies workflow --- .github/workflows/upgrade-dependencies.yml | 154 +++++++++++++++++++++ 1 file changed, 154 insertions(+) create mode 100644 .github/workflows/upgrade-dependencies.yml diff --git a/.github/workflows/upgrade-dependencies.yml b/.github/workflows/upgrade-dependencies.yml new file mode 100644 index 000000000..d67066955 --- /dev/null +++ b/.github/workflows/upgrade-dependencies.yml @@ -0,0 +1,154 @@ +# https://www.oddbird.net/2022/06/01/dependabot-single-pull-request/ +name: upgrade dependencies + +on: + workflow_dispatch: # Allow running on-demand + schedule: + # Runs every Monday at 8:00 UTC (4:00 Eastern) + - cron: '0 17 * * 1' + +jobs: + upgrade: + name: ${{ matrix.package }} (${{ matrix.os }}/py${{ matrix.python-version }}) + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: ['ubuntu-latest'] + package: ["optimade"] + python-version: ["3.10"] + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements/${{ matrix.os }}_py${{ matrix.python-version }}.txt' + - name: Upgrade Python dependencies + shell: bash + run: | + python${{ matrix.python-version }} -m pip install --upgrade pip pip-tools + cd docker/${{ matrix.package }} + python${{ matrix.python-version }} -m piptools compile -q --resolver=backtracking --upgrade -o requirements/${{ matrix.os }}_py${{ matrix.python-version }}.txt + - name: Detect changes + id: changes + shell: bash + run: | + echo "count=`git diff --quiet docker/${{ matrix.package }}/requirements; echo $?`" >> $GITHUB_OUTPUT + echo "files=`git ls-files --exclude-standard --others docker/${{ matrix.package }}/requirements | wc -l | xargs`" >> $GITHUB_OUTPUT + - name: commit & push changes + if: steps.changes.outputs.count > 0 || steps.changes.outputs.files > 0 + shell: bash + run: | + git config user.name github-actions + git config user.email github-actions@github.com + git add docker/${{ matrix.package }}/requirements + git commit -m "update dependencies for ${{ matrix.package }} (${{ matrix.os }}/py${{ matrix.python-version }})" + git push -f origin ${{ github.ref_name }}:auto-dependency-upgrades-${{ matrix.package }}-${{ matrix.os }}-py${{ matrix.python-version }} + + pull_request: + name: Merge all branches and open PR + runs-on: ubuntu-latest + needs: [upgrade] + strategy: + matrix: + python-version: ["3.10"] + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + submodules: 'recursive' + token: ${{ secrets.PAT }} + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/docker/python/requirements.txt' + + - name: make new branch + run: | + git config --global user.name github-actions + git config --global user.email github-actions@github.com + git checkout -b auto-dependency-upgrades + + - name: detect auto-upgrade-dependency branches + id: upgrade_changes + run: echo "count=$(git branch -r | grep auto-dependency-upgrades- | wc -l | xargs)" >> $GITHUB_OUTPUT + + - name: merge all auto-dependency-upgrades branches + if: steps.upgrade_changes.outputs.count > 0 + run: | + git branch -r | grep auto-dependency-upgrades- | xargs -I {} git merge {} + git rebase ${GITHUB_REF##*/} + git push -f origin auto-dependency-upgrades + git branch -r | grep auto-dependency-upgrades- | cut -d/ -f2 | xargs -I {} git push origin :{} + + - name: submodule updates + run: git submodule update --remote + + - name: compile docker/python dependencies + shell: bash + run: | + cd docker + python${{ matrix.python-version }} -m pip install --upgrade pip pip-tools + setup_packages="emmet/emmet-api emmet/emmet-core emmet/emmet-builders MPContribs/mpcontribs-api MPContribs/mpcontribs-client MPContribs/mpcontribs-portal" + pip_input=""; for i in `echo $setup_packages`; do pip_input="$pip_input $i/setup.py"; done + python${{ matrix.python-version }} -m piptools compile -q --resolver=backtracking --upgrade web/pyproject.toml MPContribs/mpcontribs-kernel-gateway/requirements.in `echo $pip_input` -o python/requirements-full.txt + grep -h -E "numpy==|scipy==|matplotlib==|pandas==" python/requirements-full.txt > python/requirements.txt + rm python/requirements-full.txt + python${{ matrix.python-version }} -m piptools compile -q --resolver=backtracking --upgrade python/requirements.txt web/pyproject.toml -o web/requirements/deployment.txt + cd web && git checkout main && git add requirements/deployment.txt && git commit -m "upgrade dependencies for deployment" && git push + cd - + python${{ matrix.python-version }} -m piptools compile -q --resolver=backtracking --upgrade python/requirements.txt MPContribs/mpcontribs-kernel-gateway/requirements.in -o MPContribs/mpcontribs-kernel-gateway/requirements/deployment.txt + for i in `echo $setup_packages`; do + python${{ matrix.python-version }} -m piptools compile -q --resolver=backtracking --upgrade python/requirements.txt $i/setup.py -o $i/requirements/deployment.txt + done + cd emmet && git checkout main && git add emmet-*/requirements/deployment.txt && git commit -m "upgrade dependencies for deployment" && git push + cd - + cd MPContribs && git checkout master && git add mpcontribs-*/requirements/deployment.txt && git commit -m "upgrade dependencies for deployment" && git push + cd - + + - name: Detect changes + id: changes + shell: bash + run: | + echo "count=`git diff --quiet --ignore-submodules -- . ':!docker/python'; echo $?`" >> $GITHUB_OUTPUT + echo "countReq=`git diff --quiet docker/python/requirements.txt; echo $?`" >> $GITHUB_OUTPUT + echo "files=$(git ls-files --exclude-standard --others | wc -l | xargs)" >> $GITHUB_OUTPUT + + - name: commit & push changes + if: steps.changes.outputs.count > 0 || steps.changes.outputs.files > 0 || steps.changes.outputs.countReq > 0 + shell: bash + run: | + git add . + git commit -m "auto dependency upgrades" + git push origin auto-dependency-upgrades + + - name: create and push tag to trigger python base image action + if: steps.changes.outputs.countReq > 0 + shell: bash + run: | + ver=`grep FROM docker/python/Dockerfile | cut -d: -f2 | cut -d- -f1` + prefix=${ver%.*}${ver##*.} + patch=`git tag -l "python-${prefix}.*" | sort -V | tail -1 | cut -d. -f3` + [[ -z "$patch" ]] && tag="python-${prefix}.0" || tag="python-${prefix}.$((++patch))" + echo $tag + git tag $tag + git push --tags + + - name: Open pull request if needed + if: steps.changes.outputs.count > 0 || steps.changes.outputs.files > 0 || steps.changes.outputs.countReq > 0 + env: + GITHUB_TOKEN: ${{ secrets.PAT }} + # Only open a PR if the branch is not attached to an existing one + run: | + PR=$(gh pr list --head auto-dependency-upgrades --json number -q '.[0].number') + if [ -z $PR ]; then + gh pr create \ + --head auto-dependency-upgrades \ + --title "Automated dependency upgrades" \ + --body "Full log: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}" + else + echo "Pull request already exists, won't create a new one." + fi From 3b8af80e489038d8821246b6e29605673b4a9022 Mon Sep 17 00:00:00 2001 From: Ryan Kingsbury Date: Mon, 31 Jul 2023 12:57:03 -0400 Subject: [PATCH 07/11] add Filestore to __init__.py --- src/maggma/stores/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/maggma/stores/__init__.py b/src/maggma/stores/__init__.py index df80c8825..baece8e81 100644 --- a/src/maggma/stores/__init__.py +++ b/src/maggma/stores/__init__.py @@ -9,6 +9,7 @@ from maggma.stores.aws import S3Store from maggma.stores.azure import AzureBlobStore from maggma.stores.compound_stores import ConcatStore, JointStore +from maggma.stores.file_store import FileStore from maggma.stores.gridfs import GridFSStore from maggma.stores.mongolike import ( JSONStore, @@ -29,6 +30,7 @@ "ConcatStore", "JointStore", "GridFSStore", + "FileStore", "JSONStore", "MemoryStore", "MongoStore", From e4d8d0f948397151242434381f807554a82c7ae5 Mon Sep 17 00:00:00 2001 From: Ryan Kingsbury Date: Mon, 31 Jul 2023 10:55:48 -0400 Subject: [PATCH 08/11] run pre-commit, fix errs and spelling --- .github/workflows/testing.yml | 2 +- docs/concepts.md | 4 +- docs/getting_started/advanced_builder.md | 2 +- docs/getting_started/group_builder.md | 6 +- docs/getting_started/running_builders.md | 4 +- docs/getting_started/simple_builder.md | 4 +- docs/getting_started/stores.md | 2 +- docs/getting_started/using_file_store.md | 6 +- pyproject.toml | 4 +- src/maggma/__init__.py | 1 - src/maggma/api/API.py | 6 +- src/maggma/api/models.py | 15 +-- src/maggma/api/query_operator/core.py | 2 +- src/maggma/api/query_operator/dynamic.py | 48 ++----- src/maggma/api/query_operator/pagination.py | 4 +- .../api/query_operator/sparse_fields.py | 19 +-- src/maggma/api/query_operator/submission.py | 11 +- src/maggma/api/resource/__init__.py | 7 +- src/maggma/api/resource/core.py | 2 - src/maggma/api/resource/post_resource.py | 27 +--- src/maggma/api/resource/read_resource.py | 30 +---- src/maggma/api/resource/s3_url.py | 14 +- src/maggma/api/resource/submission.py | 66 +++------ src/maggma/api/resource/utils.py | 4 +- src/maggma/api/utils.py | 29 ++-- src/maggma/builders/group_builder.py | 41 ++---- src/maggma/builders/map_builder.py | 22 +-- src/maggma/builders/projection_builder.py | 52 +++----- src/maggma/cli/__init__.py | 27 +--- src/maggma/cli/distributed.py | 43 ++---- src/maggma/cli/multiprocessing.py | 16 +-- src/maggma/cli/rabbitmq.py | 65 ++++----- src/maggma/cli/serial.py | 10 +- src/maggma/cli/source_loader.py | 37 ++--- src/maggma/core/builder.py | 11 +- src/maggma/core/store.py | 57 +++----- src/maggma/core/validator.py | 1 - src/maggma/stores/__init__.py | 15 +-- src/maggma/stores/advanced_stores.py | 56 +++----- src/maggma/stores/aws.py | 84 ++++-------- src/maggma/stores/azure.py | 83 ++++-------- src/maggma/stores/compound_stores.py | 58 +++----- src/maggma/stores/file_store.py | 46 +++---- src/maggma/stores/gridfs.py | 90 ++++--------- src/maggma/stores/mongolike.py | 126 +++++------------- src/maggma/stores/shared_stores.py | 92 +++---------- src/maggma/utils.py | 12 +- src/maggma/validators.py | 11 +- tests/api/test_aggregation_resource.py | 21 +-- tests/api/test_api.py | 25 ++-- tests/api/test_post_resource.py | 9 +- tests/api/test_query_operators.py | 24 +--- tests/api/test_read_resource.py | 25 ++-- tests/api/test_s3_url_resource.py | 3 +- tests/api/test_submission_resource.py | 11 +- tests/api/test_utils.py | 7 +- tests/builders/test_copy_builder.py | 16 +-- tests/builders/test_group_builder.py | 12 +- tests/builders/test_projection_builder.py | 38 +++--- tests/cli/builder_for_test.py | 2 +- tests/cli/test_distributed.py | 30 ++--- tests/cli/test_init.py | 78 +++-------- tests/cli/test_multiprocessing.py | 20 +-- tests/cli/test_serial.py | 2 +- tests/conftest.py | 22 +-- tests/stores/test_advanced_stores.py | 110 +++++---------- tests/stores/test_aws.py | 40 ++---- tests/stores/test_azure.py | 42 ++---- tests/stores/test_compound_stores.py | 18 +-- tests/stores/test_file_store.py | 37 ++--- tests/stores/test_gridfs.py | 82 +++--------- tests/stores/test_mongolike.py | 87 +++++------- tests/stores/test_shared_stores.py | 49 +++---- tests/stores/test_ssh_tunnel.py | 13 +- tests/test_utils.py | 24 ++-- tests/test_validator.py | 14 +- 76 files changed, 681 insertions(+), 1554 deletions(-) diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index de150217e..c9296832a 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -29,7 +29,7 @@ jobs: run: | pip install pre-commit pre-commit run - + test: needs: lint services: diff --git a/docs/concepts.md b/docs/concepts.md index ca87af11b..f46b59e52 100644 --- a/docs/concepts.md +++ b/docs/concepts.md @@ -12,7 +12,7 @@ s2 -- Builder 3-->s4(Store 4) ## Store -A major challenge in building scalable data piplines is dealing with all the different types of data sources out there. Maggma's `Store` class provides a consistent, unified interface for querying data from arbitrary data +A major challenge in building scalable data pipelines is dealing with all the different types of data sources out there. Maggma's `Store` class provides a consistent, unified interface for querying data from arbitrary data sources. It was originally built around MongoDB, so it's interface closely resembles `PyMongo` syntax. However, Maggma makes it possible to use that same syntax to query other types of databases, such as Amazon S3, GridFS, or even files on disk. @@ -34,4 +34,4 @@ Both `get_items` and `update_targets` can perform IO (input/output) to the data Another challenge in building complex data-transformation codes is keeping track of all the settings necessary to make some output database. One bad solution is to hard-code these settings, but then any modification is difficult to keep track of. -Maggma solves this by putting the configuration with the pipeline definition in JSON or YAML files. This is done using the `MSONable` pattern, which requires that any Maggma object (the databases and transformation steps) can convert itself to a python dictionary with it's configuration parameters in a process called serialization. These dictionaries can then be converted back to the origianl Maggma object without having to know what class it belonged. `MSONable` does this by injecting in `@class` and `@module` keys that tell it where to find the original python code for that Maggma object. +Maggma solves this by putting the configuration with the pipeline definition in JSON or YAML files. This is done using the `MSONable` pattern, which requires that any Maggma object (the databases and transformation steps) can convert itself to a python dictionary with it's configuration parameters in a process called serialization. These dictionaries can then be converted back to the original Maggma object without having to know what class it belonged. `MSONable` does this by injecting in `@class` and `@module` keys that tell it where to find the original python code for that Maggma object. diff --git a/docs/getting_started/advanced_builder.md b/docs/getting_started/advanced_builder.md index e558ffcd4..74dd51507 100644 --- a/docs/getting_started/advanced_builder.md +++ b/docs/getting_started/advanced_builder.md @@ -42,4 +42,4 @@ Since `maggma` is designed around Mongo style data sources and sinks, building i `maggma` implements templates for builders that have many of these advanced features listed above: - [MapBuilder](map_builder.md) Creates one-to-one document mapping of items in the source Store to the transformed documents in the target Store. -- [GroupBuilder](group_builder.md) Creates many-to-one document mapping of items in the source Store to transformed documents in the traget Store +- [GroupBuilder](group_builder.md) Creates many-to-one document mapping of items in the source Store to transformed documents in the target Store diff --git a/docs/getting_started/group_builder.md b/docs/getting_started/group_builder.md index 5b520cd8c..daf862dd4 100644 --- a/docs/getting_started/group_builder.md +++ b/docs/getting_started/group_builder.md @@ -56,7 +56,7 @@ class ResupplyBuilder(GroupBuilder): super().__init__(source=inventory, target=resupply, grouping_properties=["type"], **kwargs) ``` -Note that unlike the previous `MapBuilder` example, we didn't call the source and target stores as such. Providing more usefull names is a good idea in writing builders to make it clearer what the underlying data should look like. +Note that unlike the previous `MapBuilder` example, we didn't call the source and target stores as such. Providing more useful names is a good idea in writing builders to make it clearer what the underlying data should look like. `GroupBuilder` inherits from `MapBuilder` so it has the same configurational parameters. @@ -65,7 +65,7 @@ Note that unlike the previous `MapBuilder` example, we didn't call the source an - store_process_timeout: adds the process time into the target document for profiling - retry_failed: retries running the process function on previously failed documents -One parameter that doens't work in `GroupBuilder` is `delete_orphans`, since the Many-to-One relationshop makes determining orphaned documents very difficult. +One parameter that doesn't work in `GroupBuilder` is `delete_orphans`, since the Many-to-One relationshop makes determining orphaned documents very difficult. Finally let's get to the hard part which is running our function. We do this by defining `unary_function` @@ -81,4 +81,4 @@ Finally let's get to the hard part which is running our function. We do this by return {"resupply": resupply} ``` -Just as in `MapBuilder`, we're not returning all the extra information typically kept in the originally item. Normally, we would have to write code that copies over the source `key` and convert it to the target `key`. Same goes for the `last_updated_field`. `GroupBuilder` takes care of this, while also recording errors, processing time, and the Builder version.`GroupBuilder` also keeps a plural version of the `source.key` field, so in this example, all the `name` values wil be put together and kept in `names` +Just as in `MapBuilder`, we're not returning all the extra information typically kept in the originally item. Normally, we would have to write code that copies over the source `key` and convert it to the target `key`. Same goes for the `last_updated_field`. `GroupBuilder` takes care of this, while also recording errors, processing time, and the Builder version.`GroupBuilder` also keeps a plural version of the `source.key` field, so in this example, all the `name` values will be put together and kept in `names` diff --git a/docs/getting_started/running_builders.md b/docs/getting_started/running_builders.md index 298662a04..14ce3cae9 100644 --- a/docs/getting_started/running_builders.md +++ b/docs/getting_started/running_builders.md @@ -15,7 +15,7 @@ my_builder = MultiplyBuilder(source_store,target_store,multiplier=3) my_builder.run() ``` -A better way to run this builder would be to use the `mrun` command line tool. Since evrything in `maggma` is MSONable, we can use `monty` to dump the builders into a JSON file: +A better way to run this builder would be to use the `mrun` command line tool. Since everything in `maggma` is MSONable, we can use `monty` to dump the builders into a JSON file: ``` python from monty.serialization import dumpfn @@ -29,7 +29,7 @@ Then we can run the builder using `mrun`: mrun my_builder.json ``` -`mrun` has a number of usefull options: +`mrun` has a number of useful options: ``` shell mrun --help diff --git a/docs/getting_started/simple_builder.md b/docs/getting_started/simple_builder.md index 59c3271cd..62394a0cd 100644 --- a/docs/getting_started/simple_builder.md +++ b/docs/getting_started/simple_builder.md @@ -52,7 +52,7 @@ The `__init__` for a builder can have any set of parameters. Generally, you want Python type annotations provide a really nice way of documenting the types we expect and being able to later type check using `mypy`. We defined the type for `source` and `target` as `Store` since we only care that implements that pattern. How exactly these `Store`s operate doesn't concern us here. -Note that the `__init__` arguments: `source`, `target`, `multiplier`, and `kwargs` get saved as attributess: +Note that the `__init__` arguments: `source`, `target`, `multiplier`, and `kwargs` get saved as attributes: ``` python self.source = source @@ -243,4 +243,4 @@ Then we can define a prechunk method that modifies the `Builder` dict in place t } ``` -When distributed processing runs, it will modify the `Builder` dictionary in place by the prechunk dictionary. In this case, each builder distribute to a worker will get a modified `query` parameter that only runs on a subset of all posible keys. +When distributed processing runs, it will modify the `Builder` dictionary in place by the prechunk dictionary. In this case, each builder distribute to a worker will get a modified `query` parameter that only runs on a subset of all possible keys. diff --git a/docs/getting_started/stores.md b/docs/getting_started/stores.md index 9f2f14410..9ef2edf26 100644 --- a/docs/getting_started/stores.md +++ b/docs/getting_started/stores.md @@ -11,7 +11,7 @@ Current working and tested `Store` include: - `MongoStore`: interfaces to a MongoDB Collection - `MemoryStore`: just a Store that exists temporarily in memory - `JSONStore`: builds a MemoryStore and then populates it with the contents of the given JSON files -- `FileStore`: query and add metadata to files stored on disk as if they were in a databsae +- `FileStore`: query and add metadata to files stored on disk as if they were in a database - `GridFSStore`: interfaces to GridFS collection in MongoDB - `S3Store`: provides an interface to an S3 Bucket either on AWS or self-hosted solutions ([additional documentation](advanced_stores.md)) - `ConcatStore`: concatenates several Stores together so they look like one Store diff --git a/docs/getting_started/using_file_store.md b/docs/getting_started/using_file_store.md index 3fbfa3385..ca46e2369 100644 --- a/docs/getting_started/using_file_store.md +++ b/docs/getting_started/using_file_store.md @@ -80,7 +80,7 @@ and for associating custom metadata (See ["Adding Metadata"](#adding-metadata) b ## Connecting and querying As with any `Store`, you have to `connect()` before you can query any data from a `FileStore`. After that, you can use `query_one()` to examine a single document or -`query()` to return an interator of matching documents. For example, let's print the +`query()` to return an iterator of matching documents. For example, let's print the parent directory of each of the files named "input.in" in our example `FileStore`: ```python @@ -142,7 +142,7 @@ fs.add_metadata({"name":"input.in"}, {"tags":["preliminary"]}) ### Automatic metadata -You can even define a function to automatically crate metadata from file or directory names. For example, if you prefix all your files with datestamps (e.g., '2022-05-07_experiment.csv'), you can write a simple string parsing function to +You can even define a function to automatically create metadata from file or directory names. For example, if you prefix all your files with datestamps (e.g., '2022-05-07_experiment.csv'), you can write a simple string parsing function to extract information from any key in a `FileStore` record and pass the function as an argument to `add_metadata`. For example, to extract the date from files named like '2022-05-07_experiment.csv' @@ -195,7 +195,7 @@ maggma.core.store.StoreError: (StoreError(...), 'Warning! This command is about Now that you can access your files on disk via a `FileStore`, it's time to write a `Builder` to read and process the data (see [Writing a Builder](simple_builder.md)). Keep in mind that `get_items` will return documents like the one shown in (#creating-the-filestore). You can then use `process_items` to -- Create strucured data from the `contents` +- Create structured data from the `contents` - Open the file for reading using a custom piece of code - etc. diff --git a/pyproject.toml b/pyproject.toml index 4287753e9..e432b502a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,7 +83,6 @@ isort.split-on-trailing-comma = false "src/maggma/api/*" = ["B008", "B021", "RET505", "RET506"] "tests/api/*" = ["B017", "B018"] "src/maggma/cli/*" = ["EXE001"] # triggered by ! at top of file -"src/maggma/api/utils.py" = ["I001"] # to allow unsorted import block [tool.pytest.ini_options] addopts = "--color=yes -p no:warnings --import-mode=importlib --durations=30" @@ -98,4 +97,5 @@ explicit_package_bases = true no_implicit_optional = false [tool.codespell] -ignore-words-list = "ot" +ignore-words-list = "ot,nin" +skip = 'docs/CHANGELOG.md,tests/test_files/*' diff --git a/src/maggma/__init__.py b/src/maggma/__init__.py index 00d0f41a5..6e44ae29b 100644 --- a/src/maggma/__init__.py +++ b/src/maggma/__init__.py @@ -1,4 +1,3 @@ -# coding: utf-8 """ Primary Maggma module """ from pkg_resources import DistributionNotFound, get_distribution diff --git a/src/maggma/api/API.py b/src/maggma/api/API.py index 51042ce29..ff2c82f21 100644 --- a/src/maggma/api/API.py +++ b/src/maggma/api/API.py @@ -23,8 +23,8 @@ def __init__( version: str = "v0.0.0", debug: bool = False, heartbeat_meta: Optional[Dict] = None, - description: str = None, - tags_meta: List[Dict] = None, + description: Optional[str] = None, + tags_meta: Optional[List[Dict]] = None, ): """ Args: @@ -33,7 +33,7 @@ def __init__( version: the version for this API debug: turns debug on in FastAPI heartbeat_meta: dictionary of additional metadata to include in the heartbeat response - description: decription of the API to be used in the generated docs + description: description of the API to be used in the generated docs tags_meta: descriptions of tags to be used in the generated docs """ self.title = title diff --git a/src/maggma/api/models.py b/src/maggma/api/models.py index 100b34ebe..1cef13c64 100644 --- a/src/maggma/api/models.py +++ b/src/maggma/api/models.py @@ -20,8 +20,7 @@ class Meta(BaseModel): api_version: str = Field( __version__, - description="a string containing the version of the Materials API " - "implementation, e.g. v0.9.5", + description="a string containing the version of the Materials API implementation, e.g. v0.9.5", ) time_stamp: datetime = Field( @@ -29,9 +28,7 @@ class Meta(BaseModel): default_factory=datetime.utcnow, ) - total_doc: Optional[int] = Field( - None, description="the total number of documents available for this query", ge=0 - ) + total_doc: Optional[int] = Field(None, description="the total number of documents available for this query", ge=0) class Config: extra = "allow" @@ -56,9 +53,7 @@ class Response(GenericModel, Generic[DataT]): """ data: Optional[List[DataT]] = Field(None, description="List of returned data") - errors: Optional[List[Error]] = Field( - None, description="Any errors on processing this query" - ) + errors: Optional[List[Error]] = Field(None, description="Any errors on processing this query") meta: Optional[Meta] = Field(None, description="Extra information for the query") @validator("errors", always=True) @@ -92,8 +87,6 @@ class S3URLDoc(BaseModel): description="Pre-signed download URL", ) - requested_datetime: datetime = Field( - ..., description="Datetime for when URL was requested" - ) + requested_datetime: datetime = Field(..., description="Datetime for when URL was requested") expiry_datetime: datetime = Field(..., description="Expiry datetime of the URL") diff --git a/src/maggma/api/query_operator/core.py b/src/maggma/api/query_operator/core.py index 49c221d07..46827d6c8 100644 --- a/src/maggma/api/query_operator/core.py +++ b/src/maggma/api/query_operator/core.py @@ -8,7 +8,7 @@ class QueryOperator(MSONable, metaclass=ABCMeta): """ - Base Query Operator class for defining powerfull query language + Base Query Operator class for defining powerful query language in the Materials API """ diff --git a/src/maggma/api/query_operator/dynamic.py b/src/maggma/api/query_operator/dynamic.py index 4d6b181a1..39c12d1fb 100644 --- a/src/maggma/api/query_operator/dynamic.py +++ b/src/maggma/api/query_operator/dynamic.py @@ -26,9 +26,7 @@ def __init__( self.excluded_fields = excluded_fields all_fields: Dict[str, ModelField] = model.__fields__ - param_fields = fields or list( - set(all_fields.keys()) - set(excluded_fields or []) - ) + param_fields = fields or list(set(all_fields.keys()) - set(excluded_fields or [])) # Convert the fields into operator tuples ops = [ @@ -49,9 +47,7 @@ def query(**kwargs) -> STORE_PARAMS: try: criteria.append(self.mapping[k](v)) except KeyError: - raise KeyError( - f"Cannot find key {k} in current query to database mapping" - ) + raise KeyError(f"Cannot find key {k} in current query to database mapping") final_crit = {} for entry in criteria: @@ -74,18 +70,15 @@ def query(**kwargs) -> STORE_PARAMS: for op in ops ] - setattr(query, "__signature__", inspect.Signature(signatures)) + query.__signature__ = inspect.Signature(signatures) self.query = query # type: ignore def query(self): "Stub query function for abstract class" - pass @abstractmethod - def field_to_operator( - self, name: str, field: ModelField - ) -> List[Tuple[str, Any, Query, Callable[..., Dict]]]: + def field_to_operator(self, name: str, field: ModelField) -> List[Tuple[str, Any, Query, Callable[..., Dict]]]: """ Converts a PyDantic ModelField into a Tuple with the - query param name, @@ -93,7 +86,6 @@ def field_to_operator( - FastAPI Query object, - and callable to convert the value into a query dict """ - pass @classmethod def from_dict(cls, d): @@ -115,9 +107,7 @@ def as_dict(self) -> Dict: class NumericQuery(DynamicQueryOperator): "Query Operator to enable searching on numeric fields" - def field_to_operator( - self, name: str, field: ModelField - ) -> List[Tuple[str, Any, Query, Callable[..., Dict]]]: + def field_to_operator(self, name: str, field: ModelField) -> List[Tuple[str, Any, Query, Callable[..., Dict]]]: """ Converts a PyDantic ModelField into a Tuple with the query_param name, @@ -181,11 +171,7 @@ def field_to_operator( default=None, description=f"Query for {title} being any of these values. Provide a comma separated list.", ), - lambda val: { - f"{field.name}": { - "$in": [int(entry.strip()) for entry in val.split(",")] - } - }, + lambda val: {f"{field.name}": {"$in": [int(entry.strip()) for entry in val.split(",")]}}, ), ( f"{field.name}_neq_any", @@ -195,11 +181,7 @@ def field_to_operator( description=f"Query for {title} being not any of these values. \ Provide a comma separated list.", ), - lambda val: { - f"{field.name}": { - "$nin": [int(entry.strip()) for entry in val.split(",")] - } - }, + lambda val: {f"{field.name}": {"$nin": [int(entry.strip()) for entry in val.split(",")]}}, ), ] ) @@ -210,9 +192,7 @@ def field_to_operator( class StringQueryOperator(DynamicQueryOperator): "Query Operator to enable searching on numeric fields" - def field_to_operator( - self, name: str, field: ModelField - ) -> List[Tuple[str, Any, Query, Callable[..., Dict]]]: + def field_to_operator(self, name: str, field: ModelField) -> List[Tuple[str, Any, Query, Callable[..., Dict]]]: """ Converts a PyDantic ModelField into a Tuple with the query_param name, @@ -253,11 +233,7 @@ def field_to_operator( default=None, description=f"Query for {title} being any of these values. Provide a comma separated list.", ), - lambda val: { - f"{field.name}": { - "$in": [entry.strip() for entry in val.split(",")] - } - }, + lambda val: {f"{field.name}": {"$in": [entry.strip() for entry in val.split(",")]}}, ), ( f"{field.name}_neq_any", @@ -266,11 +242,7 @@ def field_to_operator( default=None, description=f"Query for {title} being not any of these values. Provide a comma separated list", ), - lambda val: { - f"{field.name}": { - "$nin": [entry.strip() for entry in val.split(",")] - } - }, + lambda val: {f"{field.name}": {"$nin": [entry.strip() for entry in val.split(",")]}}, ), ] diff --git a/src/maggma/api/query_operator/pagination.py b/src/maggma/api/query_operator/pagination.py index 14834c136..d6b2151e6 100644 --- a/src/maggma/api/query_operator/pagination.py +++ b/src/maggma/api/query_operator/pagination.py @@ -35,8 +35,7 @@ def query( ), _limit: int = Query( default_limit, - description="Max number of entries to return in a single query." - f" Limited to {max_limit}.", + description=f"Max number of entries to return in a single query. Limited to {max_limit}.", ), ) -> STORE_PARAMS: """ @@ -82,7 +81,6 @@ def query( def query(self): "Stub query function for abstract class" - pass def meta(self) -> Dict: """ diff --git a/src/maggma/api/query_operator/sparse_fields.py b/src/maggma/api/query_operator/sparse_fields.py index 297143785..3e55cec1d 100644 --- a/src/maggma/api/query_operator/sparse_fields.py +++ b/src/maggma/api/query_operator/sparse_fields.py @@ -9,9 +9,7 @@ class SparseFieldsQuery(QueryOperator): - def __init__( - self, model: Type[BaseModel], default_fields: Optional[List[str]] = None - ): + def __init__(self, model: Type[BaseModel], default_fields: Optional[List[str]] = None): """ Args: model: PyDantic Model that represents the underlying data source @@ -23,14 +21,12 @@ def __init__( model_name = self.model.__name__ # type: ignore model_fields = list(self.model.__fields__.keys()) - self.default_fields = ( - model_fields if default_fields is None else list(default_fields) - ) + self.default_fields = model_fields if default_fields is None else list(default_fields) def query( _fields: str = Query( None, - description=f"Fields to project from {str(model_name)} as a list of comma seperated strings.\ + description=f"Fields to project from {model_name!s} as a list of comma separated strings.\ Fields include: `{'` `'.join(model_fields)}`", ), _all_fields: bool = Query(False, description="Include all fields."), @@ -39,9 +35,7 @@ def query( Pagination parameters for the API Endpoint """ - properties = ( - _fields.split(",") if isinstance(_fields, str) else self.default_fields - ) + properties = _fields.split(",") if isinstance(_fields, str) else self.default_fields if _all_fields: properties = model_fields @@ -51,7 +45,6 @@ def query( def query(self): "Stub query function for abstract class" - pass def meta(self) -> Dict: """ @@ -77,9 +70,7 @@ def from_dict(cls, d): if isinstance(model, str): model = dynamic_import(model) - assert issubclass( - model, BaseModel - ), "The resource model has to be a PyDantic Model" + assert issubclass(model, BaseModel), "The resource model has to be a PyDantic Model" d["model"] = model return cls(**d) diff --git a/src/maggma/api/query_operator/submission.py b/src/maggma/api/query_operator/submission.py index 759a1a87b..66c4d0104 100644 --- a/src/maggma/api/query_operator/submission.py +++ b/src/maggma/api/query_operator/submission.py @@ -16,9 +16,7 @@ def __init__(self, status_enum): self.status_enum = status_enum def query( - state: Optional[status_enum] = Query( - None, description="Latest status of the submission" - ), + state: Optional[status_enum] = Query(None, description="Latest status of the submission"), last_updated: Optional[datetime] = Query( None, description="Minimum datetime of status update for submission", @@ -31,11 +29,7 @@ def query( crit.update(s_dict) if last_updated: - l_dict = { - "$expr": { - "$gt": [{"$arrayElemAt": ["$last_updated", -1]}, last_updated] - } - } + l_dict = {"$expr": {"$gt": [{"$arrayElemAt": ["$last_updated", -1]}, last_updated]}} crit.update(l_dict) if state and last_updated: @@ -47,4 +41,3 @@ def query( def query(self): "Stub query function for abstract class" - pass diff --git a/src/maggma/api/resource/__init__.py b/src/maggma/api/resource/__init__.py index aafd982b1..aef4513ba 100644 --- a/src/maggma/api/resource/__init__.py +++ b/src/maggma/api/resource/__init__.py @@ -1,6 +1,7 @@ -from maggma.api.resource.core import Resource -from maggma.api.resource.core import HintScheme -from maggma.api.resource.core import HeaderProcessor +# isort: off +from maggma.api.resource.core import HeaderProcessor, HintScheme, Resource + +# isort: on from maggma.api.resource.aggregation import AggregationResource from maggma.api.resource.post_resource import PostOnlyResource diff --git a/src/maggma/api/resource/core.py b/src/maggma/api/resource/core.py index 00de1c0a2..5ff5af18a 100644 --- a/src/maggma/api/resource/core.py +++ b/src/maggma/api/resource/core.py @@ -38,7 +38,6 @@ def on_startup(self): """ Callback to perform some work on resource initialization """ - pass @abstractmethod def prepare_endpoint(self): @@ -46,7 +45,6 @@ def prepare_endpoint(self): Internal method to prepare the endpoint by setting up default handlers for routes. """ - pass def setup_redirect(self): @self.router.get("$", include_in_schema=False) diff --git a/src/maggma/api/resource/post_resource.py b/src/maggma/api/resource/post_resource.py index 356447e00..713888843 100644 --- a/src/maggma/api/resource/post_resource.py +++ b/src/maggma/api/resource/post_resource.py @@ -86,20 +86,14 @@ def search(**queries: Dict[str, STORE_PARAMS]) -> Dict: queries.pop("temp_response") # type: ignore query_params = [ - entry - for _, i in enumerate(self.query_operators) - for entry in signature(i.query).parameters + entry for _, i in enumerate(self.query_operators) for entry in signature(i.query).parameters ] - overlap = [ - key for key in request.query_params.keys() if key not in query_params - ] + overlap = [key for key in request.query_params if key not in query_params] if any(overlap): raise HTTPException( status_code=400, - detail="Request contains query parameters which cannot be used: {}".format( - ", ".join(overlap) - ), + detail="Request contains query parameters which cannot be used: {}".format(", ".join(overlap)), ) query: Dict[Any, Any] = merge_queries(list(queries.values())) # type: ignore @@ -110,11 +104,7 @@ def search(**queries: Dict[str, STORE_PARAMS]) -> Dict: try: with query_timeout(self.timeout): count = self.store.count( # type: ignore - **{ - field: query[field] - for field in query - if field in ["criteria", "hint"] - } + **{field: query[field] for field in query if field in ["criteria", "hint"]} ) if isinstance(self.store, S3Store): @@ -125,11 +115,7 @@ def search(**queries: Dict[str, STORE_PARAMS]) -> Dict: data = list( self.store._collection.aggregate( pipeline, - **{ - field: query[field] - for field in query - if field in ["hint"] - }, + **{field: query[field] for field in query if field in ["hint"]}, ) ) except (NetworkTimeout, PyMongoError) as e: @@ -152,8 +138,7 @@ def search(**queries: Dict[str, STORE_PARAMS]) -> Dict: operator_meta.update(operator.meta()) meta = Meta(total_doc=count) - response = {"data": data, "meta": {**meta.dict(), **operator_meta}} - return response + return {"data": data, "meta": {**meta.dict(), **operator_meta}} self.router.post( self.sub_path, diff --git a/src/maggma/api/resource/read_resource.py b/src/maggma/api/resource/read_resource.py index eb5fb522e..d331b7de5 100644 --- a/src/maggma/api/resource/read_resource.py +++ b/src/maggma/api/resource/read_resource.py @@ -112,9 +112,7 @@ def build_get_by_key(self): model_name = self.model.__name__ if self.key_fields is None: - field_input = SparseFieldsQuery( - self.model, [self.store.key, self.store.last_updated_field] - ).query + field_input = SparseFieldsQuery(self.model, [self.store.key, self.store.last_updated_field]).query else: def field_input(): @@ -131,7 +129,7 @@ def get_by_key( _fields: STORE_PARAMS = Depends(field_input), ): f""" - Get's a document by the primary key in the store + Gets a document by the primary key in the store Args: {key_name}: the id of a single {model_name} @@ -197,14 +195,10 @@ def search(**queries: Dict[str, STORE_PARAMS]) -> Union[Dict, Response]: temp_response: Response = queries.pop("temp_response") # type: ignore query_params = [ - entry - for _, i in enumerate(self.query_operators) - for entry in signature(i.query).parameters + entry for _, i in enumerate(self.query_operators) for entry in signature(i.query).parameters ] - overlap = [ - key for key in request.query_params.keys() if key not in query_params - ] + overlap = [key for key in request.query_params if key not in query_params] if any(overlap): if "limit" in overlap or "skip" in overlap: raise HTTPException( @@ -216,9 +210,7 @@ def search(**queries: Dict[str, STORE_PARAMS]) -> Union[Dict, Response]: else: raise HTTPException( status_code=400, - detail="Request contains query parameters which cannot be used: {}".format( - ", ".join(overlap) - ), + detail="Request contains query parameters which cannot be used: {}".format(", ".join(overlap)), ) query: Dict[Any, Any] = merge_queries(list(queries.values())) # type: ignore @@ -232,11 +224,7 @@ def search(**queries: Dict[str, STORE_PARAMS]) -> Union[Dict, Response]: try: with query_timeout(self.timeout): count = self.store.count( # type: ignore - **{ - field: query[field] - for field in query - if field in ["criteria", "hint"] - } + **{field: query[field] for field in query if field in ["criteria", "hint"]} ) if isinstance(self.store, S3Store): @@ -250,11 +238,7 @@ def search(**queries: Dict[str, STORE_PARAMS]) -> Union[Dict, Response]: data = list( self.store._collection.aggregate( pipeline, - **{ - field: query[field] - for field in query - if field in ["hint"] - }, + **{field: query[field] for field in query if field in ["hint"]}, ) ) diff --git a/src/maggma/api/resource/s3_url.py b/src/maggma/api/resource/s3_url.py index c3b838df4..19c6bbc4e 100644 --- a/src/maggma/api/resource/s3_url.py +++ b/src/maggma/api/resource/s3_url.py @@ -73,7 +73,7 @@ def get_by_key( ), ): f""" - Get's a document by the primary key in the store + Gets a document by the primary key in the store Args: {key_name}: the id of a single {model_name} @@ -92,9 +92,7 @@ def get_by_key( except ClientError: raise HTTPException( status_code=404, - detail="No object found for {} = {}".format( - self.store.key, key.split("/")[-1] - ), + detail="No object found for {} = {}".format(self.store.key, key.split("/")[-1]), ) # Get URL @@ -107,9 +105,7 @@ def get_by_key( except Exception: raise HTTPException( status_code=404, - detail="Problem obtaining URL for {} = {}".format( - self.store.key, key.split("/")[-1] - ), + detail="Problem obtaining URL for {} = {}".format(self.store.key, key.split("/")[-1]), ) requested_datetime = datetime.utcnow() @@ -124,9 +120,7 @@ def get_by_key( response = {"data": [item.dict()]} # type: ignore if self.disable_validation: - response = Response( # type: ignore - orjson.dumps(response, default=serialization_helper) - ) + response = Response(orjson.dumps(response, default=serialization_helper)) # type: ignore if self.header_processor is not None: self.header_processor.process_header(temp_response, request) diff --git a/src/maggma/api/resource/submission.py b/src/maggma/api/resource/submission.py index b70b2f0d5..14ace3901 100644 --- a/src/maggma/api/resource/submission.py +++ b/src/maggma/api/resource/submission.py @@ -66,9 +66,7 @@ def __init__( """ if isinstance(state_enum, Enum) and default_state not in [entry.value for entry in state_enum]: # type: ignore - raise RuntimeError( - "If data is stateful a state enum and valid default value must be provided" - ) + raise RuntimeError("If data is stateful a state enum and valid default value must be provided") self.state_enum = state_enum self.default_state = default_state @@ -177,9 +175,7 @@ def get_by_key( for operator in self.get_query_operators: # type: ignore item = operator.post_process(item, {}) - response = {"data": item} - - return response + return {"data": item} self.router.get( f"{self.get_sub_path}{{{key_name}}}/", @@ -205,15 +201,11 @@ def search(**queries: STORE_PARAMS): for entry in signature(i.query).parameters ] - overlap = [ - key for key in request.query_params.keys() if key not in query_params - ] + overlap = [key for key in request.query_params if key not in query_params] if any(overlap): raise HTTPException( status_code=404, - detail="Request contains query parameters which cannot be used: {}".format( - ", ".join(overlap) - ), + detail="Request contains query parameters which cannot be used: {}".format(", ".join(overlap)), ) self.store.connect(force_reset=True) @@ -221,11 +213,7 @@ def search(**queries: STORE_PARAMS): try: with query_timeout(self.timeout): count = self.store.count( # type: ignore - **{ - field: query[field] - for field in query - if field in ["criteria", "hint"] - } + **{field: query[field] for field in query if field in ["criteria", "hint"]} ) if isinstance(self.store, S3Store): data = list(self.store.query(**query)) # type: ignore @@ -235,11 +223,7 @@ def search(**queries: STORE_PARAMS): data = list( self.store._collection.aggregate( pipeline, - **{ - field: query[field] - for field in query - if field in ["hint"] - }, + **{field: query[field] for field in query if field in ["hint"]}, ) ) except (NetworkTimeout, PyMongoError) as e: @@ -260,9 +244,7 @@ def search(**queries: STORE_PARAMS): for operator in self.get_query_operators: # type: ignore data = operator.post_process(data, query) - response = {"data": data, "meta": meta.dict()} - - return response + return {"data": data, "meta": meta.dict()} self.router.get( self.get_sub_path, @@ -289,15 +271,11 @@ def post_data(**queries: STORE_PARAMS): for entry in signature(i.query).parameters ] - overlap = [ - key for key in request.query_params.keys() if key not in query_params - ] + overlap = [key for key in request.query_params if key not in query_params] if any(overlap): raise HTTPException( status_code=404, - detail="Request contains query parameters which cannot be used: {}".format( - ", ".join(overlap) - ), + detail="Request contains query parameters which cannot be used: {}".format(", ".join(overlap)), ) self.store.connect(force_reset=True) @@ -305,10 +283,7 @@ def post_data(**queries: STORE_PARAMS): # Check for duplicate entry if self.duplicate_fields_check: duplicate = self.store.query_one( - criteria={ - field: query["criteria"][field] - for field in self.duplicate_fields_check - } + criteria={field: query["criteria"][field] for field in self.duplicate_fields_check} ) if duplicate: @@ -334,13 +309,11 @@ def post_data(**queries: STORE_PARAMS): detail="Problem when trying to post data.", ) - response = { + return { "data": query["criteria"], "meta": "Submission successful", } - return response - self.router.post( self.post_sub_path, tags=self.tags, @@ -366,15 +339,11 @@ def patch_data(**queries: STORE_PARAMS): for entry in signature(i.query).parameters ] - overlap = [ - key for key in request.query_params.keys() if key not in query_params - ] + overlap = [key for key in request.query_params if key not in query_params] if any(overlap): raise HTTPException( status_code=404, - detail="Request contains query parameters which cannot be used: {}".format( - ", ".join(overlap) - ), + detail="Request contains query parameters which cannot be used: {}".format(", ".join(overlap)), ) self.store.connect(force_reset=True) @@ -382,10 +351,7 @@ def patch_data(**queries: STORE_PARAMS): # Check for duplicate entry if self.duplicate_fields_check: duplicate = self.store.query_one( - criteria={ - field: query["criteria"][field] - for field in self.duplicate_fields_check - } + criteria={field: query["criteria"][field] for field in self.duplicate_fields_check} ) if duplicate: @@ -416,13 +382,11 @@ def patch_data(**queries: STORE_PARAMS): detail="Problem when trying to patch data.", ) - response = { + return { "data": query["update"], "meta": "Submission successful", } - return response - self.router.patch( self.patch_sub_path, tags=self.tags, diff --git a/src/maggma/api/resource/utils.py b/src/maggma/api/resource/utils.py index f012ea5cb..19a9bd115 100644 --- a/src/maggma/api/resource/utils.py +++ b/src/maggma/api/resource/utils.py @@ -46,9 +46,7 @@ def generate_query_pipeline(query: dict, store: Store): if sorting: sort_dict = {"$sort": {}} # type: dict sort_dict["$sort"].update(query["sort"]) - sort_dict["$sort"].update( - {store.key: 1} - ) # Ensures sort by key is last in dict to fix determinacy + sort_dict["$sort"].update({store.key: 1}) # Ensures sort by key is last in dict to fix determinacy projection_dict = {"_id": 0} # Do not return _id by default diff --git a/src/maggma/api/utils.py b/src/maggma/api/utils.py index aae16155a..03531d8aa 100644 --- a/src/maggma/api/utils.py +++ b/src/maggma/api/utils.py @@ -48,12 +48,7 @@ def merge_queries(queries: List[STORE_PARAMS]) -> STORE_PARAMS: if "properties" in sub_query: properties.extend(sub_query["properties"]) - remainder = { - k: v - for query in queries - for k, v in query.items() - if k not in ["criteria", "properties"] - } + remainder = {k: v for query in queries for k, v in query.items() if k not in ["criteria", "properties"]} return { "criteria": criteria, @@ -79,7 +74,7 @@ def attach_signature(function: Callable, defaults: Dict, annotations: Dict): default=defaults.get(param, None), annotation=annotations.get(param, None), ) - for param in annotations.keys() + for param in annotations if param not in defaults.keys() ] @@ -90,12 +85,10 @@ def attach_signature(function: Callable, defaults: Dict, annotations: Dict): default=defaults.get(param, None), annotation=annotations.get(param, None), ) - for param in defaults.keys() + for param in defaults ] - setattr( - function, "__signature__", inspect.Signature(required_params + optional_params) - ) + function.__signature__ = inspect.Signature(required_params + optional_params) def api_sanitize( @@ -114,11 +107,9 @@ def api_sanitize( fields_to_leave: list of strings for model fields as "model__name__.field" """ - models = [ - model - for model in get_flat_models_from_model(pydantic_model) - if issubclass(model, BaseModel) - ] # type: List[Type[BaseModel]] + models: List[Type[BaseModel]] = [ + model for model in get_flat_models_from_model(pydantic_model) if issubclass(model, BaseModel) + ] fields_to_leave = fields_to_leave or [] fields_tuples = [f.split(".") for f in fields_to_leave] @@ -170,15 +161,13 @@ def validate_monty(cls, v): errors.append("@class") if len(errors) > 0: - raise ValueError( - "Missing Monty seriailzation fields in dictionary: {errors}" - ) + raise ValueError("Missing Monty seriailzation fields in dictionary: {errors}") return v else: raise ValueError(f"Must provide {cls.__name__} or MSONable dictionary") - setattr(monty_cls, "validate_monty", classmethod(validate_monty)) + monty_cls.validate_monty = classmethod(validate_monty) return monty_cls diff --git a/src/maggma/builders/group_builder.py b/src/maggma/builders/group_builder.py index a530462a8..01a6f1b25 100644 --- a/src/maggma/builders/group_builder.py +++ b/src/maggma/builders/group_builder.py @@ -1,4 +1,3 @@ -# coding: utf-8 """ Many-to-Many GroupBuilder """ @@ -69,7 +68,7 @@ def __init__( def ensure_indexes(self): """ - Ensures indicies on critical fields for GroupBuilder + Ensures indices on critical fields for GroupBuilder which include the plural version of the target's key field """ index_checks = [ @@ -105,32 +104,26 @@ def prechunk(self, number_splits: int) -> Iterator[Dict]: yield {"query": dict(zip(self.grouping_keys, split))} def get_items(self): - self.logger.info("Starting {} Builder".format(self.__class__.__name__)) + self.logger.info(f"Starting {self.__class__.__name__} Builder") self.ensure_indexes() keys = self.get_ids_to_process() groups = self.get_groups_from_keys(keys) if self.projection: - projection = list( - set(self.projection + [self.source.key, self.source.last_updated_field]) - ) + projection = list({*self.projection, self.source.key, self.source.last_updated_field}) else: projection = None self.total = len(groups) for group in groups: - docs = list( - self.source.query( - criteria=dict(zip(self.grouping_keys, group)), properties=projection - ) - ) + docs = list(self.source.query(criteria=dict(zip(self.grouping_keys, group)), properties=projection)) yield docs def process_item(self, item: List[Dict]) -> Dict[Tuple, Dict]: # type: ignore - keys = list(d[self.source.key] for d in item) + keys = [d[self.source.key] for d in item] - self.logger.debug("Processing: {}".format(keys)) + self.logger.debug(f"Processing: {keys}") time_start = time() @@ -144,9 +137,7 @@ def process_item(self, item: List[Dict]) -> Dict[Tuple, Dict]: # type: ignore time_end = time() - last_updated = [ - self.source._lu_func[0](d[self.source.last_updated_field]) for d in item - ] + last_updated = [self.source._lu_func[0](d[self.source.last_updated_field]) for d in item] update_doc = { self.target.key: keys[0], @@ -194,11 +185,9 @@ def get_ids_to_process(self) -> Iterable: query = self.query or {} - distinct_from_target = list( - self.target.distinct(self._target_keys_field, criteria=query) - ) + distinct_from_target = list(self.target.distinct(self._target_keys_field, criteria=query)) processed_ids = [] - # Not always gauranteed that MongoDB will unpack the list so we + # Not always guaranteed that MongoDB will unpack the list so we # have to make sure we do that for d in distinct_from_target: if isinstance(d, list): @@ -210,9 +199,7 @@ def get_ids_to_process(self) -> Iterable: self.logger.debug(f"Found {len(all_ids)} total docs in source") if self.retry_failed: - failed_keys = self.target.distinct( - self._target_keys_field, criteria={"state": "failed", **query} - ) + failed_keys = self.target.distinct(self._target_keys_field, criteria={"state": "failed", **query}) unprocessed_ids = all_ids - (set(processed_ids) - set(failed_keys)) self.logger.debug(f"Found {len(failed_keys)} failed IDs in target") else: @@ -220,9 +207,7 @@ def get_ids_to_process(self) -> Iterable: self.logger.info(f"Found {len(unprocessed_ids)} IDs to process") - new_ids = set( - self.source.newer_in(self.target, criteria=query, exhaustive=False) - ) + new_ids = set(self.source.newer_in(self.target, criteria=query, exhaustive=False)) self.logger.info(f"Found {len(new_ids)} updated IDs to process") return list(new_ids | unprocessed_ids) @@ -244,9 +229,7 @@ def get_groups_from_keys(self, keys) -> Set[Tuple]: ) ) - sub_groups = set( - tuple(get(d, prop, None) for prop in grouping_keys) for d in docs - ) + sub_groups = {tuple(get(d, prop, None) for prop in grouping_keys) for d in docs} self.logger.debug(f"Found {len(sub_groups)} subgroups to process") groups |= sub_groups diff --git a/src/maggma/builders/map_builder.py b/src/maggma/builders/map_builder.py index 89d09652e..0ea1fd8d6 100644 --- a/src/maggma/builders/map_builder.py +++ b/src/maggma/builders/map_builder.py @@ -1,4 +1,3 @@ -# coding: utf-8 """ One-to-One Map Builder and a simple CopyBuilder implementation """ @@ -66,7 +65,7 @@ def __init__( def ensure_indexes(self): """ - Ensures indicies on critical fields for MapBuilder + Ensures indices on critical fields for MapBuilder """ index_checks = [ self.source.ensure_index(self.source.key), @@ -103,7 +102,7 @@ def get_items(self): incremental building """ - self.logger.info("Starting {} Builder".format(self.__class__.__name__)) + self.logger.info(f"Starting {self.__class__.__name__} Builder") self.ensure_indexes() @@ -116,12 +115,10 @@ def get_items(self): failed_keys = self.target.distinct(self.target.key, criteria=failed_query) keys = list(set(keys + failed_keys)) - self.logger.info("Processing {} items".format(len(keys))) + self.logger.info(f"Processing {len(keys)} items") if self.projection: - projection = list( - set(self.projection + [self.source.key, self.source.last_updated_field]) - ) + projection = list({*self.projection, self.source.key, self.source.last_updated_field}) else: projection = None @@ -142,7 +139,7 @@ def process_item(self, item: Dict): a map function """ - self.logger.debug("Processing: {}".format(item[self.source.key])) + self.logger.debug(f"Processing: {item[self.source.key]}") time_start = time() @@ -165,9 +162,7 @@ def process_item(self, item: Dict): out = { self.target.key: item[key], - self.target.last_updated_field: self.source._lu_func[0]( - item.get(last_updated_field, datetime.utcnow()) - ), + self.target.last_updated_field: self.source._lu_func[0](item.get(last_updated_field, datetime.utcnow())), } if self.store_process_time: @@ -198,9 +193,7 @@ def finalize(self): target_keyvals = set(self.target.distinct(self.target.key)) to_delete = list(target_keyvals - source_keyvals) if len(to_delete): - self.logger.info( - "Finalize: Deleting {} orphans.".format(len(to_delete)) - ) + self.logger.info(f"Finalize: Deleting {len(to_delete)} orphans.") self.target.remove_docs({self.target.key: {"$in": to_delete}}) super().finalize() @@ -214,7 +207,6 @@ def unary_function(self, item): process_item and logged to the "error" field in the target document. """ - pass class CopyBuilder(MapBuilder): diff --git a/src/maggma/builders/projection_builder.py b/src/maggma/builders/projection_builder.py index 9348f2f1f..3659d3a47 100644 --- a/src/maggma/builders/projection_builder.py +++ b/src/maggma/builders/projection_builder.py @@ -1,7 +1,7 @@ from copy import deepcopy from datetime import datetime from itertools import chain -from typing import Dict, Iterable, List, Union +from typing import Dict, Iterable, List, Optional, Union from pydash import get @@ -30,8 +30,8 @@ def __init__( source_stores: List[Store], target_store: Store, fields_to_project: Union[List[Union[List, Dict]], None] = None, - query_by_key: List = None, - **kwargs + query_by_key: Optional[List] = None, + **kwargs, ): """ Args: @@ -68,13 +68,9 @@ def __init__( raise TypeError("Input source_stores must be provided in a list") if isinstance(fields_to_project, list): if len(source_stores) != len(fields_to_project): - raise ValueError( - "There must be an equal number of elements in source_stores and fields_to_project" - ) + raise ValueError("There must be an equal number of elements in source_stores and fields_to_project") elif fields_to_project is not None: - raise TypeError( - "Input fields_to_project must be a list. E.g. [['str1','str2'],{'A':'str1','B':str2'}]" - ) + raise TypeError("Input fields_to_project must be a list. E.g. [['str1','str2'],{'A':'str1','B':str2'}]") # interpret fields_to_project to create projection_mapping attribute projection_mapping: List[Dict] # PEP 484 Type Hinting @@ -122,7 +118,7 @@ def get_items(self) -> Iterable: Returns: generator of items to process """ - self.logger.info("Starting {} get_items...".format(self.__class__.__name__)) + self.logger.info(f"Starting {self.__class__.__name__} get_items...") # get distinct key values if len(self.query_by_key) > 0: @@ -134,19 +130,17 @@ def get_items(self) -> Iterable: unique_keys.update(store_keys) if None in store_keys: self.logger.debug( - "None found as a key value for store {} with key {}".format( - store.collection_name, store.key - ) + f"None found as a key value for store {store.collection_name} with key {store.key}" ) keys = list(unique_keys) - self.logger.info("{} distinct key values found".format(len(keys))) - self.logger.debug("None found in key values? {}".format(None in keys)) + self.logger.info(f"{len(keys)} distinct key values found") + self.logger.debug(f"None found in key values? {None in keys}") # for every key (in chunks), query from each store and # project fields specified by projection_mapping for chunked_keys in grouper(keys, self.chunk_size): chunked_keys = [k for k in chunked_keys if k is not None] - self.logger.debug("Querying by chunked_keys: {}".format(chunked_keys)) + self.logger.debug(f"Querying by chunked_keys: {chunked_keys}") unsorted_items_to_process = [] for store, projection in zip(self.sources, self.projection_mapping): @@ -156,25 +150,15 @@ def get_items(self) -> Iterable: properties: Union[List, None] if projection == {}: # all fields are projected properties = None - self.logger.debug( - "For store {} getting all properties".format( - store.collection_name - ) - ) + self.logger.debug(f"For store {store.collection_name} getting all properties") else: # only specified fields are projected - properties = [v for v in projection.values()] - self.logger.debug( - "For {} store getting properties: {}".format( - store.collection_name, properties - ) - ) + properties = list(projection.values()) + self.logger.debug(f"For {store.collection_name} store getting properties: {properties}") # get docs from store for given chunk of key values, # rename fields if specified by projection mapping, # and put in list of unsorted items to be processed - docs = store.query( - criteria={store.key: {"$in": chunked_keys}}, properties=properties - ) + docs = store.query(criteria={store.key: {"$in": chunked_keys}}, properties=properties) for d in docs: if properties is None: # all fields are projected as is item = deepcopy(d) @@ -187,7 +171,7 @@ def get_items(self) -> Iterable: # key value stored under target_key is used for sorting # items during the process_items step for k in ["_id", store.last_updated_field]: - if k in item.keys(): + if k in item: del item[k] item[self.target.key] = d[store.key] @@ -225,8 +209,8 @@ def process_item(self, items: Union[List, Iterable]) -> List[Dict]: items_for_target = [] for k, i_sorted in items_sorted_by_key.items(): - self.logger.debug("Combined items for {}: {}".format(key, k)) - target_doc = {} # type: Dict + self.logger.debug(f"Combined items for {key}: {k}") + target_doc: Dict = {} for i in i_sorted: target_doc.update(i) # last modification is adding key value avoid overwriting @@ -248,7 +232,7 @@ def update_targets(self, items: List): """ items = list(filter(None, chain.from_iterable(items))) num_items = len(items) - self.logger.info("Updating target with {} items...".format(num_items)) + self.logger.info(f"Updating target with {num_items} items...") target = self.target target_insertion_time = datetime.utcnow() diff --git a/src/maggma/cli/__init__.py b/src/maggma/cli/__init__.py index 12d7504e6..ef87e2b3e 100644 --- a/src/maggma/cli/__init__.py +++ b/src/maggma/cli/__init__.py @@ -48,17 +48,14 @@ help="Store in JSON/YAML form to send reporting data to", type=click.Path(exists=True), ) -@click.option( - "-u", "--url", "url", default=None, type=str, help="URL for the distributed manager" -) +@click.option("-u", "--url", "url", default=None, type=str, help="URL for the distributed manager") @click.option( "-p", "--port", "port", default=None, type=int, - help="Port for distributed communication." - " mrun will find an open port if None is provided to the manager", + help="Port for distributed communication. mrun will find an open port if None is provided to the manager", ) @click.option( "-N", @@ -76,12 +73,8 @@ type=int, help="Number of distributed workers to process chunks", ) -@click.option( - "--no_bars", is_flag=True, help="Turns of Progress Bars for headless operations" -) -@click.option( - "--rabbitmq", is_flag=True, help="Enables the use of RabbitMQ as the work broker" -) +@click.option("--no_bars", is_flag=True, help="Turns of Progress Bars for headless operations") +@click.option("--rabbitmq", is_flag=True, help="Enables the use of RabbitMQ as the work broker") @click.option( "-q", "--queue_prefix", @@ -138,9 +131,7 @@ def run( memray_file = f"{memray_dir}/{builders[0]}_{datetime.now().isoformat()}.bin" else: - memray_file = ( - f"{settings.TEMP_DIR}/{builders[0]}_{datetime.now().isoformat()}.bin" - ) + memray_file = f"{settings.TEMP_DIR}/{builders[0]}_{datetime.now().isoformat()}.bin" if num_processes > 1: follow_fork = True @@ -167,9 +158,7 @@ def run( root = logging.getLogger() root.setLevel(level) ch = TqdmLoggingHandler() - formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") ch.setFormatter(formatter) root.addHandler(ch) @@ -232,9 +221,7 @@ def run( else: loop = asyncio.get_event_loop() for builder in builder_objects: - loop.run_until_complete( - multi(builder=builder, num_processes=num_processes, no_bars=no_bars) - ) + loop.run_until_complete(multi(builder=builder, num_processes=num_processes, no_bars=no_bars)) if memray_file: import subprocess diff --git a/src/maggma/cli/distributed.py b/src/maggma/cli/distributed.py index 20408de8f..ceb098f42 100644 --- a/src/maggma/cli/distributed.py +++ b/src/maggma/cli/distributed.py @@ -28,9 +28,7 @@ def find_port(): return sock.getsockname()[1] -def manager( # noqa: C901 - url: str, port: int, builders: List[Builder], num_chunks: int, num_workers: int -): +def manager(url: str, port: int, builders: List[Builder], num_chunks: int, num_workers: int): """ Really simple manager for distributed processing that uses a builder prechunk to modify the builder and send out modified builders for each worker to run. @@ -58,27 +56,22 @@ def manager( # noqa: C901 try: builder.connect() - chunk_dicts = [ - {"chunk": d, "distributed": False, "completed": False} - for d in builder.prechunk(num_chunks) - ] + chunk_dicts = [{"chunk": d, "distributed": False, "completed": False} for d in builder.prechunk(num_chunks)] pbar_distributed = tqdm( total=len(chunk_dicts), - desc="Distributed chunks for {}".format(builder.__class__.__name__), + desc=f"Distributed chunks for {builder.__class__.__name__}", ) pbar_completed = tqdm( total=len(chunk_dicts), - desc="Completed chunks for {}".format(builder.__class__.__name__), + desc=f"Completed chunks for {builder.__class__.__name__}", ) logger.info(f"Distributing {len(chunk_dicts)} chunks to workers") except NotImplementedError: attempt_graceful_shutdown(workers, socket) - raise RuntimeError( - f"Can't distribute process {builder.__class__.__name__} as no prechunk method exists." - ) + raise RuntimeError(f"Can't distribute process {builder.__class__.__name__} as no prechunk method exists.") completed = False @@ -117,9 +110,7 @@ def manager( # noqa: C901 # If everything is distributed, send EXIT to the worker if all(chunk["distributed"] for chunk in chunk_dicts): - logger.debug( - f"Sending exit signal to worker: {msg.split('_')[1]}" - ) + logger.debug(f"Sending exit signal to worker: {msg.split('_')[1]}") socket.send_multipart([identity, b"", b"EXIT"]) workers.pop(identity) @@ -127,9 +118,7 @@ def manager( # noqa: C901 # Remove worker and requeue work sent to it attempt_graceful_shutdown(workers, socket) raise RuntimeError( - "At least one worker has stopped with error message: {}".format( - msg.split("_")[1] - ) + "At least one worker has stopped with error message: {}".format(msg.split("_")[1]) ) elif msg == "PING": @@ -192,14 +181,14 @@ def attempt_graceful_shutdown(workers, socket): def handle_dead_workers(workers, socket): if len(workers) == 1: # Use global timeout - identity = list(workers.keys())[0] + identity = next(iter(workers.keys())) if (perf_counter() - workers[identity]["last_ping"]) >= settings.WORKER_TIMEOUT: attempt_graceful_shutdown(workers, socket) raise RuntimeError("Worker has timed out. Stopping distributed build.") elif len(workers) == 2: # Use 10% ratio between workers - workers_sorted = sorted(list(workers.items()), key=lambda x: x[1]["heartbeats"]) + workers_sorted = sorted(workers.items(), key=lambda x: x[1]["heartbeats"]) ratio = workers_sorted[1][1]["heartbeats"] / workers_sorted[0][1]["heartbeats"] @@ -217,9 +206,7 @@ def handle_dead_workers(workers, socket): z_score = 0.6745 * (workers[identity]["heartbeats"] - median) / mad if z_score <= -3.5: attempt_graceful_shutdown(workers, socket) - raise RuntimeError( - "At least one worker has timed out. Stopping distributed build." - ) + raise RuntimeError("At least one worker has timed out. Stopping distributed build.") def worker(url: str, port: int, num_processes: int, no_bars: bool): @@ -227,7 +214,7 @@ def worker(url: str, port: int, num_processes: int, no_bars: bool): Simple distributed worker that connects to a manager asks for work and deploys using multiprocessing """ - identity = "%04X-%04X" % (randint(0, 0x10000), randint(0, 0x10000)) + identity = f"{randint(0, 0x10000):04X}-{randint(0, 0x10000):04X}" logger = getLogger(f"Worker {identity}") logger.info(f"Connecting to Manager at {url}:{port}") @@ -246,7 +233,7 @@ def worker(url: str, port: int, num_processes: int, no_bars: bool): try: running = True while running: - socket.send("READY_{}".format(hostname).encode("utf-8")) + socket.send(f"READY_{hostname}".encode()) # Poll for MANAGER_TIMEOUT seconds, if nothing is given then assume manager is dead and timeout connections = dict(poller.poll(settings.MANAGER_TIMEOUT * 1000)) @@ -277,7 +264,7 @@ def worker(url: str, port: int, num_processes: int, no_bars: bool): except Exception as e: logger.error(f"A worker failed with error: {e}") - socket.send("ERROR_{}".format(e).encode("utf-8")) + socket.send(f"ERROR_{e}".encode()) socket.close() socket.close() @@ -295,6 +282,4 @@ def ping_manager(socket, poller): message: bytes = socket.recv() if message.decode("utf-8") != "PONG": socket.close() - raise RuntimeError( - "Stopping work as manager did not respond to heartbeat from worker." - ) + raise RuntimeError("Stopping work as manager did not respond to heartbeat from worker.") diff --git a/src/maggma/cli/multiprocessing.py b/src/maggma/cli/multiprocessing.py index 192f6f377..57e1bb755 100644 --- a/src/maggma/cli/multiprocessing.py +++ b/src/maggma/cli/multiprocessing.py @@ -52,7 +52,7 @@ async def release(self, async_iterator): class AsyncUnorderedMap: """ Async iterator that maps a function to an async iterator - usign an executor and returns items as they are done + using an executor and returns items as they are done This does not guarantee order """ @@ -82,9 +82,7 @@ async def process_and_release(self, idx): async def get_from_iterator(self): loop = get_event_loop() async for idx, item in enumerate(self.iterator): - future = loop.run_in_executor( - self.executor, safe_dispatch, (self.func, item) - ) + future = loop.run_in_executor(self.executor, safe_dispatch, (self.func, item)) self.tasks[idx] = future @@ -101,8 +99,8 @@ async def __anext__(self): if item == self.done_sentinel: raise StopAsyncIteration - else: - return item + + return item async def atqdm(async_iterator, *args, **kwargs): @@ -152,7 +150,7 @@ async def multi( num_processes, no_bars=False, heartbeat_func: Optional[Callable[..., Any]] = None, - heartbeat_func_kwargs: Dict[Any, Any] = {}, + heartbeat_func_kwargs: Optional[Dict[Any, Any]] = None, ): builder.connect() cursor = builder.get_items() @@ -204,6 +202,8 @@ async def multi( disable=no_bars, ) + if not heartbeat_func_kwargs: + heartbeat_func_kwargs = {} if heartbeat_func: heartbeat_func(**heartbeat_func_kwargs) @@ -213,7 +213,7 @@ async def multi( async for chunk in grouper(back_pressure_relief, n=builder.chunk_size): logger.info( - "Processed batch of {} items".format(builder.chunk_size), + f"Processed batch of {builder.chunk_size} items", extra={ "maggma": { "event": "UPDATE", diff --git a/src/maggma/cli/rabbitmq.py b/src/maggma/cli/rabbitmq.py index 62be5935a..6052a4e46 100644 --- a/src/maggma/cli/rabbitmq.py +++ b/src/maggma/cli/rabbitmq.py @@ -54,9 +54,7 @@ def manager( logger.info(f"Binding to Manager URL {url}:{port}") # Setup connection to RabbitMQ and ensure on all queues is one unit - connection, channel, status_queue, worker_queue = setup_rabbitmq( - url, queue_prefix, port, "work" - ) + connection, channel, status_queue, worker_queue = setup_rabbitmq(url, queue_prefix, port, "work") workers = {} # type: ignore @@ -68,27 +66,22 @@ def manager( try: builder.connect() - chunk_dicts = [ - {"chunk": d, "distributed": False, "completed": False} - for d in builder.prechunk(num_chunks) - ] + chunk_dicts = [{"chunk": d, "distributed": False, "completed": False} for d in builder.prechunk(num_chunks)] pbar_distributed = tqdm( total=len(chunk_dicts), - desc="Distributed chunks for {}".format(builder.__class__.__name__), + desc=f"Distributed chunks for {builder.__class__.__name__}", ) pbar_completed = tqdm( total=len(chunk_dicts), - desc="Completed chunks for {}".format(builder.__class__.__name__), + desc=f"Completed chunks for {builder.__class__.__name__}", ) logger.info(f"Distributing {len(chunk_dicts)} chunks to workers") except NotImplementedError: attempt_graceful_shutdown(connection, workers, channel, worker_queue) - raise RuntimeError( - f"Can't distribute process {builder.__class__.__name__} as no prechunk method exists." - ) + raise RuntimeError(f"Can't distribute process {builder.__class__.__name__} as no prechunk method exists.") completed = False @@ -126,13 +119,9 @@ def manager( elif "ERROR" in msg: # Remove worker and requeue work sent to it - attempt_graceful_shutdown( - connection, workers, channel, worker_queue - ) + attempt_graceful_shutdown(connection, workers, channel, worker_queue) raise RuntimeError( - "At least one worker has stopped with error message: {}".format( - msg.split("_")[1] - ) + "At least one worker has stopped with error message: {}".format(msg.split("_")[1]) ) elif "PING" in msg: @@ -169,9 +158,7 @@ def manager( attempt_graceful_shutdown(connection, workers, channel, worker_queue) -def setup_rabbitmq( - url: str, queue_prefix: str, port: int, outbound_queue: Literal["status", "work"] -): +def setup_rabbitmq(url: str, queue_prefix: str, port: int, outbound_queue: Literal["status", "work"]): connection = pika.BlockingConnection(pika.ConnectionParameters(url, port)) channel = connection.channel() channel.basic_qos(prefetch_count=1, global_qos=True) @@ -197,7 +184,7 @@ def attempt_graceful_shutdown(connection, workers, channel, worker_queue): channel.basic_publish( exchange="", routing_key=worker_queue, - body="EXIT".encode("utf-8"), + body=b"EXIT", ) connection.close() @@ -205,14 +192,14 @@ def attempt_graceful_shutdown(connection, workers, channel, worker_queue): def handle_dead_workers(connection, workers, channel, worker_queue): if len(workers) == 1: # Use global timeout - identity = list(workers.keys())[0] + identity = next(iter(workers.keys())) if (perf_counter() - workers[identity]["last_ping"]) >= settings.WORKER_TIMEOUT: attempt_graceful_shutdown(connection, workers, channel, worker_queue) raise RuntimeError("Worker has timed out. Stopping distributed build.") elif len(workers) == 2: # Use 10% ratio between workers - workers_sorted = sorted(list(workers.items()), key=lambda x: x[1]["heartbeats"]) + workers_sorted = sorted(workers.items(), key=lambda x: x[1]["heartbeats"]) ratio = workers_sorted[1][1]["heartbeats"] / workers_sorted[0][1]["heartbeats"] @@ -229,12 +216,8 @@ def handle_dead_workers(connection, workers, channel, worker_queue): for identity in list(workers.keys()): z_score = 0.6745 * (workers[identity]["heartbeats"] - median) / mad if z_score <= -3.5: - attempt_graceful_shutdown( - connection, workers, channel, worker_queue - ) - raise RuntimeError( - "At least one worker has timed out. Stopping distributed build." - ) + attempt_graceful_shutdown(connection, workers, channel, worker_queue) + raise RuntimeError("At least one worker has timed out. Stopping distributed build.") def worker(url: str, port: int, num_processes: int, no_bars: bool, queue_prefix: str): @@ -242,23 +225,21 @@ def worker(url: str, port: int, num_processes: int, no_bars: bool, queue_prefix: Simple distributed worker that connects to a manager asks for work and deploys using multiprocessing """ - identity = "%04X-%04X" % (randint(0, 0x10000), randint(0, 0x10000)) + identity = f"{randint(0, 0x10000):04X}-{randint(0, 0x10000):04X}" logger = getLogger(f"Worker {identity}") url = url.split("//")[-1] - logger.info(f"Connnecting to Manager at {url}:{port}") + logger.info(f"Connecting to Manager at {url}:{port}") # Setup connection to RabbitMQ and ensure on all queues is one unit - connection, channel, status_queue, worker_queue = setup_rabbitmq( - url, queue_prefix, port, "status" - ) + connection, channel, status_queue, worker_queue = setup_rabbitmq(url, queue_prefix, port, "status") # Send ready signal to status queue channel.basic_publish( exchange="", routing_key=status_queue, - body="READY_{}".format(identity).encode("utf-8"), + body=f"READY_{identity}".encode(), ) try: @@ -276,12 +257,12 @@ def worker(url: str, port: int, num_processes: int, no_bars: bool, queue_prefix: work = json.loads(message) builder = MontyDecoder().process_decoded(work) - logger.info("Working on builder {}".format(builder.__class__)) + logger.info(f"Working on builder {builder.__class__}") channel.basic_publish( exchange="", routing_key=status_queue, - body="WORKING_{}".format(identity).encode("utf-8"), + body=f"WORKING_{identity}".encode(), ) work = json.loads(message) builder = MontyDecoder().process_decoded(work) @@ -303,7 +284,7 @@ def worker(url: str, port: int, num_processes: int, no_bars: bool, queue_prefix: channel.basic_publish( exchange="", routing_key=status_queue, - body="DONE_{}".format(identity).encode("utf-8"), + body=f"DONE_{identity}".encode(), ) elif message == "EXIT": @@ -311,11 +292,11 @@ def worker(url: str, port: int, num_processes: int, no_bars: bool, queue_prefix: running = False except Exception as e: - logger.error(f"A worker failed with error: {repr(e)}") + logger.error(f"A worker failed with error: {e!r}") channel.basic_publish( exchange="", routing_key=status_queue, - body="ERROR_{}".format(identity).encode("utf-8"), + body=f"ERROR_{identity}".encode(), ) connection.close() @@ -326,5 +307,5 @@ def ping_manager(channel, identity, status_queue): channel.basic_publish( exchange="", routing_key=status_queue, - body="PING_{}".format(identity).encode("utf-8"), + body=f"PING_{identity}".encode(), ) diff --git a/src/maggma/cli/serial.py b/src/maggma/cli/serial.py index 688a6b281..dd36ec00b 100644 --- a/src/maggma/cli/serial.py +++ b/src/maggma/cli/serial.py @@ -47,11 +47,9 @@ def serial(builder: Builder, no_bars=False): } }, ) - for chunk in grouper( - tqdm(cursor, total=total, disable=no_bars), builder.chunk_size - ): + for chunk in grouper(tqdm(cursor, total=total, disable=no_bars), builder.chunk_size): logger.info( - "Processing batch of {} items".format(builder.chunk_size), + f"Processing batch of {builder.chunk_size} items", extra={ "maggma": { "event": "UPDATE", @@ -66,8 +64,6 @@ def serial(builder: Builder, no_bars=False): logger.info( f"Ended serial processing: {builder.__class__.__name__}", - extra={ - "maggma": {"event": "BUILD_ENDED", "builder": builder.__class__.__name__} - }, + extra={"maggma": {"event": "BUILD_ENDED", "builder": builder.__class__.__name__}}, ) builder.finalize() diff --git a/src/maggma/cli/source_loader.py b/src/maggma/cli/source_loader.py index ea8305d59..98c78930d 100644 --- a/src/maggma/cli/source_loader.py +++ b/src/maggma/cli/source_loader.py @@ -58,7 +58,7 @@ def exec_module(self, module): module.__path__ = self.path # load the notebook object - with open(self.path, "r", encoding="utf-8") as f: + with open(self.path, encoding="utf-8") as f: nb = nbformat.read(f, 4) # extra work to ensure that magics that would affect the user_ns @@ -70,9 +70,7 @@ def exec_module(self, module): for cell in nb.cells: if cell.cell_type == "code": # transform the input to executable Python - code = self.shell.input_transformer_manager.transform_cell( - cell.source - ) + code = self.shell.input_transformer_manager.transform_cell(cell.source) # run the code in themodule exec(code, module.__dict__) finally: @@ -95,30 +93,22 @@ def spec_from_source(file_path: str) -> ModuleSpec: spec = ModuleSpec( name=f"{_BASENAME}.{module_name}", - loader=SourceFileLoader( - fullname=f"{_BASENAME}.{module_name}", path=file_path_str - ), + loader=SourceFileLoader(fullname=f"{_BASENAME}.{module_name}", path=file_path_str), origin=file_path_str, ) # spec._set_fileattr = True elif file_path_obj.parts[-1][-6:] == ".ipynb": # Gets module name from the filename without the .ipnb extension - module_name = ( - "_".join(file_path_obj.parts).replace(" ", "_").replace(".ipynb", "") - ) + module_name = "_".join(file_path_obj.parts).replace(" ", "_").replace(".ipynb", "") spec = ModuleSpec( name=f"{_BASENAME}.{module_name}", - loader=NotebookLoader( - name=f"{_BASENAME}.{module_name}", path=file_path_str - ), + loader=NotebookLoader(name=f"{_BASENAME}.{module_name}", path=file_path_str), origin=file_path_str, ) # spec._set_fileattr = True else: - raise Exception( - "Can't load {file_path}. Must provide a python source file such as a .py or .ipynb file" - ) + raise Exception("Can't load {file_path}. Must provide a python source file such as a .py or .ipynb file") return spec @@ -135,13 +125,10 @@ def load_builder_from_source(file_path: str) -> List[Builder]: sys.modules[spec.name] = module_object if hasattr(module_object, "__builders__"): - return getattr(module_object, "__builders__") - elif hasattr(module_object, "__builder__"): - return getattr(module_object, "__builder__") - else: - raise Exception( - f"No __builders__ or __builder__ attribute found in {file_path}" - ) + return module_object.__builders__ + if hasattr(module_object, "__builder__"): + return module_object.__builder__ + raise Exception(f"No __builders__ or __builder__ attribute found in {file_path}") def find_matching_file(segments, curr_path="./"): @@ -167,9 +154,7 @@ def find_matching_file(segments, curr_path="./"): pos_matches = {pmatch.group(1) for pmatch in pos_matches if pmatch} for new_path in pos_matches: if Path(new_path).exists() and Path(new_path).is_dir: - for sub_match in find_matching_file( - remainder, curr_path=new_path + "/" - ): + for sub_match in find_matching_file(remainder, curr_path=new_path + "/"): yield sub_match for sub_match in find_matching_file(remainder, curr_path=new_path): yield sub_match diff --git a/src/maggma/core/builder.py b/src/maggma/core/builder.py index d46f8a17f..051ca8135 100644 --- a/src/maggma/core/builder.py +++ b/src/maggma/core/builder.py @@ -1,4 +1,3 @@ -# coding: utf-8 """ Module containing the core builder definition """ @@ -57,7 +56,7 @@ def prechunk(self, number_splits: int) -> Iterable[Dict]: Part of a domain-decomposition paradigm to allow the builder to operate on multiple nodes by divinding up the IO as well as the compute This function should return an iterator of dictionaries that can be distributed - to multiple instances of the builder to get/process/udpate on + to multiple instances of the builder to get/process/update on Arguments: number_splits: The number of groups to split the documents to work on @@ -79,7 +78,6 @@ def get_items(self) -> Iterable: Returns: generator or list of items to process """ - pass def process_item(self, item: Any) -> Any: """ @@ -105,7 +103,6 @@ def update_targets(self, items: List): Returns: """ - pass def finalize(self): """ @@ -127,9 +124,7 @@ def run(self, log_level=logging.DEBUG): root = logging.getLogger() root.setLevel(log_level) ch = TqdmLoggingHandler() - formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") ch.setFormatter(formatter) root.addHandler(ch) @@ -138,7 +133,7 @@ def run(self, log_level=logging.DEBUG): cursor = self.get_items() for chunk in grouper(tqdm(cursor), self.chunk_size): - self.logger.info("Processing batch of {} items".format(self.chunk_size)) + self.logger.info(f"Processing batch of {self.chunk_size} items") processed_chunk = [self.process_item(item) for item in chunk] processed_items = [item for item in processed_chunk if item is not None] self.update_targets(processed_items) diff --git a/src/maggma/core/store.py b/src/maggma/core/store.py index 9822ff435..864adef9a 100644 --- a/src/maggma/core/store.py +++ b/src/maggma/core/store.py @@ -1,4 +1,3 @@ -# coding: utf-8 """ Module containing the core Store definition """ @@ -7,7 +6,7 @@ from abc import ABCMeta, abstractmethod, abstractproperty from datetime import datetime from enum import Enum -from typing import Dict, Iterator, List, Optional, Tuple, Union, Callable +from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union from monty.dev import deprecated from monty.json import MontyDecoder, MSONable @@ -41,7 +40,7 @@ def __init__( self, key: str = "task_id", last_updated_field: str = "last_updated", - last_updated_type: DateTimeFormat = DateTimeFormat("datetime"), + last_updated_type: DateTimeFormat = DateTimeFormat("datetime"), # noqa: B008 validator: Optional[Validator] = None, ): """ @@ -56,9 +55,7 @@ def __init__( self.last_updated_field = last_updated_field self.last_updated_type = last_updated_type self._lu_func: Tuple[Callable, Callable] = ( - LU_KEY_ISOFORMAT - if DateTimeFormat(last_updated_type) == DateTimeFormat.IsoFormat - else (identity, identity) + LU_KEY_ISOFORMAT if DateTimeFormat(last_updated_type) == DateTimeFormat.IsoFormat else (identity, identity) ) self.validator = validator self.logger = logging.getLogger(type(self).__name__) @@ -137,7 +134,7 @@ def update(self, docs: Union[List[Dict], Dict], key: Union[List, str, None] = No @abstractmethod def ensure_index(self, key: str, unique: bool = False) -> bool: """ - Tries to create an index and return true if it suceeded + Tries to create an index and return true if it succeeded Args: key: single key to index @@ -198,13 +195,9 @@ def query_one( sort: Dictionary of sort order for fields. Keys are field names and values are 1 for ascending or -1 for descending. """ - return next( - self.query(criteria=criteria, properties=properties, sort=sort), None - ) + return next(self.query(criteria=criteria, properties=properties, sort=sort), None) - def distinct( - self, field: str, criteria: Optional[Dict] = None, all_exist: bool = False - ) -> List: + def distinct(self, field: str, criteria: Optional[Dict] = None, all_exist: bool = False) -> List: """ Get all distinct values for a field @@ -214,11 +207,8 @@ def distinct( """ criteria = criteria or {} - results = [ - key for key, _ in self.groupby(field, properties=[field], criteria=criteria) - ] - results = [get(r, field) for r in results] - return results + results = [key for key, _ in self.groupby(field, properties=[field], criteria=criteria)] + return [get(r, field) for r in results] @property def last_updated(self) -> datetime: @@ -240,15 +230,13 @@ def last_updated(self) -> datetime: "is a datetime field in your store that represents the time of " "last update to each document." ) - elif not doc or get(doc, self.last_updated_field) is None: + if not doc or get(doc, self.last_updated_field) is None: # Handle when collection has docs but `NoneType` last_updated_field. return datetime.min - else: - return self._lu_func[0](get(doc, self.last_updated_field)) - def newer_in( - self, target: "Store", criteria: Optional[Dict] = None, exhaustive: bool = False - ) -> List[str]: + return self._lu_func[0](get(doc, self.last_updated_field)) + + def newer_in(self, target: "Store", criteria: Optional[Dict] = None, exhaustive: bool = False) -> List[str]: """ Returns the keys of documents that are newer in the target Store than this Store. @@ -267,35 +255,24 @@ def newer_in( # Get our current last_updated dates for each key value props = {self.key: 1, self.last_updated_field: 1, "_id": 0} dates = { - d[self.key]: self._lu_func[0]( - d.get(self.last_updated_field, datetime.max) - ) + d[self.key]: self._lu_func[0](d.get(self.last_updated_field, datetime.max)) for d in self.query(properties=props) } # Get the last_updated for the store we're comparing with props = {target.key: 1, target.last_updated_field: 1, "_id": 0} target_dates = { - d[target.key]: target._lu_func[0]( - d.get(target.last_updated_field, datetime.min) - ) + d[target.key]: target._lu_func[0](d.get(target.last_updated_field, datetime.min)) for d in target.query(criteria=criteria, properties=props) } new_keys = set(target_dates.keys()) - set(dates.keys()) - updated_keys = { - key - for key, date in dates.items() - if target_dates.get(key, datetime.min) > date - } + updated_keys = {key for key, date in dates.items() if target_dates.get(key, datetime.min) > date} return list(new_keys | updated_keys) - else: - criteria = { - self.last_updated_field: {"$gt": self._lu_func[1](self.last_updated)} - } - return target.distinct(field=self.key, criteria=criteria) + criteria = {self.last_updated_field: {"$gt": self._lu_func[1](self.last_updated)}} + return target.distinct(field=self.key, criteria=criteria) @deprecated(message="Please use Store.newer_in") def lu_filter(self, targets): diff --git a/src/maggma/core/validator.py b/src/maggma/core/validator.py index 3619a3eec..4f270f937 100644 --- a/src/maggma/core/validator.py +++ b/src/maggma/core/validator.py @@ -1,4 +1,3 @@ -# coding: utf-8 """ Validator class for document-level validation on Stores. Attach an instance of a Validator subclass to a Store .schema variable to enable validation on diff --git a/src/maggma/stores/__init__.py b/src/maggma/stores/__init__.py index baece8e81..5bc0b8c52 100644 --- a/src/maggma/stores/__init__.py +++ b/src/maggma/stores/__init__.py @@ -1,23 +1,12 @@ """ Root store module with easy imports for implemented Stores """ from maggma.core import Store -from maggma.stores.advanced_stores import ( - AliasingStore, - MongograntStore, - SandboxStore, - VaultStore, -) +from maggma.stores.advanced_stores import AliasingStore, MongograntStore, SandboxStore, VaultStore from maggma.stores.aws import S3Store from maggma.stores.azure import AzureBlobStore from maggma.stores.compound_stores import ConcatStore, JointStore from maggma.stores.file_store import FileStore from maggma.stores.gridfs import GridFSStore -from maggma.stores.mongolike import ( - JSONStore, - MemoryStore, - MongoStore, - MongoURIStore, - MontyStore, -) +from maggma.stores.mongolike import JSONStore, MemoryStore, MongoStore, MongoURIStore, MontyStore __all__ = [ "Store", diff --git a/src/maggma/stores/advanced_stores.py b/src/maggma/stores/advanced_stores.py index cedce4b62..428e16d5e 100644 --- a/src/maggma/stores/advanced_stores.py +++ b/src/maggma/stores/advanced_stores.py @@ -1,4 +1,3 @@ -# coding: utf-8 """ Advanced Stores for behavior outside normal access patterns """ @@ -58,7 +57,7 @@ def __init__( else: client = Client() - if set(("username", "password", "database", "host")) & set(kwargs): + if {"username", "password", "database", "host"} & set(kwargs): raise StoreError( "MongograntStore does not accept " "username, password, database, or host " @@ -67,7 +66,7 @@ def __init__( self.kwargs = kwargs _auth_info = client.get_db_auth_from_spec(self.mongogrant_spec) - super(MongograntStore, self).__init__( + super().__init__( host=_auth_info["host"], database=_auth_info["authSource"], username=_auth_info["username"], @@ -81,9 +80,7 @@ def name(self): return f"mgrant://{self.mongogrant_spec}/{self.collection_name}" def __hash__(self): - return hash( - (self.mongogrant_spec, self.collection_name, self.last_updated_field) - ) + return hash((self.mongogrant_spec, self.collection_name, self.last_updated_field)) @classmethod def from_db_file(cls, file): @@ -173,9 +170,7 @@ def __init__(self, collection_name: str, vault_secret_path: str): username = db_creds.get("username", "") password = db_creds.get("password", "") - super(VaultStore, self).__init__( - database, collection_name, host, port, username, password - ) + super().__init__(database, collection_name, host, port, username, password) def __eq__(self, other: object) -> bool: """ @@ -215,7 +210,7 @@ def __init__(self, store: Store, aliases: Dict, **kwargs): "last_updated_type": store.last_updated_type, } ) - super(AliasingStore, self).__init__(**kwargs) + super().__init__(**kwargs) @property def name(self) -> str: @@ -263,15 +258,11 @@ def query( substitute(properties, self.reverse_aliases) lazy_substitute(criteria, self.reverse_aliases) - for d in self.store.query( - properties=properties, criteria=criteria, sort=sort, limit=limit, skip=skip - ): + for d in self.store.query(properties=properties, criteria=criteria, sort=sort, limit=limit, skip=skip): substitute(d, self.aliases) yield d - def distinct( - self, field: str, criteria: Optional[Dict] = None, all_exist: bool = False - ) -> List: + def distinct(self, field: str, criteria: Optional[Dict] = None, all_exist: bool = False) -> List: """ Get all distinct values for a field @@ -326,9 +317,7 @@ def groupby( lazy_substitute(criteria, self.reverse_aliases) - return self.store.groupby( - keys=keys, properties=properties, criteria=criteria, skip=skip, limit=limit - ) + return self.store.groupby(keys=keys, properties=properties, criteria=criteria, skip=skip, limit=limit) def update(self, docs: Union[List[Dict], Dict], key: Union[List, str, None] = None): """ @@ -429,10 +418,7 @@ def sbx_criteria(self) -> Dict: """ if self.exclusive: return {"sbxn": self.sandbox} - else: - return { - "$or": [{"sbxn": {"$in": [self.sandbox]}}, {"sbxn": {"$exists": False}}] - } + return {"$or": [{"sbxn": {"$in": [self.sandbox]}}, {"sbxn": {"$exists": False}}]} def count(self, criteria: Optional[Dict] = None) -> int: """ @@ -441,9 +427,7 @@ def count(self, criteria: Optional[Dict] = None) -> int: Args: criteria: PyMongo filter for documents to count in """ - criteria = ( - dict(**criteria, **self.sbx_criteria) if criteria else self.sbx_criteria - ) + criteria = dict(**criteria, **self.sbx_criteria) if criteria else self.sbx_criteria return self.store.count(criteria=criteria) def query( @@ -465,12 +449,8 @@ def query( skip: number documents to skip limit: limit on total number of documents returned """ - criteria = ( - dict(**criteria, **self.sbx_criteria) if criteria else self.sbx_criteria - ) - return self.store.query( - properties=properties, criteria=criteria, sort=sort, limit=limit, skip=skip - ) + criteria = dict(**criteria, **self.sbx_criteria) if criteria else self.sbx_criteria + return self.store.query(properties=properties, criteria=criteria, sort=sort, limit=limit, skip=skip) def groupby( self, @@ -497,13 +477,9 @@ def groupby( Returns: generator returning tuples of (dict, list of docs) """ - criteria = ( - dict(**criteria, **self.sbx_criteria) if criteria else self.sbx_criteria - ) + criteria = dict(**criteria, **self.sbx_criteria) if criteria else self.sbx_criteria - return self.store.groupby( - keys=keys, properties=properties, criteria=criteria, skip=skip, limit=limit - ) + return self.store.groupby(keys=keys, properties=properties, criteria=criteria, skip=skip, limit=limit) def update(self, docs: Union[List[Dict], Dict], key: Union[List, str, None] = None): """ @@ -532,9 +508,7 @@ def remove_docs(self, criteria: Dict): criteria: query dictionary to match """ # Update criteria and properties based on aliases - criteria = ( - dict(**criteria, **self.sbx_criteria) if criteria else self.sbx_criteria - ) + criteria = dict(**criteria, **self.sbx_criteria) if criteria else self.sbx_criteria self.store.remove_docs(criteria) def ensure_index(self, key, unique=False, **kwargs): diff --git a/src/maggma/stores/aws.py b/src/maggma/stores/aws.py index ba4808e3b..b132e791e 100644 --- a/src/maggma/stores/aws.py +++ b/src/maggma/stores/aws.py @@ -1,4 +1,3 @@ -# coding: utf-8 """ Advanced Stores for connecting to AWS data """ @@ -8,9 +7,8 @@ from concurrent.futures import wait from concurrent.futures.thread import ThreadPoolExecutor from hashlib import sha1 -from typing import Dict, Iterator, List, Optional, Tuple, Union, Any from json import dumps -from typing import Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union import msgpack # type: ignore from monty.msgpack import default as monty_default @@ -39,8 +37,8 @@ def __init__( bucket: str, s3_profile: Optional[Union[str, dict]] = None, compress: bool = False, - endpoint_url: str = None, - sub_dir: str = None, + endpoint_url: Optional[str] = None, + sub_dir: Optional[str] = None, s3_workers: int = 1, s3_resource_kwargs: Optional[dict] = None, key: str = "fs_id", @@ -81,19 +79,13 @@ def __init__( self.s3: Any = None self.s3_bucket: Any = None self.s3_workers = s3_workers - self.s3_resource_kwargs = ( - s3_resource_kwargs if s3_resource_kwargs is not None else {} - ) + self.s3_resource_kwargs = s3_resource_kwargs if s3_resource_kwargs is not None else {} self.unpack_data = unpack_data - self.searchable_fields = ( - searchable_fields if searchable_fields is not None else [] - ) + self.searchable_fields = searchable_fields if searchable_fields is not None else [] self.store_hash = store_hash # Force the key to be the same as the index - assert isinstance( - index.key, str - ), "Since we are using the key as a file name in S3, they key must be a string" + assert isinstance(index.key, str), "Since we are using the key as a file name in S3, they key must be a string" if key != index.key: warnings.warn( f'The desired S3Store key "{key}" does not match the index key "{index.key},"' @@ -103,7 +95,7 @@ def __init__( kwargs["key"] = str(index.key) self._thread_local = threading.local() - super(S3Store, self).__init__(**kwargs) + super().__init__(**kwargs) @property def name(self) -> str: @@ -119,16 +111,14 @@ def connect(self, *args, **kwargs): # lgtm[py/conflicting-attributes] """ session = self._get_session() - resource = session.resource( - "s3", endpoint_url=self.endpoint_url, **self.s3_resource_kwargs - ) + resource = session.resource("s3", endpoint_url=self.endpoint_url, **self.s3_resource_kwargs) if not self.s3: self.s3 = resource try: self.s3.meta.client.head_bucket(Bucket=self.bucket) except ClientError: - raise RuntimeError("Bucket not present on AWS: {}".format(self.bucket)) + raise RuntimeError(f"Bucket not present on AWS: {self.bucket}") self.s3_bucket = resource.Bucket(self.bucket) self.index.connect(*args, **kwargs) @@ -191,35 +181,25 @@ def query( elif isinstance(properties, list): prop_keys = set(properties) - for doc in self.index.query( - criteria=criteria, sort=sort, limit=limit, skip=skip - ): + for doc in self.index.query(criteria=criteria, sort=sort, limit=limit, skip=skip): if properties is not None and prop_keys.issubset(set(doc.keys())): yield {p: doc[p] for p in properties if p in doc} else: try: # TODO: THis is ugly and unsafe, do some real checking before pulling data - data = ( - self.s3_bucket.Object(self.sub_dir + str(doc[self.key])) - .get()["Body"] - .read() - ) + data = self.s3_bucket.Object(self.sub_dir + str(doc[self.key])).get()["Body"].read() except botocore.exceptions.ClientError as e: # If a client error is thrown, then check that it was a 404 error. # If it was a 404 error, then the object does not exist. error_code = int(e.response["Error"]["Code"]) if error_code == 404: - self.logger.error( - "Could not find S3 object {}".format(doc[self.key]) - ) + self.logger.error(f"Could not find S3 object {doc[self.key]}") break else: raise e if self.unpack_data: - data = self._unpack( - data=data, compressed=doc.get("compression", "") == "zlib" - ) + data = self._unpack(data=data, compressed=doc.get("compression", "") == "zlib") if self.last_updated_field in doc: data[self.last_updated_field] = doc[self.last_updated_field] @@ -237,12 +217,9 @@ def _unpack(data: bytes, compressed: bool): # MontyDecoder().process_decode only goes until it finds a from_dict # as such, we cannot just use msgpack.unpackb(data, object_hook=monty_object_hook, raw=False) # Should just return the unpacked object then let the user run process_decoded - unpacked_data = msgpack.unpackb(data, raw=False) - return unpacked_data + return msgpack.unpackb(data, raw=False) - def distinct( - self, field: str, criteria: Optional[Dict] = None, all_exist: bool = False - ) -> List: + def distinct(self, field: str, criteria: Optional[Dict] = None, all_exist: bool = False) -> List: """ Get all distinct values for a field @@ -289,7 +266,7 @@ def groupby( def ensure_index(self, key: str, unique: bool = False) -> bool: """ - Tries to create an index and return true if it suceeded + Tries to create an index and return true if it succeeded Args: key: single key to index @@ -352,8 +329,8 @@ def _get_session(self): if not hasattr(self._thread_local, "s3_bucket"): if isinstance(self.s3_profile, dict): return Session(**self.s3_profile) - else: - return Session(profile_name=self.s3_profile) + return Session(profile_name=self.s3_profile) + return None def _get_bucket(self): """ @@ -398,13 +375,11 @@ def write_doc_to_s3(self, doc: Dict, search_keys: List[str]): if self.last_updated_field in doc: # need this conversion for aws metadata insert - search_doc[self.last_updated_field] = str( - to_isoformat_ceil_ms(doc[self.last_updated_field]) - ) + search_doc[self.last_updated_field] = str(to_isoformat_ceil_ms(doc[self.last_updated_field])) # keep a record of original keys, in case these are important for the individual researcher # it is not expected that this information will be used except in disaster recovery - s3_to_mongo_keys = {k: self._sanitize_key(k) for k in search_doc.keys()} + s3_to_mongo_keys = {k: self._sanitize_key(k) for k in search_doc} s3_to_mongo_keys["s3-to-mongo-keys"] = "s3-to-mongo-keys" # inception # encode dictionary since values have to be strings search_doc["s3-to-mongo-keys"] = dumps(s3_to_mongo_keys) @@ -465,9 +440,7 @@ def remove_docs(self, criteria: Dict, remove_s3_object: bool = False): def last_updated(self): return self.index.last_updated - def newer_in( - self, target: Store, criteria: Optional[Dict] = None, exhaustive: bool = False - ) -> List[str]: + def newer_in(self, target: Store, criteria: Optional[Dict] = None, exhaustive: bool = False) -> List[str]: """ Returns the keys of documents that are newer in the target Store than this Store. @@ -480,13 +453,8 @@ def newer_in( that to filter out new items in """ if hasattr(target, "index"): - return self.index.newer_in( - target=target.index, criteria=criteria, exhaustive=exhaustive - ) - else: - return self.index.newer_in( - target=target, criteria=criteria, exhaustive=exhaustive - ) + return self.index.newer_in(target=target.index, criteria=criteria, exhaustive=exhaustive) + return self.index.newer_in(target=target, criteria=criteria, exhaustive=exhaustive) def __hash__(self): return hash((self.index.__hash__, self.bucket)) @@ -508,7 +476,7 @@ def rebuild_index_from_s3_data(self, **kwargs): unpacked_data = msgpack.unpackb(data, raw=False) self.update(unpacked_data, **kwargs) - def rebuild_metadata_from_index(self, index_query: dict = None): + def rebuild_metadata_from_index(self, index_query: Optional[dict] = None): """ Read data from the index store and populate the metadata of the S3 bucket Force all of the keys to be lower case to be Minio compatible @@ -525,9 +493,7 @@ def rebuild_metadata_from_index(self, index_query: dict = None): new_meta[str(k).lower()] = v new_meta.pop("_id") if self.last_updated_field in new_meta: - new_meta[self.last_updated_field] = str( - to_isoformat_ceil_ms(new_meta[self.last_updated_field]) - ) + new_meta[self.last_updated_field] = str(to_isoformat_ceil_ms(new_meta[self.last_updated_field])) # s3_object.metadata.update(new_meta) s3_object.copy_from( CopySource={"Bucket": self.s3_bucket.name, "Key": key_}, diff --git a/src/maggma/stores/azure.py b/src/maggma/stores/azure.py index bef362a53..1ba206394 100644 --- a/src/maggma/stores/azure.py +++ b/src/maggma/stores/azure.py @@ -1,4 +1,3 @@ -# coding: utf-8 """ Advanced Stores for connecting to Microsoft Azure data """ @@ -85,9 +84,7 @@ def __init__( kwargs: keywords for the base Store. """ if azure_blob is None: - raise RuntimeError( - "azure-storage-blob and azure-identity are required for AzureBlobStore" - ) + raise RuntimeError("azure-storage-blob and azure-identity are required for AzureBlobStore") self.index = index self.container_name = container_name @@ -97,13 +94,9 @@ def __init__( self.service: Optional[BlobServiceClient] = None self.container: Optional[ContainerClient] = None self.workers = workers - self.azure_resource_kwargs = ( - azure_resource_kwargs if azure_resource_kwargs is not None else {} - ) + self.azure_resource_kwargs = azure_resource_kwargs if azure_resource_kwargs is not None else {} self.unpack_data = unpack_data - self.searchable_fields = ( - searchable_fields if searchable_fields is not None else [] - ) + self.searchable_fields = searchable_fields if searchable_fields is not None else [] self.store_hash = store_hash if key_sanitize_dict is None: key_sanitize_dict = AZURE_KEY_SANITIZE @@ -123,7 +116,7 @@ def __init__( kwargs["key"] = str(index.key) self._thread_local = threading.local() - super(AzureBlobStore, self).__init__(**kwargs) + super().__init__(**kwargs) @property def name(self) -> str: @@ -151,9 +144,7 @@ def connect(self, *args, **kwargs): # lgtm[py/conflicting-attributes] except ResourceExistsError: pass else: - raise RuntimeError( - f"Container not present on Azure: {self.container_name}" - ) + raise RuntimeError(f"Container not present on Azure: {self.container_name}") self.container = container self.index.connect(*args, **kwargs) @@ -218,25 +209,17 @@ def query( elif isinstance(properties, list): prop_keys = set(properties) - for doc in self.index.query( - criteria=criteria, sort=sort, limit=limit, skip=skip - ): + for doc in self.index.query(criteria=criteria, sort=sort, limit=limit, skip=skip): if properties is not None and prop_keys.issubset(set(doc.keys())): yield {p: doc[p] for p in properties if p in doc} else: try: - data = self.container.download_blob( - self.sub_dir + str(doc[self.key]) - ).readall() + data = self.container.download_blob(self.sub_dir + str(doc[self.key])).readall() except azure.core.exceptions.ResourceNotFoundError: - self.logger.error( - "Could not find Blob object {}".format(doc[self.key]) - ) + self.logger.error(f"Could not find Blob object {doc[self.key]}") if self.unpack_data: - data = self._unpack( - data=data, compressed=doc.get("compression", "") == "zlib" - ) + data = self._unpack(data=data, compressed=doc.get("compression", "") == "zlib") if self.last_updated_field in doc: data[self.last_updated_field] = doc[self.last_updated_field] # type: ignore @@ -254,12 +237,9 @@ def _unpack(data: bytes, compressed: bool): # MontyDecoder().process_decode only goes until it finds a from_dict # as such, we cannot just use msgpack.unpackb(data, object_hook=monty_object_hook, raw=False) # Should just return the unpacked object then let the user run process_decoded - unpacked_data = msgpack.unpackb(data, raw=False) - return unpacked_data + return msgpack.unpackb(data, raw=False) - def distinct( - self, field: str, criteria: Optional[Dict] = None, all_exist: bool = False - ) -> List: + def distinct(self, field: str, criteria: Optional[Dict] = None, all_exist: bool = False) -> List: """ Get all distinct values for a field @@ -306,7 +286,7 @@ def groupby( def ensure_index(self, key: str, unique: bool = False) -> bool: """ - Tries to create an index and return true if it suceeded + Tries to create an index and return true if it succeeded Args: key: single key to index @@ -374,19 +354,16 @@ def _get_service_client(self): if isinstance(self.azure_client_info, str): # assume it is the account_url and that the connection is passwordless default_credential = DefaultAzureCredential() - return BlobServiceClient( - self.azure_client_info, credential=default_credential - ) + return BlobServiceClient(self.azure_client_info, credential=default_credential) - elif isinstance(self.azure_client_info, dict): + if isinstance(self.azure_client_info, dict): connection_string = self.azure_client_info.get("connection_string") if connection_string: - return BlobServiceClient.from_connection_string( - conn_str=connection_string - ) + return BlobServiceClient.from_connection_string(conn_str=connection_string) msg = f"Could not instantiate BlobServiceClient from azure_client_info: {self.azure_client_info}" raise RuntimeError(msg) + return None def _get_container(self) -> Optional[ContainerClient]: """ @@ -434,13 +411,11 @@ def write_doc_to_blob(self, doc: Dict, search_keys: List[str]): if self.last_updated_field in doc: # need this conversion for metadata insert - search_doc[self.last_updated_field] = str( - to_isoformat_ceil_ms(doc[self.last_updated_field]) - ) + search_doc[self.last_updated_field] = str(to_isoformat_ceil_ms(doc[self.last_updated_field])) # keep a record of original keys, in case these are important for the individual researcher # it is not expected that this information will be used except in disaster recovery - blob_to_mongo_keys = {k: self._sanitize_key(k) for k in search_doc.keys()} + blob_to_mongo_keys = {k: self._sanitize_key(k) for k in search_doc} blob_to_mongo_keys["blob_to_mongo_keys"] = "blob_to_mongo_keys" # inception # encode dictionary since values have to be strings search_doc["blob_to_mongo_keys"] = dumps(blob_to_mongo_keys) @@ -501,9 +476,7 @@ def remove_docs(self, criteria: Dict, remove_blob_object: bool = False): def last_updated(self): return self.index.last_updated - def newer_in( - self, target: Store, criteria: Optional[Dict] = None, exhaustive: bool = False - ) -> List[str]: + def newer_in(self, target: Store, criteria: Optional[Dict] = None, exhaustive: bool = False) -> List[str]: """ Returns the keys of documents that are newer in the target Store than this Store. @@ -516,13 +489,9 @@ def newer_in( that to filter out new items in """ if hasattr(target, "index"): - return self.index.newer_in( - target=target.index, criteria=criteria, exhaustive=exhaustive - ) - else: - return self.index.newer_in( - target=target, criteria=criteria, exhaustive=exhaustive - ) + return self.index.newer_in(target=target.index, criteria=criteria, exhaustive=exhaustive) + + return self.index.newer_in(target=target, criteria=criteria, exhaustive=exhaustive) def __hash__(self): return hash((self.index.__hash__, self.container_name)) @@ -565,16 +534,12 @@ def rebuild_metadata_from_index(self, index_query: Optional[Dict] = None): key_ = self.sub_dir + index_doc[self.key] blob = self.container.get_blob_client(key_) properties = blob.get_blob_properties() - new_meta = { - self._sanitize_key(k): v for k, v in properties.metadata.items() - } + new_meta = {self._sanitize_key(k): v for k, v in properties.metadata.items()} for k, v in index_doc.items(): new_meta[str(k).lower()] = v new_meta.pop("_id") if self.last_updated_field in new_meta: - new_meta[self.last_updated_field] = str( - to_isoformat_ceil_ms(new_meta[self.last_updated_field]) - ) + new_meta[self.last_updated_field] = str(to_isoformat_ceil_ms(new_meta[self.last_updated_field])) blob.set_blob_metadata(new_meta) def __eq__(self, other: object) -> bool: diff --git a/src/maggma/stores/compound_stores.py b/src/maggma/stores/compound_stores.py index 9a76ce6cd..7a0e9b07d 100644 --- a/src/maggma/stores/compound_stores.py +++ b/src/maggma/stores/compound_stores.py @@ -54,7 +54,7 @@ def __init__( self.mongoclient_kwargs = mongoclient_kwargs or {} self.kwargs = kwargs - super(JointStore, self).__init__(**kwargs) + super().__init__(**kwargs) @property def name(self) -> str: @@ -85,9 +85,7 @@ def connect(self, force_reset: bool = False): ) db = conn[self.database] self._coll = db[self.main] - self._has_merge_objects = ( - self._collection.database.client.server_info()["version"] > "3.6" - ) + self._has_merge_objects = self._collection.database.client.server_info()["version"] > "3.6" def close(self): """ @@ -99,7 +97,7 @@ def close(self): def _collection(self): """Property referring to the root pymongo collection""" if self._coll is None: - raise StoreError("Must connect Mongo-like store before attemping to use it") + raise StoreError("Must connect Mongo-like store before attempting to use it") return self._coll @property @@ -173,16 +171,14 @@ def _get_pipeline(self, criteria=None, properties=None, skip=0, limit=0): if self.merge_at_root: if not self._has_merge_objects: - raise Exception( - "MongoDB server version too low to use $mergeObjects." - ) + raise Exception("MongoDB server version too low to use $mergeObjects.") pipeline.append( { "$replaceRoot": { "newRoot": { "$mergeObjects": [ - {"$arrayElemAt": ["${}".format(cname), 0]}, + {"$arrayElemAt": [f"${cname}", 0]}, "$$ROOT", ] } @@ -193,20 +189,15 @@ def _get_pipeline(self, criteria=None, properties=None, skip=0, limit=0): pipeline.append( { "$unwind": { - "path": "${}".format(cname), + "path": f"${cname}", "preserveNullAndEmptyArrays": True, } } ) # Do projection for max last_updated - lu_max_fields = ["${}".format(self.last_updated_field)] - lu_max_fields.extend( - [ - "${}.{}".format(cname, self.last_updated_field) - for cname in self.collection_names - ] - ) + lu_max_fields = [f"${self.last_updated_field}"] + lu_max_fields.extend([f"${cname}.{self.last_updated_field}" for cname in self.collection_names]) lu_proj = {self.last_updated_field: {"$max": lu_max_fields}} pipeline.append({"$addFields": lu_proj}) @@ -244,12 +235,9 @@ def query( skip: int = 0, limit: int = 0, ) -> Iterator[Dict]: - pipeline = self._get_pipeline( - criteria=criteria, properties=properties, skip=skip, limit=limit - ) + pipeline = self._get_pipeline(criteria=criteria, properties=properties, skip=skip, limit=limit) agg = self._collection.aggregate(pipeline) - for d in agg: - yield d + yield from agg def groupby( self, @@ -260,14 +248,12 @@ def groupby( skip: int = 0, limit: int = 0, ) -> Iterator[Tuple[Dict, List[Dict]]]: - pipeline = self._get_pipeline( - criteria=criteria, properties=properties, skip=skip, limit=limit - ) + pipeline = self._get_pipeline(criteria=criteria, properties=properties, skip=skip, limit=limit) if not isinstance(keys, list): keys = [keys] group_id = {} # type: Dict[str,Any] for key in keys: - set_(group_id, key, "${}".format(key)) + set_(group_id, key, f"${key}") pipeline.append({"$group": {"_id": group_id, "docs": {"$push": "$$ROOT"}}}) agg = self._collection.aggregate(pipeline) @@ -292,8 +278,7 @@ def query_one(self, criteria=None, properties=None, **kwargs): # pipeline.append({"$limit": 1}) query = self.query(criteria=criteria, properties=properties, **kwargs) try: - doc = next(query) - return doc + return next(query) except StopIteration: return None @@ -339,7 +324,7 @@ def __init__(self, stores: List[Store], **kwargs): """ self.stores = stores self.kwargs = kwargs - super(ConcatStore, self).__init__(**kwargs) + super().__init__(**kwargs) @property def name(self) -> str: @@ -374,7 +359,7 @@ def _collection(self): def last_updated(self) -> datetime: """ Finds the most recent last_updated across all the stores. - This might not be the most usefull way to do this for this type of Store + This might not be the most useful way to do this for this type of Store since it could very easily over-estimate the last_updated based on what stores are used """ @@ -398,9 +383,7 @@ def update(self, docs: Union[List[Dict], Dict], key: Union[List, str, None] = No """ raise NotImplementedError("No update method for ConcatStore") - def distinct( - self, field: str, criteria: Optional[Dict] = None, all_exist: bool = False - ) -> List: + def distinct(self, field: str, criteria: Optional[Dict] = None, all_exist: bool = False) -> List: """ Get all distinct values for a field @@ -425,7 +408,7 @@ def ensure_index(self, key: str, unique: bool = False) -> bool: Returns: bool indicating if the index exists/was created on all stores """ - return all([store.ensure_index(key, unique) for store in self.stores]) + return all(store.ensure_index(key, unique) for store in self.stores) def count(self, criteria: Optional[Dict] = None) -> int: """ @@ -502,17 +485,16 @@ def groupby( limit=limit, ) ) - for key, group in temp_docs: + for _key, group in temp_docs: docs.extend(group) def key_set(d: Dict) -> Tuple: "index function based on passed in keys" - test_d = tuple(d.get(k, None) for k in keys) - return test_d + return tuple(d.get(k, None) for k in keys) sorted_docs = sorted(docs, key=key_set) for vals, group_iter in groupby(sorted_docs, key=key_set): - id_dict = {key: val for key, val in zip(keys, vals)} + id_dict = dict(zip(keys, vals)) yield id_dict, list(group_iter) def remove_docs(self, criteria: Dict): diff --git a/src/maggma/stores/file_store.py b/src/maggma/stores/file_store.py index 7eb8f6908..393fdfdeb 100644 --- a/src/maggma/stores/file_store.py +++ b/src/maggma/stores/file_store.py @@ -1,4 +1,3 @@ -# coding: utf-8 """ Module defining a FileStore that enables accessing files in a local directory using typical maggma access patterns. @@ -95,9 +94,7 @@ def __init__( self.json_name = json_name file_filters = file_filters if file_filters else ["*"] - self.file_filters = re.compile( - "|".join(fnmatch.translate(p) for p in file_filters) - ) + self.file_filters = re.compile("|".join(fnmatch.translate(p) for p in file_filters)) self.collection_name = "file_store" self.key = "file_id" self.include_orphans = include_orphans @@ -128,9 +125,9 @@ def name(self) -> str: def add_metadata( self, - metadata: Dict = {}, + metadata: Optional[Dict] = None, query: Optional[Dict] = None, - auto_data: Callable[[Dict], Dict] = None, + auto_data: Optional[Callable[[Dict], Dict]] = None, **kwargs, ): """ @@ -160,6 +157,8 @@ def get_metadata_from_filename(d): metadata is used. kwargs: kwargs passed to FileStore.query() """ + if metadata is None: + metadata = {} # sanitize the metadata filtered_metadata = self._filter_data(metadata) updated_docs = [] @@ -195,7 +194,7 @@ def read(self) -> List[Dict]: """ file_list = [] # generate a list of files in subdirectories - for root, dirs, files in os.walk(self.path): + for root, _dirs, files in os.walk(self.path): # for pattern in self.file_filters: for match in filter(self.file_filters.match, files): # for match in fnmatch.filter(files, pattern): @@ -211,7 +210,7 @@ def read(self) -> List[Dict]: def _create_record_from_file(self, f: Path) -> Dict: """ - Given the path to a file, return a Dict that constitues a record of + Given the path to a file, return a Dict that constitutes a record of basic information about that file. The keys in the returned dict are: @@ -283,7 +282,7 @@ def connect(self): # now read any metadata from the .json file try: self.metadata_store.connect() - metadata = [d for d in self.metadata_store.query()] + metadata = list(self.metadata_store.query()) except FileNotFoundError: metadata = [] warnings.warn( @@ -299,10 +298,7 @@ def connect(self): key = self.key file_ids = self.distinct(self.key) for d in metadata: - if isinstance(key, list): - search_doc = {k: d[k] for k in key} - else: - search_doc = {key: d[key]} + search_doc = {k: d[k] for k in key} if isinstance(key, list) else {key: d[key]} if d[key] not in file_ids: found_orphans = True @@ -346,16 +342,13 @@ def update(self, docs: Union[List[Dict], Dict], key: Union[List, str, None] = No ) super().update(docs, key) - data = [d for d in self.query()] + data = list(self.query()) filtered_data = [] # remove fields that are populated by .read() for d in data: filtered_d = self._filter_data(d) # don't write records that contain only file_id - if ( - len(set(filtered_d.keys()).difference(set(["path_relative", self.key]))) - != 0 - ): + if len(set(filtered_d.keys()).difference({"path_relative", self.key})) != 0: filtered_data.append(filtered_d) self.metadata_store.update(filtered_data, self.key) @@ -366,12 +359,7 @@ def _filter_data(self, d): Args: d: Dictionary whose keys are to be filtered """ - filtered_d = { - k: v - for k, v in d.items() - if k not in PROTECTED_KEYS.union({self.last_updated_field}) - } - return filtered_d + return {k: v for k, v in d.items() if k not in PROTECTED_KEYS.union({self.last_updated_field})} def query( # type: ignore self, @@ -404,9 +392,8 @@ def query( # type: ignore """ return_contents = False criteria = criteria if criteria else {} - if criteria.get("orphan", None) is None: - if not self.include_orphans: - criteria.update({"orphan": False}) + if criteria.get("orphan", None) is None and not self.include_orphans: + criteria.update({"orphan": False}) if criteria.get("contents"): warnings.warn("'contents' is not a queryable field! Ignoring.") @@ -502,11 +489,10 @@ def remove_docs(self, criteria: Dict, confirm: bool = False): """ if self.read_only: raise StoreError( - "This Store is read-only. To enable file I/O, re-initialize the " - "store with read_only=False." + "This Store is read-only. To enable file I/O, re-initialize the store with read_only=False." ) - docs = [d for d in self.query(criteria)] + docs = list(self.query(criteria)) # this ensures that any modifications to criteria made by self.query # (e.g., related to orphans or contents) are propagated through to the superclass new_criteria = {"file_id": {"$in": [d["file_id"] for d in docs]}} diff --git a/src/maggma/stores/gridfs.py b/src/maggma/stores/gridfs.py index e341e8fef..9e907c834 100644 --- a/src/maggma/stores/gridfs.py +++ b/src/maggma/stores/gridfs.py @@ -1,4 +1,3 @@ -# coding: utf-8 """ Module containing various definitions of Stores. Stores are a default access pattern to data and provide @@ -52,7 +51,7 @@ def __init__( password: str = "", compression: bool = False, ensure_metadata: bool = False, - searchable_fields: List[str] = None, + searchable_fields: Optional[List[str]] = None, auth_source: Optional[str] = None, mongoclient_kwargs: Optional[Dict] = None, ssh_tunnel: Optional[SSHTunnel] = None, @@ -106,7 +105,7 @@ def from_launchpad_file(cls, lp_file, collection_name, **kwargs): Returns: """ - with open(lp_file, "r") as f: + with open(lp_file) as f: lp_creds = yaml.safe_load(f.read()) db_creds = lp_creds.copy() @@ -152,19 +151,17 @@ def connect(self, force_reset: bool = False): db = conn[self.database] self._coll = gridfs.GridFS(db, self.collection_name) - self._files_collection = db["{}.files".format(self.collection_name)] + self._files_collection = db[f"{self.collection_name}.files"] self._files_store = MongoStore.from_collection(self._files_collection) self._files_store.last_updated_field = f"metadata.{self.last_updated_field}" self._files_store.key = self.key - self._chunks_collection = db["{}.chunks".format(self.collection_name)] + self._chunks_collection = db[f"{self.collection_name}.chunks"] @property def _collection(self): """Property referring to underlying pymongo collection""" if self._coll is None: - raise StoreError( - "Must connect Mongo-like store before attempting to use it" - ) + raise StoreError("Must connect Mongo-like store before attempting to use it") return self._coll @property @@ -184,9 +181,7 @@ def transform_criteria(cls, criteria: Dict) -> Dict: """ new_criteria = dict() for field in criteria: - if field not in files_collection_fields and not field.startswith( - "metadata." - ): + if field not in files_collection_fields and not field.startswith("metadata."): new_criteria["metadata." + field] = copy.copy(criteria[field]) else: new_criteria[field] = copy.copy(criteria[field]) @@ -240,9 +235,7 @@ def query( elif isinstance(properties, list): prop_keys = set(properties) - for doc in self._files_store.query( - criteria=criteria, sort=sort, limit=limit, skip=skip - ): + for doc in self._files_store.query(criteria=criteria, sort=sort, limit=limit, skip=skip): if properties is not None and prop_keys.issubset(set(doc.keys())): yield {p: doc[p] for p in properties if p in doc} else: @@ -273,9 +266,7 @@ def query( yield data - def distinct( - self, field: str, criteria: Optional[Dict] = None, all_exist: bool = False - ) -> List: + def distinct(self, field: str, criteria: Optional[Dict] = None, all_exist: bool = False) -> List: """ Get all distinct values for a field. This function only operates on the metadata in the files collection @@ -284,17 +275,10 @@ def distinct( field: the field(s) to get distinct values for criteria: PyMongo filter for documents to search in """ - criteria = ( - self.transform_criteria(criteria) - if isinstance(criteria, dict) - else criteria - ) + criteria = self.transform_criteria(criteria) if isinstance(criteria, dict) else criteria field = ( - f"metadata.{field}" - if field not in files_collection_fields - and not field.startswith("metadata.") - else field + f"metadata.{field}" if field not in files_collection_fields and not field.startswith("metadata.") else field ) return self._files_store.distinct(field=field, criteria=criteria) @@ -326,30 +310,15 @@ def groupby( generator returning tuples of (dict, list of docs) """ - criteria = ( - self.transform_criteria(criteria) - if isinstance(criteria, dict) - else criteria - ) + criteria = self.transform_criteria(criteria) if isinstance(criteria, dict) else criteria keys = [keys] if not isinstance(keys, list) else keys keys = [ - f"metadata.{k}" - if k not in files_collection_fields and not k.startswith("metadata.") - else k - for k in keys + f"metadata.{k}" if k not in files_collection_fields and not k.startswith("metadata.") else k for k in keys ] - for group, ids in self._files_store.groupby( - keys, criteria=criteria, properties=[f"metadata.{self.key}"] - ): - ids = [ - get(doc, f"metadata.{self.key}") - for doc in ids - if has(doc, f"metadata.{self.key}") - ] - - group = { - k.replace("metadata.", ""): get(group, k) for k in keys if has(group, k) - } + for group, ids in self._files_store.groupby(keys, criteria=criteria, properties=[f"metadata.{self.key}"]): + ids = [get(doc, f"metadata.{self.key}") for doc in ids if has(doc, f"metadata.{self.key}")] + + group = {k.replace("metadata.", ""): get(group, k) for k in keys if has(group, k)} yield group, list(self.query(criteria={self.key: {"$in": ids}})) @@ -366,10 +335,9 @@ def ensure_index(self, key: str, unique: Optional[bool] = False) -> bool: """ # Transform key for gridfs first if key not in files_collection_fields: - files_col_key = "metadata.{}".format(key) + files_col_key = f"metadata.{key}" return self._files_store.ensure_index(files_col_key, unique=unique) - else: - return self._files_store.ensure_index(key, unique=unique) + return self._files_store.ensure_index(key, unique=unique) def update( self, @@ -411,9 +379,7 @@ def update( metadata = { k: get(d, k) - for k in [self.last_updated_field] - + additional_metadata - + self.searchable_fields + for k in [self.last_updated_field, *additional_metadata, *self.searchable_fields] if has(d, k) } metadata.update(search_doc) @@ -426,11 +392,7 @@ def update( search_doc = self.transform_criteria(search_doc) # Cleans up old gridfs entries - for fdoc in ( - self._files_collection.find(search_doc, ["_id"]) - .sort("uploadDate", -1) - .skip(1) - ): + for fdoc in self._files_collection.find(search_doc, ["_id"]).sort("uploadDate", -1).skip(1): self._collection.delete(fdoc["_id"]) def remove_docs(self, criteria: Dict): @@ -477,10 +439,10 @@ def __init__( self, uri: str, collection_name: str, - database: str = None, + database: Optional[str] = None, compression: bool = False, ensure_metadata: bool = False, - searchable_fields: List[str] = None, + searchable_fields: Optional[List[str]] = None, mongoclient_kwargs: Optional[Dict] = None, **kwargs, ): @@ -501,9 +463,7 @@ def __init__( if database is None: d_uri = uri_parser.parse_uri(uri) if d_uri["database"] is None: - raise ConfigurationError( - "If database name is not supplied, a database must be set in the uri" - ) + raise ConfigurationError("If database name is not supplied, a database must be set in the uri") self.database = d_uri["database"] else: self.database = database @@ -528,8 +488,8 @@ def connect(self, force_reset: bool = False): conn: MongoClient = MongoClient(self.uri, **self.mongoclient_kwargs) db = conn[self.database] self._coll = gridfs.GridFS(db, self.collection_name) - self._files_collection = db["{}.files".format(self.collection_name)] + self._files_collection = db[f"{self.collection_name}.files"] self._files_store = MongoStore.from_collection(self._files_collection) self._files_store.last_updated_field = f"metadata.{self.last_updated_field}" self._files_store.key = self.key - self._chunks_collection = db["{}.chunks".format(self.collection_name)] + self._chunks_collection = db[f"{self.collection_name}.chunks"] diff --git a/src/maggma/stores/mongolike.py b/src/maggma/stores/mongolike.py index 506c7936d..6cd734b7b 100644 --- a/src/maggma/stores/mongolike.py +++ b/src/maggma/stores/mongolike.py @@ -1,4 +1,3 @@ -# coding: utf-8 """ Module containing various definitions of Stores. Stores are a default access pattern to data and provide @@ -13,19 +12,10 @@ from ruamel import yaml try: - from typing import ( - Any, - Callable, - Dict, - Iterator, - List, - Literal, - Optional, - Tuple, - Union, - ) + from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Tuple, Union except ImportError: - from typing import Dict, Iterator, List, Optional, Tuple, Union, Any, Callable + from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union + from typing_extensions import Literal import mongomock @@ -233,7 +223,7 @@ def from_launchpad_file(cls, lp_file, collection_name, **kwargs): Returns: """ - with open(lp_file, "r") as f: + with open(lp_file) as f: lp_creds = yaml.safe_load(f.read()) db_creds = lp_creds.copy() @@ -245,9 +235,7 @@ def from_launchpad_file(cls, lp_file, collection_name, **kwargs): return cls(**db_creds, **kwargs) - def distinct( - self, field: str, criteria: Optional[Dict] = None, all_exist: bool = False - ) -> List: + def distinct(self, field: str, criteria: Optional[Dict] = None, all_exist: bool = False) -> List: """ Get all distinct values for a field @@ -261,10 +249,7 @@ def distinct( distinct_vals = self._collection.distinct(field, criteria) except (OperationFailure, DocumentTooLarge): distinct_vals = [ - d["_id"] - for d in self._collection.aggregate( - [{"$match": criteria}, {"$group": {"_id": f"${field}"}}] - ) + d["_id"] for d in self._collection.aggregate([{"$match": criteria}, {"$group": {"_id": f"${field}"}}]) ] if all(isinstance(d, list) for d in filter(None, distinct_vals)): # type: ignore distinct_vals = list(chain.from_iterable(filter(None, distinct_vals))) @@ -343,9 +328,7 @@ def from_collection(cls, collection): def _collection(self): """Property referring to underlying pymongo collection""" if self._coll is None: - raise StoreError( - "Must connect Mongo-like store before attempting to use it" - ) + raise StoreError("Must connect Mongo-like store before attempting to use it") return self._coll def count( @@ -365,12 +348,7 @@ def count( criteria = criteria if criteria else {} hint_list = ( - [ - (k, Sort(v).value) if isinstance(v, int) else (k, v.value) - for k, v in hint.items() - ] - if hint - else None + [(k, Sort(v).value) if isinstance(v, int) else (k, v.value) for k, v in hint.items()] if hint else None ) if hint_list is not None: # pragma: no cover @@ -413,29 +391,20 @@ def query( # type: ignore if self.default_sort is not None: default_sort_formatted = [ - (k, Sort(v).value) if isinstance(v, int) else (k, v.value) - for k, v in self.default_sort.items() + (k, Sort(v).value) if isinstance(v, int) else (k, v.value) for k, v in self.default_sort.items() ] sort_list = ( - [ - (k, Sort(v).value) if isinstance(v, int) else (k, v.value) - for k, v in sort.items() - ] + [(k, Sort(v).value) if isinstance(v, int) else (k, v.value) for k, v in sort.items()] if sort else default_sort_formatted ) hint_list = ( - [ - (k, Sort(v).value) if isinstance(v, int) else (k, v.value) - for k, v in hint.items() - ] - if hint - else None + [(k, Sort(v).value) if isinstance(v, int) else (k, v.value) for k, v in hint.items()] if hint else None ) - for d in self._collection.find( + yield from self._collection.find( filter=criteria, projection=properties, skip=skip, @@ -443,8 +412,7 @@ def query( # type: ignore sort=sort_list, hint=hint_list, **kwargs, - ): - yield d + ) def ensure_index(self, key: str, unique: Optional[bool] = False) -> bool: """ @@ -459,12 +427,12 @@ def ensure_index(self, key: str, unique: Optional[bool] = False) -> bool: if confirm_field_index(self._collection, key): return True - else: - try: - self._collection.create_index(key, unique=unique, background=True) - return True - except Exception: - return False + + try: + self._collection.create_index(key, unique=unique, background=True) + return True + except Exception: + return False def update(self, docs: Union[List[Dict], Dict], key: Union[List, str, None] = None): """ @@ -483,7 +451,7 @@ def update(self, docs: Union[List[Dict], Dict], key: Union[List, str, None] = No if not isinstance(docs, list): docs = [docs] - for d in map(lambda x: jsanitize(x, allow_bson=True), docs): + for d in (jsanitize(x, allow_bson=True) for x in docs): # document-level validation is optional validates = True if self.validator: @@ -491,15 +459,11 @@ def update(self, docs: Union[List[Dict], Dict], key: Union[List, str, None] = No if not validates: if self.validator.strict: raise ValueError(self.validator.validation_errors(d)) - else: - self.logger.error(self.validator.validation_errors(d)) + self.logger.error(self.validator.validation_errors(d)) if validates: key = key or self.key - if isinstance(key, list): - search_doc = {k: d[k] for k in key} - else: - search_doc = {key: d[key]} + search_doc = {k: d[k] for k in key} if isinstance(key, list) else {key: d[key]} requests.append(ReplaceOne(search_doc, d, upsert=True)) @@ -509,7 +473,6 @@ def update(self, docs: Union[List[Dict], Dict], key: Union[List, str, None] = No except (OperationFailure, DocumentTooLarge) as e: if self.safe_update: for req in requests: - req._filter try: self._collection.bulk_write([req], ordered=False) except (OperationFailure, DocumentTooLarge): @@ -558,7 +521,7 @@ def __init__( self, uri: str, collection_name: str, - database: str = None, + database: Optional[str] = None, ssh_tunnel: Optional[SSHTunnel] = None, mongoclient_kwargs: Optional[Dict] = None, default_sort: Optional[Dict[str, Union[Sort, int]]] = None, @@ -581,9 +544,7 @@ def __init__( if database is None: d_uri = uri_parser.parse_uri(uri) if d_uri["database"] is None: - raise ConfigurationError( - "If database name is not supplied, a database must be set in the uri" - ) + raise ConfigurationError("If database name is not supplied, a database must be set in the uri") self.database = d_uri["database"] else: self.database = database @@ -627,7 +588,7 @@ def __init__(self, collection_name: str = "memory_db", **kwargs): self.default_sort = None self._coll = None self.kwargs = kwargs - super(MongoStore, self).__init__(**kwargs) # noqa + super(MongoStore, self).__init__(**kwargs) def connect(self, force_reset: bool = False): """ @@ -683,9 +644,7 @@ def groupby( properties = list(properties.keys()) data = [ - doc - for doc in self.query(properties=keys + properties, criteria=criteria) - if all(has(doc, k) for k in keys) + doc for doc in self.query(properties=keys + properties, criteria=criteria) if all(has(doc, k) for k in keys) ] def grouping_keys(doc): @@ -763,9 +722,7 @@ def __init__( self.kwargs = kwargs if not self.read_only and len(paths) > 1: - raise RuntimeError( - "Cannot instantiate file-writable JSONStore with multiple JSON files." - ) + raise RuntimeError("Cannot instantiate file-writable JSONStore with multiple JSON files.") self.default_sort = None self.serialization_option = serialization_option @@ -858,7 +815,7 @@ def update_json_file(self): Updates the json file when a write-like operation is performed. """ with zopen(self.paths[0], "w") as f: - data = [d for d in self.query()] + data = list(self.query()) for d in data: d.pop("_id") bytesdata = orjson.dumps( @@ -915,7 +872,7 @@ class MontyStore(MemoryStore): def __init__( self, collection_name, - database_path: str = None, + database_path: Optional[str] = None, database_name: str = "db", storage: Literal["sqlite", "flatfile", "lightning"] = "sqlite", storage_kwargs: Optional[dict] = None, @@ -931,8 +888,8 @@ def __init__( directory will be used. database_name: The database name. storage: The storage type. Options include "sqlite", "lightning", "flatfile". Note that - although MontyDB supports in memory storage, this capability is disabled in maggma to avoid unintended behavior, since multiple - in-memory MontyStore would actually point to the same data. + although MontyDB supports in memory storage, this capability is disabled in maggma to avoid unintended + behavior, since multiple in-memory MontyStore would actually point to the same data. storage_kwargs: Keyword arguments passed to ``montydb.set_storage``. client_kwargs: Keyword arguments passed to the ``montydb.MontyClient`` constructor. @@ -954,7 +911,7 @@ def __init__( "mongo_version": "4.0", } self.client_kwargs = client_kwargs or {} - super(MongoStore, self).__init__(**kwargs) # noqa + super(MongoStore, self).__init__(**kwargs) def connect(self, force_reset: bool = False): """ @@ -973,9 +930,7 @@ def connect(self, force_reset: bool = False): @property def name(self) -> str: """Return a string representing this data source.""" - return ( - f"monty://{self.database_path}/{self.database_name}/{self.collection_name}" - ) + return f"monty://{self.database_path}/{self.database_name}/{self.collection_name}" def count( self, @@ -993,12 +948,7 @@ def count( criteria = criteria if criteria else {} hint_list = ( - [ - (k, Sort(v).value) if isinstance(v, int) else (k, v.value) - for k, v in hint.items() - ] - if hint - else None + [(k, Sort(v).value) if isinstance(v, int) else (k, v.value) for k, v in hint.items()] if hint else None ) if hint_list is not None: # pragma: no cover @@ -1030,15 +980,11 @@ def update(self, docs: Union[List[Dict], Dict], key: Union[List, str, None] = No if not validates: if self.validator.strict: raise ValueError(self.validator.validation_errors(d)) - else: - self.logger.error(self.validator.validation_errors(d)) + self.logger.error(self.validator.validation_errors(d)) if validates: key = key or self.key - if isinstance(key, list): - search_doc = {k: d[k] for k in key} - else: - search_doc = {key: d[key]} + search_doc = {k: d[k] for k in key} if isinstance(key, list) else {key: d[key]} self._collection.replace_one(search_doc, d, upsert=True) diff --git a/src/maggma/stores/shared_stores.py b/src/maggma/stores/shared_stores.py index af9a48956..54d72a04e 100644 --- a/src/maggma/stores/shared_stores.py +++ b/src/maggma/stores/shared_stores.py @@ -49,6 +49,7 @@ def __init__(self, store, multistore): def __getattr__(self, name: str) -> Any: if name not in dir(self): return self.multistore._proxy_attribute(name, self.store) + return None def __setattr__(self, name: str, value: Any): if name not in ["store", "multistore"]: @@ -122,12 +123,7 @@ def query( limit=limit, ) - def update( - self, - docs: Union[List[Dict], Dict], - key: Union[List, str, None] = None, - **kwargs - ): + def update(self, docs: Union[List[Dict], Dict], key: Union[List, str, None] = None, **kwargs): """ Update documents into the Store @@ -142,7 +138,7 @@ def update( def ensure_index(self, key: str, unique: bool = False, **kwargs) -> bool: """ - Tries to create an index and return true if it suceeded + Tries to create an index and return true if it succeeded Args: key: single key to index @@ -151,9 +147,7 @@ def ensure_index(self, key: str, unique: bool = False, **kwargs) -> bool: Returns: bool indicating if the index exists/was created """ - return self.multistore.ensure_index( - self.store, key=key, unique=unique, **kwargs - ) + return self.multistore.ensure_index(self.store, key=key, unique=unique, **kwargs) def groupby( self, @@ -182,14 +176,7 @@ def groupby( generator returning tuples of (dict, list of docs) """ return self.multistore.groupby( - self.store, - keys=keys, - criteria=criteria, - properties=properties, - sort=sort, - skip=skip, - limit=limit, - **kwargs + self.store, keys=keys, criteria=criteria, properties=properties, sort=sort, skip=skip, limit=limit, **kwargs ) def remove_docs(self, criteria: Dict, **kwargs): @@ -217,17 +204,9 @@ def query_one( sort: Dictionary of sort order for fields. Keys are field names and values are 1 for ascending or -1 for descending. """ - return self.multistore.query_one( - self.store, criteria=criteria, properties=properties, sort=sort, **kwargs - ) + return self.multistore.query_one(self.store, criteria=criteria, properties=properties, sort=sort, **kwargs) - def distinct( - self, - field: str, - criteria: Optional[Dict] = None, - all_exist: bool = False, - **kwargs - ) -> List: + def distinct(self, field: str, criteria: Optional[Dict] = None, all_exist: bool = False, **kwargs) -> List: """ Get all distinct values for a field @@ -235,9 +214,7 @@ def distinct( field: the field(s) to get distinct values for criteria: PyMongo filter for documents to search in """ - return self.multistore.distinct( - self.store, field=field, criteria=criteria, all_exist=all_exist, **kwargs - ) + return self.multistore.distinct(self.store, field=field, criteria=criteria, all_exist=all_exist, **kwargs) class MultiStore: @@ -331,9 +308,9 @@ def add_store(self, store: Store): self._stores.append(MontyDecoder().process_decoded(store.as_dict())) self._stores[-1].connect() return True - else: - # Store already exists, we don't need to add it - return True + + # Store already exists, we don't need to add it + return True def ensure_store(self, store: Store) -> bool: """ @@ -446,22 +423,11 @@ def query( # We must return a list, since a generator is not serializable return list( self._stores[store_id].query( - criteria=criteria, - properties=properties, - sort=sort, - skip=skip, - limit=limit, - **kwargs + criteria=criteria, properties=properties, sort=sort, skip=skip, limit=limit, **kwargs ) ) - def update( - self, - store: Store, - docs: Union[List[Dict], Dict], - key: Union[List, str, None] = None, - **kwargs - ): + def update(self, store: Store, docs: Union[List[Dict], Dict], key: Union[List, str, None] = None, **kwargs): """ Update documents into the Store @@ -475,11 +441,9 @@ def update( store_id = self.get_store_index(store) return self._stores[store_id].update(docs=docs, key=key, **kwargs) - def ensure_index( - self, store: Store, key: str, unique: bool = False, **kwargs - ) -> bool: + def ensure_index(self, store: Store, key: str, unique: bool = False, **kwargs) -> bool: """ - Tries to create an index and return true if it suceeded + Tries to create an index and return true if it succeeded Args: key: single key to index @@ -520,13 +484,7 @@ def groupby( """ store_id = self.get_store_index(store) return self._stores[store_id].groupby( - keys=keys, - criteria=criteria, - properties=properties, - sort=sort, - skip=skip, - limit=limit, - **kwargs + keys=keys, criteria=criteria, properties=properties, sort=sort, skip=skip, limit=limit, **kwargs ) def remove_docs(self, store: Store, criteria: Dict, **kwargs): @@ -558,19 +516,12 @@ def query_one( """ store_id = self.get_store_index(store) return next( - self._stores[store_id].query( - criteria=criteria, properties=properties, sort=sort, **kwargs - ), + self._stores[store_id].query(criteria=criteria, properties=properties, sort=sort, **kwargs), None, ) def distinct( - self, - store: Store, - field: str, - criteria: Optional[Dict] = None, - all_exist: bool = False, - **kwargs + self, store: Store, field: str, criteria: Optional[Dict] = None, all_exist: bool = False, **kwargs ) -> List: """ Get all distinct values for a field @@ -580,9 +531,7 @@ def distinct( criteria: PyMongo filter for documents to search in """ store_id = self.get_store_index(store) - return self._stores[store_id].distinct( - field=field, criteria=criteria, all_exist=all_exist, **kwargs - ) + return self._stores[store_id].distinct(field=field, criteria=criteria, all_exist=all_exist, **kwargs) def set_store_attribute(self, store: Store, name: str, value: Any): """ @@ -628,8 +577,7 @@ def _proxy_attribute(self, name: str, store) -> Union[Any, Callable]: maybe_fn = getattr(self._stores[store_id], name) if callable(maybe_fn): return partial(self.call_attr, name=name, store=store) - else: - return maybe_fn + return maybe_fn class MultiStoreManager(BaseManager): diff --git a/src/maggma/utils.py b/src/maggma/utils.py index 6552bd1cc..38ac5077f 100644 --- a/src/maggma/utils.py +++ b/src/maggma/utils.py @@ -1,4 +1,3 @@ -# coding: utf-8 """ Utilities to help with maggma functions """ @@ -80,8 +79,9 @@ def to_isoformat_ceil_ms(dt: Union[datetime, str]) -> str: """Helper to account for Mongo storing datetimes with only ms precision.""" if isinstance(dt, datetime): return (dt + timedelta(milliseconds=1)).isoformat(timespec="milliseconds") - elif isinstance(dt, str): + if isinstance(dt, str): return dt + return None def to_dt(s: Union[datetime, str]) -> datetime: @@ -89,8 +89,9 @@ def to_dt(s: Union[datetime, str]) -> datetime: if isinstance(s, str): return parser.parse(s) - elif isinstance(s, datetime): + if isinstance(s, datetime): return s + return None # This lu_key prioritizes not duplicating potentially expensive item @@ -106,7 +107,7 @@ def recursive_update(d: Dict, u: Dict): Args: d (dict): dict to update - u (dict): updates to propogate + u (dict): updates to propagate """ for k, v in u.items(): @@ -214,8 +215,7 @@ def dynamic_import(abs_module_path: str, class_name: Optional[str] = None): abs_module_path = ".".join(abs_module_path.split(".")[:-1]) module_object = import_module(abs_module_path) - target_class = getattr(module_object, class_name) - return target_class + return getattr(module_object, class_name) class ReportingHandler(logging.Handler): diff --git a/src/maggma/validators.py b/src/maggma/validators.py index 3847a8536..a7522ea4f 100644 --- a/src/maggma/validators.py +++ b/src/maggma/validators.py @@ -1,4 +1,3 @@ -# coding: utf-8 """ Validator class for document-level validation on Stores. Attach an instance of a Validator subclass to a Store .schema variable to enable validation on @@ -74,8 +73,7 @@ def is_valid(self, doc: Dict) -> bool: except ValidationError: if self.strict: raise - else: - return False + return False def validation_errors(self, doc: Dict) -> List[str]: """ @@ -92,12 +90,7 @@ def validation_errors(self, doc: Dict) -> List[str]: return [] validator = validator_for(self.schema)(self.schema) - errors = [ - "{}: {}".format(".".join(error.absolute_path), error.message) - for error in validator.iter_errors(doc) - ] - - return errors + return ["{}: {}".format(".".join(error.absolute_path), error.message) for error in validator.iter_errors(doc)] def msonable_schema(cls): diff --git a/tests/api/test_aggregation_resource.py b/tests/api/test_aggregation_resource.py index 5589ac6df..8963393bc 100644 --- a/tests/api/test_aggregation_resource.py +++ b/tests/api/test_aggregation_resource.py @@ -3,12 +3,11 @@ import pytest from fastapi import FastAPI -from pydantic import BaseModel, Field -from starlette.testclient import TestClient - from maggma.api.query_operator.core import QueryOperator from maggma.api.resource import AggregationResource from maggma.stores import MemoryStore +from pydantic import BaseModel, Field +from starlette.testclient import TestClient class Owner(BaseModel): @@ -28,7 +27,7 @@ class Owner(BaseModel): total_owners = len(owners) -@pytest.fixture +@pytest.fixture() def owner_store(): store = MemoryStore("owners", key="name") store.connect() @@ -36,7 +35,7 @@ def owner_store(): return store -@pytest.fixture +@pytest.fixture() def pipeline_query_op(): class PipelineQuery(QueryOperator): def query(self): @@ -50,16 +49,12 @@ def query(self): def test_init(owner_store, pipeline_query_op): - resource = AggregationResource( - store=owner_store, pipeline_query_operator=pipeline_query_op, model=Owner - ) + resource = AggregationResource(store=owner_store, pipeline_query_operator=pipeline_query_op, model=Owner) assert len(resource.router.routes) == 2 def test_msonable(owner_store, pipeline_query_op): - owner_resource = AggregationResource( - store=owner_store, pipeline_query_operator=pipeline_query_op, model=Owner - ) + owner_resource = AggregationResource(store=owner_store, pipeline_query_operator=pipeline_query_op, model=Owner) endpoint_dict = owner_resource.as_dict() for k in ["@class", "@module", "store", "model"]: @@ -70,9 +65,7 @@ def test_msonable(owner_store, pipeline_query_op): def test_aggregation_search(owner_store, pipeline_query_op): - endpoint = AggregationResource( - owner_store, pipeline_query_operator=pipeline_query_op, model=Owner - ) + endpoint = AggregationResource(owner_store, pipeline_query_operator=pipeline_query_op, model=Owner) app = FastAPI() app.include_router(endpoint.router) diff --git a/tests/api/test_api.py b/tests/api/test_api.py index de93a8556..072471763 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -5,19 +5,13 @@ import pytest from fastapi.encoders import jsonable_encoder -from pydantic import BaseModel, Field -from requests import Response -from starlette.testclient import TestClient - from maggma.api.API import API -from maggma.api.query_operator import ( - NumericQuery, - PaginationQuery, - SparseFieldsQuery, - StringQueryOperator, -) +from maggma.api.query_operator import NumericQuery, PaginationQuery, SparseFieldsQuery, StringQueryOperator from maggma.api.resource import ReadOnlyResource from maggma.stores import MemoryStore +from pydantic import BaseModel, Field +from requests import Response +from starlette.testclient import TestClient class PetType(str, Enum): @@ -37,10 +31,7 @@ class Pet(BaseModel): owner_name: str = Field(..., title="Owner's name") -owners = [ - Owner(name=f"Person{i}", age=randint(10, 100), weight=randint(100, 200)) - for i in list(range(10)) -] +owners = [Owner(name=f"Person{i}", age=randint(10, 100), weight=randint(100, 200)) for i in list(range(10))] pets = [ @@ -53,7 +44,7 @@ class Pet(BaseModel): ] -@pytest.fixture +@pytest.fixture() def owner_store(): store = MemoryStore("owners", key="name") store.connect() @@ -61,7 +52,7 @@ def owner_store(): return store -@pytest.fixture +@pytest.fixture() def pet_store(): store = MemoryStore("pets", key="name") store.connect() @@ -92,7 +83,7 @@ def search_helper(payload, base: str = "/?", debug=True) -> Tuple[Response, Any] debug: True = print out the url, false don't print anything Returns: - request.Response object that contains the response of the correspoding payload + request.Response object that contains the response of the corresponding payload """ owner_store = MemoryStore("owners", key="name") owner_store.connect() diff --git a/tests/api/test_post_resource.py b/tests/api/test_post_resource.py index 0825d4097..ee97c809e 100644 --- a/tests/api/test_post_resource.py +++ b/tests/api/test_post_resource.py @@ -3,11 +3,10 @@ import pytest from fastapi import FastAPI -from pydantic import BaseModel, Field -from starlette.testclient import TestClient - from maggma.api.resource import PostOnlyResource from maggma.stores import MemoryStore +from pydantic import BaseModel, Field +from starlette.testclient import TestClient class Owner(BaseModel): @@ -27,7 +26,7 @@ class Owner(BaseModel): total_owners = len(owners) -@pytest.fixture +@pytest.fixture() def owner_store(): store = MemoryStore("owners", key="name") store.connect() @@ -61,7 +60,7 @@ def test_post_to_search(owner_store): assert client.post("/").status_code == 200 -@pytest.mark.xfail +@pytest.mark.xfail() def test_problem_query_params(owner_store): endpoint = PostOnlyResource(owner_store, Owner) app = FastAPI() diff --git a/tests/api/test_query_operators.py b/tests/api/test_query_operators.py index 76039d07f..b04abe0a4 100644 --- a/tests/api/test_query_operators.py +++ b/tests/api/test_query_operators.py @@ -3,18 +3,12 @@ import pytest from fastapi import HTTPException +from maggma.api.query_operator import NumericQuery, PaginationQuery, SortQuery, SparseFieldsQuery +from maggma.api.query_operator.submission import SubmissionQuery from monty.serialization import dumpfn, loadfn from monty.tempfile import ScratchDir from pydantic import BaseModel, Field -from maggma.api.query_operator import ( - NumericQuery, - PaginationQuery, - SortQuery, - SparseFieldsQuery, -) -from maggma.api.query_operator.submission import SubmissionQuery - class Owner(BaseModel): name: str = Field(..., title="Owner's name") @@ -74,9 +68,7 @@ def test_sparse_query_serialization(): with ScratchDir("."): dumpfn(op, "temp.json") new_op = loadfn("temp.json") - assert new_op.query() == { - "properties": ["name", "age", "weight", "last_updated"] - } + assert new_op.query() == {"properties": ["name", "age", "weight", "last_updated"]} def test_numeric_query_functionality(): @@ -103,9 +95,7 @@ def test_numeric_query_serialization(): def test_sort_query_functionality(): op = SortQuery() - assert op.query(_sort_fields="volume,-density") == { - "sort": {"volume": 1, "density": -1} - } + assert op.query(_sort_fields="volume,-density") == {"sort": {"volume": 1, "density": -1}} def test_sort_serialization(): @@ -114,12 +104,10 @@ def test_sort_serialization(): with ScratchDir("."): dumpfn(op, "temp.json") new_op = loadfn("temp.json") - assert new_op.query(_sort_fields="volume,-density") == { - "sort": {"volume": 1, "density": -1} - } + assert new_op.query(_sort_fields="volume,-density") == {"sort": {"volume": 1, "density": -1}} -@pytest.fixture +@pytest.fixture() def status_enum(): class StatusEnum(Enum): state_A = "A" diff --git a/tests/api/test_read_resource.py b/tests/api/test_read_resource.py index d51c7984e..a74223be0 100644 --- a/tests/api/test_read_resource.py +++ b/tests/api/test_read_resource.py @@ -5,18 +5,13 @@ import pytest from fastapi import FastAPI -from pydantic import BaseModel, Field -from requests import Response -from starlette.testclient import TestClient - -from maggma.api.query_operator import ( - NumericQuery, - SparseFieldsQuery, - StringQueryOperator, -) +from maggma.api.query_operator import NumericQuery, SparseFieldsQuery, StringQueryOperator from maggma.api.resource import ReadOnlyResource from maggma.api.resource.core import HintScheme from maggma.stores import AliasingStore, MemoryStore +from pydantic import BaseModel, Field +from requests import Response +from starlette.testclient import TestClient class Owner(BaseModel): @@ -36,7 +31,7 @@ class Owner(BaseModel): total_owners = len(owners) -@pytest.fixture +@pytest.fixture() def owner_store(): store = MemoryStore("owners", key="name") store.connect() @@ -51,9 +46,7 @@ def test_init(owner_store): resource = ReadOnlyResource(store=owner_store, model=Owner, enable_get_by_key=False) assert len(resource.router.routes) == 2 - resource = ReadOnlyResource( - store=owner_store, model=Owner, enable_default_search=False - ) + resource = ReadOnlyResource(store=owner_store, model=Owner, enable_default_search=False) assert len(resource.router.routes) == 2 @@ -92,7 +85,7 @@ def test_key_fields(owner_store): assert client.get("/Person1/").json()["data"][0]["name"] == "Person1" -@pytest.mark.xfail +@pytest.mark.xfail() def test_problem_query_params(owner_store): endpoint = ReadOnlyResource(owner_store, Owner) app = FastAPI() @@ -103,7 +96,7 @@ def test_problem_query_params(owner_store): client.get("/?param=test").status_code -@pytest.mark.xfail +@pytest.mark.xfail() def test_problem_hint_scheme(owner_store): class TestHintScheme(HintScheme): def generate_hints(query): @@ -125,7 +118,7 @@ def search_helper(payload, base: str = "/?", debug=True) -> Response: debug: True = print out the url, false don't print anything Returns: - request.Response object that contains the response of the correspoding payload + request.Response object that contains the response of the corresponding payload """ store = MemoryStore("owners", key="name") store.connect() diff --git a/tests/api/test_s3_url_resource.py b/tests/api/test_s3_url_resource.py index a54139268..d0f609c72 100644 --- a/tests/api/test_s3_url_resource.py +++ b/tests/api/test_s3_url_resource.py @@ -1,10 +1,9 @@ import pytest - from maggma.api.resource import S3URLResource from maggma.stores import MemoryStore -@pytest.fixture +@pytest.fixture() def entries_store(): store = MemoryStore("entries", key="url") store.connect() diff --git a/tests/api/test_submission_resource.py b/tests/api/test_submission_resource.py index 65a71f534..aef7b73f4 100644 --- a/tests/api/test_submission_resource.py +++ b/tests/api/test_submission_resource.py @@ -4,13 +4,12 @@ import pytest from fastapi import FastAPI -from pydantic import BaseModel, Field -from starlette.testclient import TestClient - from maggma.api.query_operator import PaginationQuery from maggma.api.query_operator.core import QueryOperator from maggma.api.resource import SubmissionResource from maggma.stores import MemoryStore +from pydantic import BaseModel, Field +from starlette.testclient import TestClient class Owner(BaseModel): @@ -30,7 +29,7 @@ class Owner(BaseModel): total_owners = len(owners) -@pytest.fixture +@pytest.fixture() def owner_store(): store = MemoryStore("owners", key="name") store.connect() @@ -38,7 +37,7 @@ def owner_store(): return store -@pytest.fixture +@pytest.fixture() def post_query_op(): class PostQuery(QueryOperator): def query(self, name): @@ -47,7 +46,7 @@ def query(self, name): return PostQuery() -@pytest.fixture +@pytest.fixture() def patch_query_op(): class PatchQuery(QueryOperator): def query(self, name, update): diff --git a/tests/api/test_utils.py b/tests/api/test_utils.py index 7186e57ac..0a783df2b 100644 --- a/tests/api/test_utils.py +++ b/tests/api/test_utils.py @@ -4,11 +4,10 @@ import pytest from bson import ObjectId +from maggma.api.utils import api_sanitize, serialization_helper from monty.json import MSONable from pydantic import BaseModel, Field -from maggma.api.utils import api_sanitize, serialization_helper - class SomeEnum(Enum): A = 1 @@ -29,7 +28,7 @@ def __init__(self, name, age): class AnotherOwner(BaseModel): - name: str = Field(..., description="Ower name") + name: str = Field(..., description="Owner name") weight_or_pet: Union[float, AnotherPet] = Field(..., title="Owners weight or Pet") @@ -96,7 +95,7 @@ def test_serialization_helper(): assert serialization_helper(oid) == "60b7d47bb671aa7b01a2adf6" -@pytest.mark.xfail +@pytest.mark.xfail() def test_serialization_helper_xfail(): oid = "test" serialization_helper(oid) diff --git a/tests/builders/test_copy_builder.py b/tests/builders/test_copy_builder.py index cef5ebf9c..3018562d9 100644 --- a/tests/builders/test_copy_builder.py +++ b/tests/builders/test_copy_builder.py @@ -1,16 +1,14 @@ -# coding: utf-8 """ Tests for MapBuilder """ from datetime import datetime, timedelta import pytest - from maggma.builders import CopyBuilder from maggma.stores import MemoryStore -@pytest.fixture +@pytest.fixture() def source(): store = MemoryStore("source", key="k", last_updated_field="lu") store.connect() @@ -19,7 +17,7 @@ def source(): return store -@pytest.fixture +@pytest.fixture() def target(): store = MemoryStore("target", key="k", last_updated_field="lu") store.connect() @@ -33,23 +31,21 @@ def now(): return datetime.utcnow() -@pytest.fixture +@pytest.fixture() def old_docs(now): return [{"lu": now, "k": k, "v": "old"} for k in range(20)] -@pytest.fixture +@pytest.fixture() def new_docs(now): toc = now + timedelta(seconds=1) return [{"lu": toc, "k": k, "v": "new"} for k in range(0, 10)] -@pytest.fixture +@pytest.fixture() def some_failed_old_docs(now): docs = [{"lu": now, "k": k, "v": "old", "state": "failed"} for k in range(3)] - docs.extend( - [{"lu": now, "k": k, "v": "old", "state": "failed"} for k in range(18, 20)] - ) + docs.extend([{"lu": now, "k": k, "v": "old", "state": "failed"} for k in range(18, 20)]) return docs diff --git a/tests/builders/test_group_builder.py b/tests/builders/test_group_builder.py index fda7bf59c..a0f65c461 100644 --- a/tests/builders/test_group_builder.py +++ b/tests/builders/test_group_builder.py @@ -1,4 +1,3 @@ -# coding: utf-8 """ Tests for group builder """ @@ -7,7 +6,6 @@ from typing import Dict, List import pytest - from maggma.builders import GroupBuilder from maggma.stores import MemoryStore @@ -17,12 +15,12 @@ def now(): return datetime.utcnow() -@pytest.fixture +@pytest.fixture() def docs(now): return [{"k": i, "a": i % 3, "b": randint(0, i), "lu": now} for i in range(20)] -@pytest.fixture +@pytest.fixture() def source(docs): store = MemoryStore("source", key="k", last_updated_field="lu") store.connect() @@ -32,7 +30,7 @@ def source(docs): return store -@pytest.fixture +@pytest.fixture() def target(): store = MemoryStore("target", key="ks", last_updated_field="lu") store.connect() @@ -57,8 +55,8 @@ def unary_function(self, items: List[Dict]) -> Dict: """ new_doc = {} for k in self.grouping_keys: - new_doc[k] = set(d[k] for d in items) - new_doc["b"] = list(d["b"] for d in items) + new_doc[k] = {d[k] for d in items} + new_doc["b"] = [d["b"] for d in items] return new_doc diff --git a/tests/builders/test_projection_builder.py b/tests/builders/test_projection_builder.py index 493d8c82d..4c44b4dfe 100644 --- a/tests/builders/test_projection_builder.py +++ b/tests/builders/test_projection_builder.py @@ -1,14 +1,12 @@ -# coding: utf-8 """ Tests for Projection_Builder """ import pytest - from maggma.builders.projection_builder import Projection_Builder from maggma.stores import MemoryStore -@pytest.fixture +@pytest.fixture() def source1(): store = MemoryStore("source1", key="k", last_updated_field="lu") store.connect() @@ -18,7 +16,7 @@ def source1(): return store -@pytest.fixture +@pytest.fixture() def source2(): store = MemoryStore("source2", key="k", last_updated_field="lu") store.connect() @@ -28,7 +26,7 @@ def source2(): return store -@pytest.fixture +@pytest.fixture() def target(): store = MemoryStore("target", key="k", last_updated_field="lu") store.connect() @@ -39,7 +37,7 @@ def target(): def test_get_items(source1, source2, target): builder = Projection_Builder(source_stores=[source1, source2], target_store=target) - items = [i for i in builder.get_items()][0] + items = next(iter(builder.get_items())) assert len(items) == 25 @@ -50,14 +48,14 @@ def test_process_item(source1, source2, target): target_store=target, fields_to_project=[[], {}], ) - items = [i for i in builder.get_items()][0] + items = next(iter(builder.get_items())) outputs = builder.process_item(items) assert len(outputs) == 15 - output = [o for o in outputs if o["k"] < 10][0] - assert all([k in ["k", "a", "b", "c", "d"] for k in output.keys()]) - output = [o for o in outputs if o["k"] > 9][0] - assert all([k in ["k", "c", "d"] for k in output.keys()]) - assert all([k not in ["a", "b"] for k in output.keys()]) + output = next(o for o in outputs if o["k"] < 10) + assert all(k in ["k", "a", "b", "c", "d"] for k in output) + output = next(o for o in outputs if o["k"] > 9) + assert all(k in ["k", "c", "d"] for k in output) + assert all(k not in ["a", "b"] for k in output) # test fields_to_project = lists builder = Projection_Builder( @@ -65,11 +63,11 @@ def test_process_item(source1, source2, target): target_store=target, fields_to_project=[["a", "b"], ["d"]], ) - items = [i for i in builder.get_items()][0] + items = next(iter(builder.get_items())) outputs = builder.process_item(items) - output = [o for o in outputs if o["k"] < 10][0] - assert all([k in ["k", "a", "b", "d"] for k in output.keys()]) - assert all([k not in ["c"] for k in output.keys()]) + output = next(o for o in outputs if o["k"] < 10) + assert all(k in ["k", "a", "b", "d"] for k in output) + assert all(k not in ["c"] for k in output) # test fields_to_project = dict and list builder = Projection_Builder( @@ -77,11 +75,11 @@ def test_process_item(source1, source2, target): target_store=target, fields_to_project=[{"newa": "a", "b": "b"}, ["d"]], ) - items = [i for i in builder.get_items()][0] + items = next(iter(builder.get_items())) outputs = builder.process_item(items) - output = [o for o in outputs if o["k"] < 10][0] - assert all([k in ["k", "newa", "b", "d"] for k in output.keys()]) - assert all([k not in ["a", "c"] for k in output.keys()]) + output = next(o for o in outputs if o["k"] < 10) + assert all(k in ["k", "newa", "b", "d"] for k in output) + assert all(k not in ["a", "c"] for k in output) def test_update_targets(source1, source2, target): diff --git a/tests/cli/builder_for_test.py b/tests/cli/builder_for_test.py index 533d4f32d..0c801eedc 100644 --- a/tests/cli/builder_for_test.py +++ b/tests/cli/builder_for_test.py @@ -10,7 +10,7 @@ def __init__(self, total=10): self.total = total def get_items(self): - for i in range(self.total): + for _i in range(self.total): self.get_called += 1 yield self.get_called diff --git a/tests/cli/test_distributed.py b/tests/cli/test_distributed.py index 71286dc0d..0035b8da0 100644 --- a/tests/cli/test_distributed.py +++ b/tests/cli/test_distributed.py @@ -5,10 +5,9 @@ import pytest import zmq.asyncio as zmq -from zmq import REP, REQ - from maggma.cli.distributed import find_port, manager, worker from maggma.core import Builder +from zmq import REP, REQ # TODO: Timeout errors? @@ -74,10 +73,7 @@ def test_manager_and_worker(log_to_stdout): ) manager_thread.start() - worker_threads = [ - threading.Thread(target=worker, args=(SERVER_URL, SERVER_PORT, 1, True)) - for _ in range(3) - ] + worker_threads = [threading.Thread(target=worker, args=(SERVER_URL, SERVER_PORT, 1, True)) for _ in range(3)] for worker_thread in worker_threads: worker_thread.start() @@ -88,7 +84,7 @@ def test_manager_and_worker(log_to_stdout): manager_thread.join() -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_manager_worker_error(log_to_stdout): manager_thread = threading.Thread( target=manager, @@ -100,26 +96,24 @@ async def test_manager_worker_error(log_to_stdout): socket = context.socket(REQ) socket.connect(f"{SERVER_URL}:{SERVER_PORT}") - await socket.send("ERROR_testerror".encode("utf-8")) + await socket.send(b"ERROR_testerror") await asyncio.sleep(1) manager_thread.join() -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_worker_error(): context = zmq.Context() socket = context.socket(REP) socket.bind(f"{SERVER_URL}:{SERVER_PORT}") - worker_task = threading.Thread( - target=worker, args=(SERVER_URL, SERVER_PORT, 1, True) - ) + worker_task = threading.Thread(target=worker, args=(SERVER_URL, SERVER_PORT, 1, True)) worker_task.start() message = await socket.recv() - assert message == "READY_{}".format(HOSTNAME).encode("utf-8") + assert message == f"READY_{HOSTNAME}".encode() dummy_work = { "@module": "tests.cli.test_distributed", @@ -137,20 +131,18 @@ async def test_worker_error(): worker_task.join() -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_worker_exit(): context = zmq.Context() socket = context.socket(REP) socket.bind(f"{SERVER_URL}:{SERVER_PORT}") - worker_task = threading.Thread( - target=worker, args=(SERVER_URL, SERVER_PORT, 1, True) - ) + worker_task = threading.Thread(target=worker, args=(SERVER_URL, SERVER_PORT, 1, True)) worker_task.start() message = await socket.recv() - assert message == "READY_{}".format(HOSTNAME).encode("utf-8") + assert message == f"READY_{HOSTNAME}".encode() await asyncio.sleep(1) await socket.send(b"EXIT") await asyncio.sleep(1) @@ -159,7 +151,7 @@ async def test_worker_exit(): worker_task.join() -@pytest.mark.xfail +@pytest.mark.xfail() def test_no_prechunk(caplog): manager( SERVER_URL, diff --git a/tests/cli/test_init.py b/tests/cli/test_init.py index 351ccc69a..fede31497 100644 --- a/tests/cli/test_init.py +++ b/tests/cli/test_init.py @@ -4,14 +4,13 @@ import pytest from click.testing import CliRunner -from monty.serialization import dumpfn - from maggma.builders import CopyBuilder from maggma.cli import run from maggma.stores import MemoryStore, MongoStore +from monty.serialization import dumpfn -@pytest.fixture +@pytest.fixture() def mongostore(): store = MongoStore("maggma_test", "test") store.connect() @@ -21,7 +20,7 @@ def mongostore(): store._collection.drop() -@pytest.fixture +@pytest.fixture() def reporting_store(): store = MongoStore("maggma_test", "reporting") store.connect() @@ -45,12 +44,7 @@ def test_run_builder(mongostore): memorystore = MemoryStore("temp") builder = CopyBuilder(mongostore, memorystore) - mongostore.update( - [ - {mongostore.key: i, mongostore.last_updated_field: datetime.utcnow()} - for i in range(10) - ] - ) + mongostore.update([{mongostore.key: i, mongostore.last_updated_field: datetime.utcnow()} for i in range(10)]) runner = CliRunner() with runner.isolated_filesystem(): @@ -70,9 +64,7 @@ def test_run_builder(mongostore): assert "CopyBuilder" in result.output assert "MultiProcessor" in result.output - result = runner.invoke( - run, ["-vvv", "-n", "2", "--no_bars", "test_builder.json"] - ) + result = runner.invoke(run, ["-vvv", "-n", "2", "--no_bars", "test_builder.json"]) assert result.exit_code == 0 assert "Get" not in result.output assert "Update" not in result.output @@ -83,12 +75,7 @@ def test_run_builder_chain(mongostore): builder1 = CopyBuilder(mongostore, memorystore) builder2 = CopyBuilder(mongostore, memorystore) - mongostore.update( - [ - {mongostore.key: i, mongostore.last_updated_field: datetime.utcnow()} - for i in range(10) - ] - ) + mongostore.update([{mongostore.key: i, mongostore.last_updated_field: datetime.utcnow()} for i in range(10)]) runner = CliRunner() with runner.isolated_filesystem(): @@ -108,9 +95,7 @@ def test_run_builder_chain(mongostore): assert "CopyBuilder" in result.output assert "MultiProcessor" in result.output - result = runner.invoke( - run, ["-vvv", "-n", "2", "--no_bars", "test_builders.json"] - ) + result = runner.invoke(run, ["-vvv", "-n", "2", "--no_bars", "test_builders.json"]) assert result.exit_code == 0 assert "Get" not in result.output assert "Update" not in result.output @@ -120,20 +105,13 @@ def test_reporting(mongostore, reporting_store): memorystore = MemoryStore("temp") builder = CopyBuilder(mongostore, memorystore) - mongostore.update( - [ - {mongostore.key: i, mongostore.last_updated_field: datetime.utcnow()} - for i in range(10) - ] - ) + mongostore.update([{mongostore.key: i, mongostore.last_updated_field: datetime.utcnow()} for i in range(10)]) runner = CliRunner() with runner.isolated_filesystem(): dumpfn(builder, "test_builder.json") dumpfn(reporting_store, "test_reporting_store.json") - result = runner.invoke( - run, ["-v", "test_builder.json", "-r", "test_reporting_store.json"] - ) + result = runner.invoke(run, ["-v", "test_builder.json", "-r", "test_reporting_store.json"]) assert result.exit_code == 0 report_docs = list(reporting_store.query()) @@ -155,9 +133,7 @@ def test_python_source(): runner = CliRunner() with runner.isolated_filesystem(): - shutil.copy2( - src=Path(__file__).parent / "builder_for_test.py", dst=Path(".").resolve() - ) + shutil.copy2(src=Path(__file__).parent / "builder_for_test.py", dst=Path(".").resolve()) result = runner.invoke(run, ["-v", "-n", "2", "builder_for_test.py"]) assert result.exit_code == 0 @@ -172,9 +148,7 @@ def test_python_notebook_source(): src=Path(__file__).parent / "builder_notebook_for_test.ipynb", dst=Path(".").resolve(), ) - result = runner.invoke( - run, ["-v", "-n", "2", "builder_notebook_for_test.ipynb"] - ) + result = runner.invoke(run, ["-v", "-n", "2", "builder_notebook_for_test.ipynb"]) assert result.exit_code == 0 assert "Ended multiprocessing: DummyBuilder" in result.output @@ -184,12 +158,7 @@ def test_memray_run_builder(mongostore): memorystore = MemoryStore("temp") builder = CopyBuilder(mongostore, memorystore) - mongostore.update( - [ - {mongostore.key: i, mongostore.last_updated_field: datetime.utcnow()} - for i in range(10) - ] - ) + mongostore.update([{mongostore.key: i, mongostore.last_updated_field: datetime.utcnow()} for i in range(10)]) runner = CliRunner() with runner.isolated_filesystem(): @@ -199,23 +168,17 @@ def test_memray_run_builder(mongostore): assert "CopyBuilder" in result.output assert "SerialProcessor" in result.output - result = runner.invoke( - run, ["-vvv", "--no_bars", "--memray", "on", "test_builder.json"] - ) + result = runner.invoke(run, ["-vvv", "--no_bars", "--memray", "on", "test_builder.json"]) assert result.exit_code == 0 assert "Get" not in result.output assert "Update" not in result.output - result = runner.invoke( - run, ["-v", "-n", "2", "--memray", "on", "test_builder.json"] - ) + result = runner.invoke(run, ["-v", "-n", "2", "--memray", "on", "test_builder.json"]) assert result.exit_code == 0 assert "CopyBuilder" in result.output assert "MultiProcessor" in result.output - result = runner.invoke( - run, ["-vvv", "-n", "2", "--no_bars", "--memray", "on", "test_builder.json"] - ) + result = runner.invoke(run, ["-vvv", "-n", "2", "--no_bars", "--memray", "on", "test_builder.json"]) assert result.exit_code == 0 assert "Get" not in result.output assert "Update" not in result.output @@ -225,18 +188,11 @@ def test_memray_user_output_dir(mongostore): memorystore = MemoryStore("temp") builder = CopyBuilder(mongostore, memorystore) - mongostore.update( - [ - {mongostore.key: i, mongostore.last_updated_field: datetime.utcnow()} - for i in range(10) - ] - ) + mongostore.update([{mongostore.key: i, mongostore.last_updated_field: datetime.utcnow()} for i in range(10)]) runner = CliRunner() with runner.isolated_filesystem(): dumpfn(builder, "test_builder.json") - result = runner.invoke( - run, ["--memray", "on", "-md", "memray_output_dir/", "test_builder.json"] - ) + result = runner.invoke(run, ["--memray", "on", "-md", "memray_output_dir/", "test_builder.json"]) assert result.exit_code == 0 assert (Path.cwd() / "memray_output_dir").exists() is True diff --git a/tests/cli/test_multiprocessing.py b/tests/cli/test_multiprocessing.py index d8a5569d9..6621551b8 100644 --- a/tests/cli/test_multiprocessing.py +++ b/tests/cli/test_multiprocessing.py @@ -2,16 +2,10 @@ from concurrent.futures import ThreadPoolExecutor import pytest +from maggma.cli.multiprocessing import AsyncUnorderedMap, BackPressure, grouper, safe_dispatch -from maggma.cli.multiprocessing import ( - AsyncUnorderedMap, - BackPressure, - grouper, - safe_dispatch, -) - -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_grouper(): async def arange(count): for i in range(count): @@ -34,7 +28,7 @@ async def arange(n): yield num -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_backpressure(): iterable = range(10) backpressure = BackPressure(iterable, 2) @@ -55,15 +49,15 @@ async def test_backpressure(): await releaser.__anext__() await releaser.__anext__() - # Ensure stop itteration works - with pytest.raises(StopAsyncIteration): - for i in range(10): + # Ensure stop iteration works + with pytest.raises(StopAsyncIteration): # noqa: PT012 + for _i in range(10): await releaser.__anext__() assert not backpressure.back_pressure.locked() -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_async_map(): executor = ThreadPoolExecutor(1) amap = AsyncUnorderedMap(wait_and_return, arange(3), executor) diff --git a/tests/cli/test_serial.py b/tests/cli/test_serial.py index 8f701ef50..88e9fe33b 100644 --- a/tests/cli/test_serial.py +++ b/tests/cli/test_serial.py @@ -11,7 +11,7 @@ def __init__(self, total=10): self.total = total def get_items(self): - for i in range(self.total): + for _i in range(self.total): self.get_called += 1 yield self.get_called diff --git a/tests/conftest.py b/tests/conftest.py index 5c72bb897..d0fce9457 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,8 +5,8 @@ import pytest -@pytest.fixture -def tmp_dir(): +@pytest.fixture() +def tmp_dir(): # noqa: PT004 """ Create a clean directory and cd into it. @@ -24,36 +24,34 @@ def tmp_dir(): shutil.rmtree(newpath) -@pytest.fixture +@pytest.fixture() def test_dir(): module_dir = Path(__file__).resolve().parent test_dir = module_dir / "test_files" return test_dir.resolve() -@pytest.fixture +@pytest.fixture() def db_json(test_dir): db_dir = test_dir / "settings_files" db_json = db_dir / "db.json" return db_json.resolve() -@pytest.fixture +@pytest.fixture() def lp_file(test_dir): db_dir = test_dir / "settings_files" lp_file = db_dir / "my_launchpad.yaml" return lp_file.resolve() -@pytest.fixture +@pytest.fixture() def log_to_stdout(): # Set Logging root = logging.getLogger() root.setLevel(logging.DEBUG) ch = logging.StreamHandler(sys.stdout) - formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") ch.setFormatter(formatter) root.addHandler(ch) return root @@ -73,9 +71,3 @@ def pytest_itemcollected(item): doc = item.obj.__doc__.strip() if item.obj.__doc__ else "" if doc: item._nodeid = item._nodeid.split("::")[0] + "::" + doc - - -if sys.version_info < (3, 7): - # Ignore API tests on python 3.6 - collect_ignore = ["cli/test_distributed.py"] - collect_ignore_glob = ["api/*"] diff --git a/tests/stores/test_advanced_stores.py b/tests/stores/test_advanced_stores.py index d0e492dac..c0522a4fc 100644 --- a/tests/stores/test_advanced_stores.py +++ b/tests/stores/test_advanced_stores.py @@ -1,4 +1,3 @@ -# coding: utf-8 """ Tests for advanced stores """ @@ -12,25 +11,17 @@ from uuid import uuid4 import pytest +from maggma.core import StoreError +from maggma.stores import AliasingStore, MemoryStore, MongograntStore, MongoStore, SandboxStore, VaultStore +from maggma.stores.advanced_stores import substitute from mongogrant import Client from mongogrant.client import check, seed from mongogrant.config import Config from pymongo import MongoClient from pymongo.collection import Collection -from maggma.core import StoreError -from maggma.stores import ( - AliasingStore, - MemoryStore, - MongograntStore, - MongoStore, - SandboxStore, - VaultStore, -) -from maggma.stores.advanced_stores import substitute - -@pytest.fixture +@pytest.fixture() def mongostore(): store = MongoStore("maggma_test", "test") store.connect() @@ -46,23 +37,16 @@ def mgrant_server(): mdpath = tempfile.mkdtemp() mdport = 27020 if not os.getenv("CONTINUOUS_INTEGRATION"): - basecmd = ( - f"mongod --port {mdport} --dbpath {mdpath} --quiet --logpath {mdlogpath} " - "--bind_ip_all --auth" - ) + basecmd = f"mongod --port {mdport} --dbpath {mdpath} --quiet --logpath {mdlogpath} --bind_ip_all --auth" mongod_process = subprocess.Popen(basecmd, shell=True, start_new_session=True) time.sleep(5) client = MongoClient(port=mdport) - client.admin.command( - "createUser", "mongoadmin", pwd="mongoadminpass", roles=["root"] - ) + client.admin.command("createUser", "mongoadmin", pwd="mongoadminpass", roles=["root"]) client.close() else: pytest.skip("Disabling mongogrant tests on CI for now") dbname = "test_" + uuid4().hex - db = MongoClient(f"mongodb://mongoadmin:mongoadminpass@127.0.0.1:{mdport}/admin")[ - dbname - ] + db = MongoClient(f"mongodb://mongoadmin:mongoadminpass@127.0.0.1:{mdport}/admin")[dbname] db.command("createUser", "reader", pwd="readerpass", roles=["read"]) db.command("createUser", "writer", pwd="writerpass", roles=["readWrite"]) db.client.close() @@ -105,16 +89,14 @@ def mgrant_user(mgrant_server): def connected_user(store): - return store._collection.database.command("connectionStatus")["authInfo"][ - "authenticatedUsers" - ][0]["user"] + return store._collection.database.command("connectionStatus")["authInfo"]["authenticatedUsers"][0]["user"] def test_mgrant_init(): with pytest.raises(StoreError): store = MongograntStore("", "", username="") - with pytest.raises(ValueError): + with pytest.raises(ValueError): # noqa: PT012 store = MongograntStore("", "") store.connect() @@ -122,15 +104,11 @@ def test_mgrant_init(): def test_mgrant_connect(mgrant_server, mgrant_user): config_path, mdport, dbname = mgrant_server assert mgrant_user is not None - store = MongograntStore( - "ro:testhost/testdb", "tasks", mgclient_config_path=config_path - ) + store = MongograntStore("ro:testhost/testdb", "tasks", mgclient_config_path=config_path) store.connect() assert isinstance(store._collection, Collection) assert connected_user(store) == "reader" - store = MongograntStore( - "rw:testhost/testdb", "tasks", mgclient_config_path=config_path - ) + store = MongograntStore("rw:testhost/testdb", "tasks", mgclient_config_path=config_path) store.connect() assert isinstance(store._collection, Collection) assert connected_user(store) == "writer" @@ -147,16 +125,10 @@ def test_mgrant_differences(): def test_mgrant_equal(mgrant_server, mgrant_user): config_path, mdport, dbname = mgrant_server assert mgrant_user is not None - store1 = MongograntStore( - "ro:testhost/testdb", "tasks", mgclient_config_path=config_path - ) + store1 = MongograntStore("ro:testhost/testdb", "tasks", mgclient_config_path=config_path) store1.connect() - store2 = MongograntStore( - "ro:testhost/testdb", "tasks", mgclient_config_path=config_path - ) - store3 = MongograntStore( - "ro:testhost/testdb", "test", mgclient_config_path=config_path - ) + store2 = MongograntStore("ro:testhost/testdb", "tasks", mgclient_config_path=config_path) + store3 = MongograntStore("ro:testhost/testdb", "test", mgclient_config_path=config_path) store2.connect() assert store1 == store2 assert store1 != store3 @@ -179,9 +151,7 @@ def vault_store(): "lease_duration": 2764800, "lease_id": "", } - v = VaultStore("test_coll", "secret/matgen/maggma") - - return v + return VaultStore("test_coll", "secret/matgen/maggma") def test_vault_init(): @@ -222,12 +192,11 @@ def test_vault_missing_env(): vault_store() -@pytest.fixture +@pytest.fixture() def alias_store(): memorystore = MemoryStore("test") memorystore.connect() - alias_store = AliasingStore(memorystore, {"a": "b", "c.d": "e", "f": "g.h"}) - return alias_store + return AliasingStore(memorystore, {"a": "b", "c.d": "e", "f": "g.h"}) def test_alias_count(alias_store): @@ -240,12 +209,10 @@ def test_aliasing_query(alias_store): d = [{"b": 1}, {"e": 2}, {"g": {"h": 3}}] alias_store.store._collection.insert_many(d) - assert "a" in list(alias_store.query(criteria={"a": {"$exists": 1}}))[0] - assert "c" in list(alias_store.query(criteria={"c.d": {"$exists": 1}}))[0] - assert "d" in list(alias_store.query(criteria={"c.d": {"$exists": 1}}))[0].get( - "c", {} - ) - assert "f" in list(alias_store.query(criteria={"f": {"$exists": 1}}))[0] + assert "a" in next(iter(alias_store.query(criteria={"a": {"$exists": 1}}))) + assert "c" in next(iter(alias_store.query(criteria={"c.d": {"$exists": 1}}))) + assert "d" in next(iter(alias_store.query(criteria={"c.d": {"$exists": 1}}))).get("c", {}) + assert "f" in next(iter(alias_store.query(criteria={"f": {"$exists": 1}}))) def test_aliasing_update(alias_store): @@ -256,14 +223,14 @@ def test_aliasing_update(alias_store): {"task_id": "mp-5", "f": 6}, ] ) - assert list(alias_store.query(criteria={"task_id": "mp-3"}))[0]["a"] == 4 - assert list(alias_store.query(criteria={"task_id": "mp-4"}))[0]["c"]["d"] == 5 - assert list(alias_store.query(criteria={"task_id": "mp-5"}))[0]["f"] == 6 + assert next(iter(alias_store.query(criteria={"task_id": "mp-3"})))["a"] == 4 + assert next(iter(alias_store.query(criteria={"task_id": "mp-4"})))["c"]["d"] == 5 + assert next(iter(alias_store.query(criteria={"task_id": "mp-5"})))["f"] == 6 - assert list(alias_store.store.query(criteria={"task_id": "mp-3"}))[0]["b"] == 4 - assert list(alias_store.store.query(criteria={"task_id": "mp-4"}))[0]["e"] == 5 + assert next(iter(alias_store.store.query(criteria={"task_id": "mp-3"})))["b"] == 4 + assert next(iter(alias_store.store.query(criteria={"task_id": "mp-4"})))["e"] == 5 - assert list(alias_store.store.query(criteria={"task_id": "mp-5"}))[0]["g"]["h"] == 6 + assert next(iter(alias_store.store.query(criteria={"task_id": "mp-5"})))["g"]["h"] == 6 def test_aliasing_remove_docs(alias_store): @@ -312,7 +279,7 @@ def test_aliasing_distinct(alias_store): assert alias_store.distinct("f") == [3] -@pytest.fixture +@pytest.fixture() def sandbox_store(): memstore = MemoryStore() store = SandboxStore(memstore, sandbox="test") @@ -354,10 +321,7 @@ def test_sandbox_distinct(sandbox_store): def test_sandbox_update(sandbox_store): sandbox_store.connect() sandbox_store.update([{"e": 6, "d": 4}], key="e") - assert ( - next(sandbox_store.query(criteria={"d": {"$exists": 1}}, properties=["d"]))["d"] - == 4 - ) + assert next(sandbox_store.query(criteria={"d": {"$exists": 1}}, properties=["d"]))["d"] == 4 assert sandbox_store._collection.find_one({"e": 6})["sbxn"] == ["test"] sandbox_store.update([{"e": 7, "sbxn": ["core"]}], key="e") assert set(sandbox_store.query_one(criteria={"e": 7})["sbxn"]) == {"test", "core"} @@ -372,33 +336,27 @@ def test_sandbox_remove_docs(sandbox_store): assert sandbox_store.query_one(criteria={"e": 7}) sandbox_store.remove_docs(criteria={"d": 4}) - assert ( - sandbox_store.query_one(criteria={"d": {"$exists": 1}}, properties=["d"]) - is None - ) + assert sandbox_store.query_one(criteria={"d": {"$exists": 1}}, properties=["d"]) is None assert sandbox_store.query_one(criteria={"e": 7}) -@pytest.fixture +@pytest.fixture() def mgrantstore(mgrant_server, mgrant_user): config_path, mdport, dbname = mgrant_server assert mgrant_user is not None - store = MongograntStore( - "ro:testhost/testdb", "tasks", mgclient_config_path=config_path - ) + store = MongograntStore("ro:testhost/testdb", "tasks", mgclient_config_path=config_path) store.connect() return store -@pytest.fixture +@pytest.fixture() def vaultstore(): os.environ["VAULT_ADDR"] = "https://fake:8200/" os.environ["VAULT_TOKEN"] = "dummy" # Just test that we successfully instantiated - v = vault_store() - return v + return vault_store() def test_eq_mgrant(mgrantstore, mongostore): diff --git a/tests/stores/test_aws.py b/tests/stores/test_aws.py index 7637fbd0e..fb2cc28b0 100644 --- a/tests/stores/test_aws.py +++ b/tests/stores/test_aws.py @@ -4,12 +4,11 @@ import boto3 import pytest from botocore.exceptions import ClientError -from moto import mock_s3 - from maggma.stores import MemoryStore, MongoStore, S3Store +from moto import mock_s3 -@pytest.fixture +@pytest.fixture() def mongostore(): store = MongoStore("maggma_test", "test") store.connect() @@ -17,7 +16,7 @@ def mongostore(): store._collection.drop() -@pytest.fixture +@pytest.fixture() def s3store(): with mock_s3(): conn = boto3.resource("s3", region_name="us-east-1") @@ -49,7 +48,7 @@ def s3store(): yield store -@pytest.fixture +@pytest.fixture() def s3store_w_subdir(): with mock_s3(): conn = boto3.resource("s3", region_name="us-east-1") @@ -62,7 +61,7 @@ def s3store_w_subdir(): yield store -@pytest.fixture +@pytest.fixture() def s3store_multi(): with mock_s3(): conn = boto3.resource("s3", region_name="us-east-1") @@ -104,8 +103,7 @@ def test_multi_update(s3store, s3store_multi): def fake_writing(doc, search_keys): time.sleep(0.20) - search_doc = {k: doc[k] for k in search_keys} - return search_doc + return {k: doc[k] for k in search_keys} s3store.write_doc_to_s3 = fake_writing s3store_multi.write_doc_to_s3 = fake_writing @@ -119,9 +117,7 @@ def fake_writing(doc, search_keys): s3store.update(data, key=["task_id"]) end = time.time() time_single = end - start - assert time_single > time_multi * (s3store_multi.s3_workers - 1) / ( - s3store.s3_workers - ) + assert time_single > time_multi * (s3store_multi.s3_workers - 1) / (s3store.s3_workers) def test_count(s3store): @@ -167,17 +163,11 @@ def test_rebuild_meta_from_index(s3store): def test_rebuild_index(s3store): s3store.update([{"task_id": "mp-2", "data": "asd"}]) - assert ( - s3store.index.query_one({"task_id": "mp-2"})["obj_hash"] - == "a69fe0c2cca3a3384c2b1d2f476972704f179741" - ) + assert s3store.index.query_one({"task_id": "mp-2"})["obj_hash"] == "a69fe0c2cca3a3384c2b1d2f476972704f179741" s3store.index.remove_docs({}) assert s3store.index.query_one({"task_id": "mp-2"}) is None s3store.rebuild_index_from_s3_data() - assert ( - s3store.index.query_one({"task_id": "mp-2"})["obj_hash"] - == "a69fe0c2cca3a3384c2b1d2f476972704f179741" - ) + assert s3store.index.query_one({"task_id": "mp-2"})["obj_hash"] == "a69fe0c2cca3a3384c2b1d2f476972704f179741" def tests_msonable_read_write(s3store): @@ -218,8 +208,8 @@ def test_close(s3store): def test_bad_import(mocker): mocker.patch("maggma.stores.aws.boto3", None) + index = MemoryStore("index") with pytest.raises(RuntimeError): - index = MemoryStore("index") S3Store(index, "bucket1") @@ -279,10 +269,7 @@ def test_remove_subdir(s3store_w_subdir): def test_searchable_fields(s3store): tic = datetime(2018, 4, 12, 16) - data = [ - {"task_id": f"mp-{i}", "a": i, s3store.last_updated_field: tic} - for i in range(4) - ] + data = [{"task_id": f"mp-{i}", "a": i, s3store.last_updated_field: tic} for i in range(4)] s3store.searchable_fields = ["task_id"] s3store.update(data, key="a") @@ -321,10 +308,7 @@ def test_newer_in(s3store): def test_additional_metadata(s3store): tic = datetime(2018, 4, 12, 16) - data = [ - {"task_id": f"mp-{i}", "a": i, s3store.last_updated_field: tic} - for i in range(4) - ] + data = [{"task_id": f"mp-{i}", "a": i, s3store.last_updated_field: tic} for i in range(4)] s3store.update(data, key="a", additional_metadata="task_id") diff --git a/tests/stores/test_azure.py b/tests/stores/test_azure.py index 489df19e2..a3aefc88c 100644 --- a/tests/stores/test_azure.py +++ b/tests/stores/test_azure.py @@ -11,7 +11,6 @@ from datetime import datetime import pytest - from maggma.stores import AzureBlobStore, MemoryStore, MongoStore try: @@ -32,7 +31,7 @@ AZURITE_CONTAINER_NAME = "maggma-test-container" -@pytest.fixture +@pytest.fixture() def mongostore(): store = MongoStore("maggma_test", "test") store.connect() @@ -47,9 +46,7 @@ def azurite_container(container_name=AZURITE_CONTAINER_NAME, create_container=Tr if azure_blob is None: pytest.skip("azure-storage-blob is required to test AzureBlobStore") - blob_service_client = BlobServiceClient.from_connection_string( - AZURITE_CONNECTION_STRING - ) + blob_service_client = BlobServiceClient.from_connection_string(AZURITE_CONNECTION_STRING) container_client = blob_service_client.get_container_client(container_name) if container_client.exists(): @@ -65,7 +62,7 @@ def azurite_container(container_name=AZURITE_CONTAINER_NAME, create_container=Tr container_client.delete_container() -@pytest.fixture +@pytest.fixture() def blobstore(): with azurite_container(): index = MemoryStore("index", key="task_id") @@ -79,7 +76,7 @@ def blobstore(): yield store -@pytest.fixture +@pytest.fixture() def blobstore_two_docs(blobstore): blobstore.update( [ @@ -100,10 +97,10 @@ def blobstore_two_docs(blobstore): ] ) - yield blobstore + return blobstore -@pytest.fixture +@pytest.fixture() def blobstore_w_subdir(): with azurite_container(): index = MemoryStore("index") @@ -118,11 +115,11 @@ def blobstore_w_subdir(): yield store -@pytest.fixture +@pytest.fixture() def blobstore_multi(blobstore): blobstore.workers = 4 - yield blobstore + return blobstore def test_keys(): @@ -158,8 +155,7 @@ def test_multi_update(blobstore_two_docs, blobstore_multi): def fake_writing(doc, search_keys): time.sleep(0.20) - search_doc = {k: doc[k] for k in search_keys} - return search_doc + return {k: doc[k] for k in search_keys} blobstore_two_docs.write_doc_to_blob = fake_writing blobstore_multi.write_doc_to_blob = fake_writing @@ -173,9 +169,7 @@ def fake_writing(doc, search_keys): blobstore_two_docs.update(data, key=["task_id"]) end = time.time() time_single = end - start - assert time_single > time_multi * (blobstore_multi.workers - 1) / ( - blobstore_two_docs.workers - ) + assert time_single > time_multi * (blobstore_multi.workers - 1) / (blobstore_two_docs.workers) def test_count(blobstore_two_docs): @@ -278,8 +272,8 @@ def test_close(blobstore_two_docs): def test_bad_import(mocker): mocker.patch("maggma.stores.azure.azure_blob", None) + index = MemoryStore("index") with pytest.raises(RuntimeError): - index = MemoryStore("index") AzureBlobStore(index, "bucket1") @@ -321,10 +315,7 @@ def test_remove_subdir(blobstore_w_subdir): def test_searchable_fields(blobstore_two_docs): tic = datetime(2018, 4, 12, 16) - data = [ - {"task_id": f"mp-{i}", "a": i, blobstore_two_docs.last_updated_field: tic} - for i in range(4) - ] + data = [{"task_id": f"mp-{i}", "a": i, blobstore_two_docs.last_updated_field: tic} for i in range(4)] blobstore_two_docs.searchable_fields = ["task_id"] blobstore_two_docs.update(data, key="a") @@ -375,10 +366,7 @@ def test_newer_in(blobstore): def test_additional_metadata(blobstore_two_docs): tic = datetime(2018, 4, 12, 16) - data = [ - {"task_id": f"mp-{i}", "a": i, blobstore_two_docs.last_updated_field: tic} - for i in range(4) - ] + data = [{"task_id": f"mp-{i}", "a": i, blobstore_two_docs.last_updated_field: tic} for i in range(4)] blobstore_two_docs.update(data, key="a", additional_metadata="task_id") @@ -429,7 +417,5 @@ def test_no_login(): azure_client_info={}, ) - with pytest.raises( - RuntimeError, match=r".*Could not instantiate BlobServiceClient.*" - ): + with pytest.raises(RuntimeError, match=r".*Could not instantiate BlobServiceClient.*"): store.connect() diff --git a/tests/stores/test_compound_stores.py b/tests/stores/test_compound_stores.py index 9151e0de1..a31fe9d20 100644 --- a/tests/stores/test_compound_stores.py +++ b/tests/stores/test_compound_stores.py @@ -2,12 +2,11 @@ from itertools import chain import pytest -from pydash import get - from maggma.stores import ConcatStore, JointStore, MemoryStore, MongoStore +from pydash import get -@pytest.fixture +@pytest.fixture() def mongostore(): store = MongoStore("magmma_test", "test") store.connect() @@ -164,7 +163,7 @@ def test_joint_remove_docs(jointstore): jointstore.remove_docs({}) -@pytest.fixture +@pytest.fixture() def concat_store(): mem_stores = [MemoryStore(str(i)) for i in range(4)] store = ConcatStore(mem_stores) @@ -174,10 +173,7 @@ def concat_store(): props = {i: str(i) for i in range(10)} for mem_store in mem_stores: - docs = [ - {"task_id": i, "prop": props[i - index], "index": index} - for i in range(index, index + 10) - ] + docs = [{"task_id": i, "prop": props[i - index], "index": index} for i in range(index, index + 10)] index = index + 10 mem_store.update(docs) return store @@ -185,11 +181,7 @@ def concat_store(): def test_concat_store_distinct(concat_store): docs = list(concat_store.distinct("task_id")) - actual_docs = list( - chain.from_iterable( - [store.distinct("task_id") for store in concat_store.stores] - ) - ) + actual_docs = list(chain.from_iterable([store.distinct("task_id") for store in concat_store.stores])) assert len(docs) == len(actual_docs) assert set(docs) == set(actual_docs) diff --git a/tests/stores/test_file_store.py b/tests/stores/test_file_store.py index 1ac29bfab..4f1df1dfd 100644 --- a/tests/stores/test_file_store.py +++ b/tests/stores/test_file_store.py @@ -25,12 +25,11 @@ from pathlib import Path import pytest - from maggma.core import StoreError from maggma.stores.file_store import FileStore -@pytest.fixture +@pytest.fixture() def test_dir(tmp_path): module_dir = Path(__file__).resolve().parent test_dir = module_dir / ".." / "test_files" / "file_store_test" @@ -58,9 +57,7 @@ def test_record_from_file(test_dir): assert d["size"] == pytest.approx(90, abs=1) assert isinstance(d["hash"], str) assert d["file_id"] == file_id - assert d["last_updated"] == datetime.fromtimestamp( - f.stat().st_mtime, tz=timezone.utc - ) + assert d["last_updated"] == datetime.fromtimestamp(f.stat().st_mtime, tz=timezone.utc) def test_newer_in_on_local_update(test_dir): @@ -140,9 +137,7 @@ def test_orphaned_metadata(test_dir): # this will result in orphaned metadata # with include_orphans=True, this should be returned in queries fs = FileStore(test_dir, read_only=True, max_depth=1, include_orphans=True) - with pytest.warns( - UserWarning, match="Orphaned metadata was found in FileStore.json" - ): + with pytest.warns(UserWarning, match="Orphaned metadata was found in FileStore.json"): fs.connect() assert len(list(fs.query())) == 6 assert len(list(fs.query({"tags": {"$exists": True}}))) == 6 @@ -156,13 +151,9 @@ def test_orphaned_metadata(test_dir): # this will result in orphaned metadata # with include_orphans=False (default), that metadata should be # excluded from query results - Path(test_dir / "calculation1" / "input.in").rename( - test_dir / "calculation1" / "input_renamed.in" - ) + Path(test_dir / "calculation1" / "input.in").rename(test_dir / "calculation1" / "input_renamed.in") fs = FileStore(test_dir, read_only=True, include_orphans=False) - with pytest.warns( - UserWarning, match="Orphaned metadata was found in FileStore.json" - ): + with pytest.warns(UserWarning, match="Orphaned metadata was found in FileStore.json"): fs.connect() assert len(list(fs.query())) == 6 assert len(list(fs.query({"tags": {"$exists": True}}))) == 5 @@ -224,8 +215,8 @@ def test_read_only(test_dir): fs = FileStore(test_dir, read_only=True, json_name="random.json") fs.connect() assert not Path(test_dir / "random.json").exists() + file_id = fs.query_one()["file_id"] with pytest.raises(StoreError, match="read-only"): - file_id = fs.query_one()["file_id"] fs.update({"file_id": file_id, "tags": "something"}) with pytest.raises(StoreError, match="read-only"): fs.remove_docs({}) @@ -286,7 +277,7 @@ def test_remove(test_dir): assert not Path.exists(test_dir / "calculation1" / "input.in") assert not Path.exists(test_dir / "calculation2" / "input.in") fs.remove_docs({}, confirm=True) - assert not any([Path(p).exists() for p in paths]) + assert not any(Path(p).exists() for p in paths) def test_metadata(test_dir): @@ -300,7 +291,7 @@ def test_metadata(test_dir): fs = FileStore(test_dir, read_only=False, last_updated_field="last_change") fs.connect() query = {"name": "input.in", "parent": "calculation1"} - key = list(fs.query(query))[0][fs.key] + key = next(iter(fs.query(query)))[fs.key] fs.add_metadata( { "metadata": {"experiment date": "2022-01-18"}, @@ -310,7 +301,7 @@ def test_metadata(test_dir): ) # make sure metadata has been added to the item without removing other contents - item_from_store = list(fs.query({"file_id": key}))[0] + item_from_store = next(iter(fs.query({"file_id": key}))) assert item_from_store.get("name", False) assert item_from_store.get("metadata", False) fs.close() @@ -319,7 +310,7 @@ def test_metadata(test_dir): # and it should not contain any of the protected keys data = fs.metadata_store.read_json_file(fs.path / fs.json_name) assert len(data) == 1 - item_from_file = [d for d in data if d["file_id"] == key][0] + item_from_file = next(d for d in data if d["file_id"] == key) assert item_from_file["metadata"] == {"experiment date": "2022-01-18"} assert not item_from_file.get("name") assert not item_from_file.get("path") @@ -330,11 +321,11 @@ def test_metadata(test_dir): fs2 = FileStore(test_dir, read_only=True) fs2.connect() data = fs2.metadata_store.read_json_file(fs2.path / fs2.json_name) - item_from_file = [d for d in data if d["file_id"] == key][0] + item_from_file = next(d for d in data if d["file_id"] == key) assert item_from_file["metadata"] == {"experiment date": "2022-01-18"} # make sure reconnected store properly merges in the metadata - item_from_store = [d for d in fs2.query({"file_id": key})][0] + item_from_store = next(iter(fs2.query({"file_id": key}))) assert item_from_store["name"] == "input.in" assert item_from_store["parent"] == "calculation1" assert item_from_store.get("metadata") == {"experiment date": "2022-01-18"} @@ -344,9 +335,9 @@ def test_metadata(test_dir): fs3 = FileStore(test_dir, read_only=False) fs3.connect() data = fs3.metadata_store.read_json_file(fs3.path / fs3.json_name) - item_from_file = [d for d in data if d["file_id"] == key][0] + item_from_file = next(d for d in data if d["file_id"] == key) assert item_from_file["metadata"] == {"experiment date": "2022-01-18"} - item_from_store = [d for d in fs3.query({"file_id": key})][0] + item_from_store = next(iter(fs3.query({"file_id": key}))) assert item_from_store["name"] == "input.in" assert item_from_store["parent"] == "calculation1" assert item_from_store.get("metadata") == {"experiment date": "2022-01-18"} diff --git a/tests/stores/test_gridfs.py b/tests/stores/test_gridfs.py index e82ce1b75..ed4b2b98e 100644 --- a/tests/stores/test_gridfs.py +++ b/tests/stores/test_gridfs.py @@ -5,14 +5,13 @@ import numpy as np import numpy.testing.utils as nptu import pytest -from pymongo.errors import ConfigurationError - from maggma.core import StoreError from maggma.stores import GridFSStore, MongoStore from maggma.stores.gridfs import GridFSURIStore, files_collection_fields +from pymongo.errors import ConfigurationError -@pytest.fixture +@pytest.fixture() def mongostore(): store = MongoStore("maggma_test", "test") store.connect() @@ -20,7 +19,7 @@ def mongostore(): store._collection.drop() -@pytest.fixture +@pytest.fixture() def gridfsstore(): store = GridFSStore("maggma_test", "test", key="task_id") store.connect() @@ -34,47 +33,30 @@ def test_update(gridfsstore): data2 = np.random.rand(256) tic = datetime(2018, 4, 12, 16) # Test metadata storage - gridfsstore.update( - [{"task_id": "mp-1", "data": data1, gridfsstore.last_updated_field: tic}] - ) - assert ( - gridfsstore._files_collection.find_one({"metadata.task_id": "mp-1"}) is not None - ) + gridfsstore.update([{"task_id": "mp-1", "data": data1, gridfsstore.last_updated_field: tic}]) + assert gridfsstore._files_collection.find_one({"metadata.task_id": "mp-1"}) is not None # Test storing data - gridfsstore.update( - [{"task_id": "mp-1", "data": data2, gridfsstore.last_updated_field: tic}] - ) + gridfsstore.update([{"task_id": "mp-1", "data": data2, gridfsstore.last_updated_field: tic}]) assert len(list(gridfsstore.query({"task_id": "mp-1"}))) == 1 assert "task_id" in gridfsstore.query_one({"task_id": "mp-1"}) - nptu.assert_almost_equal( - gridfsstore.query_one({"task_id": "mp-1"})["data"], data2, 7 - ) + nptu.assert_almost_equal(gridfsstore.query_one({"task_id": "mp-1"})["data"], data2, 7) # Test storing compressed data gridfsstore = GridFSStore("maggma_test", "test", key="task_id", compression=True) gridfsstore.connect() gridfsstore.update([{"task_id": "mp-1", "data": data1}]) - assert ( - gridfsstore._files_collection.find_one({"metadata.compression": "zlib"}) - is not None - ) + assert gridfsstore._files_collection.find_one({"metadata.compression": "zlib"}) is not None - nptu.assert_almost_equal( - gridfsstore.query_one({"task_id": "mp-1"})["data"], data1, 7 - ) + nptu.assert_almost_equal(gridfsstore.query_one({"task_id": "mp-1"})["data"], data1, 7) def test_remove(gridfsstore): data1 = np.random.rand(256) data2 = np.random.rand(256) tic = datetime(2018, 4, 12, 16) - gridfsstore.update( - [{"task_id": "mp-1", "data": data1, gridfsstore.last_updated_field: tic}] - ) - gridfsstore.update( - [{"task_id": "mp-2", "data": data2, gridfsstore.last_updated_field: tic}] - ) + gridfsstore.update([{"task_id": "mp-1", "data": data1, gridfsstore.last_updated_field: tic}]) + gridfsstore.update([{"task_id": "mp-2", "data": data2, gridfsstore.last_updated_field: tic}]) assert gridfsstore.query_one(criteria={"task_id": "mp-1"}) assert gridfsstore.query_one(criteria={"task_id": "mp-2"}) @@ -87,15 +69,11 @@ def test_count(gridfsstore): data1 = np.random.rand(256) data2 = np.random.rand(256) tic = datetime(2018, 4, 12, 16) - gridfsstore.update( - [{"task_id": "mp-1", "data": data1, gridfsstore.last_updated_field: tic}] - ) + gridfsstore.update([{"task_id": "mp-1", "data": data1, gridfsstore.last_updated_field: tic}]) assert gridfsstore.count() == 1 - gridfsstore.update( - [{"task_id": "mp-2", "data": data2, gridfsstore.last_updated_field: tic}] - ) + gridfsstore.update([{"task_id": "mp-2", "data": data2, gridfsstore.last_updated_field: tic}]) assert gridfsstore.count() == 2 assert gridfsstore.count({"task_id": "mp-2"}) == 1 @@ -105,12 +83,8 @@ def test_query(gridfsstore): data1 = np.random.rand(256) data2 = np.random.rand(256) tic = datetime(2018, 4, 12, 16) - gridfsstore.update( - [{"task_id": "mp-1", "data": data1, gridfsstore.last_updated_field: tic}] - ) - gridfsstore.update( - [{"task_id": "mp-2", "data": data2, gridfsstore.last_updated_field: tic}] - ) + gridfsstore.update([{"task_id": "mp-1", "data": data1, gridfsstore.last_updated_field: tic}]) + gridfsstore.update([{"task_id": "mp-2", "data": data2, gridfsstore.last_updated_field: tic}]) doc = gridfsstore.query_one(criteria={"task_id": "mp-1"}) nptu.assert_almost_equal(doc["data"], data1, 7) @@ -136,26 +110,18 @@ def test_last_updated(gridfsstore): data2 = np.random.rand(256) tic = datetime(2018, 4, 12, 16) - gridfsstore.update( - [{"task_id": "mp-1", "data": data1, gridfsstore.last_updated_field: tic}] - ) - gridfsstore.update( - [{"task_id": "mp-2", "data": data2, gridfsstore.last_updated_field: tic}] - ) + gridfsstore.update([{"task_id": "mp-1", "data": data1, gridfsstore.last_updated_field: tic}]) + gridfsstore.update([{"task_id": "mp-2", "data": data2, gridfsstore.last_updated_field: tic}]) assert gridfsstore.last_updated == tic toc = datetime(2019, 6, 12, 16) - gridfsstore.update( - [{"task_id": "mp-3", "data": data2, gridfsstore.last_updated_field: toc}] - ) + gridfsstore.update([{"task_id": "mp-3", "data": data2, gridfsstore.last_updated_field: toc}]) assert gridfsstore.last_updated == toc tic = datetime(2017, 6, 12, 16) - gridfsstore.update( - [{"task_id": "mp-4", "data": data2, gridfsstore.last_updated_field: tic}] - ) + gridfsstore.update([{"task_id": "mp-4", "data": data2, gridfsstore.last_updated_field: tic}]) assert gridfsstore.last_updated == toc @@ -246,10 +212,7 @@ def test_gridfsstore_from_launchpad_file(lp_file): def test_searchable_fields(gridfsstore): tic = datetime(2018, 4, 12, 16) - data = [ - {"task_id": f"mp-{i}", "a": i, gridfsstore.last_updated_field: tic} - for i in range(3) - ] + data = [{"task_id": f"mp-{i}", "a": i, gridfsstore.last_updated_field: tic} for i in range(3)] gridfsstore.searchable_fields = ["task_id"] gridfsstore.update(data, key="a") @@ -260,10 +223,7 @@ def test_searchable_fields(gridfsstore): def test_additional_metadata(gridfsstore): tic = datetime(2018, 4, 12, 16) - data = [ - {"task_id": f"mp-{i}", "a": i, gridfsstore.last_updated_field: tic} - for i in range(3) - ] + data = [{"task_id": f"mp-{i}", "a": i, gridfsstore.last_updated_field: tic} for i in range(3)] gridfsstore.update(data, key="a", additional_metadata="task_id") diff --git a/tests/stores/test_mongolike.py b/tests/stores/test_mongolike.py index 7a0147719..68784a639 100644 --- a/tests/stores/test_mongolike.py +++ b/tests/stores/test_mongolike.py @@ -8,15 +8,14 @@ import orjson import pymongo.collection import pytest -from monty.tempfile import ScratchDir -from pymongo.errors import ConfigurationError, DocumentTooLarge, OperationFailure - from maggma.core import StoreError from maggma.stores import JSONStore, MemoryStore, MongoStore, MongoURIStore, MontyStore from maggma.validators import JSONSchemaValidator +from monty.tempfile import ScratchDir +from pymongo.errors import ConfigurationError, DocumentTooLarge, OperationFailure -@pytest.fixture +@pytest.fixture() def mongostore(): store = MongoStore( database="maggma_test", @@ -27,21 +26,21 @@ def mongostore(): store._collection.drop() -@pytest.fixture +@pytest.fixture() def montystore(tmp_dir): store = MontyStore("maggma_test") store.connect() return store -@pytest.fixture +@pytest.fixture() def memorystore(): store = MemoryStore() store.connect() return store -@pytest.fixture +@pytest.fixture() def jsonstore(test_dir): files = [] for f in ["a.json", "b.json"]: @@ -107,21 +106,17 @@ def test_mongostore_distinct(mongostore): vals = mongostore.distinct("key") # Test to make sure distinct on array field is unraveled when using manual distinct assert len(vals) == len(list(range(1000000))) - assert all([isinstance(v, str) for v in vals]) + assert all(isinstance(v, str) for v in vals) # Test to make sure manual distinct uses the criteria query - mongostore._collection.insert_many( - [{"key": f"mp-{i}", "a": 2} for i in range(1000001, 2000001)] - ) + mongostore._collection.insert_many([{"key": f"mp-{i}", "a": 2} for i in range(1000001, 2000001)]) vals = mongostore.distinct("key", {"a": 2}) assert len(vals) == len(list(range(1000001, 2000001))) def test_mongostore_update(mongostore): mongostore.update({"e": 6, "d": 4}, key="e") - assert ( - mongostore.query_one(criteria={"d": {"$exists": 1}}, properties=["d"])["d"] == 4 - ) + assert mongostore.query_one(criteria={"d": {"$exists": 1}}, properties=["d"])["d"] == 4 mongostore.update([{"e": 7, "d": 8, "f": 9}], key=["d", "f"]) assert mongostore.query_one(criteria={"d": 8, "f": 9}, properties=["e"])["e"] == 7 @@ -163,9 +158,9 @@ def test_mongostore_groupby(mongostore): ) data = list(mongostore.groupby("d")) assert len(data) == 2 - grouped_by_9 = [g[1] for g in data if g[0]["d"] == 9][0] + grouped_by_9 = next(g[1] for g in data if g[0]["d"] == 9) assert len(grouped_by_9) == 3 - grouped_by_10 = [g[1] for g in data if g[0]["d"] == 10][0] + grouped_by_10 = next(g[1] for g in data if g[0]["d"] == 10) assert len(grouped_by_10) == 1 data = list(mongostore.groupby(["e", "d"])) @@ -214,12 +209,10 @@ def test_mongostore_last_updated(mongostore): assert mongostore.last_updated == datetime.min start_time = datetime.utcnow() mongostore._collection.insert_one({mongostore.key: 1, "a": 1}) - with pytest.raises(StoreError) as cm: - mongostore.last_updated - assert cm.match(mongostore.last_updated_field) - mongostore.update( - [{mongostore.key: 1, "a": 1, mongostore.last_updated_field: datetime.utcnow()}] - ) + with pytest.raises(StoreError) as cm: # noqa: PT012 + mongostore.last_updated # noqa: B018 + assert cm.match(mongostore.last_updated_field) + mongostore.update([{mongostore.key: 1, "a": 1, mongostore.last_updated_field: datetime.utcnow()}]) assert mongostore.last_updated > start_time @@ -229,20 +222,10 @@ def test_mongostore_newer_in(mongostore): # make sure docs are newer in mongostore then target and check updated_keys - target.update( - [ - {mongostore.key: i, mongostore.last_updated_field: datetime.utcnow()} - for i in range(10) - ] - ) + target.update([{mongostore.key: i, mongostore.last_updated_field: datetime.utcnow()} for i in range(10)]) # Update docs in source - mongostore.update( - [ - {mongostore.key: i, mongostore.last_updated_field: datetime.utcnow()} - for i in range(10) - ] - ) + mongostore.update([{mongostore.key: i, mongostore.last_updated_field: datetime.utcnow()} for i in range(10)]) assert len(target.newer_in(mongostore)) == 10 assert len(target.newer_in(mongostore, exhaustive=True)) == 10 @@ -271,11 +254,11 @@ def test_groupby(memorystore): ) data = list(memorystore.groupby("d", properties={"e": 1, "f": 1})) assert len(data) == 2 - grouped_by_9 = [g[1] for g in data if g[0]["d"] == 9][0] + grouped_by_9 = next(g[1] for g in data if g[0]["d"] == 9) assert len(grouped_by_9) == 3 - assert all([d.get("f", False) for d in grouped_by_9]) - assert all([d.get("e", False) for d in grouped_by_9]) - grouped_by_10 = [g[1] for g in data if g[0]["d"] == 10][0] + assert all(d.get("f", False) for d in grouped_by_9) + assert all(d.get("e", False) for d in grouped_by_9) + grouped_by_10 = next(g[1] for g in data if g[0]["d"] == 10) assert len(grouped_by_10) == 1 data = list(memorystore.groupby(["e", "d"])) @@ -324,9 +307,9 @@ def test_monty_store_groupby(montystore): ) data = list(montystore.groupby("d")) assert len(data) == 2 - grouped_by_9 = [g[1] for g in data if g[0]["d"] == 9][0] + grouped_by_9 = next(g[1] for g in data if g[0]["d"] == 9) assert len(grouped_by_9) == 3 - grouped_by_10 = [g[1] for g in data if g[0]["d"] == 10][0] + grouped_by_10 = next(g[1] for g in data if g[0]["d"] == 10) assert len(grouped_by_10) == 1 data = list(montystore.groupby(["e", "d"])) @@ -384,9 +367,7 @@ def test_monty_store_distinct(montystore): def test_monty_store_update(montystore): montystore.update({"e": 6, "d": 4}, key="e") - assert ( - montystore.query_one(criteria={"d": {"$exists": 1}}, properties=["d"])["d"] == 4 - ) + assert montystore.query_one(criteria={"d": {"$exists": 1}}, properties=["d"])["d"] == 4 montystore.update([{"e": 7, "d": 8, "f": 9}], key=["d", "f"]) assert montystore.query_one(criteria={"d": 8, "f": 9}, properties=["e"])["e"] == 7 @@ -418,12 +399,10 @@ def test_monty_store_last_updated(montystore): assert montystore.last_updated == datetime.min start_time = datetime.utcnow() montystore._collection.insert_one({montystore.key: 1, "a": 1}) - with pytest.raises(StoreError) as cm: - montystore.last_updated - assert cm.match(montystore.last_updated_field) - montystore.update( - [{montystore.key: 1, "a": 1, montystore.last_updated_field: datetime.utcnow()}] - ) + with pytest.raises(StoreError) as cm: # noqa: PT012 + montystore.last_updated # noqa: B018 + assert cm.match(montystore.last_updated_field) + montystore.update([{montystore.key: 1, "a": 1, montystore.last_updated_field: datetime.utcnow()}]) assert montystore.last_updated > start_time @@ -436,8 +415,8 @@ def test_json_store_load(jsonstore, test_dir): assert len(list(jsonstore.query())) == 20 # confirm descriptive error raised if you get a KeyError + jsonstore = JSONStore(test_dir / "test_set" / "c.json.gz", key="random_key") with pytest.raises(KeyError, match="Key field 'random_key' not found"): - jsonstore = JSONStore(test_dir / "test_set" / "c.json.gz", key="random_key") jsonstore.connect() # if the .json does not exist, it should be created @@ -491,18 +470,14 @@ def test_json_store_writeable(test_dir): jsonstore.connect() assert jsonstore.count() == 2 jsonstore.close() - with mock.patch( - "maggma.stores.JSONStore.update_json_file" - ) as update_json_file_mock: + with mock.patch("maggma.stores.JSONStore.update_json_file") as update_json_file_mock: jsonstore = JSONStore("d.json", file_writable=False) jsonstore.connect() jsonstore.update({"new": "hello", "task_id": 5}) assert jsonstore.count() == 3 jsonstore.close() update_json_file_mock.assert_not_called() - with mock.patch( - "maggma.stores.JSONStore.update_json_file" - ) as update_json_file_mock: + with mock.patch("maggma.stores.JSONStore.update_json_file") as update_json_file_mock: jsonstore = JSONStore("d.json", file_writable=False) jsonstore.connect() jsonstore.remove_docs({"task_id": 5}) diff --git a/tests/stores/test_shared_stores.py b/tests/stores/test_shared_stores.py index 663a83579..e33d14023 100644 --- a/tests/stores/test_shared_stores.py +++ b/tests/stores/test_shared_stores.py @@ -1,13 +1,12 @@ import pymongo import pytest -from pymongo.errors import DocumentTooLarge, OperationFailure - from maggma.stores import GridFSStore, MemoryStore, MongoStore from maggma.stores.shared_stores import MultiStore, StoreFacade from maggma.validators import JSONSchemaValidator +from pymongo.errors import DocumentTooLarge, OperationFailure -@pytest.fixture +@pytest.fixture() def mongostore(): store = MongoStore("maggma_test", "test") store.connect() @@ -15,7 +14,7 @@ def mongostore(): store._collection.drop() -@pytest.fixture +@pytest.fixture() def gridfsstore(): store = GridFSStore("maggma_test", "test", key="task_id") store.connect() @@ -24,13 +23,12 @@ def gridfsstore(): store._chunks_collection.drop() -@pytest.fixture +@pytest.fixture() def multistore(): - store = MultiStore() - yield store + return MultiStore() -@pytest.fixture +@pytest.fixture() def memorystore(): store = MemoryStore() store.connect() @@ -88,9 +86,7 @@ def test_store_facade(multistore, mongostore, gridfsstore): def test_multistore_query(multistore, mongostore, memorystore): memorystore_facade = StoreFacade(memorystore, multistore) mongostore_facade = StoreFacade(mongostore, multistore) - temp_mongostore_facade = StoreFacade( - MongoStore.from_dict(mongostore.as_dict()), multistore - ) + temp_mongostore_facade = StoreFacade(MongoStore.from_dict(mongostore.as_dict()), multistore) memorystore_facade._collection.insert_one({"a": 1, "b": 2, "c": 3}) assert memorystore_facade.query_one(properties=["a"])["a"] == 1 @@ -143,18 +139,14 @@ def test_multistore_distinct(multistore, mongostore): assert mongostore_facade.distinct("i") == [None] # Test to make sure DocumentTooLarge errors get dealt with properly using built in distinct - mongostore_facade._collection.insert_many( - [{"key": [f"mp-{i}"]} for i in range(1000000)] - ) + mongostore_facade._collection.insert_many([{"key": [f"mp-{i}"]} for i in range(1000000)]) vals = mongostore_facade.distinct("key") # Test to make sure distinct on array field is unraveled when using manual distinct assert len(vals) == len(list(range(1000000))) - assert all([isinstance(v, str) for v in vals]) + assert all(isinstance(v, str) for v in vals) # Test to make sure manual distinct uses the criteria query - mongostore_facade._collection.insert_many( - [{"key": f"mp-{i}", "a": 2} for i in range(1000001, 2000001)] - ) + mongostore_facade._collection.insert_many([{"key": f"mp-{i}", "a": 2} for i in range(1000001, 2000001)]) vals = mongostore_facade.distinct("key", {"a": 2}) assert len(vals) == len(list(range(1000001, 2000001))) @@ -163,24 +155,13 @@ def test_multistore_update(multistore, mongostore): mongostore_facade = StoreFacade(mongostore, multistore) mongostore_facade.update({"e": 6, "d": 4}, key="e") - assert ( - mongostore_facade.query_one(criteria={"d": {"$exists": 1}}, properties=["d"])[ - "d" - ] - == 4 - ) + assert mongostore_facade.query_one(criteria={"d": {"$exists": 1}}, properties=["d"])["d"] == 4 mongostore_facade.update([{"e": 7, "d": 8, "f": 9}], key=["d", "f"]) - assert ( - mongostore_facade.query_one(criteria={"d": 8, "f": 9}, properties=["e"])["e"] - == 7 - ) + assert mongostore_facade.query_one(criteria={"d": 8, "f": 9}, properties=["e"])["e"] == 7 mongostore_facade.update([{"e": 11, "d": 8, "f": 9}], key=["d", "f"]) - assert ( - mongostore_facade.query_one(criteria={"d": 8, "f": 9}, properties=["e"])["e"] - == 11 - ) + assert mongostore_facade.query_one(criteria={"d": 8, "f": 9}, properties=["e"])["e"] == 11 test_schema = { "type": "object", @@ -219,9 +200,9 @@ def test_multistore_groupby(multistore, mongostore): ) data = list(mongostore_facade.groupby("d")) assert len(data) == 2 - grouped_by_9 = [g[1] for g in data if g[0]["d"] == 9][0] + grouped_by_9 = next(g[1] for g in data if g[0]["d"] == 9) assert len(grouped_by_9) == 3 - grouped_by_10 = [g[1] for g in data if g[0]["d"] == 10][0] + grouped_by_10 = next(g[1] for g in data if g[0]["d"] == 10) assert len(grouped_by_10) == 1 data = list(mongostore_facade.groupby(["e", "d"])) diff --git a/tests/stores/test_ssh_tunnel.py b/tests/stores/test_ssh_tunnel.py index 78e334c80..041d6906f 100644 --- a/tests/stores/test_ssh_tunnel.py +++ b/tests/stores/test_ssh_tunnel.py @@ -1,18 +1,13 @@ import paramiko import pymongo import pytest -from monty.serialization import dumpfn, loadfn -from paramiko.ssh_exception import ( - AuthenticationException, - NoValidConnectionsError, - SSHException, -) - from maggma.stores.mongolike import MongoStore, SSHTunnel +from monty.serialization import dumpfn, loadfn +from paramiko.ssh_exception import AuthenticationException, NoValidConnectionsError, SSHException -@pytest.fixture -def ssh_server_available(): +@pytest.fixture() +def ssh_server_available(): # noqa: PT004 """ Fixture to determine if an SSH server is available to test the SSH tunnel diff --git a/tests/test_utils.py b/tests/test_utils.py index 89cf971b7..58217b911 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,3 @@ -# coding: utf-8 """ Tests for builders """ @@ -6,9 +5,8 @@ from time import sleep import pytest - -from maggma.utils import Timeout # dt_to_isoformat_ceil_ms,; isostr_to_dt, from maggma.utils import ( + Timeout, # dt_to_isoformat_ceil_ms,; isostr_to_dt, dynamic_import, grouper, primed, @@ -44,14 +42,13 @@ def takes_too_long(): def test_primed(): - global is_primed + global is_primed # noqa: PLW0603 is_primed = False def unprimed_iter(): - global is_primed + global is_primed # noqa: PLW0603 is_primed = True - for i in range(10): - yield i + yield from range(10) iterator = unprimed_iter() @@ -62,22 +59,17 @@ def unprimed_iter(): assert is_primed is True assert list(iterator) == list(range(10)) - # test stop itteration + # test stop iteration with pytest.raises(StopIteration): next(primed(iterator)) def test_datetime_utils(): - assert ( - to_isoformat_ceil_ms(datetime(2019, 12, 13, 0, 23, 11, 9515)) - == "2019-12-13T00:23:11.010" - ) + assert to_isoformat_ceil_ms(datetime(2019, 12, 13, 0, 23, 11, 9515)) == "2019-12-13T00:23:11.010" assert to_isoformat_ceil_ms("2019-12-13T00:23:11.010") == "2019-12-13T00:23:11.010" assert to_dt("2019-12-13T00:23:11.010") == datetime(2019, 12, 13, 0, 23, 11, 10000) - assert to_dt(datetime(2019, 12, 13, 0, 23, 11, 10000)) == datetime( - 2019, 12, 13, 0, 23, 11, 10000 - ) + assert to_dt(datetime(2019, 12, 13, 0, 23, 11, 10000)) == datetime(2019, 12, 13, 0, 23, 11, 10000) def test_dynamic_import(): @@ -89,7 +81,7 @@ def test_grouper(): assert len(list(grouper(my_iterable, 10))) == 10 - my_iterable = list(range(100)) + [None] + my_iterable = [*list(range(100)), None] my_groups = list(grouper(my_iterable, 10)) assert len(my_groups) == 11 assert len(my_groups[10]) == 1 diff --git a/tests/test_validator.py b/tests/test_validator.py index a48e87489..70628244f 100644 --- a/tests/test_validator.py +++ b/tests/test_validator.py @@ -1,11 +1,9 @@ -# coding: utf-8 """ Tests the validators """ import pytest -from monty.json import MSONable - from maggma.validators import JSONSchemaValidator, ValidationError, msonable_schema +from monty.json import MSONable class LatticeMock(MSONable): @@ -17,7 +15,7 @@ def __init__(self, a): self.a = a -@pytest.fixture +@pytest.fixture() def test_schema(): return { "type": "object", @@ -69,10 +67,6 @@ def test_jsonschemevalidator(test_schema): "lattice: ['I am not a lattice!'] is not of type 'object'" ] - assert validator.validation_errors(invalid_doc_missing_key) == [ - ": 'successful' is a required property" - ] + assert validator.validation_errors(invalid_doc_missing_key) == [": 'successful' is a required property"] - assert validator.validation_errors(invalid_doc_wrong_type) == [ - "successful: 'true' is not of type 'boolean'" - ] + assert validator.validation_errors(invalid_doc_wrong_type) == ["successful: 'true' is not of type 'boolean'"] From cb6573f759c5f2d7d3c0d9edbd7ab6491c0a3d1b Mon Sep 17 00:00:00 2001 From: Ryan Kingsbury Date: Mon, 31 Jul 2023 13:53:36 -0400 Subject: [PATCH 09/11] sync PR template with pymatgen --- .github/pull_request_template.md | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 20931c625..4a5bac8f5 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,9 +1,29 @@ -*Start with a description of this PR. Then edit the list below to the items that make sense for your PR scope, and check off the boxes as you go!* +## Summary -## Contributor Checklist +Major changes: -- [ ] I have broken down my PR scope into the following TODO tasks - - [ ] task 1 - - [ ] task 2 +- feature 1: ... +- fix 1: ... + +## Todos + +If this is work in progress, what else needs to be done? + +- feature 2: ... +- fix 2: + +## Checklist + +- [ ] Google format doc strings added. +- [ ] Code linted with `ruff`. (For guidance in fixing rule violates, see [rule list](https://beta.ruff.rs/docs/rules/)) +- [ ] Type annotations included. Check with `mypy`. +- [ ] Tests added for new features/fixes. - [ ] I have run the tests locally and they passed. -- [ ] I have added tests, or extended existing tests, to cover any new features or bugs fixed in this PR + + +Tip: Install `pre-commit` hooks to auto-check types and linting before every commit: + +```sh +pip install -U pre-commit +pre-commit install +``` From c152a5b39b8cfe2f53cb17150db9646fed380dbd Mon Sep 17 00:00:00 2001 From: Ryan Kingsbury Date: Mon, 31 Jul 2023 14:43:16 -0400 Subject: [PATCH 10/11] pre-commit run --all-files --- src/maggma/api/utils.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/maggma/api/utils.py b/src/maggma/api/utils.py index 03531d8aa..d02c3d099 100644 --- a/src/maggma/api/utils.py +++ b/src/maggma/api/utils.py @@ -1,14 +1,8 @@ import base64 import inspect import sys -from typing import ( - Any, - Callable, - Dict, - List, - Optional, - Type, -) +from typing import Any, Callable, Dict, List, Optional, Type + from bson.objectid import ObjectId from monty.json import MSONable from pydantic import BaseModel From cbaefa6652fe3b096c52bf08430a818f9321b717 Mon Sep 17 00:00:00 2001 From: Ryan Kingsbury Date: Mon, 31 Jul 2023 14:54:12 -0400 Subject: [PATCH 11/11] auto upgrade dependencies fixes --- .github/workflows/upgrade-dependencies.yml | 107 ++++----------------- 1 file changed, 21 insertions(+), 86 deletions(-) diff --git a/.github/workflows/upgrade-dependencies.yml b/.github/workflows/upgrade-dependencies.yml index d67066955..71b790f7f 100644 --- a/.github/workflows/upgrade-dependencies.yml +++ b/.github/workflows/upgrade-dependencies.yml @@ -1,11 +1,12 @@ # https://www.oddbird.net/2022/06/01/dependabot-single-pull-request/ +# https://github.com/materialsproject/MPContribs/blob/master/.github/workflows/upgrade-dependencies.yml name: upgrade dependencies on: workflow_dispatch: # Allow running on-demand schedule: # Runs every Monday at 8:00 UTC (4:00 Eastern) - - cron: '0 17 * * 1' + - cron: '0 8 * * 1' jobs: upgrade: @@ -13,9 +14,9 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: ['ubuntu-latest'] - package: ["optimade"] - python-version: ["3.10"] + os: ['ubuntu-latest', 'macos-latest', windows-latest] + package: ["maggma"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v3 with: @@ -24,121 +25,55 @@ jobs: with: python-version: ${{ matrix.python-version }} cache: 'pip' - cache-dependency-path: '**/requirements/${{ matrix.os }}_py${{ matrix.python-version }}.txt' - name: Upgrade Python dependencies shell: bash run: | python${{ matrix.python-version }} -m pip install --upgrade pip pip-tools - cd docker/${{ matrix.package }} - python${{ matrix.python-version }} -m piptools compile -q --resolver=backtracking --upgrade -o requirements/${{ matrix.os }}_py${{ matrix.python-version }}.txt + cd ${{ matrix.package }} + python${{ matrix.python-version }} -m piptools compile -q --upgrade --resolver=backtracking -o requirements/${{ matrix.os }}_py${{ matrix.python-version }}.txt + python${{ matrix.python-version }} -m piptools compile -q --upgrade --resolver=backtracking --all-extras -o requirements/${{ matrix.os }}_py${{ matrix.python-version }}_extras.txt - name: Detect changes id: changes shell: bash run: | - echo "count=`git diff --quiet docker/${{ matrix.package }}/requirements; echo $?`" >> $GITHUB_OUTPUT - echo "files=`git ls-files --exclude-standard --others docker/${{ matrix.package }}/requirements | wc -l | xargs`" >> $GITHUB_OUTPUT + #git diff-index HEAD ${{ matrix.package }}/requirements/${{ matrix.os }}_py${{ matrix.python-version }}*.txt | awk '{print $4}' | sort -u + #sha1=$(git diff-index HEAD ${{ matrix.package }}/requirements/${{ matrix.os }}_py${{ matrix.python-version }}*.txt | awk '{print $4}' | sort -u | head -n1) + #[[ $sha1 == "0000000000000000000000000000000000000000" ]] && git update-index --really-refresh ${{ matrix.package }}/requirements/${{ matrix.os }}_py${{ matrix.python-version }}*.txt + echo "count=$(git diff-index HEAD ${{ matrix.package }}/requirements/${{ matrix.os }}_py${{ matrix.python-version }}*.txt | wc -l | xargs)" >> $GITHUB_OUTPUT + echo "files=$(git ls-files --exclude-standard --others ${{ matrix.package }}/requirements/${{ matrix.os }}_py${{ matrix.python-version }}*.txt | wc -l | xargs)" >> $GITHUB_OUTPUT - name: commit & push changes if: steps.changes.outputs.count > 0 || steps.changes.outputs.files > 0 shell: bash run: | git config user.name github-actions git config user.email github-actions@github.com - git add docker/${{ matrix.package }}/requirements + git add ${{ matrix.package }}/requirements git commit -m "update dependencies for ${{ matrix.package }} (${{ matrix.os }}/py${{ matrix.python-version }})" git push -f origin ${{ github.ref_name }}:auto-dependency-upgrades-${{ matrix.package }}-${{ matrix.os }}-py${{ matrix.python-version }} pull_request: name: Merge all branches and open PR runs-on: ubuntu-latest - needs: [upgrade] - strategy: - matrix: - python-version: ["3.10"] + needs: upgrade steps: - uses: actions/checkout@v3 with: fetch-depth: 0 - submodules: 'recursive' - token: ${{ secrets.PAT }} - - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: '**/docker/python/requirements.txt' - - - name: make new branch - run: | - git config --global user.name github-actions - git config --global user.email github-actions@github.com - git checkout -b auto-dependency-upgrades - - name: detect auto-upgrade-dependency branches - id: upgrade_changes + id: changes run: echo "count=$(git branch -r | grep auto-dependency-upgrades- | wc -l | xargs)" >> $GITHUB_OUTPUT - - name: merge all auto-dependency-upgrades branches - if: steps.upgrade_changes.outputs.count > 0 + if: steps.changes.outputs.count > 0 run: | + git config user.name github-actions + git config user.email github-actions@github.com + git checkout -b auto-dependency-upgrades git branch -r | grep auto-dependency-upgrades- | xargs -I {} git merge {} git rebase ${GITHUB_REF##*/} git push -f origin auto-dependency-upgrades git branch -r | grep auto-dependency-upgrades- | cut -d/ -f2 | xargs -I {} git push origin :{} - - - name: submodule updates - run: git submodule update --remote - - - name: compile docker/python dependencies - shell: bash - run: | - cd docker - python${{ matrix.python-version }} -m pip install --upgrade pip pip-tools - setup_packages="emmet/emmet-api emmet/emmet-core emmet/emmet-builders MPContribs/mpcontribs-api MPContribs/mpcontribs-client MPContribs/mpcontribs-portal" - pip_input=""; for i in `echo $setup_packages`; do pip_input="$pip_input $i/setup.py"; done - python${{ matrix.python-version }} -m piptools compile -q --resolver=backtracking --upgrade web/pyproject.toml MPContribs/mpcontribs-kernel-gateway/requirements.in `echo $pip_input` -o python/requirements-full.txt - grep -h -E "numpy==|scipy==|matplotlib==|pandas==" python/requirements-full.txt > python/requirements.txt - rm python/requirements-full.txt - python${{ matrix.python-version }} -m piptools compile -q --resolver=backtracking --upgrade python/requirements.txt web/pyproject.toml -o web/requirements/deployment.txt - cd web && git checkout main && git add requirements/deployment.txt && git commit -m "upgrade dependencies for deployment" && git push - cd - - python${{ matrix.python-version }} -m piptools compile -q --resolver=backtracking --upgrade python/requirements.txt MPContribs/mpcontribs-kernel-gateway/requirements.in -o MPContribs/mpcontribs-kernel-gateway/requirements/deployment.txt - for i in `echo $setup_packages`; do - python${{ matrix.python-version }} -m piptools compile -q --resolver=backtracking --upgrade python/requirements.txt $i/setup.py -o $i/requirements/deployment.txt - done - cd emmet && git checkout main && git add emmet-*/requirements/deployment.txt && git commit -m "upgrade dependencies for deployment" && git push - cd - - cd MPContribs && git checkout master && git add mpcontribs-*/requirements/deployment.txt && git commit -m "upgrade dependencies for deployment" && git push - cd - - - - name: Detect changes - id: changes - shell: bash - run: | - echo "count=`git diff --quiet --ignore-submodules -- . ':!docker/python'; echo $?`" >> $GITHUB_OUTPUT - echo "countReq=`git diff --quiet docker/python/requirements.txt; echo $?`" >> $GITHUB_OUTPUT - echo "files=$(git ls-files --exclude-standard --others | wc -l | xargs)" >> $GITHUB_OUTPUT - - - name: commit & push changes - if: steps.changes.outputs.count > 0 || steps.changes.outputs.files > 0 || steps.changes.outputs.countReq > 0 - shell: bash - run: | - git add . - git commit -m "auto dependency upgrades" - git push origin auto-dependency-upgrades - - - name: create and push tag to trigger python base image action - if: steps.changes.outputs.countReq > 0 - shell: bash - run: | - ver=`grep FROM docker/python/Dockerfile | cut -d: -f2 | cut -d- -f1` - prefix=${ver%.*}${ver##*.} - patch=`git tag -l "python-${prefix}.*" | sort -V | tail -1 | cut -d. -f3` - [[ -z "$patch" ]] && tag="python-${prefix}.0" || tag="python-${prefix}.$((++patch))" - echo $tag - git tag $tag - git push --tags - - name: Open pull request if needed - if: steps.changes.outputs.count > 0 || steps.changes.outputs.files > 0 || steps.changes.outputs.countReq > 0 + if: steps.changes.outputs.count > 0 env: GITHUB_TOKEN: ${{ secrets.PAT }} # Only open a PR if the branch is not attached to an existing one