Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement store.list_prefix and store._set_many #2064

Merged
merged 9 commits into from
Sep 19, 2024
14 changes: 12 additions & 2 deletions src/zarr/abc/store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from asyncio import gather
from collections.abc import AsyncGenerator, Iterable
from typing import Any, NamedTuple, Protocol, runtime_checkable

from typing_extensions import Self
Expand Down Expand Up @@ -158,6 +159,13 @@ async def set(self, key: str, value: Buffer) -> None:
"""
...

async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None:
"""
Insert multiple (key, value) pairs into storage.
"""
await gather(*(self.set(key, value) for key, value in values))
return None

@property
@abstractmethod
def supports_deletes(self) -> bool:
Expand Down Expand Up @@ -211,7 +219,9 @@ def list(self) -> AsyncGenerator[str, None]:

@abstractmethod
def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
"""Retrieve all keys in the store with a given prefix.
"""
Retrieve all keys in the store that begin with a given prefix. Keys are returned with the
common leading prefix removed.

Parameters
----------
Expand Down
1 change: 1 addition & 0 deletions src/zarr/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
ZGROUP_JSON = ".zgroup"
ZATTRS_JSON = ".zattrs"

ByteRangeRequest = tuple[int | None, int | None]
BytesLike = bytes | bytearray | memoryview
ShapeLike = tuple[int, ...] | int
ChunkCoords = tuple[int, ...]
Expand Down
17 changes: 17 additions & 0 deletions src/zarr/core/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,23 @@ def _get_loop() -> asyncio.AbstractEventLoop:
return loop[0]


async def _collect_aiterator(data: AsyncIterator[T]) -> tuple[T, ...]:
"""
Collect an entire async iterator into a tuple
"""
result = []
async for x in data:
result.append(x)
return tuple(result)


def collect_aiterator(data: AsyncIterator[T]) -> tuple[T, ...]:
"""
Synchronously collect an entire async iterator into a tuple.
"""
return sync(_collect_aiterator(data))


class SyncMixin:
def _sync(self, coroutine: Coroutine[Any, Any, T]) -> T:
# TODO: refactor this to to take *args and **kwargs and pass those to the method
Expand Down
12 changes: 5 additions & 7 deletions src/zarr/store/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,9 @@ async def list(self) -> AsyncGenerator[str, None]:
yield str(p).replace(to_strip, "")

async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
"""Retrieve all keys in the store with a given prefix.
"""
Retrieve all keys in the store that begin with a given prefix. Keys are returned with the
common leading prefix removed.

Parameters
----------
Expand All @@ -201,14 +203,10 @@ async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
-------
AsyncGenerator[str, None]
"""
to_strip = os.path.join(str(self.root / prefix))
for p in (self.root / prefix).rglob("*"):
if p.is_file():
yield str(p)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Were we getting duplicates?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we were. this code path was not tested until this PR


to_strip = str(self.root) + "/"
for p in (self.root / prefix).rglob("*"):
if p.is_file():
yield str(p).replace(to_strip, "")
yield str(p.relative_to(to_strip))

async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
"""
Expand Down
14 changes: 13 additions & 1 deletion src/zarr/store/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,21 @@ async def list(self) -> AsyncGenerator[str, None]:
async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
for key in self._store_dict:
if key.startswith(prefix):
yield key
yield key.removeprefix(prefix)

async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
"""
Retrieve all keys in the store that begin with a given prefix. Keys are returned with the
common leading prefix removed.

Parameters
----------
prefix : str

Returns
-------
AsyncGenerator[str, None]
"""
if prefix.endswith("/"):
prefix = prefix[:-1]

Expand Down
18 changes: 16 additions & 2 deletions src/zarr/store/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,5 +216,19 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
yield onefile.removeprefix(self.path).removeprefix("/")

async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
for onefile in await self._fs._ls(prefix, detail=False):
yield onefile
"""
Retrieve all keys in the store that begin with a given prefix. Keys are returned with the
common leading prefix removed.

Parameters
----------
prefix : str

Returns
-------
AsyncGenerator[str, None]
"""

find_str = "/".join([self.path, prefix])
for onefile in await self._fs._find(find_str, detail=False, maxdepth=None, withdirs=False):
yield onefile.removeprefix(find_str)
14 changes: 13 additions & 1 deletion src/zarr/store/zip.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,21 @@ async def list(self) -> AsyncGenerator[str, None]:
yield key

async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
"""
Retrieve all keys in the store that begin with a given prefix. Keys are returned with the
common leading prefix removed.

Parameters
----------
prefix : str

