Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: check that store, array, and group classes are serializable #2006

Merged
merged 13 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/zarr/abc/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ def _check_writable(self) -> None:
if not self.writeable:
raise ValueError("store mode does not support writing")

@abstractmethod
def __eq__(self, value: object) -> bool:
"""Equality comparison."""
...

@abstractmethod
async def get(
self,
Expand Down
5 changes: 5 additions & 0 deletions src/zarr/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,11 @@ def __add__(self, other: Buffer) -> Self:
np.concatenate((np.asanyarray(self._data), np.asanyarray(other_array)))
)

def __eq__(self, other: object) -> bool:
# Note: this was needed to support comparing MemoryStore instances with Buffer values in them
# if/when we stopped putting buffers in memory stores, this can be removed
return isinstance(other, type(self)) and self.to_bytes() == other.to_bytes()
jhamman marked this conversation as resolved.
Show resolved Hide resolved


class NDBuffer:
"""An n-dimensional memory block
Expand Down
19 changes: 19 additions & 0 deletions src/zarr/store/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,25 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return f"MemoryStore({str(self)!r})"

def __eq__(self, other: object) -> bool:
return (
isinstance(other, type(self))
and self._store_dict == other._store_dict
and self.mode == other.mode
)

def __setstate__(self, state: tuple[MutableMapping[str, Buffer], OpenMode]) -> None:
# warnings.warn(
# f"unpickling {type(self)} may produce unexpected behavior and should only be used for testing and/or debugging"
# )
self._store_dict, self._mode = state

def __getstate__(self) -> tuple[MutableMapping[str, Buffer], OpenMode]:
# warnings.warn(
# f"pickling {type(self)} may produce unexpected behavior and should only be used for testing and/or debugging"
# )
return self._store_dict, self._mode

jhamman marked this conversation as resolved.
Show resolved Hide resolved
async def get(
self,
key: str,
Expand Down
10 changes: 10 additions & 0 deletions src/zarr/store/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
"""

super().__init__(mode=mode)
self._storage_options = storage_options
if isinstance(url, str):
self._url = url.rstrip("/")
self._fs, _path = fsspec.url_to_fs(url, **storage_options)
Expand Down Expand Up @@ -81,6 +82,15 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return f"<RemoteStore({type(self._fs).__name__}, {self.path})>"

def __eq__(self, other: object) -> bool:
return (
isinstance(other, type(self))
and self.path == other.path
and self.mode == other.mode
and self._url == other._url
# and self._storage_options == other._storage_options # FIXME: this isn't working for some reason
)

async def get(
self,
key: str,
Expand Down
14 changes: 14 additions & 0 deletions src/zarr/testing/store.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pickle
from typing import Any, Generic, TypeVar

import pytest
Expand Down Expand Up @@ -42,6 +43,19 @@ def test_store_type(self, store: S) -> None:
assert isinstance(store, Store)
assert isinstance(store, self.store_cls)

def test_store_eq(self, store: S, store_kwargs: dict[str, Any]) -> None:
# check self equality
assert store == store

# check store equality with same inputs
# asserting this is important for being able to compare (de)serialized stores
store2 = self.store_cls(**store_kwargs)
jhamman marked this conversation as resolved.
Show resolved Hide resolved
assert store == store2

def test_serizalizable_store(self, store: S) -> None:
foo = pickle.dumps(store)
assert pickle.loads(foo) == store

def test_store_mode(self, store: S, store_kwargs: dict[str, Any]) -> None:
assert store.mode == "w", store.mode
assert store.writeable
Expand Down
38 changes: 37 additions & 1 deletion tests/v3/test_array.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import pickle

import numpy as np
import pytest

from zarr.array import Array
from zarr.array import Array, AsyncArray
from zarr.common import ZarrFormat
from zarr.group import Group
from zarr.store import LocalStore, MemoryStore
Expand Down Expand Up @@ -34,3 +37,36 @@ def test_array_name_properties_with_group(
assert spam.path == "bar/spam"
assert spam.name == "/bar/spam"
assert spam.basename == "spam"


@pytest.mark.parametrize("store", ("memory", "local"), indirect=["store"])
@pytest.mark.parametrize("zarr_format", (2, 3))
async def test_serizalizable_async_array(
store: LocalStore | MemoryStore, zarr_format: ZarrFormat
) -> None:
expected = await AsyncArray.create(
store=store, shape=(100,), chunks=(10,), zarr_format=zarr_format, dtype="i4"
)
# await expected.setitems(list(range(100)))

p = pickle.dumps(expected)
actual = pickle.loads(p)

assert actual == expected
# np.testing.assert_array_equal(await actual.getitem(slice(None)), await expected.getitem(slice(None)))
# TODO: uncomment the parts of this test that will be impacted by the config/prototype changes in flight


@pytest.mark.parametrize("store", ("memory", "local"), indirect=["store"])
@pytest.mark.parametrize("zarr_format", (2, 3))
def test_serizalizable_sync_array(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
jhamman marked this conversation as resolved.
Show resolved Hide resolved
expected = Array.create(
store=store, shape=(100,), chunks=(10,), zarr_format=zarr_format, dtype="i4"
)
expected[:] = list(range(100))

p = pickle.dumps(expected)
actual = pickle.loads(p)

assert actual == expected
np.testing.assert_array_equal(actual[:], expected[:])
24 changes: 24 additions & 0 deletions tests/v3/test_group.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import pickle
from typing import TYPE_CHECKING, Any

from zarr.array import AsyncArray
Expand Down Expand Up @@ -391,3 +392,26 @@ def test_group_name_properties(store: LocalStore | MemoryStore, zarr_format: Zar
assert bar.path == "foo/bar"
assert bar.name == "/foo/bar"
assert bar.basename == "bar"


@pytest.mark.parametrize("store", ("memory", "local"), indirect=["store"])
@pytest.mark.parametrize("zarr_format", (2, 3))
async def test_serizalizable_async_group(
jhamman marked this conversation as resolved.
Show resolved Hide resolved
store: LocalStore | MemoryStore, zarr_format: ZarrFormat
) -> None:
expected = await AsyncGroup.create(
store=store, attributes={"foo": 999}, zarr_format=zarr_format
)
p = pickle.dumps(expected)
actual = pickle.loads(p)
assert actual == expected


@pytest.mark.parametrize("store", ("memory", "local"), indirect=["store"])
@pytest.mark.parametrize("zarr_format", (2, 3))
def test_serizalizable_sync_group(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
jhamman marked this conversation as resolved.
Show resolved Hide resolved
expected = Group.create(store=store, attributes={"foo": 999}, zarr_format=zarr_format)
p = pickle.dumps(expected)
actual = pickle.loads(p)

assert actual == expected
2 changes: 1 addition & 1 deletion tests/v3/test_store/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def store_kwargs(self, request) -> dict[str, str | bool]:
anon = False
mode = "w"
if request.param == "use_upath":
return {"mode": mode, "url": UPath(url, endpoint_url=endpoint_url, anon=anon)}
return {"url": UPath(url, endpoint_url=endpoint_url, anon=anon), "mode": mode}
elif request.param == "use_str":
return {"url": url, "mode": mode, "anon": anon, "endpoint_url": endpoint_url}

Expand Down