Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: check that store, array, and group classes are serializable #2006

Merged
merged 13 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/zarr/abc/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ def _check_writable(self) -> None:
if self.mode.readonly:
raise ValueError("store mode does not support writing")

@abstractmethod
def __eq__(self, value: object) -> bool:
"""Equality comparison."""
...

@abstractmethod
async def get(
self,
Expand Down
6 changes: 6 additions & 0 deletions src/zarr/core/buffer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,12 @@ def __add__(self, other: Buffer) -> Self:
"""Concatenate two buffers"""
...

def __eq__(self, other: object) -> bool:
# Another Buffer class can override this to choose a more efficient path
return isinstance(other, Buffer) and np.array_equal(
self.as_numpy_array(), other.as_numpy_array()
)


class NDBuffer:
"""An n-dimensional memory block
Expand Down
16 changes: 15 additions & 1 deletion src/zarr/store/memory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from collections.abc import AsyncGenerator, MutableMapping
from typing import TYPE_CHECKING, Any

from zarr.abc.store import Store
from zarr.core.buffer import Buffer, gpu
Expand Down Expand Up @@ -47,6 +48,19 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return f"MemoryStore({str(self)!r})"

def __eq__(self, other: object) -> bool:
return (
isinstance(other, type(self))
and self._store_dict == other._store_dict
and self.mode == other.mode
)

def __setstate__(self, state: Any) -> None:
raise NotImplementedError(f"{type(self)} cannot be pickled")

def __getstate__(self) -> None:
raise NotImplementedError(f"{type(self)} cannot be pickled")

jhamman marked this conversation as resolved.
Show resolved Hide resolved
async def get(
self,
key: str,
Expand Down
10 changes: 10 additions & 0 deletions src/zarr/store/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
this must not be used.
"""
super().__init__(mode=mode)
self._storage_options = storage_options
if isinstance(url, str):
self._url = url.rstrip("/")
self._fs, _path = fsspec.url_to_fs(url, **storage_options)
Expand Down Expand Up @@ -91,6 +92,15 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return f"<RemoteStore({type(self._fs).__name__}, {self.path})>"

def __eq__(self, other: object) -> bool:
return (
isinstance(other, type(self))
and self.path == other.path
and self.mode == other.mode
and self._url == other._url
# and self._storage_options == other._storage_options # FIXME: this isn't working for some reason
)

async def get(
self,
key: str,
Expand Down
15 changes: 13 additions & 2 deletions src/zarr/store/zip.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
import zipfile
from pathlib import Path
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Any, Literal

from zarr.abc.store import Store
from zarr.core.buffer import Buffer, BufferPrototype
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(
self.compression = compression
self.allowZip64 = allowZip64

async def _open(self) -> None:
def _sync_open(self) -> None:
if self._is_open:
raise ValueError("store is already open")

Expand All @@ -83,6 +83,17 @@ async def _open(self) -> None:

self._is_open = True

async def _open(self) -> None:
self._sync_open()

def __getstate__(self) -> tuple[Path, ZipStoreAccessModeLiteral, int, bool]:
return self.path, self._zmode, self.compression, self.allowZip64

def __setstate__(self, state: Any) -> None:
self.path, self._zmode, self.compression, self.allowZip64 = state
self._is_open = False
self._sync_open()

def close(self) -> None:
super().close()
with self._lock:
Expand Down
14 changes: 14 additions & 0 deletions src/zarr/testing/store.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pickle
from typing import Any, Generic, TypeVar

import pytest
Expand Down Expand Up @@ -48,6 +49,19 @@ def test_store_type(self, store: S) -> None:
assert isinstance(store, Store)
assert isinstance(store, self.store_cls)

def test_store_eq(self, store: S, store_kwargs: dict[str, Any]) -> None:
# check self equality
assert store == store

# check store equality with same inputs
# asserting this is important for being able to compare (de)serialized stores
store2 = self.store_cls(**store_kwargs)
jhamman marked this conversation as resolved.
Show resolved Hide resolved
assert store == store2

def test_serizalizable_store(self, store: S) -> None:
foo = pickle.dumps(store)
assert pickle.loads(foo) == store

def test_store_mode(self, store: S, store_kwargs: dict[str, Any]) -> None:
assert store.mode == AccessMode.from_literal("r+")
assert not store.mode.readonly
Expand Down
36 changes: 35 additions & 1 deletion tests/v3/test_array.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import pickle
from typing import Literal

import numpy as np
import pytest

from zarr import Array, Group
from zarr import Array, AsyncArray, Group
from zarr.core.common import ZarrFormat
from zarr.errors import ContainsArrayError, ContainsGroupError
from zarr.store import LocalStore, MemoryStore
Expand Down Expand Up @@ -135,3 +136,36 @@ def test_array_v3_fill_value(store: MemoryStore, fill_value: int, dtype_str: str

assert arr.fill_value == np.dtype(dtype_str).type(fill_value)
assert arr.fill_value.dtype == arr.dtype


@pytest.mark.parametrize("store", ("local",), indirect=["store"])
jhamman marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize("zarr_format", (2, 3))
async def test_serializable_async_array(
store: LocalStore | MemoryStore, zarr_format: ZarrFormat
) -> None:
expected = await AsyncArray.create(
store=store, shape=(100,), chunks=(10,), zarr_format=zarr_format, dtype="i4"
)
# await expected.setitems(list(range(100)))

p = pickle.dumps(expected)
actual = pickle.loads(p)

assert actual == expected
# np.testing.assert_array_equal(await actual.getitem(slice(None)), await expected.getitem(slice(None)))
# TODO: uncomment the parts of this test that will be impacted by the config/prototype changes in flight


@pytest.mark.parametrize("store", ("local",), indirect=["store"])
@pytest.mark.parametrize("zarr_format", (2, 3))
def test_serializable_sync_array(store: LocalStore, zarr_format: ZarrFormat) -> None:
expected = Array.create(
store=store, shape=(100,), chunks=(10,), zarr_format=zarr_format, dtype="i4"
)
expected[:] = list(range(100))

p = pickle.dumps(expected)
actual = pickle.loads(p)

assert actual == expected
np.testing.assert_array_equal(actual[:], expected[:])
165 changes: 17 additions & 148 deletions tests/v3/test_group.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
from __future__ import annotations

import pickle
from typing import TYPE_CHECKING, Any, Literal, cast

import numpy as np
import pytest

import zarr.api.asynchronous
from zarr import Array, AsyncArray, AsyncGroup, Group
from zarr.abc.store import Store
from zarr.api.synchronous import open_group
from zarr.core.buffer import default_buffer_prototype
from zarr.core.common import JSON, ZarrFormat
from zarr.core.group import GroupMetadata
from zarr.core.sync import sync
from zarr.errors import ContainsArrayError, ContainsGroupError
from zarr.store import LocalStore, MemoryStore, StorePath
from zarr.store import LocalStore, StorePath
from zarr.store.common import make_store_path

from .conftest import parse_store
Expand Down Expand Up @@ -681,152 +680,22 @@ async def test_asyncgroup_update_attributes(store: Store, zarr_format: ZarrForma
assert agroup_new_attributes.attrs == attributes_new


async def test_group_members_async(store: LocalStore | MemoryStore) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jhamman was this test_group_members_async test moved somewhere, or deliberately removed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not deliberately removed. Must have been a bad merge conflict resolution. I'll bring it back today. Sorry!

group = AsyncGroup(
GroupMetadata(),
store_path=StorePath(store=store, path="root"),
)
a0 = await group.create_array("a0", shape=(1,))
g0 = await group.create_group("g0")
a1 = await g0.create_array("a1", shape=(1,))
g1 = await g0.create_group("g1")
a2 = await g1.create_array("a2", shape=(1,))
g2 = await g1.create_group("g2")

# immediate children
children = sorted([x async for x in group.members()], key=lambda x: x[0])
assert children == [
("a0", a0),
("g0", g0),
]

nmembers = await group.nmembers()
assert nmembers == 2

# partial
children = sorted([x async for x in group.members(max_depth=1)], key=lambda x: x[0])
expected = [
("a0", a0),
("g0", g0),
("g0/a1", a1),
("g0/g1", g1),
]
assert children == expected
nmembers = await group.nmembers(max_depth=1)
assert nmembers == 4

# all children
all_children = sorted([x async for x in group.members(max_depth=None)], key=lambda x: x[0])
expected = [
("a0", a0),
("g0", g0),
("g0/a1", a1),
("g0/g1", g1),
("g0/g1/a2", a2),
("g0/g1/g2", g2),
]
assert all_children == expected

nmembers = await group.nmembers(max_depth=None)
assert nmembers == 6

with pytest.raises(ValueError, match="max_depth"):
[x async for x in group.members(max_depth=-1)]


async def test_require_group(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
root = await AsyncGroup.create(store=store, zarr_format=zarr_format)

# create foo group
_ = await root.create_group("foo", attributes={"foo": 100})

# test that we can get the group using require_group
foo_group = await root.require_group("foo")
assert foo_group.attrs == {"foo": 100}

# test that we can get the group using require_group and overwrite=True
foo_group = await root.require_group("foo", overwrite=True)

_ = await foo_group.create_array(
"bar", shape=(10,), dtype="uint8", chunk_shape=(2,), attributes={"foo": 100}
@pytest.mark.parametrize("store", ("local",), indirect=["store"])
@pytest.mark.parametrize("zarr_format", (2, 3))
async def test_serializable_async_group(store: LocalStore, zarr_format: ZarrFormat) -> None:
expected = await AsyncGroup.create(
store=store, attributes={"foo": 999}, zarr_format=zarr_format
)
p = pickle.dumps(expected)
actual = pickle.loads(p)
assert actual == expected

# test that overwriting a group w/ children fails
# TODO: figure out why ensure_no_existing_node is not catching the foo.bar array
#
# with pytest.raises(ContainsArrayError):
# await root.require_group("foo", overwrite=True)

# test that requiring a group where an array is fails
with pytest.raises(TypeError):
await foo_group.require_group("bar")


async def test_require_groups(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
root = await AsyncGroup.create(store=store, zarr_format=zarr_format)
# create foo group
_ = await root.create_group("foo", attributes={"foo": 100})
# create bar group
_ = await root.create_group("bar", attributes={"bar": 200})

foo_group, bar_group = await root.require_groups("foo", "bar")
assert foo_group.attrs == {"foo": 100}
assert bar_group.attrs == {"bar": 200}

# get a mix of existing and new groups
foo_group, spam_group = await root.require_groups("foo", "spam")
assert foo_group.attrs == {"foo": 100}
assert spam_group.attrs == {}

# no names
no_group = await root.require_groups()
assert no_group == ()


async def test_create_dataset(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
root = await AsyncGroup.create(store=store, zarr_format=zarr_format)
with pytest.warns(DeprecationWarning):
foo = await root.create_dataset("foo", shape=(10,), dtype="uint8")
assert foo.shape == (10,)

with pytest.raises(ContainsArrayError), pytest.warns(DeprecationWarning):
await root.create_dataset("foo", shape=(100,), dtype="int8")

_ = await root.create_group("bar")
with pytest.raises(ContainsGroupError), pytest.warns(DeprecationWarning):
await root.create_dataset("bar", shape=(100,), dtype="int8")


async def test_require_array(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
root = await AsyncGroup.create(store=store, zarr_format=zarr_format)
foo1 = await root.require_array("foo", shape=(10,), dtype="i8", attributes={"foo": 101})
assert foo1.attrs == {"foo": 101}
foo2 = await root.require_array("foo", shape=(10,), dtype="i8")
assert foo2.attrs == {"foo": 101}

# exact = False
_ = await root.require_array("foo", shape=10, dtype="f8")

# errors w/ exact True
with pytest.raises(TypeError, match="Incompatible dtype"):
await root.require_array("foo", shape=(10,), dtype="f8", exact=True)

with pytest.raises(TypeError, match="Incompatible shape"):
await root.require_array("foo", shape=(100, 100), dtype="i8")

with pytest.raises(TypeError, match="Incompatible dtype"):
await root.require_array("foo", shape=(10,), dtype="f4")

_ = await root.create_group("bar")
with pytest.raises(TypeError, match="Incompatible object"):
await root.require_array("bar", shape=(10,), dtype="int8")


async def test_open_mutable_mapping():
group = await zarr.api.asynchronous.open_group(store={}, mode="w")
assert isinstance(group.store_path.store, MemoryStore)

@pytest.mark.parametrize("store", ("local",), indirect=["store"])
@pytest.mark.parametrize("zarr_format", (2, 3))
def test_serializable_sync_group(store: LocalStore, zarr_format: ZarrFormat) -> None:
expected = Group.create(store=store, attributes={"foo": 999}, zarr_format=zarr_format)
p = pickle.dumps(expected)
actual = pickle.loads(p)

def test_open_mutable_mapping_sync():
group = open_group(store={}, mode="w")
assert isinstance(group.store_path.store, MemoryStore)
assert actual == expected
Loading