From ebbfbe068ee0d1d8428e505df1c84a0be9570b48 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Sat, 3 Aug 2024 11:36:13 +0200 Subject: [PATCH 1/6] implement store.list_prefix and store._set_dict --- src/zarr/abc/store.py | 10 +++- src/zarr/store/local.py | 4 -- src/zarr/store/memory.py | 2 +- src/zarr/store/remote.py | 9 +++- src/zarr/sync.py | 17 ++++++ src/zarr/testing/store.py | 85 ++++++++++++++++++++---------- tests/v3/test_store/test_remote.py | 36 ++++++------- 7 files changed, 107 insertions(+), 56 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 449816209..95d12943b 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Mapping from typing import Any, NamedTuple, Protocol, runtime_checkable from typing_extensions import Self @@ -221,6 +221,14 @@ def close(self) -> None: self._is_open = False pass + async def _set_dict(self, dict: Mapping[str, Buffer]) -> None: + """ + Insert objects into storage as defined by a prefix: value mapping. + """ + for key, value in dict.items(): + await self.set(key, value) + return None + @runtime_checkable class ByteGetter(Protocol): diff --git a/src/zarr/store/local.py b/src/zarr/store/local.py index 25fd9fc13..cc6ba38f2 100644 --- a/src/zarr/store/local.py +++ b/src/zarr/store/local.py @@ -193,10 +193,6 @@ async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: ------- AsyncGenerator[str, None] """ - for p in (self.root / prefix).rglob("*"): - if p.is_file(): - yield str(p) - to_strip = str(self.root) + "/" for p in (self.root / prefix).rglob("*"): if p.is_file(): diff --git a/src/zarr/store/memory.py b/src/zarr/store/memory.py index dd3e52e70..c3a61e8e5 100644 --- a/src/zarr/store/memory.py +++ b/src/zarr/store/memory.py @@ -101,7 +101,7 @@ async def list(self) -> AsyncGenerator[str, None]: async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: for key in self._store_dict: if key.startswith(prefix): - yield key + yield key.removeprefix(prefix) async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: if prefix.endswith("/"): diff --git a/src/zarr/store/remote.py b/src/zarr/store/remote.py index c742d9e56..87b8fe657 100644 --- a/src/zarr/store/remote.py +++ b/src/zarr/store/remote.py @@ -205,5 +205,10 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: yield onefile async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: - for onefile in await self._fs._ls(prefix, detail=False): - yield onefile + if prefix == "": + find_str = "/".join([self.path, prefix]) + else: + find_str = "/".join([self.path, prefix]) + + for onefile in await self._fs._find(find_str): + yield onefile.removeprefix(find_str) diff --git a/src/zarr/sync.py b/src/zarr/sync.py index 8af14f602..446ffd43e 100644 --- a/src/zarr/sync.py +++ b/src/zarr/sync.py @@ -114,6 +114,23 @@ def _get_loop() -> asyncio.AbstractEventLoop: return loop[0] +async def _collect_aiterator(data: AsyncIterator[T]) -> tuple[T, ...]: + """ + Collect an entire async iterator into a tuple + """ + result = [] + async for x in data: + result.append(x) + return tuple(result) + + +def collect_aiterator(data: AsyncIterator[T]) -> tuple[T, ...]: + """ + Synchronously collect an entire async iterator into a tuple. + """ + return sync(_collect_aiterator(data)) + + class SyncMixin: def _sync(self, coroutine: Coroutine[Any, Any, T]) -> T: # TODO: refactor this to to take *args and **kwargs and pass those to the method diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 4fdf497a6..ba37dda62 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -5,6 +5,7 @@ from zarr.abc.store import AccessMode, Store from zarr.buffer import Buffer, default_buffer_prototype from zarr.store.utils import _normalize_interval_index +from zarr.sync import _collect_aiterator from zarr.testing.utils import assert_bytes_equal S = TypeVar("S", bound=Store) @@ -103,6 +104,18 @@ async def test_set(self, store: S, key: str, data: bytes) -> None: observed = self.get(store, key) assert_bytes_equal(observed, data_buf) + async def test_set_dict(self, store: S) -> None: + """ + Test that a dict of key : value pairs can be inserted into the store via the + `_set_dict` method. + """ + keys = ["zarr.json", "c/0", "foo/c/0.0", "foo/0/0"] + data_buf = [Buffer.from_bytes(k.encode()) for k in keys] + store_dict = dict(zip(keys, data_buf, strict=True)) + await store._set_dict(store_dict) + for k, v in store_dict.items(): + assert self.get(store, k).to_bytes() == v.to_bytes() + @pytest.mark.parametrize( "key_ranges", ( @@ -165,37 +178,55 @@ async def test_clear(self, store: S) -> None: assert await store.empty() async def test_list(self, store: S) -> None: - assert [k async for k in store.list()] == [] - await store.set("foo/zarr.json", Buffer.from_bytes(b"bar")) - keys = [k async for k in store.list()] - assert keys == ["foo/zarr.json"], keys - - expected = ["foo/zarr.json"] - for i in range(10): - key = f"foo/c/{i}" - expected.append(key) - await store.set( - f"foo/c/{i}", Buffer.from_bytes(i.to_bytes(length=3, byteorder="little")) - ) + assert await _collect_aiterator(store.list()) == () + prefix = "foo" + data = Buffer.from_bytes(b"") + store_dict = { + prefix + "/zarr.json": data, + **{prefix + f"/c/{idx}": data for idx in range(10)}, + } + await store._set_dict(store_dict) + expected_sorted = sorted(store_dict.keys()) + observed = await _collect_aiterator(store.list()) + observed_sorted = sorted(observed) + assert observed_sorted == expected_sorted - @pytest.mark.xfail async def test_list_prefix(self, store: S) -> None: - # TODO: we currently don't use list_prefix anywhere - raise NotImplementedError + """ + Test that the `list_prefix` method works as intended. Given a prefix, it should return + all the keys in storage that start with this prefix. Keys should be returned with the shared + prefix removed. + """ + prefixes = ("", "a/", "a/b/", "a/b/c/") + data = Buffer.from_bytes(b"") + fname = "zarr.json" + store_dict = {p + fname: data for p in prefixes} + await store._set_dict(store_dict) + for p in prefixes: + observed = tuple(sorted(await _collect_aiterator(store.list_prefix(p)))) + expected: tuple[str, ...] = () + for k in store_dict.keys(): + if k.startswith(p): + expected += (k.removeprefix(p),) + expected = tuple(sorted(expected)) + assert observed == expected async def test_list_dir(self, store: S) -> None: - out = [k async for k in store.list_dir("")] - assert out == [] - assert [k async for k in store.list_dir("foo")] == [] - await store.set("foo/zarr.json", Buffer.from_bytes(b"bar")) - await store.set("foo/c/1", Buffer.from_bytes(b"\x01")) + root = "foo" + store_dict = { + root + "/zarr.json": Buffer.from_bytes(b"bar"), + root + "/c/1": Buffer.from_bytes(b"\x01"), + } + + assert await _collect_aiterator(store.list_dir("")) == () + assert await _collect_aiterator(store.list_dir(root)) == () + + await store._set_dict(store_dict) - keys_expected = ["zarr.json", "c"] - keys_observed = [k async for k in store.list_dir("foo")] + keys_observed = await _collect_aiterator(store.list_dir(root)) + keys_expected = {k.removeprefix(root + "/").split("/")[0] for k in store_dict.keys()} - assert len(keys_observed) == len(keys_expected), keys_observed - assert set(keys_observed) == set(keys_expected), keys_observed + assert sorted(keys_observed) == sorted(keys_expected) - keys_observed = [k async for k in store.list_dir("foo/")] - assert len(keys_expected) == len(keys_observed), keys_observed - assert set(keys_observed) == set(keys_expected), keys_observed + keys_observed = await _collect_aiterator(store.list_dir(root + "/")) + assert sorted(keys_expected) == sorted(keys_observed) diff --git a/tests/v3/test_store/test_remote.py b/tests/v3/test_store/test_remote.py index be9fa5ef6..14a181d7b 100644 --- a/tests/v3/test_store/test_remote.py +++ b/tests/v3/test_store/test_remote.py @@ -1,12 +1,18 @@ +from __future__ import annotations + import os +from collections.abc import Generator import fsspec import pytest +from botocore.client import BaseClient +from botocore.session import Session +from s3fs import S3FileSystem from upath import UPath from zarr.buffer import Buffer, default_buffer_prototype from zarr.store import RemoteStore -from zarr.sync import sync +from zarr.sync import _collect_aiterator, sync from zarr.testing.store import StoreTests s3fs = pytest.importorskip("s3fs") @@ -22,7 +28,7 @@ @pytest.fixture(scope="module") -def s3_base(): +def s3_base() -> Generator[None, None, None]: # writable local S3 system # This fixture is module-scoped, meaning that we can reuse the MotoServer across all tests @@ -37,16 +43,14 @@ def s3_base(): server.stop() -def get_boto3_client(): - from botocore.session import Session - +def get_boto3_client() -> BaseClient: # NB: we use the sync botocore client for setup session = Session() return session.create_client("s3", endpoint_url=endpoint_url) @pytest.fixture(autouse=True, scope="function") -def s3(s3_base): +def s3(s3_base: Generator[None, None, None]) -> Generator[S3FileSystem, None, None]: """ Quoting Martin Durant: pytest-asyncio creates a new event loop for each async test. @@ -71,21 +75,11 @@ def s3(s3_base): sync(session.close()) -# ### end from s3fs ### # - - -async def alist(it): - out = [] - async for a in it: - out.append(a) - return out - - -async def test_basic(): +async def test_basic() -> None: store = await RemoteStore.open( f"s3://{test_bucket_name}", mode="w", endpoint_url=endpoint_url, anon=False ) - assert not await alist(store.list()) + assert await _collect_aiterator(store.list()) == () assert not await store.exists("foo") data = b"hello" await store.set("foo", Buffer.from_bytes(data)) @@ -101,7 +95,7 @@ class TestRemoteStoreS3(StoreTests[RemoteStore]): store_cls = RemoteStore @pytest.fixture(scope="function", params=("use_upath", "use_str")) - def store_kwargs(self, request) -> dict[str, str | bool]: + def store_kwargs(self, request: pytest.FixtureRequest) -> dict[str, str | bool | UPath]: # type: ignore url = f"s3://{test_bucket_name}" anon = False mode = "r+" @@ -113,8 +107,8 @@ def store_kwargs(self, request) -> dict[str, str | bool]: raise AssertionError @pytest.fixture(scope="function") - def store(self, store_kwargs: dict[str, str | bool]) -> RemoteStore: - url = store_kwargs["url"] + async def store(self, store_kwargs: dict[str, str | bool | UPath]) -> RemoteStore: + url: str | UPath = store_kwargs["url"] mode = store_kwargs["mode"] if isinstance(url, UPath): out = self.store_cls(url=url, mode=mode) From da6083e6761ff2cc786a2651bdf5722f65f6636f Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Sat, 3 Aug 2024 11:47:00 +0200 Subject: [PATCH 2/6] simplify string handling --- src/zarr/store/local.py | 2 +- src/zarr/store/remote.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/zarr/store/local.py b/src/zarr/store/local.py index cc6ba38f2..fe1821343 100644 --- a/src/zarr/store/local.py +++ b/src/zarr/store/local.py @@ -196,7 +196,7 @@ async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: to_strip = str(self.root) + "/" for p in (self.root / prefix).rglob("*"): if p.is_file(): - yield str(p).replace(to_strip, "") + yield str(p).removeprefix(to_strip) async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: """ diff --git a/src/zarr/store/remote.py b/src/zarr/store/remote.py index 87b8fe657..0b9e3bb8c 100644 --- a/src/zarr/store/remote.py +++ b/src/zarr/store/remote.py @@ -205,10 +205,6 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: yield onefile async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: - if prefix == "": - find_str = "/".join([self.path, prefix]) - else: - find_str = "/".join([self.path, prefix]) - + find_str = "/".join([self.path, prefix]) for onefile in await self._fs._find(find_str): yield onefile.removeprefix(find_str) From e4101b7aabede9f7fec7a8141ef8118219c6cd0a Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Sat, 3 Aug 2024 20:52:55 +0200 Subject: [PATCH 3/6] use asyncio.gather in _set_dict --- src/zarr/abc/store.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 95d12943b..c8ebc34f1 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from asyncio import gather from collections.abc import AsyncGenerator, Mapping from typing import Any, NamedTuple, Protocol, runtime_checkable @@ -225,8 +226,7 @@ async def _set_dict(self, dict: Mapping[str, Buffer]) -> None: """ Insert objects into storage as defined by a prefix: value mapping. """ - for key, value in dict.items(): - await self.set(key, value) + await gather(*(self.set(key, value) for key, value in dict.items())) return None From 70f9cebe767b300fb0068d4f2d05ea6bf9b1c3cf Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Sat, 3 Aug 2024 23:09:32 +0200 Subject: [PATCH 4/6] add docstrings to list_prefix methods, and make invocation of _find more explicit --- src/zarr/abc/store.py | 4 +++- src/zarr/store/local.py | 4 +++- src/zarr/store/memory.py | 12 ++++++++++++ src/zarr/store/remote.py | 15 ++++++++++++++- 4 files changed, 32 insertions(+), 3 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index c8ebc34f1..18a85b974 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -189,7 +189,9 @@ def list(self) -> AsyncGenerator[str, None]: @abstractmethod def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: - """Retrieve all keys in the store with a given prefix. + """ + Retrieve all keys in the store that begin with a given prefix. Keys are returned with the + common leading prefix removed. Parameters ---------- diff --git a/src/zarr/store/local.py b/src/zarr/store/local.py index fe1821343..c7e33f8ca 100644 --- a/src/zarr/store/local.py +++ b/src/zarr/store/local.py @@ -183,7 +183,9 @@ async def list(self) -> AsyncGenerator[str, None]: yield str(p).replace(to_strip, "") async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: - """Retrieve all keys in the store with a given prefix. + """ + Retrieve all keys in the store that begin with a given prefix. Keys are returned with the + common leading prefix removed. Parameters ---------- diff --git a/src/zarr/store/memory.py b/src/zarr/store/memory.py index c3a61e8e5..46ca0b657 100644 --- a/src/zarr/store/memory.py +++ b/src/zarr/store/memory.py @@ -104,6 +104,18 @@ async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: yield key.removeprefix(prefix) async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: + """ + Retrieve all keys in the store that begin with a given prefix. Keys are returned with the + common leading prefix removed. + + Parameters + ---------- + prefix : str + + Returns + ------- + AsyncGenerator[str, None] + """ if prefix.endswith("/"): prefix = prefix[:-1] diff --git a/src/zarr/store/remote.py b/src/zarr/store/remote.py index 0b9e3bb8c..5963a33ea 100644 --- a/src/zarr/store/remote.py +++ b/src/zarr/store/remote.py @@ -205,6 +205,19 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: yield onefile async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: + """ + Retrieve all keys in the store that begin with a given prefix. Keys are returned with the + common leading prefix removed. + + Parameters + ---------- + prefix : str + + Returns + ------- + AsyncGenerator[str, None] + """ + find_str = "/".join([self.path, prefix]) - for onefile in await self._fs._find(find_str): + for onefile in await self._fs._find(find_str, detail=False, maxdepth=None, withdirs=False): yield onefile.removeprefix(find_str) From 0b54e4cda9790acd4fe761439e1962ed4bade9fc Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 6 Sep 2024 11:19:21 +0200 Subject: [PATCH 5/6] add byterangerequest type --- src/zarr/abc/store.py | 16 ++++++++-------- src/zarr/core/common.py | 1 + 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 1ae66d72e..7630dca5b 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from asyncio import gather -from collections.abc import AsyncGenerator, Mapping +from collections.abc import AsyncGenerator, Iterable from typing import Any, NamedTuple, Protocol, runtime_checkable from typing_extensions import Self @@ -144,6 +144,13 @@ async def set(self, key: str, value: Buffer) -> None: """ ... + async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None: + """ + Insert multiple (key, value) pairs into storage. + """ + await gather(*(self.set(key, value) for key, value in values)) + return None + @abstractmethod async def delete(self, key: str) -> None: """Remove a key from the store @@ -226,13 +233,6 @@ def close(self) -> None: self._is_open = False pass - async def _set_dict(self, dict: Mapping[str, Buffer]) -> None: - """ - Insert objects into storage as defined by a prefix: value mapping. - """ - await gather(*(self.set(key, value) for key, value in dict.items())) - return None - @runtime_checkable class ByteGetter(Protocol): diff --git a/src/zarr/core/common.py b/src/zarr/core/common.py index aaa30cfcb..fd238e792 100644 --- a/src/zarr/core/common.py +++ b/src/zarr/core/common.py @@ -27,6 +27,7 @@ ZGROUP_JSON = ".zgroup" ZATTRS_JSON = ".zattrs" +ByteRangeRequest = tuple[int | None, int | None] BytesLike = bytes | bytearray | memoryview ChunkCoords = tuple[int, ...] ChunkCoordsLike = Iterable[int] From 49b4c1a63e45432707248bcf4d336076435f0597 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 19 Sep 2024 18:03:01 +0200 Subject: [PATCH 6/6] fix: activate test of localstore.list_prefix, fix zipstore.list_prefix --- src/zarr/store/local.py | 4 ++-- src/zarr/store/zip.py | 14 +++++++++++++- src/zarr/testing/store.py | 22 ++++++++++++---------- tests/v3/test_store/test_local.py | 3 --- 4 files changed, 27 insertions(+), 16 deletions(-) diff --git a/src/zarr/store/local.py b/src/zarr/store/local.py index 1c00ed088..c78837586 100644 --- a/src/zarr/store/local.py +++ b/src/zarr/store/local.py @@ -203,10 +203,10 @@ async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: ------- AsyncGenerator[str, None] """ - to_strip = str(self.root) + "/" + to_strip = os.path.join(str(self.root / prefix)) for p in (self.root / prefix).rglob("*"): if p.is_file(): - yield str(p).removeprefix(to_strip) + yield str(p.relative_to(to_strip)) async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: """ diff --git a/src/zarr/store/zip.py b/src/zarr/store/zip.py index ea31ad934..2e4927ace 100644 --- a/src/zarr/store/zip.py +++ b/src/zarr/store/zip.py @@ -209,9 +209,21 @@ async def list(self) -> AsyncGenerator[str, None]: yield key async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: + """ + Retrieve all keys in the store that begin with a given prefix. Keys are returned with the + common leading prefix removed. + + Parameters + ---------- + prefix : str + + Returns + ------- + AsyncGenerator[str, None] + """ async for key in self.list(): if key.startswith(prefix): - yield key + yield key.removeprefix(prefix) async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: if prefix.endswith("/"): diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 885f7d0cb..8a9f27e4b 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -129,7 +129,7 @@ async def test_set_many(self, store: S) -> None: `_set_many` method. """ keys = ["zarr.json", "c/0", "foo/c/0.0", "foo/0/0"] - data_buf = [Buffer.from_bytes(k.encode()) for k in keys] + data_buf = [self.buffer_cls.from_bytes(k.encode()) for k in keys] store_dict = dict(zip(keys, data_buf, strict=True)) await store._set_many(store_dict.items()) for k, v in store_dict.items(): @@ -199,7 +199,7 @@ async def test_clear(self, store: S) -> None: async def test_list(self, store: S) -> None: assert await _collect_aiterator(store.list()) == () prefix = "foo" - data = Buffer.from_bytes(b"") + data = self.buffer_cls.from_bytes(b"") store_dict = { prefix + "/zarr.json": data, **{prefix + f"/c/{idx}": data for idx in range(10)}, @@ -217,24 +217,26 @@ async def test_list_prefix(self, store: S) -> None: prefix removed. """ prefixes = ("", "a/", "a/b/", "a/b/c/") - data = Buffer.from_bytes(b"") + data = self.buffer_cls.from_bytes(b"") fname = "zarr.json" store_dict = {p + fname: data for p in prefixes} + await store._set_many(store_dict.items()) - for p in prefixes: - observed = tuple(sorted(await _collect_aiterator(store.list_prefix(p)))) + + for prefix in prefixes: + observed = tuple(sorted(await _collect_aiterator(store.list_prefix(prefix)))) expected: tuple[str, ...] = () - for k in store_dict.keys(): - if k.startswith(p): - expected += (k.removeprefix(p),) + for key in store_dict.keys(): + if key.startswith(prefix): + expected += (key.removeprefix(prefix),) expected = tuple(sorted(expected)) assert observed == expected async def test_list_dir(self, store: S) -> None: root = "foo" store_dict = { - root + "/zarr.json": Buffer.from_bytes(b"bar"), - root + "/c/1": Buffer.from_bytes(b"\x01"), + root + "/zarr.json": self.buffer_cls.from_bytes(b"bar"), + root + "/c/1": self.buffer_cls.from_bytes(b"\x01"), } assert await _collect_aiterator(store.list_dir("")) == () diff --git a/tests/v3/test_store/test_local.py b/tests/v3/test_store/test_local.py index 59cae22de..5f1dde3fc 100644 --- a/tests/v3/test_store/test_local.py +++ b/tests/v3/test_store/test_local.py @@ -35,6 +35,3 @@ def test_store_supports_partial_writes(self, store: LocalStore) -> None: def test_store_supports_listing(self, store: LocalStore) -> None: assert store.supports_listing - - def test_list_prefix(self, store: LocalStore) -> None: - assert True