diff --git a/tests/v3/test_group.py b/tests/v3/test_group.py index 94b839a18..d5fb9e7b5 100644 --- a/tests/v3/test_group.py +++ b/tests/v3/test_group.py @@ -6,6 +6,7 @@ import numpy as np import pytest +import zarr from zarr import Array, AsyncArray, AsyncGroup, Group from zarr.abc.store import Store from zarr.core.buffer import default_buffer_prototype @@ -13,7 +14,7 @@ from zarr.core.group import GroupMetadata from zarr.core.sync import sync from zarr.errors import ContainsArrayError, ContainsGroupError -from zarr.store import LocalStore, StorePath +from zarr.store import LocalStore, MemoryStore, StorePath from zarr.store.common import make_store_path from .conftest import parse_store @@ -699,3 +700,154 @@ def test_serializable_sync_group(store: LocalStore, zarr_format: ZarrFormat) -> actual = pickle.loads(p) assert actual == expected + + +async def test_group_members_async(store: LocalStore | MemoryStore) -> None: + group = AsyncGroup( + GroupMetadata(), + store_path=StorePath(store=store, path="root"), + ) + a0 = await group.create_array("a0", shape=(1,)) + g0 = await group.create_group("g0") + a1 = await g0.create_array("a1", shape=(1,)) + g1 = await g0.create_group("g1") + a2 = await g1.create_array("a2", shape=(1,)) + g2 = await g1.create_group("g2") + + # immediate children + children = sorted([x async for x in group.members()], key=lambda x: x[0]) + assert children == [ + ("a0", a0), + ("g0", g0), + ] + + nmembers = await group.nmembers() + assert nmembers == 2 + + # partial + children = sorted([x async for x in group.members(max_depth=1)], key=lambda x: x[0]) + expected = [ + ("a0", a0), + ("g0", g0), + ("g0/a1", a1), + ("g0/g1", g1), + ] + assert children == expected + nmembers = await group.nmembers(max_depth=1) + assert nmembers == 4 + + # all children + all_children = sorted([x async for x in group.members(max_depth=None)], key=lambda x: x[0]) + expected = [ + ("a0", a0), + ("g0", g0), + ("g0/a1", a1), + ("g0/g1", g1), + ("g0/g1/a2", a2), + ("g0/g1/g2", g2), + ] + assert all_children == expected + + nmembers = await group.nmembers(max_depth=None) + assert nmembers == 6 + + with pytest.raises(ValueError, match="max_depth"): + [x async for x in group.members(max_depth=-1)] + + +async def test_require_group(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: + root = await AsyncGroup.create(store=store, zarr_format=zarr_format) + + # create foo group + _ = await root.create_group("foo", attributes={"foo": 100}) + + # test that we can get the group using require_group + foo_group = await root.require_group("foo") + assert foo_group.attrs == {"foo": 100} + + # test that we can get the group using require_group and overwrite=True + foo_group = await root.require_group("foo", overwrite=True) + + _ = await foo_group.create_array( + "bar", shape=(10,), dtype="uint8", chunk_shape=(2,), attributes={"foo": 100} + ) + + # test that overwriting a group w/ children fails + # TODO: figure out why ensure_no_existing_node is not catching the foo.bar array + # + # with pytest.raises(ContainsArrayError): + # await root.require_group("foo", overwrite=True) + + # test that requiring a group where an array is fails + with pytest.raises(TypeError): + await foo_group.require_group("bar") + + +async def test_require_groups(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: + root = await AsyncGroup.create(store=store, zarr_format=zarr_format) + # create foo group + _ = await root.create_group("foo", attributes={"foo": 100}) + # create bar group + _ = await root.create_group("bar", attributes={"bar": 200}) + + foo_group, bar_group = await root.require_groups("foo", "bar") + assert foo_group.attrs == {"foo": 100} + assert bar_group.attrs == {"bar": 200} + + # get a mix of existing and new groups + foo_group, spam_group = await root.require_groups("foo", "spam") + assert foo_group.attrs == {"foo": 100} + assert spam_group.attrs == {} + + # no names + no_group = await root.require_groups() + assert no_group == () + + +async def test_create_dataset(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: + root = await AsyncGroup.create(store=store, zarr_format=zarr_format) + with pytest.warns(DeprecationWarning): + foo = await root.create_dataset("foo", shape=(10,), dtype="uint8") + assert foo.shape == (10,) + + with pytest.raises(ContainsArrayError), pytest.warns(DeprecationWarning): + await root.create_dataset("foo", shape=(100,), dtype="int8") + + _ = await root.create_group("bar") + with pytest.raises(ContainsGroupError), pytest.warns(DeprecationWarning): + await root.create_dataset("bar", shape=(100,), dtype="int8") + + +async def test_require_array(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None: + root = await AsyncGroup.create(store=store, zarr_format=zarr_format) + foo1 = await root.require_array("foo", shape=(10,), dtype="i8", attributes={"foo": 101}) + assert foo1.attrs == {"foo": 101} + foo2 = await root.require_array("foo", shape=(10,), dtype="i8") + assert foo2.attrs == {"foo": 101} + + # exact = False + _ = await root.require_array("foo", shape=10, dtype="f8") + + # errors w/ exact True + with pytest.raises(TypeError, match="Incompatible dtype"): + await root.require_array("foo", shape=(10,), dtype="f8", exact=True) + + with pytest.raises(TypeError, match="Incompatible shape"): + await root.require_array("foo", shape=(100, 100), dtype="i8") + + with pytest.raises(TypeError, match="Incompatible dtype"): + await root.require_array("foo", shape=(10,), dtype="f4") + + _ = await root.create_group("bar") + with pytest.raises(TypeError, match="Incompatible object"): + await root.require_array("bar", shape=(10,), dtype="int8") + + +async def test_open_mutable_mapping(): + group = await zarr.api.asynchronous.open_group(store={}, mode="w") + assert isinstance(group.store_path.store, MemoryStore) + + +def test_open_mutable_mapping_sync(): + group = zarr.open_group(store={}, mode="w") + assert isinstance(group.store_path.store, MemoryStore) diff --git a/tests/v3/test_store/test_memory.py b/tests/v3/test_store/test_memory.py index 04d17eb24..3d97b650c 100644 --- a/tests/v3/test_store/test_memory.py +++ b/tests/v3/test_store/test_memory.py @@ -4,9 +4,10 @@ import pytest -from zarr.core.buffer import Buffer, cpu -from zarr.store.memory import MemoryStore +from zarr.core.buffer import Buffer, cpu, gpu +from zarr.store.memory import GpuMemoryStore, MemoryStore from zarr.testing.store import StoreTests +from zarr.testing.utils import gpu_test class TestMemoryStore(StoreTests[MemoryStore, cpu.Buffer]): @@ -56,3 +57,38 @@ def test_serizalizable_store(self, store: MemoryStore) -> None: with pytest.raises(NotImplementedError): pickle.dumps(store) + + +@gpu_test +class TestGpuMemoryStore(StoreTests[GpuMemoryStore, gpu.Buffer]): + store_cls = GpuMemoryStore + buffer_cls = gpu.Buffer + + def set(self, store: GpuMemoryStore, key: str, value: Buffer) -> None: + store._store_dict[key] = value + + def get(self, store: MemoryStore, key: str) -> Buffer: + return store._store_dict[key] + + @pytest.fixture(scope="function", params=[None, {}]) + def store_kwargs(self, request) -> dict[str, str | None | dict[str, Buffer]]: + return {"store_dict": request.param, "mode": "r+"} + + @pytest.fixture(scope="function") + def store(self, store_kwargs: str | None | dict[str, gpu.Buffer]) -> GpuMemoryStore: + return self.store_cls(**store_kwargs) + + def test_store_repr(self, store: GpuMemoryStore) -> None: + assert str(store) == f"gpumemory://{id(store._store_dict)}" + + def test_store_supports_writes(self, store: GpuMemoryStore) -> None: + assert store.supports_writes + + def test_store_supports_listing(self, store: GpuMemoryStore) -> None: + assert store.supports_listing + + def test_store_supports_partial_writes(self, store: GpuMemoryStore) -> None: + assert store.supports_partial_writes + + def test_list_prefix(self, store: GpuMemoryStore) -> None: + assert True