Returns
-------
AsyncGenerator[str, None]
"""
async for key in self.list():
if key.startswith(prefix):
yield key
yield key.removeprefix(prefix)

async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
if prefix.endswith("/"):
Expand Down
123 changes: 58 additions & 65 deletions src/zarr/testing/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

import pytest

import zarr.api.asynchronous
from zarr.abc.store import AccessMode, Store
from zarr.core.buffer import Buffer, default_buffer_prototype
from zarr.core.sync import _collect_aiterator
from zarr.store._utils import _normalize_interval_index
from zarr.testing.utils import assert_bytes_equal

Expand Down Expand Up @@ -123,6 +123,18 @@ async def test_set(self, store: S, key: str, data: bytes) -> None:
observed = self.get(store, key)
assert_bytes_equal(observed, data_buf)

async def test_set_many(self, store: S) -> None:
"""
Test that a dict of key : value pairs can be inserted into the store via the
`_set_many` method.
"""
keys = ["zarr.json", "c/0", "foo/c/0.0", "foo/0/0"]
data_buf = [self.buffer_cls.from_bytes(k.encode()) for k in keys]
store_dict = dict(zip(keys, data_buf, strict=True))
await store._set_many(store_dict.items())
for k, v in store_dict.items():
assert self.get(store, k).to_bytes() == v.to_bytes()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if x.to_bytes() == y.to_bytes(), does x== y?

Isn't there a multiple get? Maybe not important here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if x.to_bytes() == y.to_bytes(), does x== y?

no, and I suspect this might be deliberate since in principle Buffer instances can have identical bytes but different devices (e.g., gpu memory vs host memory); thus x == y might only be true if two buffers are bytes-equal and device-equal, but I'm speculating here. @madsbk would have a better answer I think.

Isn't there a multiple get? Maybe not important here.

there is no multiple get (nor a multiple set, nor a multiple delete).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


@pytest.mark.parametrize(
"key_ranges",
(
Expand Down Expand Up @@ -185,76 +197,57 @@ async def test_clear(self, store: S) -> None:
assert await store.empty()

async def test_list(self, store: S) -> None:
assert [k async for k in store.list()] == []
await store.set("foo/zarr.json", self.buffer_cls.from_bytes(b"bar"))
keys = [k async for k in store.list()]
assert keys == ["foo/zarr.json"], keys

expected = ["foo/zarr.json"]
for i in range(10):
key = f"foo/c/{i}"
expected.append(key)
await store.set(
f"foo/c/{i}", self.buffer_cls.from_bytes(i.to_bytes(length=3, byteorder="little"))
)
assert await _collect_aiterator(store.list()) == ()
prefix = "foo"
data = self.buffer_cls.from_bytes(b"")
store_dict = {
prefix + "/zarr.json": data,
**{prefix + f"/c/{idx}": data for idx in range(10)},
}
await store._set_many(store_dict.items())
expected_sorted = sorted(store_dict.keys())
observed = await _collect_aiterator(store.list())
observed_sorted = sorted(observed)
assert observed_sorted == expected_sorted

@pytest.mark.xfail
async def test_list_prefix(self, store: S) -> None:
# TODO: we currently don't use list_prefix anywhere
raise NotImplementedError
"""
Test that the `list_prefix` method works as intended. Given a prefix, it should return
all the keys in storage that start with this prefix. Keys should be returned with the shared
prefix removed.
"""
prefixes = ("", "a/", "a/b/", "a/b/c/")
data = self.buffer_cls.from_bytes(b"")
fname = "zarr.json"
store_dict = {p + fname: data for p in prefixes}

await store._set_many(store_dict.items())

for prefix in prefixes:
observed = tuple(sorted(await _collect_aiterator(store.list_prefix(prefix))))
expected: tuple[str, ...] = ()
for key in store_dict.keys():
if key.startswith(prefix):
expected += (key.removeprefix(prefix),)
expected = tuple(sorted(expected))
assert observed == expected

async def test_list_dir(self, store: S) -> None:
out = [k async for k in store.list_dir("")]
assert out == []
assert [k async for k in store.list_dir("foo")] == []
await store.set("foo/zarr.json", self.buffer_cls.from_bytes(b"bar"))
await store.set("group-0/zarr.json", self.buffer_cls.from_bytes(b"\x01")) # group
await store.set("group-0/group-1/zarr.json", self.buffer_cls.from_bytes(b"\x01")) # group
await store.set("group-0/group-1/a1/zarr.json", self.buffer_cls.from_bytes(b"\x01"))
await store.set("group-0/group-1/a2/zarr.json", self.buffer_cls.from_bytes(b"\x01"))
await store.set("group-0/group-1/a3/zarr.json", self.buffer_cls.from_bytes(b"\x01"))

keys_expected = ["foo", "group-0"]
keys_observed = [k async for k in store.list_dir("")]
assert set(keys_observed) == set(keys_expected)

keys_expected = ["zarr.json"]
keys_observed = [k async for k in store.list_dir("foo")]

assert len(keys_observed) == len(keys_expected), keys_observed
assert set(keys_observed) == set(keys_expected), keys_observed

keys_observed = [k async for k in store.list_dir("foo/")]
assert len(keys_expected) == len(keys_observed), keys_observed
assert set(keys_observed) == set(keys_expected), keys_observed

keys_observed = [k async for k in store.list_dir("group-0")]
keys_expected = ["zarr.json", "group-1"]

assert len(keys_observed) == len(keys_expected), keys_observed
assert set(keys_observed) == set(keys_expected), keys_observed

keys_observed = [k async for k in store.list_dir("group-0/")]
assert len(keys_expected) == len(keys_observed), keys_observed
assert set(keys_observed) == set(keys_expected), keys_observed
root = "foo"
store_dict = {
root + "/zarr.json": self.buffer_cls.from_bytes(b"bar"),
root + "/c/1": self.buffer_cls.from_bytes(b"\x01"),
}

keys_observed = [k async for k in store.list_dir("group-0/group-1")]
keys_expected = ["zarr.json", "a1", "a2", "a3"]
assert await _collect_aiterator(store.list_dir("")) == ()
assert await _collect_aiterator(store.list_dir(root)) == ()

assert len(keys_observed) == len(keys_expected), keys_observed
assert set(keys_observed) == set(keys_expected), keys_observed
await store._set_many(store_dict.items())

keys_observed = [k async for k in store.list_dir("group-0/group-1")]
assert len(keys_expected) == len(keys_observed), keys_observed
assert set(keys_observed) == set(keys_expected), keys_observed
keys_observed = await _collect_aiterator(store.list_dir(root))
keys_expected = {k.removeprefix(root + "/").split("/")[0] for k in store_dict.keys()}

async def test_set_get(self, store_kwargs: dict[str, Any]) -> None:
kwargs = {**store_kwargs, **{"mode": "w"}}
store = self.store_cls(**kwargs)
await zarr.api.asynchronous.open_array(store=store, path="a", mode="w", shape=(4,))
keys = [x async for x in store.list()]
assert keys == ["a/zarr.json"]
assert sorted(keys_observed) == sorted(keys_expected)

# no errors
await zarr.api.asynchronous.open_array(store=store, path="a", mode="r")
await zarr.api.asynchronous.open_array(store=store, path="a", mode="a")
keys_observed = await _collect_aiterator(store.list_dir(root + "/"))
assert sorted(keys_expected) == sorted(keys_observed)
3 changes: 0 additions & 3 deletions tests/v3/test_store/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,3 @@ def test_store_supports_partial_writes(self, store: LocalStore) -> None:

def test_store_supports_listing(self, store: LocalStore) -> None:
assert store.supports_listing

def test_list_prefix(self, store: LocalStore) -> None:
assert True
24 changes: 15 additions & 9 deletions tests/v3/test_store/test_remote.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from collections.abc import Generator

import botocore.client

import os
from collections.abc import Generator

import botocore.client
import fsspec
import pytest
from botocore.session import Session
from upath import UPath

from zarr.core.buffer import Buffer, cpu, default_buffer_prototype
from zarr.core.sync import sync
from zarr.core.sync import _collect_aiterator, sync
from zarr.store import RemoteStore
from zarr.testing.store import StoreTests

Expand Down Expand Up @@ -40,8 +48,6 @@ def s3_base() -> Generator[None, None, None]:


def get_boto3_client() -> botocore.client.BaseClient:
from botocore.session import Session

# NB: we use the sync botocore client for setup
session = Session()
return session.create_client("s3", endpoint_url=endpoint_url)
Expand Down Expand Up @@ -87,7 +93,7 @@ async def test_basic() -> None:
store = await RemoteStore.open(
f"s3://{test_bucket_name}", mode="w", endpoint_url=endpoint_url, anon=False
)
assert not await alist(store.list())
assert await _collect_aiterator(store.list()) == ()
assert not await store.exists("foo")
data = b"hello"
await store.set("foo", cpu.Buffer.from_bytes(data))
Expand All @@ -104,7 +110,7 @@ class TestRemoteStoreS3(StoreTests[RemoteStore, cpu.Buffer]):
buffer_cls = cpu.Buffer

@pytest.fixture(scope="function", params=("use_upath", "use_str"))
def store_kwargs(self, request) -> dict[str, str | bool]:
def store_kwargs(self, request: pytest.FixtureRequest) -> dict[str, str | bool | UPath]: # type: ignore
url = f"s3://{test_bucket_name}"
anon = False
mode = "r+"
Expand All @@ -116,8 +122,8 @@ def store_kwargs(self, request) -> dict[str, str | bool]:
raise AssertionError

@pytest.fixture(scope="function")
def store(self, store_kwargs: dict[str, str | bool]) -> RemoteStore:
url = store_kwargs["url"]
async def store(self, store_kwargs: dict[str, str | bool | UPath]) -> RemoteStore:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't actually async

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct, but the class we are inheriting from defines this as an async method

url: str | UPath = store_kwargs["url"]
mode = store_kwargs["mode"]
if isinstance(url, UPath):
out = self.store_cls(url=url, mode=mode)
Expand Down