Skip to content

Commit

Permalink
Ensure parents are created when creating a node (#2262)
Browse files Browse the repository at this point in the history
* Ensure parents are created when creating a node

This updates our Array and Group creation methods to
ensure that parents implicitly defined through a nested path are also
created.

To accomplish this semi-safely and efficiently, we require a new
setdefulat method on the Store class.

* use the API

* fixed logging store

* Update src/zarr/testing/store.py

* fixes

* fixup

* fixes

* pre-commit

---------

Co-authored-by: Joe Hamman <joe@earthmover.io>
  • Loading branch information
TomAugspurger and jhamman authored Sep 27, 2024
1 parent 8e2c660 commit 2edc548
Show file tree
Hide file tree
Showing 14 changed files with 221 additions and 12 deletions.
18 changes: 18 additions & 0 deletions src/zarr/abc/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,22 @@ async def set(self, key: str, value: Buffer) -> None:
"""
...

async def set_if_not_exists(self, key: str, value: Buffer) -> None:
"""
Store a key to ``value`` if the key is not already present.
Parameters
-----------
key : str
value : Buffer
"""
# Note for implementers: the default implementation provided here
# is not safe for concurrent writers. There's a race condition between
# the `exists` check and the `set` where another writer could set some
# value at `key` or delete `key`.
if not await self.exists(key):
await self.set(key, value)

async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None:
"""
Insert multiple (key, value) pairs into storage.
Expand Down Expand Up @@ -297,6 +313,8 @@ async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -

async def delete(self) -> None: ...

async def set_if_not_exists(self, default: Buffer) -> None: ...


async def set_or_delete(byte_setter: ByteSetter, value: Buffer | None) -> None:
if value is None:
Expand Down
3 changes: 3 additions & 0 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -
async def delete(self) -> None:
del self.shard_dict[self.chunk_coords]

async def set_if_not_exists(self, default: Buffer) -> None:
self.shard_dict.setdefault(self.chunk_coords, default)


class _ShardIndex(NamedTuple):
# dtype uint64, shape (chunks_per_shard_0, chunks_per_shard_1, ..., 2)
Expand Down
47 changes: 43 additions & 4 deletions src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
from zarr.codecs import BytesCodec
from zarr.codecs._v2 import V2Compressor, V2Filters
from zarr.core.attributes import Attributes
from zarr.core.buffer import BufferPrototype, NDArrayLike, NDBuffer, default_buffer_prototype
from zarr.core.buffer import (
BufferPrototype,
NDArrayLike,
NDBuffer,
default_buffer_prototype,
)
from zarr.core.chunk_grids import RegularChunkGrid, _guess_chunks
from zarr.core.chunk_key_encodings import (
ChunkKeyEncoding,
Expand Down Expand Up @@ -71,6 +76,7 @@
from collections.abc import Iterable, Iterator, Sequence

from zarr.abc.codec import Codec, CodecPipeline
from zarr.core.group import AsyncGroup
from zarr.core.metadata.common import ArrayMetadata

# Array and AsyncArray are defined in the base ``zarr`` namespace
Expand Down Expand Up @@ -337,7 +343,7 @@ async def _create_v3(
)

array = cls(metadata=metadata, store_path=store_path)
await array._save_metadata(metadata)
await array._save_metadata(metadata, ensure_parents=True)
return array

@classmethod
Expand Down Expand Up @@ -376,7 +382,7 @@ async def _create_v2(
attributes=attributes,
)
array = cls(metadata=metadata, store_path=store_path)
await array._save_metadata(metadata)
await array._save_metadata(metadata, ensure_parents=True)
return array

@classmethod
Expand Down Expand Up @@ -621,9 +627,24 @@ async def getitem(
)
return await self._get_selection(indexer, prototype=prototype)

async def _save_metadata(self, metadata: ArrayMetadata) -> None:
async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = False) -> None:
to_save = metadata.to_buffer_dict(default_buffer_prototype())
awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()]

if ensure_parents:
# To enable zarr.create(store, path="a/b/c"), we need to create all the intermediate groups.
parents = _build_parents(self)

for parent in parents:
awaitables.extend(
[
(parent.store_path / key).set_if_not_exists(value)
for key, value in parent.metadata.to_buffer_dict(
default_buffer_prototype()
).items()
]
)

await gather(*awaitables)

