diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index c453733f0..42eb18ce0 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -1,16 +1,24 @@ +from __future__ import annotations + from abc import ABC, abstractmethod from asyncio import gather from collections.abc import AsyncGenerator, Iterable -from types import TracebackType -from typing import Any, NamedTuple, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, runtime_checkable + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator, Iterable + from types import TracebackType + from typing import Any, TypeAlias -from typing_extensions import Self + from typing_extensions import Self -from zarr.core.buffer import Buffer, BufferPrototype -from zarr.core.common import AccessModeLiteral, BytesLike + from zarr.core.buffer import Buffer, BufferPrototype + from zarr.core.common import AccessModeLiteral, BytesLike __all__ = ["Store", "AccessMode", "ByteGetter", "ByteSetter", "set_or_delete"] +ByteRangeRequest: TypeAlias = tuple[int | None, int | None] + class AccessMode(NamedTuple): str: AccessModeLiteral @@ -100,14 +108,14 @@ async def get( self, key: str, prototype: BufferPrototype, - byte_range: tuple[int | None, int | None] | None = None, + byte_range: ByteRangeRequest | None = None, ) -> Buffer | None: """Retrieve the value associated with a given key. Parameters ---------- key : str - byte_range : tuple[int, Optional[int]], optional + byte_range : tuple[int | None, int | None], optional Returns ------- @@ -119,13 +127,13 @@ async def get( async def get_partial_values( self, prototype: BufferPrototype, - key_ranges: list[tuple[str, tuple[int | None, int | None]]], + key_ranges: Iterable[tuple[str, ByteRangeRequest]], ) -> list[Buffer | None]: """Retrieve possibly partial values from given key_ranges. Parameters ---------- - key_ranges : list[tuple[str, tuple[int, int]]] + key_ranges : Iterable[tuple[str, tuple[int | None, int | None]]] Ordered set of key, range pairs, a key may occur multiple times with different ranges Returns @@ -195,7 +203,9 @@ def supports_partial_writes(self) -> bool: ... @abstractmethod - async def set_partial_values(self, key_start_values: list[tuple[str, int, BytesLike]]) -> None: + async def set_partial_values( + self, key_start_values: Iterable[tuple[str, int, BytesLike]] + ) -> None: """Store values at a given key, starting at byte range_start. Parameters @@ -259,21 +269,32 @@ def close(self) -> None: """Close the store.""" self._is_open = False + async def _get_many( + self, requests: Iterable[tuple[str, BufferPrototype, ByteRangeRequest | None]] + ) -> AsyncGenerator[tuple[str, Buffer | None], None]: + """ + Retrieve a collection of objects from storage. In general this method does not guarantee + that objects will be retrieved in the order in which they were requested, so this method + yields tuple[str, Buffer | None] instead of just Buffer | None + """ + for req in requests: + yield (req[0], await self.get(*req)) + @runtime_checkable class ByteGetter(Protocol): async def get( - self, prototype: BufferPrototype, byte_range: tuple[int, int | None] | None = None + self, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None ) -> Buffer | None: ... @runtime_checkable class ByteSetter(Protocol): async def get( - self, prototype: BufferPrototype, byte_range: tuple[int, int | None] | None = None + self, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None ) -> Buffer | None: ... - async def set(self, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: ... + async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -> None: ... async def delete(self) -> None: ... diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 6282750f2..2f8946e46 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -17,7 +17,7 @@ Codec, CodecPipeline, ) -from zarr.abc.store import ByteGetter, ByteSetter +from zarr.abc.store import ByteGetter, ByteRangeRequest, ByteSetter from zarr.codecs.bytes import BytesCodec from zarr.codecs.crc32c_ import Crc32cCodec from zarr.core.array_spec import ArraySpec @@ -78,7 +78,7 @@ class _ShardingByteGetter(ByteGetter): chunk_coords: ChunkCoords async def get( - self, prototype: BufferPrototype, byte_range: tuple[int, int | None] | None = None + self, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None ) -> Buffer | None: assert byte_range is None, "byte_range is not supported within shards" assert ( @@ -91,7 +91,7 @@ async def get( class _ShardingByteSetter(_ShardingByteGetter, ByteSetter): shard_dict: ShardMutableMapping - async def set(self, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: + async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -> None: assert byte_range is None, "byte_range is not supported within shards" self.shard_dict[self.chunk_coords] = value diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index f7747e9b2..fac0facd7 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -49,6 +49,8 @@ OrthogonalSelection, Selection, VIndex, + _iter_grid, + ceildiv, check_fields, check_no_multi_fields, is_pure_fancy_indexing, @@ -58,7 +60,7 @@ ) from zarr.core.metadata.v2 import ArrayV2Metadata from zarr.core.metadata.v3 import ArrayV3Metadata -from zarr.core.sync import sync +from zarr.core.sync import collect_aiterator, sync from zarr.registry import get_pipeline_class from zarr.store import StoreLike, StorePath, make_store_path from zarr.store.common import ( @@ -66,7 +68,7 @@ ) if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Iterable, Iterator, Sequence from zarr.abc.codec import Codec, CodecPipeline from zarr.core.metadata.common import ArrayMetadata @@ -390,10 +392,12 @@ def shape(self) -> ChunkCoords: def chunks(self) -> ChunkCoords: if isinstance(self.metadata.chunk_grid, RegularChunkGrid): return self.metadata.chunk_grid.chunk_shape - else: - raise TypeError( - f"chunk attribute is only available for RegularChunkGrid, this array has a {self.metadata.chunk_grid}" - ) + + msg = ( + f"The `chunks` attribute is only defined for arrays using `RegularChunkGrid`." + f"This array has a {self.metadata.chunk_grid} instead." + ) + raise NotImplementedError(msg) @property def size(self) -> int: @@ -434,6 +438,111 @@ def basename(self) -> str | None: return self.name.split("/")[-1] return None + @property + def cdata_shape(self) -> ChunkCoords: + """ + The shape of the chunk grid for this array. + """ + return tuple(ceildiv(s, c) for s, c in zip(self.shape, self.chunks, strict=False)) + + @property + def nchunks(self) -> int: + """ + The number of chunks in the stored representation of this array. + """ + return product(self.cdata_shape) + + @property + def nchunks_initialized(self) -> int: + """ + The number of chunks that have been persisted in storage. + """ + return nchunks_initialized(self) + + def _iter_chunk_coords( + self, *, origin: Sequence[int] | None = None, selection_shape: Sequence[int] | None = None + ) -> Iterator[ChunkCoords]: + """ + Create an iterator over the coordinates of chunks in chunk grid space. If the `origin` + keyword is used, iteration will start at the chunk index specified by `origin`. + The default behavior is to start at the origin of the grid coordinate space. + If the `selection_shape` keyword is used, iteration will be bounded over a contiguous region + ranging from `[origin, origin + selection_shape]`, where the upper bound is exclusive as + per python indexing conventions. + + Parameters + ---------- + origin: Sequence[int] | None, default=None + The origin of the selection relative to the array's chunk grid. + selection_shape: Sequence[int] | None, default=None + The shape of the selection in chunk grid coordinates. + + Yields + ------ + chunk_coords: ChunkCoords + The coordinates of each chunk in the selection. + """ + return _iter_grid(self.cdata_shape, origin=origin, selection_shape=selection_shape) + + def _iter_chunk_keys( + self, *, origin: Sequence[int] | None = None, selection_shape: Sequence[int] | None = None + ) -> Iterator[str]: + """ + Iterate over the storage keys of each chunk, relative to an optional origin, and optionally + limited to a contiguous region in chunk grid coordinates. + + Parameters + ---------- + origin: Sequence[int] | None, default=None + The origin of the selection relative to the array's chunk grid. + selection_shape: Sequence[int] | None, default=None + The shape of the selection in chunk grid coordinates. + + Yields + ------ + key: str + The storage key of each chunk in the selection. + """ + # Iterate over the coordinates of chunks in chunk grid space. + for k in self._iter_chunk_coords(origin=origin, selection_shape=selection_shape): + # Encode the chunk key from the chunk coordinates. + yield self.metadata.encode_chunk_key(k) + + def _iter_chunk_regions( + self, *, origin: Sequence[int] | None = None, selection_shape: Sequence[int] | None = None + ) -> Iterator[tuple[slice, ...]]: + """ + Iterate over the regions spanned by each chunk. + + Parameters + ---------- + origin: Sequence[int] | None, default=None + The origin of the selection relative to the array's chunk grid. + selection_shape: Sequence[int] | None, default=None + The shape of the selection in chunk grid coordinates. + + Yields + ------ + region: tuple[slice, ...] + A tuple of slice objects representing the region spanned by each chunk in the selection. + """ + for cgrid_position in self._iter_chunk_coords( + origin=origin, selection_shape=selection_shape + ): + out: tuple[slice, ...] = () + for c_pos, c_shape in zip(cgrid_position, self.chunks, strict=False): + start = c_pos * c_shape + stop = start + c_shape + out += (slice(start, stop, 1),) + yield out + + @property + def nbytes(self) -> int: + """ + The number of bytes that can be stored in this array. + """ + return self.nchunks * self.dtype.itemsize + async def _get_selection( self, indexer: Indexer, @@ -742,6 +851,106 @@ def read_only(self) -> bool: def fill_value(self) -> Any: return self.metadata.fill_value + @property + def cdata_shape(self) -> ChunkCoords: + """ + The shape of the chunk grid for this array. + """ + return tuple(ceildiv(s, c) for s, c in zip(self.shape, self.chunks, strict=False)) + + @property + def nchunks(self) -> int: + """ + The number of chunks in the stored representation of this array. + """ + return self._async_array.nchunks + + def _iter_chunk_coords( + self, origin: Sequence[int] | None = None, selection_shape: Sequence[int] | None = None + ) -> Iterator[ChunkCoords]: + """ + Create an iterator over the coordinates of chunks in chunk grid space. If the `origin` + keyword is used, iteration will start at the chunk index specified by `origin`. + The default behavior is to start at the origin of the grid coordinate space. + If the `selection_shape` keyword is used, iteration will be bounded over a contiguous region + ranging from `[origin, origin + selection_shape]`, where the upper bound is exclusive as + per python indexing conventions. + + Parameters + ---------- + origin: Sequence[int] | None, default=None + The origin of the selection relative to the array's chunk grid. + selection_shape: Sequence[int] | None, default=None + The shape of the selection in chunk grid coordinates. + + Yields + ------ + chunk_coords: ChunkCoords + The coordinates of each chunk in the selection. + """ + yield from self._async_array._iter_chunk_coords( + origin=origin, selection_shape=selection_shape + ) + + @property + def nbytes(self) -> int: + """ + The number of bytes that can be stored in this array. + """ + return self._async_array.nbytes + + @property + def nchunks_initialized(self) -> int: + """ + The number of chunks that have been initialized in the stored representation of this array. + """ + return self._async_array.nchunks_initialized + + def _iter_chunk_keys( + self, origin: Sequence[int] | None = None, selection_shape: Sequence[int] | None = None + ) -> Iterator[str]: + """ + Iterate over the storage keys of each chunk, relative to an optional origin, and optionally + limited to a contiguous region in chunk grid coordinates. + + Parameters + ---------- + origin: Sequence[int] | None, default=None + The origin of the selection relative to the array's chunk grid. + selection_shape: Sequence[int] | None, default=None + The shape of the selection in chunk grid coordinates. + + Yields + ------ + key: str + The storage key of each chunk in the selection. + """ + yield from self._async_array._iter_chunk_keys( + origin=origin, selection_shape=selection_shape + ) + + def _iter_chunk_regions( + self, origin: Sequence[int] | None = None, selection_shape: Sequence[int] | None = None + ) -> Iterator[tuple[slice, ...]]: + """ + Iterate over the regions spanned by each chunk. + + Parameters + ---------- + origin: Sequence[int] | None, default=None + The origin of the selection relative to the array's chunk grid. + selection_shape: Sequence[int] | None, default=None + The shape of the selection in chunk grid coordinates. + + Yields + ------ + region: tuple[slice, ...] + A tuple of slice objects representing the region spanned by each chunk in the selection. + """ + yield from self._async_array._iter_chunk_regions( + origin=origin, selection_shape=selection_shape + ) + def __array__( self, dtype: npt.DTypeLike | None = None, copy: bool | None = None ) -> NDArrayLike: @@ -2073,3 +2282,57 @@ def info(self) -> None: return sync( self._async_array.info(), ) + + +def nchunks_initialized(array: AsyncArray | Array) -> int: + """ + Calculate the number of chunks that have been initialized, i.e. the number of chunks that have + been persisted to the storage backend. + + Parameters + ---------- + array : Array + The array to inspect. + + Returns + ------- + nchunks_initialized : int + The number of chunks that have been initialized. + + See Also + -------- + chunks_initialized + """ + return len(chunks_initialized(array)) + + +def chunks_initialized(array: Array | AsyncArray) -> tuple[str, ...]: + """ + Return the keys of the chunks that have been persisted to the storage backend. + + Parameters + ---------- + array : Array + The array to inspect. + + Returns + ------- + chunks_initialized : tuple[str, ...] + The keys of the chunks that have been initialized. + + See Also + -------- + nchunks_initialized + + """ + # TODO: make this compose with the underlying async iterator + store_contents = list( + collect_aiterator(array.store_path.store.list_prefix(prefix=array.store_path.path)) + ) + out: list[str] = [] + + for chunk_key in array._iter_chunk_keys(): + if chunk_key in store_contents: + out.append(chunk_key) + + return tuple(out) diff --git a/src/zarr/core/common.py b/src/zarr/core/common.py index 6847bd419..80c743cc9 100644 --- a/src/zarr/core/common.py +++ b/src/zarr/core/common.py @@ -45,7 +45,7 @@ def product(tup: ChunkCoords) -> int: async def concurrent_map( - items: list[T], func: Callable[..., Awaitable[V]], limit: int | None = None + items: Iterable[T], func: Callable[..., Awaitable[V]], limit: int | None = None ) -> list[V]: if limit is None: return await asyncio.gather(*[func(*item) for item in items]) diff --git a/src/zarr/core/indexing.py b/src/zarr/core/indexing.py index 3968a057f..1c153fc16 100644 --- a/src/zarr/core/indexing.py +++ b/src/zarr/core/indexing.py @@ -12,8 +12,10 @@ from typing import ( TYPE_CHECKING, Any, + Literal, NamedTuple, Protocol, + TypeAlias, TypeGuard, TypeVar, cast, @@ -95,6 +97,84 @@ def ceildiv(a: float, b: float) -> int: return math.ceil(a / b) +_ArrayIndexingOrder: TypeAlias = Literal["lexicographic"] + + +def _iter_grid( + grid_shape: Sequence[int], + *, + origin: Sequence[int] | None = None, + selection_shape: Sequence[int] | None = None, + order: _ArrayIndexingOrder = "lexicographic", +) -> Iterator[ChunkCoords]: + """ + Iterate over the elements of grid of integers, with the option to restrict the domain of + iteration to a contiguous subregion of that grid. + + Parameters + ---------- + grid_shape: Sequence[int] + The size of the domain to iterate over. + origin: Sequence[int] | None, default=None + The first coordinate of the domain to return. + selection_shape: Sequence[int] | None, default=None + The shape of the selection. + order: Literal["lexicographic"], default="lexicographic" + The linear indexing order to use. + + Returns + ------- + + itertools.product object + An iterator over tuples of integers + + Examples + -------- + >>> tuple(iter_grid((1,))) + ((0,),) + + >>> tuple(iter_grid((2,3))) + ((0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2)) + + >>> tuple(iter_grid((2,3)), origin=(1,1)) + ((1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3)) + + >>> tuple(iter_grid((2,3)), origin=(1,1), selection_shape=(2,2)) + ((1, 1), (1, 2), (1, 3), (2, 1)) + """ + if origin is None: + origin_parsed = (0,) * len(grid_shape) + else: + if len(origin) != len(grid_shape): + msg = ( + "Shape and origin parameters must have the same length." + f"Got {len(grid_shape)} elements in shape, but {len(origin)} elements in origin." + ) + raise ValueError(msg) + origin_parsed = tuple(origin) + if selection_shape is None: + selection_shape_parsed = tuple( + g - o for o, g in zip(origin_parsed, grid_shape, strict=True) + ) + else: + selection_shape_parsed = tuple(selection_shape) + if order == "lexicographic": + dimensions: tuple[range, ...] = () + for idx, (o, gs, ss) in enumerate( + zip(origin_parsed, grid_shape, selection_shape_parsed, strict=True) + ): + if o + ss > gs: + raise IndexError( + f"Invalid selection shape ({selection_shape}) for origin ({origin}) and grid shape ({grid_shape}) at axis {idx}." + ) + dimensions += (range(o, o + ss),) + yield from itertools.product(*(dimensions)) + + else: + msg = f"Indexing order {order} is not supported at this time." # type: ignore[unreachable] + raise NotImplementedError(msg) + + def is_integer(x: Any) -> TypeGuard[int]: """True if x is an integer (both pure Python or NumPy).""" return isinstance(x, numbers.Integral) and not is_bool(x) diff --git a/src/zarr/store/_utils.py b/src/zarr/store/_utils.py index 04a06351c..cbc9c42bb 100644 --- a/src/zarr/store/_utils.py +++ b/src/zarr/store/_utils.py @@ -1,4 +1,9 @@ -from zarr.core.buffer import Buffer +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from zarr.core.buffer import Buffer def _normalize_interval_index( diff --git a/src/zarr/store/common.py b/src/zarr/store/common.py index 0c126c63d..f39edb19a 100644 --- a/src/zarr/store/common.py +++ b/src/zarr/store/common.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Literal -from zarr.abc.store import AccessMode, Store +from zarr.abc.store import AccessMode, ByteRangeRequest, Store from zarr.core.buffer import Buffer, default_buffer_prototype from zarr.core.common import ZARR_JSON, ZARRAY_JSON, ZGROUP_JSON, ZarrFormat from zarr.errors import ContainsArrayAndGroupError, ContainsArrayError, ContainsGroupError @@ -37,13 +37,13 @@ def __init__(self, store: Store, path: str | None = None) -> None: async def get( self, prototype: BufferPrototype | None = None, - byte_range: tuple[int, int | None] | None = None, + byte_range: ByteRangeRequest | None = None, ) -> Buffer | None: if prototype is None: prototype = default_buffer_prototype() return await self.store.get(self.path, prototype=prototype, byte_range=byte_range) - async def set(self, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: + async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -> None: if byte_range is not None: raise NotImplementedError("Store.set does not have partial writes yet") await self.store.set(self.path, value) diff --git a/src/zarr/store/local.py b/src/zarr/store/local.py index 39a94969e..f1bce769d 100644 --- a/src/zarr/store/local.py +++ b/src/zarr/store/local.py @@ -6,12 +6,12 @@ from pathlib import Path from typing import TYPE_CHECKING -from zarr.abc.store import Store +from zarr.abc.store import ByteRangeRequest, Store from zarr.core.buffer import Buffer from zarr.core.common import concurrent_map, to_thread if TYPE_CHECKING: - from collections.abc import AsyncGenerator + from collections.abc import AsyncGenerator, Iterable from zarr.core.buffer import BufferPrototype from zarr.core.common import AccessModeLiteral @@ -131,7 +131,7 @@ async def get( async def get_partial_values( self, prototype: BufferPrototype, - key_ranges: list[tuple[str, tuple[int | None, int | None]]], + key_ranges: Iterable[tuple[str, ByteRangeRequest]], ) -> list[Buffer | None]: """ Read byte ranges from multiple keys. @@ -161,7 +161,9 @@ async def set(self, key: str, value: Buffer) -> None: path = self.root / key await to_thread(_put, path, value) - async def set_partial_values(self, key_start_values: list[tuple[str, int, bytes]]) -> None: + async def set_partial_values( + self, key_start_values: Iterable[tuple[str, int, bytes | bytearray | memoryview]] + ) -> None: self._check_writable() args = [] for key, start, value in key_start_values: diff --git a/src/zarr/store/memory.py b/src/zarr/store/memory.py index 16599b260..ee4107b0a 100644 --- a/src/zarr/store/memory.py +++ b/src/zarr/store/memory.py @@ -2,13 +2,13 @@ from typing import TYPE_CHECKING -from zarr.abc.store import Store +from zarr.abc.store import ByteRangeRequest, Store from zarr.core.buffer import Buffer, gpu from zarr.core.common import concurrent_map from zarr.store._utils import _normalize_interval_index if TYPE_CHECKING: - from collections.abc import AsyncGenerator, MutableMapping + from collections.abc import AsyncGenerator, Iterable, MutableMapping from zarr.core.buffer import BufferPrototype from zarr.core.common import AccessModeLiteral @@ -73,10 +73,10 @@ async def get( async def get_partial_values( self, prototype: BufferPrototype, - key_ranges: list[tuple[str, tuple[int | None, int | None]]], + key_ranges: Iterable[tuple[str, ByteRangeRequest]], ) -> list[Buffer | None]: # All the key-ranges arguments goes with the same prototype - async def _get(key: str, byte_range: tuple[int, int | None]) -> Buffer | None: + async def _get(key: str, byte_range: ByteRangeRequest) -> Buffer | None: return await self.get(key, prototype=prototype, byte_range=byte_range) return await concurrent_map(key_ranges, _get, limit=None) @@ -106,7 +106,7 @@ async def delete(self, key: str) -> None: except KeyError: pass # Q(JH): why not raise? - async def set_partial_values(self, key_start_values: list[tuple[str, int, bytes]]) -> None: + async def set_partial_values(self, key_start_values: Iterable[tuple[str, int, bytes]]) -> None: raise NotImplementedError async def list(self) -> AsyncGenerator[str, None]: diff --git a/src/zarr/store/remote.py b/src/zarr/store/remote.py index 7aea8a378..284cd8d77 100644 --- a/src/zarr/store/remote.py +++ b/src/zarr/store/remote.py @@ -4,11 +4,11 @@ import fsspec -from zarr.abc.store import Store +from zarr.abc.store import ByteRangeRequest, Store from zarr.store.common import _dereference_path if TYPE_CHECKING: - from collections.abc import AsyncGenerator + from collections.abc import AsyncGenerator, Iterable from fsspec.asyn import AsyncFileSystem @@ -110,7 +110,7 @@ async def get( self, key: str, prototype: BufferPrototype, - byte_range: tuple[int | None, int | None] | None = None, + byte_range: ByteRangeRequest | None = None, ) -> Buffer | None: if not self._is_open: await self._open() @@ -177,7 +177,7 @@ async def exists(self, key: str) -> bool: async def get_partial_values( self, prototype: BufferPrototype, - key_ranges: list[tuple[str, tuple[int | None, int | None]]], + key_ranges: Iterable[tuple[str, ByteRangeRequest]], ) -> list[Buffer | None]: if key_ranges: paths, starts, stops = zip( @@ -203,7 +203,9 @@ async def get_partial_values( return [None if isinstance(r, Exception) else prototype.buffer.from_bytes(r) for r in res] - async def set_partial_values(self, key_start_values: list[tuple[str, int, BytesLike]]) -> None: + async def set_partial_values( + self, key_start_values: Iterable[tuple[str, int, BytesLike]] + ) -> None: raise NotImplementedError async def list(self) -> AsyncGenerator[str, None]: diff --git a/src/zarr/store/zip.py b/src/zarr/store/zip.py index cd9df4d37..949660913 100644 --- a/src/zarr/store/zip.py +++ b/src/zarr/store/zip.py @@ -7,11 +7,11 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Literal -from zarr.abc.store import Store +from zarr.abc.store import ByteRangeRequest, Store from zarr.core.buffer import Buffer, BufferPrototype if TYPE_CHECKING: - from collections.abc import AsyncGenerator + from collections.abc import AsyncGenerator, Iterable ZipStoreAccessModeLiteral = Literal["r", "w", "a"] @@ -128,7 +128,7 @@ def _get( self, key: str, prototype: BufferPrototype, - byte_range: tuple[int | None, int | None] | None = None, + byte_range: ByteRangeRequest | None = None, ) -> Buffer | None: try: with self._zf.open(key) as f: # will raise KeyError @@ -151,7 +151,7 @@ async def get( self, key: str, prototype: BufferPrototype, - byte_range: tuple[int | None, int | None] | None = None, + byte_range: ByteRangeRequest | None = None, ) -> Buffer | None: assert isinstance(key, str) @@ -161,7 +161,7 @@ async def get( async def get_partial_values( self, prototype: BufferPrototype, - key_ranges: list[tuple[str, tuple[int | None, int | None]]], + key_ranges: Iterable[tuple[str, ByteRangeRequest]], ) -> list[Buffer | None]: out = [] with self._lock: @@ -188,7 +188,7 @@ async def set(self, key: str, value: Buffer) -> None: with self._lock: self._set(key, value) - async def set_partial_values(self, key_start_values: list[tuple[str, int, bytes]]) -> None: + async def set_partial_values(self, key_start_values: Iterable[tuple[str, int, bytes]]) -> None: raise NotImplementedError async def delete(self, key: str) -> None: diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 70d2e16ef..7b78b8ed0 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -5,7 +5,7 @@ from zarr.abc.store import AccessMode, Store from zarr.core.buffer import Buffer, default_buffer_prototype -from zarr.core.sync import _collect_aiterator +from zarr.core.sync import _collect_aiterator, collect_aiterator from zarr.store._utils import _normalize_interval_index from zarr.testing.utils import assert_bytes_equal @@ -111,6 +111,28 @@ async def test_get( expected = data_buf[start : start + length] assert_bytes_equal(observed, expected) + async def test_get_many(self, store: S) -> None: + """ + Ensure that multiple keys can be retrieved at once with the _get_many method. + """ + keys = tuple(map(str, range(10))) + values = tuple(f"{k}".encode() for k in keys) + for k, v in zip(keys, values, strict=False): + self.set(store, k, self.buffer_cls.from_bytes(v)) + observed_buffers = collect_aiterator( + store._get_many( + zip( + keys, + (default_buffer_prototype(),) * len(keys), + (None,) * len(keys), + strict=False, + ) + ) + ) + observed_kvs = sorted(((k, b.to_bytes()) for k, b in observed_buffers)) # type: ignore[union-attr] + expected_kvs = sorted(((k, b) for k, b in zip(keys, values, strict=False))) + assert observed_kvs == expected_kvs + @pytest.mark.parametrize("key", ["zarr.json", "c/0", "foo/c/0.0", "foo/0/0"]) @pytest.mark.parametrize("data", [b"\x01\x02\x03\x04", b""]) async def test_set(self, store: S, key: str, data: bytes) -> None: diff --git a/tests/v3/conftest.py b/tests/v3/conftest.py index fc1f950ad..15a0b55b0 100644 --- a/tests/v3/conftest.py +++ b/tests/v3/conftest.py @@ -129,6 +129,16 @@ def array_fixture(request: pytest.FixtureRequest) -> npt.NDArray[Any]: ) +@pytest.fixture(params=(2, 3)) +def zarr_format(request: pytest.FixtureRequest) -> ZarrFormat: + if request.param == 2: + return 2 + elif request.param == 3: + return 3 + msg = f"Invalid zarr format requested. Got {request.param}, expected on of (2,3)." + raise ValueError(msg) + + settings.register_profile( "ci", max_examples=1000, diff --git a/tests/v3/test_array.py b/tests/v3/test_array.py index 02358cb39..95bbde174 100644 --- a/tests/v3/test_array.py +++ b/tests/v3/test_array.py @@ -1,12 +1,16 @@ import pickle +from itertools import accumulate from typing import Literal import numpy as np import pytest from zarr import Array, AsyncArray, Group +from zarr.core.array import chunks_initialized from zarr.core.buffer.cpu import NDBuffer from zarr.core.common import ZarrFormat +from zarr.core.indexing import ceildiv +from zarr.core.sync import sync from zarr.errors import ContainsArrayError, ContainsGroupError from zarr.store import LocalStore, MemoryStore from zarr.store.common import StorePath @@ -232,3 +236,72 @@ def test_serializable_sync_array(store: LocalStore, zarr_format: ZarrFormat) -> assert actual == expected np.testing.assert_array_equal(actual[:], expected[:]) + + +@pytest.mark.parametrize("test_cls", [Array, AsyncArray]) +@pytest.mark.parametrize("nchunks", [2, 5, 10]) +def test_nchunks(test_cls: type[Array] | type[AsyncArray], nchunks: int) -> None: + """ + Test that nchunks returns the number of chunks defined for the array. + """ + store = MemoryStore({}, mode="w") + shape = 100 + arr = Array.create(store, shape=(shape,), chunks=(ceildiv(shape, nchunks),), dtype="i4") + expected = nchunks + if test_cls == Array: + observed = arr.nchunks + else: + observed = arr._async_array.nchunks + assert observed == expected + + +@pytest.mark.parametrize("test_cls", [Array, AsyncArray]) +def test_nchunks_initialized(test_cls: type[Array] | type[AsyncArray]) -> None: + """ + Test that nchunks_initialized accurately returns the number of stored chunks. + """ + store = MemoryStore({}, mode="w") + arr = Array.create(store, shape=(100,), chunks=(10,), dtype="i4") + + # write chunks one at a time + for idx, region in enumerate(arr._iter_chunk_regions()): + arr[region] = 1 + expected = idx + 1 + if test_cls == Array: + observed = arr.nchunks_initialized + else: + observed = arr._async_array.nchunks_initialized + assert observed == expected + + # delete chunks + for idx, key in enumerate(arr._iter_chunk_keys()): + sync(arr.store_path.store.delete(key)) + if test_cls == Array: + observed = arr.nchunks_initialized + else: + observed = arr._async_array.nchunks_initialized + expected = arr.nchunks - idx - 1 + assert observed == expected + + +@pytest.mark.parametrize("test_cls", [Array, AsyncArray]) +def test_chunks_initialized(test_cls: type[Array] | type[AsyncArray]) -> None: + """ + Test that chunks_initialized accurately returns the keys of stored chunks. + """ + store = MemoryStore({}, mode="w") + arr = Array.create(store, shape=(100,), chunks=(10,), dtype="i4") + + chunks_accumulated = tuple( + accumulate(tuple(tuple(v.split(" ")) for v in arr._iter_chunk_keys())) + ) + for keys, region in zip(chunks_accumulated, arr._iter_chunk_regions(), strict=False): + arr[region] = 1 + + if test_cls == Array: + observed = sorted(chunks_initialized(arr)) + else: + observed = sorted(chunks_initialized(arr._async_array)) + + expected = sorted(keys) + assert observed == expected diff --git a/tests/v3/test_indexing.py b/tests/v3/test_indexing.py index da358afbd..59169c67b 100644 --- a/tests/v3/test_indexing.py +++ b/tests/v3/test_indexing.py @@ -1,5 +1,6 @@ from __future__ import annotations +import itertools from collections import Counter from typing import TYPE_CHECKING, Any from uuid import uuid4 @@ -16,6 +17,7 @@ CoordinateSelection, OrthogonalSelection, Selection, + _iter_grid, make_slice_selection, normalize_integer_selection, oindex, @@ -1861,6 +1863,56 @@ def test_orthogonal_bool_indexing_like_numpy_ix( assert_array_equal(expected, actual, err_msg=f"{selection=}") +@pytest.mark.parametrize("ndim", [1, 2, 3]) +@pytest.mark.parametrize("origin_0d", [None, (0,), (1,)]) +@pytest.mark.parametrize("selection_shape_0d", [None, (2,), (3,)]) +def test_iter_grid( + ndim: int, origin_0d: tuple[int] | None, selection_shape_0d: tuple[int] | None +) -> None: + """ + Test that iter_grid works as expected for 1, 2, and 3 dimensions. + """ + grid_shape = (5,) * ndim + + if origin_0d is not None: + origin_kwarg = origin_0d * ndim + origin = origin_kwarg + else: + origin_kwarg = None + origin = (0,) * ndim + + if selection_shape_0d is not None: + selection_shape_kwarg = selection_shape_0d * ndim + selection_shape = selection_shape_kwarg + else: + selection_shape_kwarg = None + selection_shape = tuple(gs - o for gs, o in zip(grid_shape, origin, strict=False)) + + observed = tuple( + _iter_grid(grid_shape, origin=origin_kwarg, selection_shape=selection_shape_kwarg) + ) + + # generate a numpy array of indices, and index it + coord_array = np.array(list(itertools.product(*[range(s) for s in grid_shape]))).reshape( + (*grid_shape, ndim) + ) + coord_array_indexed = coord_array[ + tuple(slice(o, o + s, 1) for o, s in zip(origin, selection_shape, strict=False)) + + (range(ndim),) + ] + + expected = tuple(map(tuple, coord_array_indexed.reshape(-1, ndim).tolist())) + assert observed == expected + + +def test_iter_grid_invalid() -> None: + """ + Ensure that a selection_shape that exceeds the grid_shape + origin produces an indexing error. + """ + with pytest.raises(IndexError): + list(_iter_grid((5,), origin=(0,), selection_shape=(10,))) + + def test_indexing_with_zarr_array(store: StorePath) -> None: # regression test for https://github.com/zarr-developers/zarr-python/issues/2133 a = np.arange(10) diff --git a/tests/v3/test_store/test_remote.py b/tests/v3/test_store/test_remote.py index ca74fc184..18ba1e6d1 100644 --- a/tests/v3/test_store/test_remote.py +++ b/tests/v3/test_store/test_remote.py @@ -84,10 +84,6 @@ def s3(s3_base: None) -> Generator[s3fs.S3FileSystem, None, None]: # ### end from s3fs ### # -async def alist(it): - return [a async for a in it] - - async def test_basic() -> None: store = RemoteStore.from_url( f"s3://{test_bucket_name}",