async def _set_selection(
Expand Down Expand Up @@ -2354,3 +2375,21 @@ def chunks_initialized(array: Array | AsyncArray) -> tuple[str, ...]:
out.append(chunk_key)

return tuple(out)


def _build_parents(node: AsyncArray | AsyncGroup) -> list[AsyncGroup]:
from zarr.core.group import AsyncGroup, GroupMetadata

required_parts = node.store_path.path.split("/")[:-1]
parents = []

for i, part in enumerate(required_parts):
path = "/".join(required_parts[:i] + [part])
parents.append(
AsyncGroup(
metadata=GroupMetadata(zarr_format=node.metadata.zarr_format),
store_path=StorePath(store=node.store_path.store, path=path),
)
)

return parents
19 changes: 16 additions & 3 deletions src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import zarr.api.asynchronous as async_api
from zarr.abc.metadata import Metadata
from zarr.abc.store import Store, set_or_delete
from zarr.core.array import Array, AsyncArray
from zarr.core.array import Array, AsyncArray, _build_parents
from zarr.core.attributes import Attributes
from zarr.core.buffer import default_buffer_prototype
from zarr.core.common import (
Expand Down Expand Up @@ -144,7 +144,7 @@ async def from_store(
metadata=GroupMetadata(attributes=attributes, zarr_format=zarr_format),
store_path=store_path,
)
await group._save_metadata()
await group._save_metadata(ensure_parents=True)
return group

@classmethod
Expand Down Expand Up @@ -279,9 +279,22 @@ async def delitem(self, key: str) -> None:
else:
raise ValueError(f"unexpected zarr_format: {self.metadata.zarr_format}")

async def _save_metadata(self) -> None:
async def _save_metadata(self, ensure_parents: bool = False) -> None:
to_save = self.metadata.to_buffer_dict(default_buffer_prototype())
awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()]

if ensure_parents:
parents = _build_parents(self)
for parent in parents:
awaitables.extend(
[
(parent.store_path / key).set_if_not_exists(value)
for key, value in parent.metadata.to_buffer_dict(
default_buffer_prototype()
).items()
]
)

await asyncio.gather(*awaitables)

@property
Expand Down
3 changes: 3 additions & 0 deletions src/zarr/store/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -
async def delete(self) -> None:
await self.store.delete(self.path)

async def set_if_not_exists(self, default: Buffer) -> None:
await self.store.set_if_not_exists(self.path, default)

async def exists(self) -> bool:
return await self.store.exists(self.path)

Expand Down
20 changes: 18 additions & 2 deletions src/zarr/store/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def _put(
path: Path,
value: Buffer,
start: int | None = None,
exclusive: bool = False,
) -> int | None:
path.parent.mkdir(parents=True, exist_ok=True)
if start is not None:
Expand All @@ -68,7 +69,13 @@ def _put(
f.write(value.as_numpy_array().tobytes())
return None
else:
return path.write_bytes(value.as_numpy_array().tobytes())
view = memoryview(value.as_numpy_array().tobytes())
if exclusive:
mode = "xb"
else:
mode = "wb"
with path.open(mode=mode) as f:
return f.write(view)


class LocalStore(Store):
Expand Down Expand Up @@ -152,14 +159,23 @@ async def get_partial_values(
return await concurrent_map(args, to_thread, limit=None) # TODO: fix limit

async def set(self, key: str, value: Buffer) -> None:
return await self._set(key, value)

async def set_if_not_exists(self, key: str, value: Buffer) -> None:
try:
return await self._set(key, value, exclusive=True)
except FileExistsError:
pass

async def _set(self, key: str, value: Buffer, exclusive: bool = False) -> None:
if not self._is_open:
await self._open()
self._check_writable()
assert isinstance(key, str)
if not isinstance(value, Buffer):
raise TypeError("LocalStore.set(): `value` must a Buffer instance")
path = self.root / key
await to_thread(_put, path, value)
await to_thread(_put, path, value, start=None, exclusive=exclusive)

async def set_partial_values(
self, key_start_values: Iterable[tuple[str, int, bytes | bytearray | memoryview]]
Expand Down
5 changes: 5 additions & 0 deletions src/zarr/store/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import TYPE_CHECKING

from zarr.abc.store import AccessMode, ByteRangeRequest, Store
from zarr.core.buffer import Buffer

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Generator, Iterable
Expand Down Expand Up @@ -138,6 +139,10 @@ async def set(self, key: str, value: Buffer) -> None:
with self.log():
return await self._store.set(key=key, value=value)

async def set_if_not_exists(self, key: str, default: Buffer) -> None:
with self.log():
return await self._store.set_if_not_exists(key=key, value=default)

async def delete(self, key: str) -> None:
with self.log():
return await self._store.delete(key=key)
Expand Down
8 changes: 6 additions & 2 deletions src/zarr/store/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,8 @@ async def exists(self, key: str) -> bool:
return key in self._store_dict

async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None:
if not self._is_open:
await self._open()
self._check_writable()
await self._ensure_open()
assert isinstance(key, str)
if not isinstance(value, Buffer):
raise TypeError(f"Expected Buffer. Got {type(value)}.")
Expand All @@ -99,6 +98,11 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None
else:
self._store_dict[key] = value

async def set_if_not_exists(self, key: str, default: Buffer) -> None:
self._check_writable()
await self._ensure_open()
self._store_dict.setdefault(key, default)

async def delete(self, key: str) -> None:
self._check_writable()
try:
Expand Down
1 change: 1 addition & 0 deletions src/zarr/store/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import fsspec

from zarr.abc.store import ByteRangeRequest, Store
from zarr.core.buffer import Buffer
from zarr.store.common import _dereference_path

if TYPE_CHECKING:
Expand Down
7 changes: 7 additions & 0 deletions src/zarr/store/zip.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,13 @@ async def set(self, key: str, value: Buffer) -> None:
async def set_partial_values(self, key_start_values: Iterable[tuple[str, int, bytes]]) -> None:
raise NotImplementedError

async def set_if_not_exists(self, key: str, default: Buffer) -> None:
self._check_writable()
with self._lock:
members = self._zf.namelist()
if key not in members:
self._set(key, default)

async def delete(self, key: str) -> None:
raise NotImplementedError

Expand Down
16 changes: 16 additions & 0 deletions src/zarr/testing/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,19 @@ async def test_list_dir(self, store: S) -> None:

keys_observed = await _collect_aiterator(store.list_dir(root + "/"))
assert sorted(keys_expected) == sorted(keys_observed)

async def test_set_if_not_exists(self, store: S) -> None:
key = "k"
data_buf = self.buffer_cls.from_bytes(b"0000")
self.set(store, key, data_buf)

new = self.buffer_cls.from_bytes(b"1111")
await store.set_if_not_exists("k", new) # no error

result = await store.get(key, default_buffer_prototype())
assert result == data_buf

await store.set_if_not_exists("k2", new) # no error

result = await store.get("k2", default_buffer_prototype())
assert result == new
43 changes: 43 additions & 0 deletions tests/v3/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import numpy as np
import pytest

import zarr.api.asynchronous
from zarr import Array, AsyncArray, Group
from zarr.codecs.bytes import BytesCodec
from zarr.core.array import chunks_initialized
from zarr.core.buffer.cpu import NDBuffer
from zarr.core.common import JSON, ZarrFormat
from zarr.core.group import AsyncGroup
from zarr.core.indexing import ceildiv
from zarr.core.sync import sync
from zarr.errors import ContainsArrayError, ContainsGroupError
Expand Down Expand Up @@ -66,6 +68,47 @@ def test_array_creation_existing_node(
)


@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"])
@pytest.mark.parametrize("zarr_format", [2, 3])
async def test_create_creates_parents(
store: LocalStore | MemoryStore, zarr_format: ZarrFormat
) -> None:
# prepare a root node, with some data set
await zarr.api.asynchronous.open_group(
store=store, path="a", zarr_format=zarr_format, attributes={"key": "value"}
)

# create a child node with a couple intermediates
await zarr.api.asynchronous.create(
shape=(2, 2), store=store, path="a/b/c/d", zarr_format=zarr_format
)
parts = ["a", "a/b", "a/b/c"]

if zarr_format == 2:
files = [".zattrs", ".zgroup"]
else:
files = ["zarr.json"]

expected = [f"{part}/{file}" for file in files for part in parts]

if zarr_format == 2:
expected.append("a/b/c/d/.zarray")
expected.append("a/b/c/d/.zattrs")
else:
expected.append("a/b/c/d/zarr.json")

expected = sorted(expected)

result = sorted([x async for x in store.list_prefix("")])

assert result == expected

paths = ["a", "a/b", "a/b/c"]
for path in paths:
g = await zarr.api.asynchronous.open_group(store=store, path=path)
assert isinstance(g, AsyncGroup)


@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"])
@pytest.mark.parametrize("zarr_format", [2, 3])
def test_array_name_properties_no_group(
Expand Down
Loading

0 comments on commit 2edc548

Please sign in to comment.